diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml index daa9b73d5971a..183d52d5c1d44 100644 --- a/.config/1espt/PipelineAutobaseliningConfig.yml +++ b/.config/1espt/PipelineAutobaseliningConfig.yml @@ -5,15 +5,16 @@ pipelines: retail: source: credscan: - lastModifiedDate: 2024-10-24 + lastModifiedDate: 2024-10-25 policheck: - lastModifiedDate: 2024-10-24 + lastModifiedDate: 2024-10-25 eslint: - lastModifiedDate: 2024-10-24 + lastModifiedDate: 2024-10-25 psscriptanalyzer: - lastModifiedDate: 2024-10-24 + lastModifiedDate: 2024-10-25 armory: - lastModifiedDate: 2024-10-24 + lastModifiedDate: 2024-10-25 + usedNonDefaultBranch: true 1299: retail: source: @@ -25,6 +26,8 @@ pipelines: lastModifiedDate: 2024-10-25 armory: lastModifiedDate: 2024-10-25 + policheck: + lastModifiedDate: 2024-10-29 binary: credscan: lastModifiedDate: 2024-10-25 @@ -32,3 +35,43 @@ pipelines: lastModifiedDate: 2024-10-25 spotbugs: lastModifiedDate: 2024-10-25 + 1625: + retail: + source: + credscan: + lastModifiedDate: 2024-11-05 + policheck: + lastModifiedDate: 2024-11-05 + eslint: + lastModifiedDate: 2024-11-05 + psscriptanalyzer: + lastModifiedDate: 2024-11-05 + armory: + lastModifiedDate: 2024-11-05 + binary: + credscan: + lastModifiedDate: 2024-11-13 + binskim: + lastModifiedDate: 2024-11-13 + spotbugs: + lastModifiedDate: 2024-11-13 + 1626: + retail: + source: + credscan: + lastModifiedDate: 2024-11-13 + policheck: + lastModifiedDate: 2024-11-13 + eslint: + lastModifiedDate: 2024-11-13 + psscriptanalyzer: + lastModifiedDate: 2024-11-13 + armory: + lastModifiedDate: 2024-11-13 + binary: + credscan: + lastModifiedDate: 2024-11-13 + binskim: + lastModifiedDate: 2024-11-13 + spotbugs: + lastModifiedDate: 2024-11-13 diff --git a/.config/guardian/.gdnbaselines b/.config/guardian/.gdnbaselines new file mode 100644 index 0000000000000..a7ee2a4b69dda --- /dev/null +++ b/.config/guardian/.gdnbaselines @@ -0,0 +1,43 @@ +{ + "properties": { + "helpUri": "https://eng.ms/docs/microsoft-security/security/azure-security/cloudai-security-fundamentals-engineering/security-integration/guardian-wiki/microsoft-guardian/general/baselines" + }, + "version": "1.0.0", + "baselines": { + "default": { + "name": "default", + "createdDate": "2024-11-13 00:40:35Z", + "lastUpdatedDate": "2024-11-13 00:40:35Z" + } + }, + "results": { + "48f03e2797fc40ecea50f878a0268947c7e13db1b2fa51aa3981246844fc4c68": { + "signature": "48f03e2797fc40ecea50f878a0268947c7e13db1b2fa51aa3981246844fc4c68", + "alternativeSignatures": [], + "target": "ScanTelemetry_20241113003616898.json", + "line": 1, + "memberOf": [ + "default" + ], + "tool": "credscan", + "ruleId": "CSCAN-AZURE0130", + "createdDate": "2024-11-13 00:40:35Z", + "expirationDate": "2025-05-02 01:29:47Z", + "justification": "This error is baselined with an expiration date of 180 days from 2024-11-13 01:29:47Z" + }, + "9cb6eddb3f3e886ad06cae65f5886412ff0c5fb0b96d4e943e4efa237be617b1": { + "signature": "9cb6eddb3f3e886ad06cae65f5886412ff0c5fb0b96d4e943e4efa237be617b1", + "alternativeSignatures": [], + "target": "ScanTelemetry_20241113111547065.json", + "line": 1, + "memberOf": [ + "default" + ], + "tool": "credscan", + "ruleId": "CSCAN-AZURE0130", + "createdDate": "2024-11-13 11:20:17Z", + "expirationDate": "2025-05-02 11:55:15Z", + "justification": "This error is baselined with an expiration date of 180 days from 2024-11-13 11:55:15Z" + } + } +} \ No newline at end of file diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index d3b51c0681a20..d1dc717c2a9c9 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -15,6 +15,10 @@ on: schedule: - cron: '41 13 * * 0' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: analyze: name: Analyze diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ec834b07b2c78..64785574c7728 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -37,6 +37,9 @@ jobs: # Required workflow name: Python format runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: + contents: read + security-events: write steps: - uses: actions/checkout@v4 - name: Setup Python @@ -49,10 +52,15 @@ jobs: with: toolchain: stable components: rustfmt + - name: Update PATH + run: | + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + - name: Install dependencies run: | - python -m pip install -r requirements-dev.txt - python -m pip install lintrunner lintrunner-adapters + set -e -x + python -m pip install --user -r requirements-dev.txt + python -m pip install --user lintrunner lintrunner-adapters lintrunner init - name: Run lintrunner on all files run: | @@ -81,8 +89,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@master + - name: Update PATH + run: | + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + - name: Install ninja - run: python -m pip install --upgrade ninja + run: python -m pip install --user --upgrade ninja - name: Generate compile_commands.json run: | python tools/ci_build/build.py \ diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index af3f00c4e35ab..af890d88995be 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -41,12 +41,12 @@ jobs: - name: Install dependencies and run lintrunner on all files run: | - set -e python -m pip install --user -r requirements-dev.txt - python -m pip install --user lintrunner lintrunner-adapters + python -m pip install --user lintrunner lintrunner-adapters lintrunner init + set +e lintrunner f --all-files -v exit 0 - - uses: parkerbxyz/suggest-changes@v1 + - uses: parkerbxyz/suggest-changes@v2 with: comment: 'You can commit the suggested changes from lintrunner.' diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index c704adb263db4..7cca0969a168b 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -20,18 +20,17 @@ permissions: jobs: build: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] env: DOCFXVERSION: 2.62.2 steps: - uses: actions/checkout@v4 - - name: Setup .NET - uses: actions/setup-dotnet@v4 - with: - dotnet-version: 8.0.x - name: Install DocFX run: | dotnet tool update -g docfx + - name: Update PATH + run: | + Add-Content -Value "$env:USERPROFILE\.dotnet\tools" -Encoding utf8 -Path $env:GITHUB_PATH # NOTE: We need to restore Microsoft.ML.OnnxRuntime.csproj manually to set IncludeMobileTargets=false # docfx doesn't seem to be able to do that properly resulting in build errors - name: Restore dependencies @@ -50,10 +49,12 @@ jobs: - name: Log source commit run: git rev-parse --short HEAD > csharp/ApiDocs/csharp/source-version.txt - name: Move C# docs into site + shell: pwsh run: | - mkdir -p _site/docs/api - rm -rf _site/docs/api/csharp - mv csharp/ApiDocs/csharp _site/docs/api/csharp + New-Item -Path _site/docs/api -Force -ItemType "Directory" | Out-Null + $OutputDirectory="_site/docs/api/csharp" + if (Test-Path $OutputDirectory) { Remove-Item -Recurse -Force $OutputDirectory } + Move-Item -Path csharp\ApiDocs\csharp -Destination $OutputDirectory - name: Upload docs artifact uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index 2be9ad957c5cb..adc2346d1bf1b 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -32,10 +32,10 @@ jobs: sudo apt-get install graphviz - name: Install dependencies run: | - python3 -m pip install --upgrade pip + python3 -m pip install --user --upgrade pip cd docs/python - python3 -m pip install -r requirements.txt - python3 -m pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_cpu.html + python3 -m pip install --user -r requirements.txt + python3 -m pip install --user --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_cpu.html python3 -m pip list - name: Generate Python docs with Sphinx run: | diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 294bd926a34cb..b9932eb563b83 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index 3528545dfb06e..37fe2d378b7fd 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/CPPLINT.cfg b/CPPLINT.cfg new file mode 100644 index 0000000000000..12c1c7be0d773 --- /dev/null +++ b/CPPLINT.cfg @@ -0,0 +1 @@ +filter=-whitespace diff --git a/README.md b/README.md index 8452e26a58d4d..f1817282b61a0 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,8 @@ |System|Inference|Training| |---|---|---| -|Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CI%20Pipeline?label=Windows+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)|| -|Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining/orttraining-ortmodule-distributed?label=Training+Distributed)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=148)| +|Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CUDA%20CI%20Pipeline?label=Windows+GPU+CUDA)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=218)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20WebGPU%20CI%20Pipeline?label=Windows+GPU+WebGPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=228)|| +|Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)| |Mac|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20CI%20Pipeline?label=MacOS+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)|| |Android|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)|| |iOS|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)|| diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 20142e734dfac..26084ab42ec1c 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -2108,261 +2108,6 @@ SOFTWARE. _____ -TVM Open Deep Learning Compiler Stack - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -CONTRIBUTORS - -TVM Contributors -================ -TVM adopts the Apache style model and governs by merit. We believe that it is important to create an inclusive community where everyone can use, -contribute to, and influence the direction of the project. We actively invite contributors who have earned the merit to be part of the development community. - -See the [community structure document](http://docs.tvm.ai/contribute/community.html) for the explanation of community structure and contribution guidelines. - -## Committers -- [Tianqi Chen](https://github.com/tqchen) (PMC) -- [Thierry Moreau](http://homes.cs.washington.edu/~moreau/) -- [Ziheng Jiang](https://github.com/ZihengJiang) -- [Haichen Shen](http://homes.cs.washington.edu/~haichen/) -- [Yizhi Liu](https://github.com/yzhliu) - -## Code Owners -- [Aditya Atluri](https://github.com/adityaatluri) ROCM -- [Leyuan Wang](https://github.com/Laurawly) TOPI -- [Yuwei Hu](https://github.com/Huyuwei) TOPI -- [Zhixun Tan](https://github.com/phisiart) OpenGL/WebGL backend -- [Nick Hynes](https://github.com/nhynes) SGX and secured computing -- [Lianmin Zheng](https://github.com/merrymercy) AutoTVM - -## Reviewers -- [Zhi Chen](https://github.com/zhiics) -- [Xiaoqiang Dan](https://github.com/xqdan) -- [Liangfu Chen](https://github.com/liangfu) -- [Masahiro Masuda](https://github.com/masahi) -- [Kazutaka Morita](https://github.com/kazum) -- [Tatsuya Nishiyama](https://github.com/nishi-t) -- [Pariksheet Pinjari](https://github.com/PariksheetPinjari909) -- [Jared Roesch](https://github.com/jroesch) -- [Siva](https://github.com/srkreddy1238) -- [Siju Samuel](https://github.com/siju-samuel) -- [Alex Weaver](https://github.com/alex-weaver) -- [Yao Wang](https://github.com/kevinthesun) -- [Jian Weng](https://github.com/were) -- [Eddie Yan](https://github.com/eqy) -- [Joshua Z. Zhang](https://github.com/zhreshold) - -## List of Contributors -- [Full List of Contributors](https://github.com/dmlc/tvm/graphs/contributors) - - To contributors: please add your name to the list. -- [Qiao Zhang](https://github.com/zhangqiaorjc) -- [Haolong Zhang](https://github.com/haolongzhangm) -- [Cody Hao Yu](https://github.com/comaniac) -- [Chris Nuernberger](https://github.com/cnuernber) - -_____ - FreeBSD: getopt.c file Copyright (c) 1987, 1993, 1994 diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index 1432193ac9080..46349f43923e2 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -1,578 +1,508 @@ { - "$schema": "https://json.schemastore.org/component-detection-manifest.json", - "Registrations": [ - { - "component": { - "type": "git", - "git": { - "commitHash": "215105818dfde3174fe799600bb0f3cae233d0bf", - "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" - } - } - }, - { - "component": { - "Type": "maven", - "maven": { - "GroupId": "org.junit.platform", - "ArtifactId": "junit-platform-console-standalone", - "Version": "1.6.2" - }, - "DevelopmentDependency": true - } - }, - { - "component": { - "Type": "maven", - "maven": { - "GroupId": "com.google.protobuf", - "ArtifactId": "protobuf-java", - "Version": "3.21.7" - }, - "DevelopmentDependency": true - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "2379917985919ed3918dc12cad47f469f245be7a", - "repositoryUrl": "https://github.com/apache/tvm.git" - }, - "comments": "needed for TVM EP" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "cabe04d6d6b05356fa8f9741704924788f0dd762", - "repositoryUrl": "https://github.com/agauniyal/rang.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "a3bcc6981d5dad3afb212689e2c7853d1b1ee45d", - "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "08f7c7e69f8ea61a0c4151359bc8023be8e9217b", - "repositoryUrl": "https://github.com/tlc-pack/libbacktrace.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "36a91576edf633479c78649e050f18dd2ddc8103", - "repositoryUrl": "https://github.com/apache/incubator-tvm-vta.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "111c9be5188f7350c2eac9ddaedd8cca3d7bf394", - "repositoryUrl": "https://github.com/kazuho/picojson.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "b5e4186d7ab63458e79084842dced166be2ca5b5", - "repositoryUrl": "https://github.com/lammertb/libcrc.git" - }, - "comments": "dependency from tvm" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "e4a4c02764d37c9c3db0d64c4996651a3ef9513c", - "repositoryUrl": "https://github.com/dmlc/HalideIR.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "bee4d1dd8dc1ee4a1fd8fa6a96476c2f8b7492a3", - "repositoryUrl": "https://github.com/dmlc/dlpack.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "4d49691f1a9d944c3b0aa5e63f1db3cad1f941f8", - "repositoryUrl": "https://github.com/dmlc/dmlc-core.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "7de7e5d02bf687f971e7668963649728356e0c20", - "repositoryUrl": "https://github.com/intel/mkl-dnn.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "d860915b0198ddb96f93e9e97a789af156544dc6", - "repositoryUrl": "https://github.com/tensorflow/tensorflow.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "eddf9023206dc40974c26f589ee2ad63a4227a1e", - "repositoryUrl": "https://github.com/glennrp/libpng.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "217f52fb121ef92491e5d5f71394b07ce4ead1d0", - "repositoryUrl": "https://github.com/KjellKod/g3log.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "50893291621658f355bc5b4d450a8d06a563053d", - "repositoryUrl": "https://github.com/madler/zlib.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "d264a2603493fecda607c1d1cda87fedba77d36b", - "repositoryUrl": "https://github.com/Microsoft/CNTK.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "971e2e89d08deeae0139d3011d15646fdac13c92", - "repositoryUrl": "https://github.com/numpy/numpy.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "90537289a04ef5d572496240e2ac3a881be518d2", - "repositoryUrl": "https://github.com/pytorch/pytorch.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "b31f58de6fa8bbda5353b3c77d9be4914399724d", - "repositoryUrl": "https://github.com/pytorch/pytorch.git" - }, - "comments": "pytorch 1.6 used by onnxruntime training image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "7389dbac82d362f296dc2746f10e43ffa1615660", - "repositoryUrl": "https://github.com/scikit-learn/scikit-learn.git" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "eeebdab16155d34ff8f5f42137da7df4d1c7eab0", - "repositoryUrl": "https://github.com/BVLC/caffe.git" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "LLVM", - "Version": "9.0.0", - "DownloadUrl": "https://releases.llvm.org/9.0.0/llvm-9.0.0.src.tar.xz" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "FreeBSD GetOpt", - "Version": "12.0.0", - "DownloadUrl": "https://svnweb.freebsd.org/base/release/12.0.0/lib/libc/stdlib/getopt.c?revision=341707&view=co" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "Boost", - "Version": "1.69.0", - "DownloadUrl": "https://boostorg.jfrog.io/artifactory/main/release/1.69.0/source/boost_1_69_0.tar.bz2" - } - } - }, - { - "component": { - "git": { - "commitHash": "02a2a458ac15912d7d87cc1171e811b0c5219ece", - "repositoryUrl": "https://github.com/grpc/grpc" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "b29b21a81b32ec273f118f589f46d56ad3332420", - "repositoryUrl": "https://github.com/google/boringssl.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "3be1924221e1326df520f8498d704a5c4c8d0cce", - "repositoryUrl": "https://github.com/c-ares/c-ares.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "6599cac0965be8e5a835ab7a5684bbef033d5ad0", - "repositoryUrl": "https://github.com/llvm-mirror/libcxx.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "9245d481eb3e890f708ff2d7dadf2a10c04748ba", - "repositoryUrl": "https://github.com/llvm-mirror/libcxxabi.git" - }, - "type": "git" - } - }, - { - "component": { - "git": { - "commitHash": "9ce4a77f61c134bbed28bfd5be5cd7dc0e80f5e3", - "repositoryUrl": "https://github.com/google/upb.git" - }, - "type": "git" - } - }, - { - "component": { - "type": "other", - "Other": { - "Name": "Go", - "Version": "1.12.6", - "DownloadUrl": "https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "OpenMPI", - "Version": "4.0.0", - "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz" - } - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "OpenMPI", - "Version": "4.0.4", - "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.4.tar.gz" - }, - "comments": "openmpi 4.0.4 used by onnxruntime training image" - } - }, - { - "component": { - "Type": "git", - "git": { - "commitHash": "7db3f9c741d3dfd8dda14ffb537ed251280d2025", - "repositoryUrl": "https://github.com/mpi4py/mpi4py" - }, - "comments": "mpi4py 3.0.3 used by onnxruntime training image" - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "NCCL", - "Version": "2.4.8", - "DownloadUrl": "https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "67afac65ce64fd4dce1494f43e565e8fe34bdffb", - "repositoryUrl": "https://android.googlesource.com/platform/frameworks/ml" - }, - "comments": "used by onnxruntime" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c30b7da2301202da5f9f0529966944f110e5d6e7", - "repositoryUrl": "https://github.com/openucx/ucx" - }, - "comments": "middleware between IB verbs and OpenMPI used by onnxruntime training image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "63d1e08e64e7e09408eb63cd8dd7c65ad766f277", - "repositoryUrl": "https://github.com/nodejs/node" - }, - "comments": "For Nodejs binding" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "aead4d751c2101e23336aa73f2380df83e7a13f3", - "repositoryUrl": "https://github.com/pypa/manylinux" - }, - "comments": "For building our CI build docker image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c974557598645360fbabac71352b083117e3cc17", - "repositoryUrl": "https://gitlab.kitware.com/cmake/cmake" - }, - "comments": "CMake 3.24.3. For building our CI build docker image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "1e5d33e9b9b8631b36f061103a30208b206fd03a", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.9.1" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6503f05dd59e26a9986bdea097b3da9b3546f45b", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.8.7" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "13c94747c74437e594b7fc242ff7da668e81887c", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.7.9" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c0a9afe2ac1820409e6173bd1893ebee2cf50270", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.6.12" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "426b022776672fdf3d71ddd98d89af341c88080f", - "repositoryUrl": "https://github.com/python/cpython" - }, - "comments": "Python 3.5.10" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "transformers", - "Version": "4.38.0" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "msgpack", - "Version": "1.0.0" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "tensorboardX", - "Version": "1.8" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "pip", - "pip": { - "Name": "tensorboard", - "Version": "2.3.0" - }, - "comments": "Installed in the training docker image" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "92cf3702fcfaadc84eb7bef59825a23e0cd84f56", - "repositoryUrl": "https://github.com/aappleby/smhasher" - }, - "comments": "MurmurHash3" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "b89da3c5a0aa18fb2c6163ad9984f81ab65b22e3", - "repositoryUrl": "https://github.com/mestevens/gtest-ios-framework" - }, - "comments": "gtest-ios-framework" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", - "repositoryUrl": "https://github.com/dmlc/dlpack.git" - }, - "comments": "dlpack" - } - }, - { - "component": { - "Type": "other", - "Other": { - "Name": "SQLite3", - "Version": "3.22.0", - "DownloadUrl": "http://security.ubuntu.com/ubuntu/pool/main/s/sqlite3/libsqlite3-dev_3.22.0-1ubuntu0.4_amd64.deb" - } - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "9d0ef119d9fcb9139f831adc224857b791c81140", - "repositoryUrl": "https://github.com/dlfcn-win32/dlfcn-win32.git" - }, - "comments": "dlfcn-win32" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6812205f18ca4ef54372e87e1a13ce4a859434df", - "repositoryUrl": "https://github.com/python-pillow/Pillow.git" - }, - "comments": "python-pillow. Implementation logic for anti-aliasing copied by Resize CPU kernel." - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "e7248b26a1ed53fa030c5c459f7ea095dfd276ac", - "repositoryUrl": "https://gitlab.com/libeigen/eigen.git" - } - } - } - ], - "Version": 1 + "$schema": "https://json.schemastore.org/component-detection-manifest.json", + "Registrations": [ + { + "component": { + "type": "git", + "git": { + "commitHash": "215105818dfde3174fe799600bb0f3cae233d0bf", + "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" + } + } + }, + { + "component": { + "Type": "maven", + "maven": { + "GroupId": "org.junit.platform", + "ArtifactId": "junit-platform-console-standalone", + "Version": "1.6.2" + }, + "DevelopmentDependency": true + } + }, + { + "component": { + "Type": "maven", + "maven": { + "GroupId": "com.google.protobuf", + "ArtifactId": "protobuf-java", + "Version": "3.21.7" + }, + "DevelopmentDependency": true + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "e4a4c02764d37c9c3db0d64c4996651a3ef9513c", + "repositoryUrl": "https://github.com/dmlc/HalideIR.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "bee4d1dd8dc1ee4a1fd8fa6a96476c2f8b7492a3", + "repositoryUrl": "https://github.com/dmlc/dlpack.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "4d49691f1a9d944c3b0aa5e63f1db3cad1f941f8", + "repositoryUrl": "https://github.com/dmlc/dmlc-core.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7de7e5d02bf687f971e7668963649728356e0c20", + "repositoryUrl": "https://github.com/intel/mkl-dnn.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d860915b0198ddb96f93e9e97a789af156544dc6", + "repositoryUrl": "https://github.com/tensorflow/tensorflow.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "eddf9023206dc40974c26f589ee2ad63a4227a1e", + "repositoryUrl": "https://github.com/glennrp/libpng.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "217f52fb121ef92491e5d5f71394b07ce4ead1d0", + "repositoryUrl": "https://github.com/KjellKod/g3log.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "50893291621658f355bc5b4d450a8d06a563053d", + "repositoryUrl": "https://github.com/madler/zlib.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "d264a2603493fecda607c1d1cda87fedba77d36b", + "repositoryUrl": "https://github.com/Microsoft/CNTK.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "971e2e89d08deeae0139d3011d15646fdac13c92", + "repositoryUrl": "https://github.com/numpy/numpy.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "90537289a04ef5d572496240e2ac3a881be518d2", + "repositoryUrl": "https://github.com/pytorch/pytorch.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b31f58de6fa8bbda5353b3c77d9be4914399724d", + "repositoryUrl": "https://github.com/pytorch/pytorch.git" + }, + "comments": "pytorch 1.6 used by onnxruntime training image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7389dbac82d362f296dc2746f10e43ffa1615660", + "repositoryUrl": "https://github.com/scikit-learn/scikit-learn.git" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "eeebdab16155d34ff8f5f42137da7df4d1c7eab0", + "repositoryUrl": "https://github.com/BVLC/caffe.git" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "LLVM", + "Version": "9.0.0", + "DownloadUrl": "https://releases.llvm.org/9.0.0/llvm-9.0.0.src.tar.xz" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "FreeBSD GetOpt", + "Version": "12.0.0", + "DownloadUrl": "https://svnweb.freebsd.org/base/release/12.0.0/lib/libc/stdlib/getopt.c?revision=341707&view=co" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "Boost", + "Version": "1.69.0", + "DownloadUrl": "https://boostorg.jfrog.io/artifactory/main/release/1.69.0/source/boost_1_69_0.tar.bz2" + } + } + }, + { + "component": { + "git": { + "commitHash": "02a2a458ac15912d7d87cc1171e811b0c5219ece", + "repositoryUrl": "https://github.com/grpc/grpc" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "b29b21a81b32ec273f118f589f46d56ad3332420", + "repositoryUrl": "https://github.com/google/boringssl.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "3be1924221e1326df520f8498d704a5c4c8d0cce", + "repositoryUrl": "https://github.com/c-ares/c-ares.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "6599cac0965be8e5a835ab7a5684bbef033d5ad0", + "repositoryUrl": "https://github.com/llvm-mirror/libcxx.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "9245d481eb3e890f708ff2d7dadf2a10c04748ba", + "repositoryUrl": "https://github.com/llvm-mirror/libcxxabi.git" + }, + "type": "git" + } + }, + { + "component": { + "git": { + "commitHash": "9ce4a77f61c134bbed28bfd5be5cd7dc0e80f5e3", + "repositoryUrl": "https://github.com/google/upb.git" + }, + "type": "git" + } + }, + { + "component": { + "type": "other", + "Other": { + "Name": "Go", + "Version": "1.12.6", + "DownloadUrl": "https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "OpenMPI", + "Version": "4.0.0", + "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz" + } + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "OpenMPI", + "Version": "4.0.4", + "DownloadUrl": "https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.4.tar.gz" + }, + "comments": "openmpi 4.0.4 used by onnxruntime training image" + } + }, + { + "component": { + "Type": "git", + "git": { + "commitHash": "7db3f9c741d3dfd8dda14ffb537ed251280d2025", + "repositoryUrl": "https://github.com/mpi4py/mpi4py" + }, + "comments": "mpi4py 3.0.3 used by onnxruntime training image" + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "NCCL", + "Version": "2.4.8", + "DownloadUrl": "https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "67afac65ce64fd4dce1494f43e565e8fe34bdffb", + "repositoryUrl": "https://android.googlesource.com/platform/frameworks/ml" + }, + "comments": "used by onnxruntime" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c30b7da2301202da5f9f0529966944f110e5d6e7", + "repositoryUrl": "https://github.com/openucx/ucx" + }, + "comments": "middleware between IB verbs and OpenMPI used by onnxruntime training image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "63d1e08e64e7e09408eb63cd8dd7c65ad766f277", + "repositoryUrl": "https://github.com/nodejs/node" + }, + "comments": "For Nodejs binding" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "aead4d751c2101e23336aa73f2380df83e7a13f3", + "repositoryUrl": "https://github.com/pypa/manylinux" + }, + "comments": "For building our CI build docker image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c974557598645360fbabac71352b083117e3cc17", + "repositoryUrl": "https://gitlab.kitware.com/cmake/cmake" + }, + "comments": "CMake 3.24.3. For building our CI build docker image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "1e5d33e9b9b8631b36f061103a30208b206fd03a", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.9.1" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "6503f05dd59e26a9986bdea097b3da9b3546f45b", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.8.7" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "13c94747c74437e594b7fc242ff7da668e81887c", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.7.9" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c0a9afe2ac1820409e6173bd1893ebee2cf50270", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.6.12" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "426b022776672fdf3d71ddd98d89af341c88080f", + "repositoryUrl": "https://github.com/python/cpython" + }, + "comments": "Python 3.5.10" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "transformers", + "Version": "4.38.0" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "msgpack", + "Version": "1.0.0" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "tensorboardX", + "Version": "1.8" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "pip", + "pip": { + "Name": "tensorboard", + "Version": "2.3.0" + }, + "comments": "Installed in the training docker image" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "92cf3702fcfaadc84eb7bef59825a23e0cd84f56", + "repositoryUrl": "https://github.com/aappleby/smhasher" + }, + "comments": "MurmurHash3" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b89da3c5a0aa18fb2c6163ad9984f81ab65b22e3", + "repositoryUrl": "https://github.com/mestevens/gtest-ios-framework" + }, + "comments": "gtest-ios-framework" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", + "repositoryUrl": "https://github.com/dmlc/dlpack.git" + }, + "comments": "dlpack" + } + }, + { + "component": { + "Type": "other", + "Other": { + "Name": "SQLite3", + "Version": "3.22.0", + "DownloadUrl": "http://security.ubuntu.com/ubuntu/pool/main/s/sqlite3/libsqlite3-dev_3.22.0-1ubuntu0.4_amd64.deb" + } + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "9d0ef119d9fcb9139f831adc224857b791c81140", + "repositoryUrl": "https://github.com/dlfcn-win32/dlfcn-win32.git" + }, + "comments": "dlfcn-win32" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "6812205f18ca4ef54372e87e1a13ce4a859434df", + "repositoryUrl": "https://github.com/python-pillow/Pillow.git" + }, + "comments": "python-pillow. Implementation logic for anti-aliasing copied by Resize CPU kernel." + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "e7248b26a1ed53fa030c5c459f7ea095dfd276ac", + "repositoryUrl": "https://gitlab.com/libeigen/eigen.git" + } + } + } + ], + "Version": 1 } diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index c8236c7c529a6..07dff50f9a3bd 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -196,7 +196,7 @@ "component": { "type": "git", "git": { - "commitHash": "9f98e2ebe7507fe0774d06a44bbf4b0e82cc9ce7", + "commitHash": "bc0d2e35909b8456abe32f3b30a49bb0c125e8b7", "repositoryUrl": "https://github.com/onnx/onnx-tensorrt.git" }, "comments": "onnx_tensorrt" @@ -346,7 +346,7 @@ "component": { "type": "git", "git": { - "commitHash": "511eb80847afe6bded34ec491a38d5d78ba2d604", + "commitHash": "12a3b24c456cebd9fd11f23ac0164f78129b00c6", "repositoryUrl": "https://github.com/google/dawn.git" }, "comments": "dawn" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1070627d5e7da..d2fe7e7457983 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -86,7 +86,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) # use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF) -option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) +cmake_dependent_option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) @@ -102,7 +102,6 @@ option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to provide eigen_SOURCE_PATH if turn this on." OFF) option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) -option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) @@ -129,6 +128,10 @@ option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF) + +# When loading a delay loaded DLL, Windows searches the main EXE's folder first. +# In a Python process, it searches where python.exe lives, but it doesn't search the python package's installation folder. Therefore we cannot enable this flag when Python is enabled. +cmake_dependent_option(onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS "Delay load some of the dependent DLls that are part of the OS" ON "WIN32;NOT GDK_PLATFORM;NOT onnxruntime_ENABLE_PYTHON" OFF) option(onnxruntime_USE_DML "Build with DirectML support" OFF) option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF) option(onnxruntime_USE_WINML "Build with WinML support" OFF) @@ -141,13 +144,15 @@ option(onnxruntime_USE_TELEMETRY "Build with Telemetry" OFF) cmake_dependent_option(onnxruntime_USE_MIMALLOC "Override new/delete and arena allocator with mimalloc" OFF "WIN32;NOT onnxruntime_USE_CUDA;NOT onnxruntime_USE_OPENVINO" OFF) option(onnxruntime_USE_CANN "Build with CANN support" OFF) option(onnxruntime_USE_ROCM "Build with AMD GPU support" OFF) -option(onnxruntime_USE_TVM "Build with TVM support" OFF) -option(onnxruntime_TVM_CUDA_RUNTIME "Build TVM with CUDA support" OFF) -option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llvm-config.exe here if need" OFF) -option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only") option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF) +option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF) +option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.") +option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF) +# The following 2 options are only for Windows +option(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN "Enable Vulkan backend for Dawn (on Windows)" OFF) +option(onnxruntime_ENABLE_DAWN_BACKEND_D3D12 "Enable D3D12 backend for Dawn (on Windows)" ON) # Options related to reducing the binary size produced by the build # XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON @@ -252,6 +257,7 @@ cmake_dependent_option(MSVC_Z7_OVERRIDE "replacing /Zi and /ZI with /Z7 when usi option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF) option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF) +option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF) # ENABLE_TRAINING includes all training functionality # The following 2 entry points @@ -901,11 +907,6 @@ if (onnxruntime_USE_SNPE) list(APPEND ONNXRUNTIME_PROVIDER_NAMES snpe) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_SNPE=1) endif() -if (onnxruntime_USE_TVM) - list(APPEND ORT_PROVIDER_FLAGS -DUSE_TVM=1) - list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_TVM=1) - list(APPEND ONNXRUNTIME_PROVIDER_NAMES tvm) -endif() if (onnxruntime_USE_WINML) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WINML=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WINML=1) @@ -958,6 +959,18 @@ if (onnxruntime_USE_WEBGPU) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + list(APPEND ORT_PROVIDER_FLAGS -DBUILD_DAWN_MONOLITHIC_LIBRARY=1) + endif() + if (onnxruntime_USE_EXTERNAL_DAWN) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_VULKAN=1) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1) + endif() endif() if (onnxruntime_USE_CANN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1) @@ -973,6 +986,10 @@ if (onnxruntime_USE_LOCK_FREE_QUEUE) add_compile_definitions(USE_LOCK_FREE_QUEUE) endif() +if (onnxruntime_FORCE_GENERIC_ALGORITHMS) + add_compile_definitions(FORCE_GENERIC_ALGORITHMS) +endif() + if (onnxruntime_ENABLE_LAZY_TENSOR) # To support LazyTensor, ORT needs to call Python function from C/C++. # so onnxruntime_ENABLE_PYTHON is required. @@ -1305,50 +1322,6 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -# TVM EP -if (onnxruntime_USE_TVM) - if (NOT TARGET tvm) - message(STATUS "Include TVM(*).") - include(tvm) - endif() - - # ipp-crypto - if (onnxruntime_TVM_USE_HASH) - message(STATUS "Include ipp-crypto(*).") - include(ipp-crypto) - endif() - - # TVM - if (onnxruntime_TVM_USE_LLVM) - set(USE_LLVM "${onnxruntime_TVM_USE_LLVM}" CACHE STRING "Path to LLVM for correct TVM build") - elseif(onnxruntime_USE_LLVM) - set(USE_LLVM ON CACHE BOOL "Only defined for TVM") - endif() - - if (onnxruntime_TVM_CUDA_RUNTIME) - set(USE_CUDA ON CACHE BOOL "Only defined for TVM" FORCE) - endif() - - # TODO(vvchernov): customized tvm logger is hidden due to the issue on TVM side (https://github.com/apache/tvm/issues/10139) - # add_compile_definitions(TVM_LOG_CUSTOMIZE=1) - # add_library(tvm_custom_logger STATIC ${ONNXRUNTIME_ROOT}/core/providers/tvm/custom_logging.cc) - - set(USE_OPENMP gnu CACHE STRING "Only defined for TVM") - add_subdirectory(${tvm_SOURCE_DIR} ${tvm_BINARY_DIR} EXCLUDE_FROM_ALL) - - set_target_properties(tvm PROPERTIES FOLDER ${tvm_SOURCE_DIR}) - # target_link_libraries(tvm PUBLIC tvm_custom_logger) - - set(TVM_INCLUDES ${tvm_SOURCE_DIR}/include - ${tvm_SOURCE_DIR}/3rdparty/dmlc-core/include - ${tvm_SOURCE_DIR}/3rdparty/dlpack/include - $) - - set(onnxruntime_tvm_libs onnxruntime_providers_tvm) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm) - list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm) -endif() - # onnxruntime-extensions if (onnxruntime_USE_EXTENSIONS) include(extensions) @@ -1359,7 +1332,7 @@ endif() #Adjust warning flags set_msvc_c_cpp_compiler_warning_level(4) -set(onnxruntime_DELAYLOAD_FLAGS "") +set(onnxruntime_DELAYLOAD_FLAGS ) include_directories( ${ONNXRUNTIME_INCLUDE_DIR} diff --git a/cmake/deps.txt b/cmake/deps.txt index 2aec0e35e1d7f..21f9ee1701c46 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -36,8 +36,8 @@ microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.z mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.16.1.zip;2eb9198bb352757d5ff13977cbe0634898e0837c -# Use the latest commit of 10.4-GA-ORT-DDS -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/9f98e2ebe7507fe0774d06a44bbf4b0e82cc9ce7.zip;1d92137f424513bce20033ab4fb31cc0be8d1185 +# Use the latest commit of 10.6-GA-ORT-DDS +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bc0d2e35909b8456abe32f3b30a49bb0c125e8b7.zip;f233ae871ad82c023da62e5dd620639f00bc2d15 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 @@ -58,5 +58,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d839 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1 -dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99 +dawn;https://github.com/google/dawn/archive/12a3b24c456cebd9fd11f23ac0164f78129b00c6.zip;ad428f6dc16f1336d584f7bad5714e1097dafc43 kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/v0.2.0/kleidiai-v0.2.0.zip;B1E3173992FD91F20DB904AB77D6E901778C2681 diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index e03506de12728..3cfcdd4b04c62 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.2) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.15.4) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 339cded091b29..95dd438702a18 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -15,6 +15,7 @@ else () eigen URL ${DEP_URL_eigen} URL_HASH SHA1=${DEP_SHA1_eigen} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/eigen/eigen-edge.patch ) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index a69d2649ad832..aeaaa7b51d595 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -615,17 +615,39 @@ if (onnxruntime_USE_COREML) endif() if (onnxruntime_USE_WEBGPU) - FetchContent_Declare( - dawn - URL ${DEP_URL_dawn} - URL_HASH SHA1=${DEP_SHA1_dawn} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch - ) + if (onnxruntime_CUSTOM_DAWN_SRC_PATH) + # use the custom dawn source path if provided + # + # specified as: + # build.py --use_webgpu --cmake_extra_defines "onnxruntime_CUSTOM_DAWN_SRC_PATH=" + FetchContent_Declare( + dawn + SOURCE_DIR ${onnxruntime_CUSTOM_DAWN_SRC_PATH} + ) + else() + FetchContent_Declare( + dawn + URL ${DEP_URL_dawn} + URL_HASH SHA1=${DEP_SHA1_dawn} + # All previous patches are merged into the upstream dawn project. We don't need to apply any patches right now. + # if we need to apply patches in the future, we can uncomment the following line. + # PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch + ) + endif() - # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size - set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) + + if (onnxruntime_USE_EXTERNAL_DAWN) + message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") + endif() + else() + # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size + set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) + endif() set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) @@ -654,13 +676,34 @@ if (onnxruntime_USE_WEBGPU) set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) - # Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it. - set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) + message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + endif() + # We are currently always using the D3D12 backend. + set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() onnxruntime_fetchcontent_makeavailable(dawn) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native dawn::dawn_proc) + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::webgpu_dawn) + else() + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native) + endif() + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc) + endif() endif() set(onnxruntime_LINK_DIRS) diff --git a/cmake/external/tvm.cmake b/cmake/external/tvm.cmake deleted file mode 100644 index 93049c8b85853..0000000000000 --- a/cmake/external/tvm.cmake +++ /dev/null @@ -1,24 +0,0 @@ -if (onnxruntime_USE_TVM) - message(STATUS "onnxruntime_USE_TVM: Fetch tvm for TVM EP") - - FetchContent_Declare( - tvm - GIT_REPOSITORY https://github.com/apache/tvm.git - GIT_TAG 2379917985919ed3918dc12cad47f469f245be7a - ) - - FetchContent_GetProperties(tvm) - if(NOT tvm_POPULATED) - FetchContent_Populate(tvm) - if (WIN32) - execute_process( - COMMAND ${CMAKE_COMMAND} -E create_symlink ${tvm_BINARY_DIR}/${CMAKE_BUILD_TYPE} ${tvm_SOURCE_DIR}/build - ) - else() - file(CREATE_LINK ${tvm_BINARY_DIR} ${tvm_SOURCE_DIR}/build SYMBOLIC) - endif() - endif() - - set(tvm_INCLUDE_DIRS ${tvm_SOURCE_DIR}/include) - -endif() diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 9602e54f3bc2d..732c0511d400f 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -122,8 +122,12 @@ else() else() onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c ) endif() - if (onnxruntime_USE_CUDA) - set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN") + if(NOT APPLE) + include(CheckLinkerFlag) + check_linker_flag(CXX "LINKER:-rpath=\$ORIGIN" LINKER_SUPPORT_RPATH) + if(LINKER_SUPPORT_RPATH) + target_link_options(onnxruntime PRIVATE "LINKER:-rpath=\$ORIGIN") + endif() endif() endif() @@ -139,17 +143,17 @@ target_compile_definitions(onnxruntime PRIVATE FILE_NAME=\"onnxruntime.dll\") if(UNIX) if (APPLE) - set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker -dead_strip") + target_link_options(onnxruntime PRIVATE "LINKER:-dead_strip") elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker --version-script=${SYMBOL_FILE} -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") + target_link_options(onnxruntime PRIVATE "LINKER:--version-script=${SYMBOL_FILE}" "LINKER:--no-undefined" "LINKER:--gc-sections") endif() else() - set(ONNXRUNTIME_SO_LINK_FLAG " -DEF:${SYMBOL_FILE}") + target_link_options(onnxruntime PRIVATE "-DEF:${SYMBOL_FILE}") endif() -if (NOT WIN32) - if (APPLE OR ${CMAKE_SYSTEM_NAME} MATCHES "^iOS") - set(ONNXRUNTIME_SO_LINK_FLAG " -Wl,-exported_symbols_list,${SYMBOL_FILE}") + +if (APPLE OR ${CMAKE_SYSTEM_NAME} MATCHES "^iOS") + target_link_options(onnxruntime PRIVATE "LINKER:-exported_symbols_list,${SYMBOL_FILE}") if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") set_target_properties(onnxruntime PROPERTIES MACOSX_RPATH TRUE @@ -159,12 +163,10 @@ if (NOT WIN32) else() set_target_properties(onnxruntime PROPERTIES INSTALL_RPATH "@loader_path") endif() - elseif (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'") - endif() endif() + if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_MINIMAL_BUILD) # target onnxruntime is a shared library, the dummy __cxa_demangle is only attach to it to avoid # affecting downstream ort library users with the behavior of dummy __cxa_demangle. So the dummy @@ -208,7 +210,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_NNAPI} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} - ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} @@ -219,7 +220,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${onnxruntime_winml} onnxruntime_optimizer onnxruntime_providers - ${onnxruntime_tvm_libs} onnxruntime_lora onnxruntime_framework onnxruntime_graph @@ -248,7 +248,9 @@ target_link_libraries(onnxruntime PRIVATE ${onnxruntime_EXTERNAL_LIBRARIES} ) -set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_SO_LINK_FLAG} ${onnxruntime_DELAYLOAD_FLAGS}) +if(WIN32) + target_link_options(onnxruntime PRIVATE ${onnxruntime_DELAYLOAD_FLAGS}) +endif() #See: https://cmake.org/cmake/help/latest/prop_tgt/SOVERSION.html if(NOT APPLE AND NOT WIN32) if(${CMAKE_SYSTEM_NAME} MATCHES "AIX") diff --git a/cmake/onnxruntime_codegen_tvm.cmake b/cmake/onnxruntime_codegen_tvm.cmake deleted file mode 100644 index 7b50d8f8603ae..0000000000000 --- a/cmake/onnxruntime_codegen_tvm.cmake +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -file(GLOB_RECURSE onnxruntime_codegen_common_srcs - "${ONNXRUNTIME_ROOT}/core/codegen/common/*.h" - "${ONNXRUNTIME_ROOT}/core/codegen/common/*.cc" -) - -file(GLOB_RECURSE onnxruntime_codegen_tvm_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/codegen/mti/*.h" - "${ONNXRUNTIME_ROOT}/core/codegen/mti/*.cc" - "${ONNXRUNTIME_ROOT}/core/codegen/passes/*.h" - "${ONNXRUNTIME_ROOT}/core/codegen/passes/*.cc" -) - -source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_codegen_common_srcs} ${onnxruntime_codegen_tvm_srcs}) - -#onnxruntime_codegen_tvm depends on onnxruntime framework -onnxruntime_add_static_library(onnxruntime_codegen_tvm ${onnxruntime_codegen_common_srcs} ${onnxruntime_codegen_tvm_srcs}) -set_target_properties(onnxruntime_codegen_tvm PROPERTIES FOLDER "ONNXRuntime") -target_include_directories(onnxruntime_codegen_tvm PRIVATE ${ONNXRUNTIME_ROOT} ${TVM_INCLUDES} ${MKLML_INCLUDE_DIR} ${eigen_INCLUDE_DIRS}) -onnxruntime_add_include_to_target(onnxruntime_codegen_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11) -target_compile_options(onnxruntime_codegen_tvm PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) -# need onnx to build to create headers that this project includes -add_dependencies(onnxruntime_codegen_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/cmake/onnxruntime_csharp.cmake b/cmake/onnxruntime_csharp.cmake index 22c993d07f7f9..39533429e181c 100644 --- a/cmake/onnxruntime_csharp.cmake +++ b/cmake/onnxruntime_csharp.cmake @@ -30,10 +30,6 @@ if (onnxruntime_USE_NNAPI_BUILTIN) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_NNAPI;") endif() -if (onnxruntime_USE_TVM) - STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_TVM,") -endif() - if (onnxruntime_USE_OPENVINO) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_OPENVINO;") endif() diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 765ebab111ac7..b15b9632e9e24 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -7,7 +7,7 @@ include(FindJava) find_package(Java REQUIRED) include(UseJava) -if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android") +if (NOT ANDROID) find_package(JNI REQUIRED) endif() @@ -21,23 +21,28 @@ endif() set(GRADLE_EXECUTABLE "${JAVA_ROOT}/gradlew") +set(COMMON_GRADLE_ARGS --console=plain) +if(WIN32) + list(APPEND COMMON_GRADLE_ARGS -Dorg.gradle.daemon=false) +elseif (ANDROID) + # For Android build, we may run gradle multiple times in same build, + # sometimes gradle JVM will run out of memory if we keep the daemon running + # it is better to not keep a daemon running + list(APPEND COMMON_GRADLE_ARGS --no-daemon) +endif() + # Specify the Java source files file(GLOB_RECURSE onnxruntime4j_gradle_files "${JAVA_ROOT}/*.gradle") file(GLOB_RECURSE onnxruntime4j_src "${JAVA_ROOT}/src/main/java/ai/onnxruntime/*.java") set(JAVA_OUTPUT_JAR ${JAVA_ROOT}/build/libs/onnxruntime.jar) # this jar is solely used to signaling mechanism for dependency management in CMake # if any of the Java sources change, the jar (and generated headers) will be regenerated and the onnxruntime4j_jni target will be rebuilt -set(GRADLE_ARGS --console=plain clean jar -x test) -if(WIN32) - set(GRADLE_ARGS ${GRADLE_ARGS} -Dorg.gradle.daemon=false) -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") - # For Android build, we may run gradle multiple times in same build, - # sometimes gradle JVM will run out of memory if we keep the daemon running - # it is better to not keep a daemon running - set(GRADLE_ARGS ${GRADLE_ARGS} --no-daemon) -endif() +set(GRADLE_ARGS clean jar -x test) -add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_ARGS} WORKING_DIRECTORY ${JAVA_ROOT} DEPENDS ${onnxruntime4j_gradle_files} ${onnxruntime4j_src}) +add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} + COMMAND ${GRADLE_EXECUTABLE} ${COMMON_GRADLE_ARGS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_ROOT} + DEPENDS ${onnxruntime4j_gradle_files} ${onnxruntime4j_src}) add_custom_target(onnxruntime4j DEPENDS ${JAVA_OUTPUT_JAR}) set_source_files_properties(${JAVA_OUTPUT_JAR} PROPERTIES GENERATED TRUE) set_property(TARGET onnxruntime4j APPEND PROPERTY ADDITIONAL_CLEAN_FILES "${JAVA_OUTPUT_DIR}") @@ -62,7 +67,7 @@ target_link_libraries(onnxruntime4j_jni PUBLIC onnxruntime) set(JAVA_PACKAGE_OUTPUT_DIR ${JAVA_OUTPUT_DIR}/build) file(MAKE_DIRECTORY ${JAVA_PACKAGE_OUTPUT_DIR}) -if (CMAKE_SYSTEM_NAME STREQUAL "Android") +if (ANDROID) set(ANDROID_PACKAGE_OUTPUT_DIR ${JAVA_PACKAGE_OUTPUT_DIR}/android) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_OUTPUT_DIR}) endif() @@ -88,7 +93,7 @@ if(APPLE) elseif(JNI_ARCH STREQUAL "arm64") set(JNI_ARCH aarch64) endif() -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") +elseif (ANDROID) set(JNI_ARCH ${ANDROID_ABI}) elseif (ARM64) set(JNI_ARCH aarch64) @@ -180,15 +185,7 @@ else() endif() # run the build process (this copies the results back into CMAKE_CURRENT_BINARY_DIR) -set(GRADLE_ARGS --console=plain cmakeBuild -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}) -if(WIN32) - set(GRADLE_ARGS ${GRADLE_ARGS} -Dorg.gradle.daemon=false) -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") - # For Android build, we may run gradle multiple times in same build, - # sometimes gradle JVM will run out of memory if we keep the daemon running - # it is better to not keep a daemon running - set(GRADLE_ARGS ${GRADLE_ARGS} --no-daemon) -endif() +set(GRADLE_ARGS cmakeBuild -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}) # Append relevant native build flags to gradle command set(GRADLE_ARGS ${GRADLE_ARGS} ${ORT_PROVIDER_FLAGS}) @@ -197,9 +194,11 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() message(STATUS "GRADLE_ARGS: ${GRADLE_ARGS}") -add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_ARGS} WORKING_DIRECTORY ${JAVA_ROOT}) +add_custom_command(TARGET onnxruntime4j_jni POST_BUILD + COMMAND ${GRADLE_EXECUTABLE} ${COMMON_GRADLE_ARGS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_ROOT}) -if (CMAKE_SYSTEM_NAME STREQUAL "Android") +if (ANDROID) set(ANDROID_PACKAGE_JNILIBS_DIR ${JAVA_OUTPUT_DIR}/android) set(ANDROID_PACKAGE_ABI_DIR ${ANDROID_PACKAGE_JNILIBS_DIR}/${ANDROID_ABI}) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_JNILIBS_DIR}) @@ -214,6 +213,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android") POST_BUILD COMMAND ${CMAKE_COMMAND} -E echo "Generating Android AAR package..." COMMAND ${GRADLE_EXECUTABLE} + ${COMMON_GRADLE_ARGS} build -b build-android.gradle -c settings-android.gradle -DjniLibsDir=${ANDROID_PACKAGE_JNILIBS_DIR} -DbuildDir=${ANDROID_PACKAGE_OUTPUT_DIR} @@ -237,6 +237,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android") POST_BUILD COMMAND ${CMAKE_COMMAND} -E echo "Building and running Android test for Android AAR package..." COMMAND ${GRADLE_EXECUTABLE} + ${COMMON_GRADLE_ARGS} clean assembleDebug assembleDebugAndroidTest -DminSdkVer=${ANDROID_MIN_SDK} --stacktrace diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 9e5a51ca3bee8..f5d8bde0b0427 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -36,11 +36,13 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/sqnbitgemm.h - ${MLAS_SRC_DIR}/sqnbitgemm.cpp + ${MLAS_SRC_DIR}/qnbitgemm.h + ${MLAS_SRC_DIR}/qnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/rotary_embedding.h + ${MLAS_SRC_DIR}/rotary_embedding.cpp ${MLAS_SRC_DIR}/qsoftmax.cpp ${MLAS_SRC_DIR}/qsoftmax_kernel_naive.cpp ) @@ -86,11 +88,15 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -369,10 +375,12 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -390,7 +398,9 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -400,7 +410,9 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -460,7 +472,6 @@ else() bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); return 0; } - } #endif" HAS_P10_RUNTIME ) @@ -686,6 +697,13 @@ endif() if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/scalar/*.cpp") + elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) + file(GLOB_RECURSE mlas_platform_srcs_generic + "${MLAS_SRC_DIR}/scalar/*.cpp") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_generic} + ) endif() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 9666877cdc206..582491de9503d 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -101,9 +101,6 @@ endif() if(onnxruntime_USE_ROCM) set(PROVIDERS_ROCM onnxruntime_providers_rocm) endif() -if (onnxruntime_USE_TVM) - set(PROVIDERS_TVM onnxruntime_providers_tvm) -endif() if (onnxruntime_USE_XNNPACK) set(PROVIDERS_XNNPACK onnxruntime_providers_xnnpack) endif() @@ -194,10 +191,6 @@ if (onnxruntime_USE_ROCM) include(onnxruntime_providers_rocm.cmake) endif() -if (onnxruntime_USE_TVM) - include(onnxruntime_providers_tvm.cmake) -endif() - if (onnxruntime_USE_VSINPU) include(onnxruntime_providers_vsinpu.cmake) endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 39ad530146b33..4f86717026118 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -224,8 +224,7 @@ include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} - PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_providers_dml.cmake b/cmake/onnxruntime_providers_dml.cmake index 439be882dcc5e..3141aa85a1163 100644 --- a/cmake/onnxruntime_providers_dml.cmake +++ b/cmake/onnxruntime_providers_dml.cmake @@ -61,8 +61,9 @@ target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib) - if (NOT GDK_PLATFORM) - set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:dxcore.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199") + if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS AND NOT GDK_PLATFORM) + #NOTE: the flags are only applied to onnxruntime.dll and the PYD file in our python package. Our C/C++ unit tests do not use these flags. + list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:DirectML.dll" "/DELAYLOAD:d3d12.dll" "/DELAYLOAD:dxgi.dll" "/DELAYLOAD:dxcore.dll" "/DELAYLOAD:api-ms-win-core-com-l1-1-0.dll" "/DELAYLOAD:shlwapi.dll" "/DELAYLOAD:oleaut32.dll" "/DELAYLOAD:ext-ms-win-dxcore-l1-*.dll" "/ignore:4199") endif() target_compile_definitions(onnxruntime_providers_dml diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 5dcee285a5b13..f5fae8d169ccc 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -11,22 +11,22 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) - if (WIN32) - set(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO Release) - endif() - # Header paths find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) - if(OpenVINO_VERSION VERSION_LESS 2024.0) - message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release") + if(OpenVINO_VERSION VERSION_LESS 2024.4) + message(FATAL_ERROR "OpenVINO 2024.4 and newer are supported. Please, use latest OpenVINO release") endif() if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4) add_definitions(-DUSE_OVEP_NPU_MEMORY=1) endif() - if (WIN32) - unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) + # If building RelWithDebInfo and OV package does not have that configuration map to Release + get_target_property(ov_rt_implib_rwdi openvino::runtime IMPORTED_IMPLIB_RELWITHDEBINFO) + if ((CMAKE_BUILD_TYPE STREQUAL RelWithDebInfo) AND NOT ov_rt_implib_rwdi) + set_target_properties(openvino::runtime PROPERTIES + MAP_IMPORTED_CONFIG_RELWITHDEBINFO Release + ) endif() list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES}) @@ -82,3 +82,8 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + +set_target_properties(onnxruntime_providers_openvino PROPERTIES + MAP_IMPORTED_CONFIG_RELEASE RelWithDebInfo + MAP_IMPORTED_CONFIG_DEBUG RelWithDebInfo + ) \ No newline at end of file diff --git a/cmake/onnxruntime_providers_tvm.cmake b/cmake/onnxruntime_providers_tvm.cmake deleted file mode 100644 index 8fd50c70dd5d7..0000000000000 --- a/cmake/onnxruntime_providers_tvm.cmake +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - - add_definitions(-DUSE_TVM=1) - if (onnxruntime_TVM_USE_HASH) - add_definitions(-DUSE_TVM_HASH=1) - endif() - - if (onnxruntime_TVM_USE_HASH) - file (GLOB_RECURSE onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" - ) - else() - file (GLOB onnxruntime_providers_tvm_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/tvm/*.cc" - ) - endif() - - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tvm_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_tvm ${onnxruntime_providers_tvm_cc_srcs}) - - if ( CMAKE_COMPILER_IS_GNUCC ) - target_compile_options(onnxruntime_providers_tvm PRIVATE -Wno-unused-parameter -Wno-missing-field-initializers) - endif() - - target_include_directories(onnxruntime_providers_tvm PRIVATE - ${TVM_INCLUDES} - ${PYTHON_INCLUDE_DIRS}) - onnxruntime_add_include_to_target(onnxruntime_providers_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) - - add_dependencies(onnxruntime_providers_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) - - if (onnxruntime_TVM_USE_HASH) - add_dependencies(onnxruntime_providers_tvm ippcp_s) - target_include_directories(onnxruntime_providers_tvm PRIVATE ${IPP_CRYPTO_INCLUDE_DIR}) - target_link_libraries(onnxruntime_providers_tvm PRIVATE ippcp_s) - endif() - - set_target_properties(onnxruntime_providers_tvm PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_tvm PROPERTIES LINKER_LANGUAGE CXX) - - if (WIN32 AND MSVC) - # wd4100: identifier' : unreferenced formal parameter - # wd4127: conditional expression is constant - # wd4244: conversion from 'int' to 'char', possible loss of data - # TODO: 4244 should not be disabled - target_compile_options(onnxruntime_providers_tvm PRIVATE "/wd4100" "/wd4127" "/wd4244") - else() - target_compile_options(onnxruntime_providers_tvm PRIVATE "-Wno-error=type-limits") - endif() - target_compile_definitions(onnxruntime_providers_tvm PUBLIC DMLC_USE_LOGGING_LIBRARY=) - - install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tvm/tvm_provider_factory.h - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_tvm - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 764cde9491da8..561a323533f48 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -12,6 +12,7 @@ file(GLOB onnxruntime_providers_vitisai_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include/vaip/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index eb25c55ab23e0..fea5964f0dda9 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -22,6 +22,25 @@ onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native dawn::dawn_proc) + + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + + if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS) + list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll") + endif() + + # Copy webgpu_dawn.dll to the output directory + add_custom_command( + TARGET onnxruntime_providers_webgpu + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" "$" + VERBATIM ) + else() + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native) + endif() + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc) + endif() set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 0d038d210ea2b..5a87252b08573 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -110,17 +110,17 @@ if (onnxruntime_USE_NCCL) endif() if(APPLE) - set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker -exported_symbols_list -Xlinker ${ONNXRUNTIME_ROOT}/python/exported_symbols.lst") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:-exported_symbols_list,${ONNXRUNTIME_ROOT}/python/exported_symbols.lst") elseif(UNIX) if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) - set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/python/version_script_expose_onnx_protobuf.lds -Xlinker --gc-sections") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:--version-script=${ONNXRUNTIME_ROOT}/python/version_script_expose_onnx_protobuf.lds" "LINKER:--gc-sections") else() if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") - set(ONNXRUNTIME_SO_LINK_FLAG "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/python/version_script.lds -Xlinker --gc-sections") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:--version-script=${ONNXRUNTIME_ROOT}/python/version_script.lds" "LINKER:--gc-sections") endif() endif() else() - set(ONNXRUNTIME_SO_LINK_FLAG "-DEF:${ONNXRUNTIME_ROOT}/python/pybind.def") + target_link_options(onnxruntime_pybind11_state PRIVATE "-DEF:${ONNXRUNTIME_ROOT}/python/pybind.def") endif() if (onnxruntime_ENABLE_ATEN) @@ -169,8 +169,8 @@ endif() target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} - ${PROVIDERS_TVM} ${PROVIDERS_NNAPI} + ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} ${PROVIDERS_RKNPU} @@ -184,7 +184,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_optimizer onnxruntime_providers onnxruntime_util - ${onnxruntime_tvm_libs} onnxruntime_lora onnxruntime_framework onnxruntime_util @@ -199,11 +198,11 @@ set(onnxruntime_pybind11_state_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES} ${pybind11_dep} ) -set_property(TARGET onnxruntime_pybind11_state APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_SO_LINK_FLAG} ${onnxruntime_DELAYLOAD_FLAGS}) + add_dependencies(onnxruntime_pybind11_state ${onnxruntime_pybind11_state_dependencies}) if (MSVC) - set_target_properties(onnxruntime_pybind11_state PROPERTIES LINK_FLAGS "${ONNXRUNTIME_SO_LINK_FLAG}") + target_link_options(onnxruntime_pybind11_state PRIVATE ${onnxruntime_DELAYLOAD_FLAGS}) # if MSVC, pybind11 undefines _DEBUG in pybind11/detail/common.h, which causes the pragma in pyconfig.h # from the python installation to require the release version of the lib # e.g. from a python 3.10 install: @@ -220,14 +219,15 @@ if (MSVC) # Explicitly use the release version of the python library to make the project file consistent with this. target_link_libraries(onnxruntime_pybind11_state PRIVATE ${Python_LIBRARY_RELEASE}) elseif (APPLE) - set_target_properties(onnxruntime_pybind11_state PROPERTIES LINK_FLAGS "${ONNXRUNTIME_SO_LINK_FLAG} -Xlinker -undefined -Xlinker dynamic_lookup") + # The following flag no longer works + #target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:-undefined,dynamic_lookup") set_target_properties(onnxruntime_pybind11_state PROPERTIES INSTALL_RPATH "@loader_path" BUILD_WITH_INSTALL_RPATH TRUE INSTALL_RPATH_USE_LINK_PATH FALSE) else() if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") - set_property(TARGET onnxruntime_pybind11_state APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN") + target_link_options(onnxruntime_pybind11_state PRIVATE "LINKER:-rpath=\$ORIGIN") endif() endif() @@ -238,8 +238,8 @@ if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) MATH(EXPR PROTOBUF_INDEX_NEXT "${PROTOBUF_INDEX} + 1") if (ONNX_INDEX GREATER_EQUAL 0 AND PROTOBUF_INDEX GREATER_EQUAL 0) # Expect protobuf to follow onnx due to dependence - list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${ONNX_INDEX} "-Wl,--no-as-needed") - list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${PROTOBUF_INDEX_NEXT} "-Wl,--as-needed") + list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${ONNX_INDEX} "LINKER:--no-as-needed") + list(INSERT onnxruntime_CUSTOM_EXTERNAL_LIBRARIES ${PROTOBUF_INDEX_NEXT} "LINKER:--as-needed") else() message(FATAL_ERROR "Required external libraries onnx and protobuf are not found in onnxruntime_EXTERNAL_LIBRARIES") endif() @@ -964,37 +964,6 @@ if (onnxruntime_USE_ROCM) ) endif() -if (onnxruntime_USE_TVM) - file(GLOB onnxruntime_python_providers_tvm_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/providers/tvm/*.py" - ) - add_custom_command( - TARGET onnxruntime_pybind11_state POST_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/providers - COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/providers/tvm - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_providers_tvm_srcs} - $/onnxruntime/providers/tvm - COMMAND ${CMAKE_COMMAND} -E copy - $ - $/onnxruntime/capi/ - ) - - add_custom_command( - TARGET onnxruntime_pybind11_state POST_BUILD - WORKING_DIRECTORY ${tvm_SOURCE_DIR}/python - COMMAND ${Python_EXECUTABLE} setup.py bdist_wheel - ) - - add_custom_command( - TARGET onnxruntime_pybind11_state POST_BUILD - COMMAND ${Python_EXECUTABLE} - $/onnxruntime/providers/tvm/extend_python_file.py - --target_file $/onnxruntime/capi/_ld_preload.py - ) - -endif() - if (onnxruntime_USE_DML) if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(dml_shared_lib_path ${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}-win/${DML_SHARED_LIB}) @@ -1050,4 +1019,13 @@ if (onnxruntime_USE_QNN) endif() endif() +if (onnxruntime_USE_VSINPU) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $/onnxruntime/capi/ + ) +endif() + endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 67e5a9c0aa08b..e822f0a3655fc 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -9,9 +9,6 @@ set(TEST_INC_DIR ${ONNXRUNTIME_ROOT}) if (onnxruntime_ENABLE_TRAINING) list(APPEND TEST_INC_DIR ${ORTTRAINING_ROOT}) endif() -if (onnxruntime_USE_TVM) - list(APPEND TEST_INC_DIR ${TVM_INCLUDES}) -endif() set(disabled_warnings) function(AddTest) @@ -67,7 +64,10 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart cudnn_frontend) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) + if(NOT onnxruntime_CUDA_MINIMAL) + target_link_libraries(${_UT_TARGET} PRIVATE cudnn_frontend) + endif() endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -111,7 +111,6 @@ function(AddTest) endif() target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings}) else() - target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") if (${HAS_NOERROR}) @@ -523,6 +522,9 @@ set (onnxruntime_global_thread_pools_test_SRC ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc ${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc) +set (onnxruntime_webgpu_external_dawn_test_SRC + ${TEST_SRC_DIR}/webgpu/external_dawn/main.cc) + # tests from lowest level library up. # the order of libraries should be maintained, with higher libraries being added first in the list @@ -638,13 +640,11 @@ set(ONNXRUNTIME_TEST_LIBS ${PROVIDERS_ACL} ${PROVIDERS_ARMNN} ${PROVIDERS_COREML} - # ${PROVIDERS_TVM} ${PROVIDERS_XNNPACK} ${PROVIDERS_AZURE} onnxruntime_optimizer onnxruntime_providers onnxruntime_util - ${onnxruntime_tvm_libs} onnxruntime_lora onnxruntime_framework onnxruntime_util @@ -746,12 +746,6 @@ if(onnxruntime_USE_AZURE) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_azure) endif() -if(WIN32) - if (onnxruntime_USE_TVM) - list(APPEND disabled_warnings ${DISABLED_WARNINGS_FOR_TVM}) - endif() -endif() - file(GLOB onnxruntime_test_framework_src CONFIGURE_DEPENDS ${onnxruntime_test_framework_src_patterns} ) @@ -852,9 +846,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) list(APPEND all_tests ${onnxruntime_test_training_api_src}) endif() -if (onnxruntime_USE_TVM) - list(APPEND all_tests ${onnxruntime_test_tvm_src}) -endif() if (onnxruntime_USE_OPENVINO) list(APPEND all_tests ${onnxruntime_test_openvino_src}) @@ -1086,15 +1077,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) COMMAND ${CMAKE_COMMAND} -E copy ${DNNL_DLL_PATH} $ ) endif() - if(WIN32) - if (onnxruntime_USE_TVM) - add_custom_command( - TARGET ${test_data_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy $ $ - ) - endif() - endif() - if(WIN32) set(wide_get_opt_src_dir ${TEST_SRC_DIR}/win_getopt/wide) onnxruntime_add_static_library(win_getopt_wide ${wide_get_opt_src_dir}/getopt.cc ${wide_get_opt_src_dir}/include/getopt.h) @@ -1136,12 +1118,6 @@ if (NOT IOS) endif() set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") - endif() - endif() - install(TARGETS onnx_test_runner ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -1295,11 +1271,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000") - endif() - endif() endif() @@ -1884,4 +1855,13 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD endif() endif() +if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN) + AddTest(TARGET onnxruntime_webgpu_external_dawn_test + SOURCES ${onnxruntime_webgpu_external_dawn_test_SRC} + LIBS dawn::dawn_native ${onnxruntime_test_providers_libs} + DEPENDS ${all_dependencies} + ) + onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers) +endif() + include(onnxruntime_fuzz_test.cmake) diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch deleted file mode 100644 index d696d386452e8..0000000000000 --- a/cmake/patches/dawn/dawn.patch +++ /dev/null @@ -1,66 +0,0 @@ -diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt -index 9c0bd6fa4e..bf8a57aeac 100644 ---- a/src/dawn/native/CMakeLists.txt -+++ b/src/dawn/native/CMakeLists.txt -@@ -857,6 +857,11 @@ if (DAWN_ENABLE_SWIFTSHADER) - target_compile_definitions(dawn_native PRIVATE "DAWN_ENABLE_SWIFTSHADER") - endif() - -+if (IOS) -+ target_compile_options(dawn_native_objects PRIVATE -fno-objc-arc) -+ target_compile_options(dawn_native PRIVATE -fno-objc-arc) -+endif() -+ - if (DAWN_BUILD_MONOLITHIC_LIBRARY) - ############################################################################### - # Do the 'complete_lib' build. -diff --git a/src/dawn/native/Surface_metal.mm b/src/dawn/native/Surface_metal.mm -index ce55acbd43..baa4835362 100644 ---- a/src/dawn/native/Surface_metal.mm -+++ b/src/dawn/native/Surface_metal.mm -@@ -36,7 +36,13 @@ - namespace dawn::native { - - bool InheritsFromCAMetalLayer(void* obj) { -- id object = static_cast(obj); -+ id object = -+#if TARGET_OS_IOS -+ (__bridge id)obj; -+#else -+ static_cast(obj); -+#endif -+ - return [object isKindOfClass:[CAMetalLayer class]]; - } - -diff --git a/src/dawn/native/metal/SharedFenceMTL.mm b/src/dawn/native/metal/SharedFenceMTL.mm -index bde8bfea07..f2f6459e91 100644 ---- a/src/dawn/native/metal/SharedFenceMTL.mm -+++ b/src/dawn/native/metal/SharedFenceMTL.mm -@@ -40,7 +40,13 @@ ResultOrError> SharedFence::Create( - DAWN_INVALID_IF(descriptor->sharedEvent == nullptr, "MTLSharedEvent is missing."); - if (@available(macOS 10.14, iOS 12.0, *)) { - return AcquireRef(new SharedFence( -- device, label, static_cast>(descriptor->sharedEvent))); -+ device, label, -+#if TARGET_OS_IOS -+ (__bridge id)(descriptor->sharedEvent) -+#else -+ static_cast>(descriptor->sharedEvent) -+#endif -+ )); - } else { - return DAWN_INTERNAL_ERROR("MTLSharedEvent not supported."); - } -diff --git a/src/tint/api/BUILD.cmake b/src/tint/api/BUILD.cmake -index 0037d83276..6372c4ee77 100644 ---- a/src/tint/api/BUILD.cmake -+++ b/src/tint/api/BUILD.cmake -@@ -57,6 +57,7 @@ tint_target_add_dependencies(tint_api lib - tint_lang_wgsl_ast_transform - tint_lang_wgsl_common - tint_lang_wgsl_features -+ tint_lang_wgsl_inspector - tint_lang_wgsl_program - tint_lang_wgsl_sem - tint_lang_wgsl_writer_ir_to_program diff --git a/cmake/patches/eigen/eigen-edge.patch b/cmake/patches/eigen/eigen-edge.patch new file mode 100644 index 0000000000000..d8dc850b4bd55 --- /dev/null +++ b/cmake/patches/eigen/eigen-edge.patch @@ -0,0 +1,13 @@ +diff --git a/Eigen/src/Core/util/IndexedViewHelper.h b/Eigen/src/Core/util/IndexedViewHelper.h +index f85de305f..3dc2bb5e7 100644 +--- a/Eigen/src/Core/util/IndexedViewHelper.h ++++ b/Eigen/src/Core/util/IndexedViewHelper.h +@@ -178,7 +178,7 @@ namespace placeholders { + + EIGEN_DEPRECATED static const all_t all = Eigen::all; // PLEASE use Eigen::all instead of Eigen::placeholders::all + EIGEN_DEPRECATED static const last_t last = Eigen::last; // PLEASE use Eigen::last instead of Eigen::placeholders::last +- EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end ++ // EIGEN_DEPRECATED static const end_t end = Eigen::lastp1; // PLEASE use Eigen::lastp1 instead of Eigen::placeholders::end + } + + } // end namespace Eigen diff --git a/cmake/target_delayload.cmake b/cmake/target_delayload.cmake index 53f252a3e71ac..92273f5424233 100644 --- a/cmake/target_delayload.cmake +++ b/cmake/target_delayload.cmake @@ -6,9 +6,12 @@ function(target_delayload target_name) if(NOT MSVC) message(SEND_ERROR "Delayloading is only supported in MSVC") endif() - foreach(lib ${ARGN}) - target_link_options(${target_name} PRIVATE /DELAYLOAD:"${lib}") - endforeach() + if(onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS) + foreach(lib ${ARGN}) + target_link_options(${target_name} PRIVATE /DELAYLOAD:"${lib}") + endforeach() - target_link_libraries(${target_name} PRIVATE delayimp.lib) + target_link_libraries(${target_name} PRIVATE delayimp.lib) + endif() endfunction() + diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 159b8654c1cb1..fcb2c7d5de89b 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -66,6 +66,12 @@ "platform": "windows" } ], + "overrides": [ + { + "name": "flatbuffers", + "version": "23.5.26" + } + ], "features": { "tests": { "description": "Build ONNXRuntime unit tests", diff --git a/csharp/ApiDocs/docfx.json b/csharp/ApiDocs/docfx.json index 0671d4aeb7d95..88a3283ad76e8 100644 --- a/csharp/ApiDocs/docfx.json +++ b/csharp/ApiDocs/docfx.json @@ -14,7 +14,7 @@ "disableDefaultFilter": false, "noRestore": true, "properties": { - "AllowUnsafeBlocks": true, + "AllowUnsafeBlocks": "true", "TargetFramework": "net8.0", "Nullable": "enable", "LangVersion": "8.0", diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 95207d158affe..6779fd60bcd0a 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -64,13 +64,6 @@ CMake creates a target to this project - - - - - - - @@ -153,7 +146,7 @@ CMake creates a target to this project $(BaseTargets);$(MobileTargets) + + + true + true + true + + + true + true + true + true + + $(ProjectDir)..\..\.. + + + true + + + Microsoft.ML.OnnxRuntime Microsoft.ML.OnnxRuntime @@ -66,54 +93,31 @@ Commit: $(BUILD_SOURCEVERSION) Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=$(BUILD_BUILDID) + README.md + LICENSE.txt + + + true + + true + ..\..\OnnxRuntime.snk + + $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb + AnyCPU;x86 default true - true - ..\..\OnnxRuntime.snk - - $(ProjectDir)..\..\.. - $(OnnxRuntimeRoot)\csharp x64 false false portable - - true - - - true - - - - - false - $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb Debug;Release;RelWithDebInfo - - true - true - true - - - true - true - true - - - $(OnnxRuntimeCsharpRoot)\..\build\Linux - $(OnnxRuntimeBuildDirectory)\$(Configuration) - - - - $(OnnxRuntimeCsharpRoot)\..\build\Windows $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) - - - - $(OnnxRuntimeCsharpRoot)\..\build\MacOS + $(OnnxRuntimeBuildDirectory)\$(Configuration) - + $(OrtConstants);__MOBILE__ @@ -155,12 +148,12 @@ $(OrtConstants);__ANDROID__ - + $(OrtConstants);__IOS__ - - + + $(OrtConstants);__ENABLE_COREML__ @@ -178,128 +171,6 @@ $(DefineConstants);$(OrtConstants) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + - - - + + + + + + + + + + + + + + + + + + - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index be157a0419fc0..d628b065ceaa7 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -1142,9 +1142,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); - - [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] - public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tvm(IntPtr /*(OrtSessionOptions*) */ options, byte[] /*(char char*)*/ settings); #endif /// /// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance @@ -1272,7 +1269,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// /// Append an execution provider instance to the native OrtSessionOptions instance. /// - /// 'SNPE' and 'XNNPACK' are currently supported as providerName values. + /// 'SNPE', 'XNNPACK' and 'CoreML' are currently supported as providerName values. /// /// The number of providerOptionsKeys must match the number of providerOptionsValues and equal numKeys. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 3acd84b3016de..bd450451a1265 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -146,27 +146,6 @@ public static SessionOptions MakeSessionOptionWithTensorrtProvider(OrtTensorRTPr } } - /// - /// A helper method to construct a SessionOptions object for TVM execution. - /// Use only if you have the onnxruntime package specific to this Execution Provider. - /// - /// settings string, comprises of comma separated key:value pairs. default is empty - /// A SessionsOptions() object configured for execution with TVM - public static SessionOptions MakeSessionOptionWithTvmProvider(String settings = "") - { - SessionOptions options = new SessionOptions(); - try - { - options.AppendExecutionProvider_Tvm(settings); - return options; - } - catch (Exception) - { - options.Dispose(); - throw; - } - } - /// /// A helper method to construct a SessionOptions object for ROCM execution. /// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider. @@ -397,20 +376,6 @@ public void AppendExecutionProvider_CoreML(CoreMLFlags coremlFlags = CoreMLFlags #endif } - /// - /// Use only if you have the onnxruntime package specific to this Execution Provider. - /// - /// string with TVM specific settings - public void AppendExecutionProvider_Tvm(string settings = "") - { -#if __MOBILE__ - throw new NotSupportedException("The TVM Execution Provider is not supported in this build"); -#else - var utf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(settings); - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tvm(handle, utf8)); -#endif - } - private class ExecutionProviderAppender { private byte[] _utf8ProviderName; @@ -430,16 +395,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt /// /// Append QNN, SNPE or XNNPACK execution provider /// - /// Execution provider to add. 'QNN', 'SNPE' or 'XNNPACK' are currently supported. + /// Execution provider to add. 'QNN', 'SNPE' 'XNNPACK', 'CoreML and 'AZURE are currently supported. /// Optional key/value pairs to specify execution provider options. public void AppendExecutionProvider(string providerName, Dictionary providerOptions = null) { - if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE") - { - throw new NotSupportedException( - "Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method."); - } - if (providerOptions == null) { providerOptions = new Dictionary(); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index aa0e6ee62248a..17738da515134 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -146,10 +146,6 @@ public void TestSessionOptions() opt.AppendExecutionProvider_Nnapi(0); #endif -#if USE_TVM - opt.AppendExecutionProvider_Tvm("Vulkan -device=amd_apu"); -#endif - #if USE_OPENVINO opt.AppendExecutionProvider_OpenVINO(); #endif @@ -179,6 +175,12 @@ public void TestSessionOptions() ex = Assert.Throws(() => { opt.AppendExecutionProvider("QNN"); }); Assert.Contains("QNN execution provider is not supported in this build", ex.Message); #endif +#if USE_COREML + opt.AppendExecutionProvider("CoreML"); +#else + ex = Assert.Throws(() => { opt.AppendExecutionProvider("CoreML"); }); + Assert.Contains("CoreML execution provider is not supported in this build", ex.Message); +#endif opt.AppendExecutionProvider_CPU(1); } @@ -2041,7 +2043,7 @@ public SkipNonPackageTests() } // Test hangs on mobile. -#if !(ANDROID || IOS) +#if !(ANDROID || IOS) [Fact(DisplayName = "TestModelRunAsyncTask")] private async Task TestModelRunAsyncTask() { diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj index 60d18ad31e811..07ca7fe7c64bf 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj @@ -1,16 +1,19 @@  + + true + true + true + + $(ProjectDir)..\..\.. + netstandard2.0;net8.0 false - $(ProjectDir)..\.. AnyCPU bin\$(Configuration)\ - true - true - true - $(OnnxRuntimeCsharpRoot)\..\cmake\external\onnx + $(OnnxRuntimeRoot)\cmake\external\onnx 8981 @@ -22,30 +25,22 @@ ..\..\OnnxRuntime.snk Debug;Release;RelWithDebInfo + Microsoft.ML.OnnxRuntime.Tests Microsoft.ML.OnnxRuntime.Tests.Common - - - $(OnnxRuntimeCsharpRoot)\..\build\Linux - $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake - $(ProtocDirectory)\protoc - - - - $(OnnxRuntimeCsharpRoot)\..\build\Windows - $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake\$(Configuration) $(ProtocDirectory)\protoc.exe + + $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake + $(ProtocDirectory)\protoc + + - - $(OnnxRuntimeCsharpRoot)\..\build\MacOS - $(OnnxRuntimeBuildDirectory)\$(Configuration) $(OnnxRuntimeBuildDirectory)\$(Configuration)\external\protobuf\cmake $(ProtocDirectory)\protoc @@ -102,28 +97,6 @@ - - - - PreserveNewest - false - - - - PreserveNewest - false - - - - PreserveNewest - false - - - @@ -132,16 +105,20 @@ - + - + + - + + @@ -152,20 +129,20 @@ + - TestData\%(Filename)%(Extension) + TestData\%(Filename)%(Extension) - - TestData\overridable_initializer.onnx + + TestData\overridable_initializer.onnx - - TestData\capi_symbolic_dims.onnx + + TestData\capi_symbolic_dims.onnx - diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/NativeLibraryInclude.props b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/NativeLibraryInclude.props new file mode 100644 index 0000000000000..3daab21dbcbac --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/NativeLibraryInclude.props @@ -0,0 +1,171 @@ + + + + + true + true + true + + + true + true + true + true + + + false + 1.20.0-dev-20241007 + + + + + + + + + + + + + + + + + + + + + + $(OnnxRuntimeRoot)\build\Windows + $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\Linux + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\MacOS + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\Android + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + $(OnnxRuntimeRoot)\build\iOS + iPhoneSimulator + $(Platform.ToLower()) + $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration)-$(PlatformLower) + + + + $(OnnxRuntimeRoot)\build\macOS + $(OnnxRuntimeBuildDirectory)\$(Configuration) + + + + + PreserveNewest + true + + + + + + PreserveNewest + false + + + + + + PreserveNewest + false + + + + + + libs\libonnxruntime.so + + + + + + libs\libonnxruntime.dylib + Dynamic + True + True + + + + + + libs\libonnxruntime.dylib + Dynamic + True + True + + + + + + + + + false + true + false + true + false + true + + + + + + + + + + + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs index 27cde1dbe9ed8..46dd292e8514e 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTests.cs @@ -2180,10 +2180,13 @@ public void GetArrayString(TensorConstructor constructor) {22,23} } }"; + // remove \r so the newlines are just \n on all platforms + expected = expected.Replace("\r", ""); + var actual= tensor.GetArrayString().Replace("\r", ""); - Assert.Equal(expected, tensor.GetArrayString()); + Assert.Equal(expected, actual); - var expectedNoSpace = expected.Replace(Environment.NewLine, "").Replace(" ", ""); + var expectedNoSpace = expected.Replace("\n", "").Replace(" ", ""); Assert.Equal(expectedNoSpace, tensor.GetArrayString(false)); } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj index 210a04d78f107..e07448daeea7f 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj @@ -1,306 +1,125 @@  - - - true - true - true - true - $(ProjectDir)..\..\.. - - - - - net8.0-android;net8.0-ios;net8.0-maccatalyst - $(TargetFrameworks);net8.0-windows10.0.19041.0 - - - - - Exe - Microsoft.ML.OnnxRuntime.Tests.MAUI - true - true - enable - enable - true - - 8002 - - - $(DefineConstants);INCLUDE_FAILING_TESTS - $(DefineConstants);MODE_NON_INTERACTIVE_VISUAL - $(DefineConstants);MODE_XHARNESS - - - Microsoft.ML.OnnxRuntime.Tests.MAUI - - - ORT.CSharp.Tests.MAUI - - - 1.0 - 1 - - 15.0 - 13.1 - 30.0 - 10.0.17763.0 - 10.0.17763.0 - - true - ..\..\OnnxRuntime.snk - - - false - - - - - $(OnnxRuntimeRoot)\build\microsoft.ml.onnxruntime.1.18.1\runtimes - - true - - - - $(OnnxRuntimeRoot)\build\Windows - $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) - - $(PrebuiltRuntimesDir)\win-x64\native - - - $(OnnxRuntimeRoot)\build\Android - $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(PrebuiltRuntimesDir)\android\native\onnxruntime.aar - - - $(OnnxRuntimeRoot)\build\iOS - iPhoneSimulator - $(Platform.ToLower()) - $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration)-$(PlatformLower) - $(PrebuiltRuntimesDir)\ios\native\onnxruntime.xcframework - - - $(OnnxRuntimeRoot)\build\macOS - $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(PrebuiltRuntimesDir)\ios\native\onnxruntime.xcframework - - - - - - PreserveNewest - true - - - - - PreserveNewest - true - - - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - PreserveNewest - false - - - - - - - libs\libonnxruntime.so - - - - - - - - - - libs\libonnxruntime.dylib - Dynamic - True - True - - - - - Framework - True - True - - - - - - - libs\libonnxruntime.dylib - Dynamic - True - True - - - - - Framework - True - True - - - - - - - false - true - false - true - false - true - - false - true - false - true - false - true - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - InferenceTest.cs - - - OrtIoBindingAllocationTest.cs - - - TensorTests.cs - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - <_VisualStudioTestRunnerFiles Include="@(PackagingOutputs)" Condition="$([System.String]::Copy('%(PackagingOutputs.FullPath)').Contains('xunit.runner.visualstudio'))" /> - - - + + $(ProjectDir)..\..\.. + + + + + + + net8.0-android;net8.0-ios;net8.0-maccatalyst + $(TargetFrameworks);net8.0-windows10.0.19041.0 + + + + + Exe + Microsoft.ML.OnnxRuntime.Tests.MAUI + true + true + enable + enable + true + + 8002 + + + $(DefineConstants);INCLUDE_FAILING_TESTS + $(DefineConstants);MODE_NON_INTERACTIVE_VISUAL + $(DefineConstants);MODE_XHARNESS + + + Microsoft.ML.OnnxRuntime.Tests.MAUI + + + ORT.CSharp.Tests.MAUI + + + 1.0 + 1 + + 15.0 + 13.1 + 30.0 + 10.0.17763.0 + 10.0.17763.0 + + true + ..\..\OnnxRuntime.snk + + + + + + + + + + + + + + + + + + + + + + + + InferenceTest.cs + + + OrtIoBindingAllocationTest.cs + + + TensorTests.cs + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <_VisualStudioTestRunnerFiles + Include="@(PackagingOutputs)" + Condition="$([System.String]::Copy('%(PackagingOutputs.FullPath)').Contains('xunit.runner.visualstudio'))" /> + + + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/ReadMe.md b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/ReadMe.md new file mode 100644 index 0000000000000..07cb5fe7c9b3d --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.MAUI/ReadMe.md @@ -0,0 +1,9 @@ +The MAUI test project can be optionally used with a pre-built ONNX Runtime native nuget package (Microsoft.ML.OnnxRuntime). + +To do so, specify the `UsePrebuiltNativePackage` and `CurrentOnnxRuntimeVersion` properties when building the project. These can be set via the command-line or as environment variables. + +For example: + +```cmd +dotnet build csharp\test\Microsoft.ML.OnnxRuntime.Tests.MAUI\Microsoft.ML.OnnxRuntime.Tests.MAUI.csproj --property:UsePrebuiltNativePackage=true --property:CurrentOnnxRuntimeVersion=1.19.2 --source directory_containing_native_nuget_package --source https://api.nuget.org/v3/index.json +``` diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index b822c999e4d39..a8abcd2b4aa1c 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -1,4 +1,9 @@  + + $(ProjectDir)..\..\.. + + + net8.0 @@ -6,9 +11,7 @@ $(ProjectDir)..\.. AnyCPU;x86 bin\$(Configuration)\ - true - true - true + $(OnnxSourceDirectory)\onnx default @@ -35,19 +38,19 @@ - $(OnnxRuntimeCsharpRoot)\..\build\Linux + $(OnnxRuntimeRoot)\build\Linux $(OnnxRuntimeBuildDirectory)\$(Configuration) - $(OnnxRuntimeCsharpRoot)\..\build\Windows + $(OnnxRuntimeRoot)\build\Windows $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) - $(OnnxRuntimeCsharpRoot)\..\build\MacOS + $(OnnxRuntimeRoot)\build\MacOS $(OnnxRuntimeBuildDirectory)\$(Configuration) @@ -58,15 +61,14 @@ PreserveNewest @@ -74,45 +76,39 @@ PreserveNewest false PreserveNewest false - - PreserveNewest - false - - + PreserveNewest false - - PreserveNewest - false - - + + PreserveNewest false - + + PreserveNewest false - + + PreserveNewest false + @@ -131,7 +127,7 @@ - + PreserveNewest false diff --git a/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist b/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist index 0004a4fdee5d5..fbb865624bbda 100644 --- a/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist +++ b/csharp/tools/MauiModelTester/Platforms/iOS/Info.plist @@ -27,6 +27,6 @@ UIInterfaceOrientationLandscapeRight XSAppIconAssets - Assets.xcassets/appicon.appiconset + Assets.xcassets/onnxruntime_icon.appiconset diff --git a/dockerfiles/Dockerfile.cuda b/dockerfiles/Dockerfile.cuda index d2d656648f2e7..40f11dca623a7 100644 --- a/dockerfiles/Dockerfile.cuda +++ b/dockerfiles/Dockerfile.cuda @@ -48,7 +48,7 @@ RUN cd /code \ && python3 -m venv /code/env \ && . /code/env/bin/activate \ && pip install --upgrade psutil setuptools wheel packaging \ - && pip install -r tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt \ + && pip install -r /code/tools/ci_build/github/linux/python/requirements.txt \ && python /code/tools/ci_build/build.py --build_dir /code/build/Linux \ --allow_running_as_root --skip_submodule_sync \ --use_cuda --cuda_home /usr/local/cuda \ @@ -56,7 +56,6 @@ RUN cd /code \ --build_shared_lib --skip_tests \ --config Release --build_wheel --update --build --parallel \ --cmake_generator Ninja \ - --enable_cuda_nhwc_ops \ --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) "CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}" onnxruntime_BUILD_UNIT_TESTS=OFF # Start second stage to copy the build artifacts diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index c5d998d503899..876a07e4ffaf6 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -10,7 +10,7 @@ FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} RUN apt-get update &&\ apt-get install -y migraphx diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 39e75a68a369f..d1ebdae3cbdd6 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -11,7 +11,7 @@ FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} AS builder ENV WORKDIR_PATH=/home/openvino WORKDIR $WORKDIR_PATH -ENV DEBIAN_FRONTEND noninteractive +ENV DEBIAN_FRONTEND=noninteractive ARG DEVICE=CPU ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime.git @@ -41,7 +41,7 @@ RUN tar cvf GPL_sources.tar.gz /sources # Deploy stage FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} -ENV DEBIAN_FRONTEND noninteractive +ENV DEBIAN_FRONTEND=noninteractive USER root COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/dist/*.whl ./ COPY --from=builder /GPL_sources.tar.gz ./ @@ -50,7 +50,7 @@ ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER RUN usermod -a -G video,users ${BUILD_USER} -ENV WORKDIR_PATH /home/${BUILD_USER} +ENV WORKDIR_PATH=/home/${BUILD_USER} WORKDIR ${WORKDIR_PATH} USER ${BUILD_USER} diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index bef8d7a5f47d2..aca8c3feaff71 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -12,7 +12,7 @@ ARG ONNXRUNTIME_BRANCH=main WORKDIR /code -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/Dockerfile.tensorrt b/dockerfiles/Dockerfile.tensorrt index ef51d41c5ff1b..24947df6308a6 100644 --- a/dockerfiles/Dockerfile.tensorrt +++ b/dockerfiles/Dockerfile.tensorrt @@ -17,7 +17,7 @@ RUN apt-get update &&\ RUN unattended-upgrade WORKDIR /code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime with TensorRT RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/Dockerfile.vitisai b/dockerfiles/Dockerfile.vitisai index e11ab70a61332..c6226155e01e3 100644 --- a/dockerfiles/Dockerfile.vitisai +++ b/dockerfiles/Dockerfile.vitisai @@ -22,8 +22,8 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:$PATH -ENV LD_LIBRARY_PATH /opt/xilinx/xrt/lib:$LD_LIBRARY_PATH +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/xilinx/xrt/lib:$LD_LIBRARY_PATH WORKDIR /code RUN . $VAI_ROOT/conda/etc/profile.d/conda.sh &&\ diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b87532debe4bc..6ea3f93cdea12 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
(Optional) Hardware architecture.
main_context : int
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
+
max_size : int
+
max size in the context. Usage depend on the EP.
notes : string
(Optional) Some notes for the model
onnx_model_filename : string
diff --git a/docs/How_To_Update_ONNX_Dev_Notes.md b/docs/How_To_Update_ONNX_Dev_Notes.md index 4d8a286bde66e..199e6671f6a1a 100644 --- a/docs/How_To_Update_ONNX_Dev_Notes.md +++ b/docs/How_To_Update_ONNX_Dev_Notes.md @@ -21,7 +21,7 @@ This file should be generated. See [cgmanifests/README](/cgmanifests/README.md) - [onnxruntime/test/python/requirements.txt](/onnxruntime/test/python/requirements.txt) - [tools/ci_build/github/linux/docker/scripts/requirements.txt](/tools/ci_build/github/linux/docker/scripts/requirements.txt) - [tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt](/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt) -- [tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt](/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/requirements.txt) +- [tools/ci_build/github/linux/python/requirements.txt](/tools/ci_build/github/linux/python/requirements.txt) - Run `git grep -rn "onnx==1" .` to find other locations and update this document if necessary. 1. If there is any change to `cmake/external/onnx/onnx/*.in.proto`, you need to regenerate OnnxMl.cs. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bd886abc98a89..eeb8ebb3ccefe 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -453,6 +453,7 @@ Do not modify directly.* |SVMClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |SVMRegressor|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(float)| |Scaler|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|TreeEnsemble|*in* X:**T**
*out* Y:**T**|5+|**T** = tensor(double), tensor(float)| |TreeEnsembleClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |TreeEnsembleRegressor|*in* X:**T**
*out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)| @@ -554,8 +555,12 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||12|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||12|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| @@ -921,6 +926,35 @@ Do not modify directly.* |WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | +|**Operator Domain:** *com.ms.internal.nhwc*|||| +|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|||10|**T** = tensor(float), tensor(float16)| +|||[7, 9]|**T** = tensor(float), tensor(float16)| +|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(double), tensor(float), tensor(float16)| +|||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| +|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(float), tensor(float16)| +|ConvTranspose|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(float), tensor(float16)| +|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|GlobalAveragePool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|GlobalMaxPool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|MaxPool|*in* X:**T**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**
*out* Indices:**I**|12+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)| +|||11|**I** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|||10|**I** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|||[8, 9]|**I** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|||[1, 7]|**T** = tensor(float), tensor(float16)| +|SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +| | +| | @@ -1053,11 +1087,13 @@ Do not modify directly.* |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|GroupNorm||21+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|21+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1156,7 +1192,8 @@ Do not modify directly.* |||12+|**T** = tensor(float), tensor(float16), tensor(int32)
**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| -|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|||10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)| |||19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| diff --git a/docs/TVM_EP.md b/docs/TVM_EP.md deleted file mode 100644 index df59d5c05855c..0000000000000 --- a/docs/TVM_EP.md +++ /dev/null @@ -1,319 +0,0 @@ -# TVM Execution Provider - -## Contents - -- [Introduction](#introduction) -- [Build](#build-onnx-runtime-with-the-tvm-execution-provider) - - [Linux](#linux) - - [Windows](#windows) -- [Configuration options](#configuration-options) -- [Performance Tuning](#performance-tuning) - - [Using precompiled model](#using-precompiled-model) -- [Samples](#samples) -- [Known issues](#known-issues) - - -## Introduction - -TVM is an execution provider for ONNX Runtime that is built on top of Apache TVM. It enables ONNX Runtime users to leverage Apache TVM model optimizations. -TVM EP is currently in "Preview". It's been tested to work on a handful of models on Linux or Windows, but not on MacOS. - -## Build ONNX Runtime with the TVM Execution Provider - -### **Linux** -Install the minimal pre-requisites on Ubuntu/Debian like linux operating systems: -```bash -apt-get install -y python3 python3-dev python3-pip python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev llvm-12 -pip3 install numpy decorator attrs nasm -``` -Note: since ONNX Runtime with TVM EP is built with Intel ipp-crypto library there are new requirements. Compiler gcc (and g++) version should be equal to or higher than 8.2. nasm version should be 2.14.02 or higher. Problem with small nasm version can be seen [here](https://github.com/intel/ipp-crypto/issues/9) or [here](https://bugzilla.nasm.us/show_bug.cgi?id=3392205). For ubuntu LTS 18 `apt-get install nasm` is not enough due to it has version 2.13.02, see how to install from sources instruction [here](https://stackoverflow.com/questions/36144930/steps-to-install-nasm-offline-on-ubuntu). - -Also, the current implementation has `NVidia GPU` support for TVM EP. For now, you can use only `NVidia GPU` with CUDA Toolkit support. -To do this, make sure you have installed the NVidia driver and CUDA Toolkit. -More detailed instructions can be found on the [official page](https://developer.nvidia.com/cuda-toolkit). - -Clone this repo. -In order to build ONNXRT you will need to have CMake 3.18 or higher. In Ubuntu 20.04 you can use the following commands to install the latest version of CMake: - -```bash -sudo apt-get update -sudo apt-get install gpg wget - -wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | sudo tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null - -echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ focal main' | sudo tee /etc/apt/sources.list.d/kitware.list >/dev/null -sudo apt-get update - -sudo rm /usr/share/keyrings/kitware-archive-keyring.gpg -sudo apt-get install kitware-archive-keyring - -sudo apt-get install cmake -``` - -Build ONNX Runtime (TVM x86): -```bash -./build.sh --config Release --enable_pybind --build_wheel --parallel --skip_tests --skip_onnx_tests --use_tvm -``` - -Build ONNX Runtime (TVM with CUDA support): -```bash -./build.sh --config Release --enable_pybind --build_wheel --parallel --skip_tests --skip_onnx_tests --use_tvm --tvm_cuda_runtime -``` - -This command builds both `TVM` and `onnxruntime-tvm`. It creates two wheel, one for each project. -Build the python API for ONNX Runtime instead of using the standard package. Instructions for this are given below. - -Package for TVM: -```bash -cd -python3 -m pip uninstall tvm -y -whl_path=$(find ./build//Release/_deps/tvm-src/python/dist -name "*.whl") -python3 -m pip install $whl_path -``` - -Package for TVM EP: -```bash -cd -python3 -m pip uninstall onnxruntime onnxruntime-tvm -y -whl_path=$(find ./build//Release/dist -name "*.whl") -python3 -m pip install $whl_path -``` - -Alternatively, you can set `PYTHONPATH` to tell python where to find the ONNXRT library and the TVM library. -```bash -export PYTHONPATH=/build//Release:${PYTHONPATH} -export PYTHONPATH=/build//Release/_deps/tvm-src/python:${PYTHONPATH} -``` - -### **Windows** -Install the minimal prerequisites on Windows: Git, CMake, Visual Studio, Python, LLVM -- Git: Download Git for Windows from [here](https://git-scm.com/download/win) and install it. Please make sure that the git.exe path is included in the environment variable. By default, it should be added. To check git after the installation use `git --version` in command line (cmd). -- CMake: use [the link](https://cmake.org/download/) to download and install CMake. msi-file is recommended for it. To verify CMake installation use `cmake --version` in cmd. -- Visual Studio: Download from [here](https://visualstudio.microsoft.com/ru/downloads/) and install Visual Studio 20** Community & Visual Studio Build Tools respectively. It is recommended not to change the default installation path. Chose "Desktop development with C++" workload and make sure that both options of “MSVC [contemporary version] C++ build tools” and “Windows 10 SDK” are selected. -- Python: Download Python 3.* from [here](https://www.python.org/downloads/windows/) and install it. Please have a check on the option of “Add Python to PATH”, so the installer will include the Python directory into the environment variable directly. To check python after the installation use `python` from cmd. The expected output is similar to the following: -```cmd -Python 3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)] on win32 -Type "help", "copyright", "credits" or "license" for more information. ->>> -``` -Use `quit()` to exit from python interface. -- LLVM: the compiler is not necessary for pure ONNX Runtime installation but it is needed for TVM EP by default. -```cmd -git clone --depth 1 --branch release/11.x https://github.com/llvm/llvm-project.git -cmake -S llvm -B build -DLLVM_ENABLE_PROJECTS="clang;libcxx;libcxxabi" -DLLVM_TARGETS_TO_BUILD=X86 -Thost=x64 -DCMAKE_BUILD_TYPE=Release -G "Visual Studio 17 2022" -cmake --build ./build --config Release -``` -- Dependencies of ipp-crypto:
-1. install asm compiler (nasm) on windows by line: -```cmd -winget install nasm -i -``` -          -Add it to PATH (instruction for Windows GUI can be seen [here](https://www.computerhope.com/issues/ch000549.htm#dospath)) or by cmd: -```cmd -set PATH="%PATH%;C:\Program Files\NASM" -``` -          -or -```cmd -setx PATH "%PATH%;C:\Program Files\NASM" -``` -          -Check by `nasm --version` in prompt command line.
-       -2. install openssl on windows by msi-file from [here](https://slproweb.com/products/Win32OpenSSL.html) -Add path to directory (e.g. "C:\Program Files\OpenSSL-Win64\bin") with executable file to PATH (see instructions above).
-          -Check by `openssl version` in prompt command line.
-       -3. Correct build of ipp-crytpo requires specific environment variables for supported MSVC compiler. Long way to adjust the environment is to follow to instructions [here](https://docs.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=msvc-170&viewFallbackFrom=vs-2017). Quick way is to use VS Developer command prompt where the environment have been already adjusted or add some paths to standard Windows command prompt: -```cmd -set INCLUDE=C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.32.31326\include;C:\Program Files (x86)\Windows Kits\10\include\10.0.22621.0\ucrt -``` -          -Take into account that MSVC and Kit versions are specific for Visual Studio built on the machine, specified values here are used as example. -
-
- -For using NVIDIA GPU (optional) CUDA and cuDNN should be installed. -- CUDA: Install CUDA by the [link](https://developer.nvidia.com/cuda-11.0-download-archive). -- cuDNN: download cuDNN installer from [here](https://developer.nvidia.com/rdp/cudnn-archive). Choose v8.* for corresponding CUDA v11.*, unzip it, and move cuDNN files as following: -1. [unzipped dir]\bin\ → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\bin -2. [unzipped dir]\include\ → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\include -3. [unzipped dir]\lib\ → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\lib - -To verify the CUDA installation use `nvcc --version` in cmd. -
-
- -#### **Build ONNX Runtime with TVM Execution Provider from source (Python):** -- Use command line and clone sources from github: -```cmd -git clone --recursive https://github.com/Microsoft/onnxruntime -cd onnxruntime -``` -- CPU build: -``` -build.bat --config Release --enable_pybind --build_wheel --skip_tests --parallel --use_tvm --skip_onnx_tests --cmake_generator "Visual Studio 17 2022" --llvm_config /build/Release/bin/llvm-config.exe -``` -- GPU build: -``` -build.bat --config Release --enable_pybind --build_wheel --skip_tests --parallel --use_tvm --skip_onnx_tests --cmake_generator "Visual Studio 17 2022" --llvm_config /build/Release/bin/llvm-config.exe --use_cuda --cudnn_home “C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.*” --cuda_home “C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.*” -``` -In both cases (CPU, GPU) there are the following options for cmake generator: "Visual Studio 17 2022" and "Ninja". Also handshake mechanism can be switched on by `--use_tvm_hash` flag. At the latter case ipp-crypto library is built with dependencies, see details above. -- Install python wheel package for ONNX Runtime:
-Default path to the package is `/build/Windows/Release/Release/dist`. Note that it is different in comparison with path to the package on Linux. Before installation check names of wheel packages and use corresponding one. It can be looked like the following: -```cmd -python -m pip install .\onnxruntime\build\Windows\Release\Release\dist\onnxruntime_tvm-1.6.0-cp38-cp38-win_amd64.whl -``` -- Install python wheel package for TVM due to its python API is used inside TVM EP:
-It can be looked like the following: -```cmd -python -m pip install .\onnxruntime\build\Windows\Release\_deps\tvm-src\python\dist\tvm-0.9.dev1728+g3425ed846-cp39-cp39-win_amd64.whl -``` -- Verify result by python script. Note: python should not be launched from directory containing 'onnxruntime' directory for correct result: -```python -import onnxruntime -print(onnxruntime.__version__) -print(onnxruntime.get_device()) -print(onnxruntime.get_available_providers()) -``` -- Uninstall procedure: -```cmd -pip uninstall onnxruntime-tvm -``` - -#### **Build ONNX Runtime with TVM Execution Provider from source (C#):** -- Use command line and clone sources from github: -```cmd -git clone --recursive https://github.com/Microsoft/onnxruntime -cd onnxruntime -``` -- CPU build: - -Make sure you download [nuget.exe](https://docs.microsoft.com/en-us/nuget/install-nuget-client-tools#nugetexe-cli) and add path to it into `PATH` env. -``` -build.bat --config Release --build_nuget --skip_tests --parallel --use_tvm --skip_onnx_tests --cmake_generator "Visual Studio 17 2022" --llvm_config llvm-config.exe -``` -- Install C# nuget package for TVM EP. Default path to the package is `\build\Windows\Release\Release`. - - -## Configuration options -TVM Executor Provider can be configured with the following provider options: -1. Python -```python -po = [dict(executor=tvm_executor_type, - so_folder=folder_with_pretuned_files, - check_hash=check_hash, - hash_file_path=hash_file_path, - target=client_target, - target_host=client_target_host, - opt_level=client_opt_level, - freeze_weights=freeze, - to_nhwc=layout_transform, - tuning_type=tvm_optimizer_type, - tuning_file_path=client_tuning_logfile, - input_names = input_names_str, - input_shapes = input_shapes_str)] -tvm_session = onnxruntime.InferenceSession(model_path, providers=["TvmExecutionProvider"], provider_options=po) -``` - -2. C# - -Currently, only precompiled models are supported in C# (see the related section below). - -```CSharp -SessionOptions session_options = new SessionOptions{}; -string tvm_ep_options = - $"executor: {tvm_executor_type}, " + - $"so_folder: {folder_with_pretuned_files}, " + - $"check_hash: {check_hash}, " + - $"hash_file_path: {hash_file_path}, " + - $"target: {client_target}, " + - $"target_host: {client_target_host}, " + - $"opt_level: {client_opt_level}, " + - $"freeze_weights: {freeze}, " + - $"to_nhwc: {layout_transform}, " + - $"tuning_type: {tvm_optimizer_type}, " + - $"tuning_file_path: {client_tuning_logfile}, " + - $"input_names: {input_names_str}, " + - $"input_shapes: {input_shapes_str}"; - -session_options.AppendExecutionProvider_Tvm(tvm_ep_options); -using var tvm_session = new InferenceSession(modelFilePath, session_options); -``` -
- -- `executor` is executor type used by TVM. There is choice between two types: GraphExecutor and VirtualMachine which are corresponded to "graph" and "vm" tags. VirtualMachine is used by default. -- `so_folder` is path to folder with set of files (.ro-, .so/.dll-files and weights) obtained after model tuning. It uses these files for executor compilation instead of onnx-model. But the latter is still needed for ONNX Runtime. -- `check_hash` means that it is necessary to perform a HASH check for the model obtained in the `so_folder` parameter. It is `False` by default. -- `hash_file_path` is path to file that contains the pre-computed HASH for the ONNX model which result of tuning locates in the path passed by `so_folder` parameter. - If an empty string was passed as this value, then the file will be searched in the folder that was passed in the `so_folder` parameter. -- `target` and `target_host` are strings like in TVM (e.g. "llvm --mcpu=avx2"). When using accelerators, target may be something like `cuda` while target_host may be `llvm -mtriple=x86_64-linux-gnu` -- `opt_level` is TVM optimization level. It is 3 by default -- `freeze_weights` means that all model weights are kept on compilation stage otherwise they are downloaded each inference. True is recommended value for the best performance. It is true by default. -- `to_nhwc` switches on special model transformations, particularly data layout, which Octomizer is used. It allows to work correctly with tuning logs obtained from Octomizer. It is false by default. -- `tuning_type` defines the type of TVM tuning logs being used, and can be set to either `AutoTVM` (1st gen auto tuning logs) or `Ansor` (2nd gen auto tuning logs). By default this option is set to `AutoTVM`. -- `tuning_file_path` is path to AutoTVM or Ansor tuning file which gives specifications for given model and target for the best performance. (See below for more details). - -TVM supports models with fixed graph only. If your model has unknown dimensions in input shapes (excluding batch size) you must provide the shape using the `input_names` and `input_shapes` provider options. Below is an example of what must be passed to `provider_options`: -```python -input_names = "input_1 input_2" -input_shapes = "[1 3 224 224] [1 2]" -``` - -## Performance Tuning -TVM optimizes machine learning models through an automated tuning process that produces model variants specific to targeted hardware architectures. This process also generates 'tuning logs' that the TVM EP relies on to maximize model performance. These logs can be acquired for your model by either using TVM as described here: - -AutoTVM: -https://tvm.apache.org/docs/how_to/tune_with_autotvm/index.html - -Ansor (Autoscheduling): -https://tvm.apache.org/docs/how_to/tune_with_autoscheduler/index.html - -or by using logs generated through the OctoML platform (https://onnx.octoml.ai) using instructions [here](https://help.octoml.ai/en/articles/5814452-using-octoml-platform-logs-with-onnx-rt-tvm-ep) - -Using the TVM EP with TVM tuning logs also requires users to turn off ONNX Runtime preprocessing. To do this, the following `SessionOptions()` can be used: -``` -so = onnxruntime.SessionOptions() -so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - -tvm_session = onnxruntime.InferenceSession(model_path, sess_options=so, providers=["TvmExecutionProvider"], provider_options=po) -``` - -### **Using precompiled model** -It is also possible to use a precompiled model. - -The compiled model can be obtained using the [OctoML platform](https://onnx.octoml.ai) -or compiled directly (see **Support precompiled model** section in -[Sample notebook for ResNet50 inference with TVM EP](https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb) -for more information on model compilation). - -In order to use the precompiled model, only need to pass two options: -* **executor** - `vm` (`VirtualMachine`) must be used as a value -(this functionality is not supported for `GraphExecutor`); -* **so_folder** - as a value, you must pass the path to the directory where -the files of the precompiled model are located. -* **check_hash** - (optional) if you want to check hash, you must pass `True` as the value. -* **hash_file_path** - (optional) by default, the file containing the hash for the tuned model will be searched in the directory that is passed in the `so_folder` parameter. - If you want to specify different location, then you must pass the path to the file that contains the desired hash as a value. - -You can read more about these options in section [Configuration options](#configuration-options) above. - - -## Samples -- [Sample notebook for ResNet50 inference with TVM EP](https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb) - -## Known issues -- At this moment, the TVM EP has only been verified on UNIX/Linux and Windows systems. -- Some compatibility issues have been found between ONNX and Google protobuf. `AttributeError: module 'google.protobuf.internal.containers' has no attribute 'MutableMapping'`. This usually occurss during `import onnx` in any python scripts for protobuf version >= 3.19.0 and ONNX version <= 1.8.1. To resolve the issue Google protobuf and ONNX can be reinstalled separately or together using: -``` -pip3 uninstall onnx -y -pip3 install onnx==1.10.1 -pip3 uninstall protobuf -y -pip3 install protobuf==3.19.1 -``` - -The following pair of ONNX and protobuf versions have been found to be compatible: -- 3.17.3 and 1.8.0 -- 3.19.1 and 1.10.1 diff --git a/docs/python/api_summary.rst b/docs/python/api_summary.rst index 092b42010a5c6..fb2850c547463 100644 --- a/docs/python/api_summary.rst +++ b/docs/python/api_summary.rst @@ -244,9 +244,36 @@ You can also bind inputs and outputs directly to a PyTorch tensor. ) session.run_with_iobinding(binding) - + You can also see code examples of this API in in the `ONNX Runtime inferences examples `_. +Some onnx data type (like TensorProto.BFLOAT16, TensorProto.FLOAT8E4M3FN and TensorProto.FLOAT8E5M2) are not supported by Numpy. You can directly bind input or output with Torch tensor of corresponding data type +(like torch.bfloat16, torch.float8_e4m3fn and torch.float8_e5m2) in GPU memory. + +.. code-block:: python + + x = torch.ones([3], dtype=torch.float8_e5m2, device='cuda:0') + y = torch.empty([3], dtype=torch.bfloat16, device='cuda:0') + + binding = session.io_binding() + binding.bind_input( + name='X', + device_type='cuda', + device_id=0, + element_type=TensorProto.FLOAT8E5M2, + shape=tuple(x.shape), + buffer_ptr=x.data_ptr(), + ) + binding.bind_output( + name='Y', + device_type='cuda', + device_id=0, + element_type=TensorProto.BFLOAT16, + shape=tuple(y.shape), + buffer_ptr=y.data_ptr(), + ) + session.run_with_iobinding(binding) + API Details =========== diff --git a/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb b/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb deleted file mode 100644 index 830495bdfb98d..0000000000000 --- a/docs/python/notebooks/onnxruntime-tvm-tutorial.ipynb +++ /dev/null @@ -1,657 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "72476497", - "metadata": {}, - "source": [ - "# ONNX Runtime: Tutorial for TVM execution provider\n", - "\n", - "This notebook shows a simple example for model inference with TVM EP.\n", - "\n", - "\n", - "#### Tutorial Roadmap:\n", - "1. Prerequistes\n", - "2. Accuracy check for TVM EP\n", - "3. Configuration options\n", - "4. Support precompiled model" - ] - }, - { - "cell_type": "markdown", - "id": "9345cbab", - "metadata": {}, - "source": [ - "## 1. Prerequistes\n", - "\n", - "Make sure that you have installed all the necessary dependencies described in the corresponding paragraph of the documentation.\n", - "\n", - "Also, make sure you have the `tvm` and `onnxruntime-tvm` packages in your pip environment. \n", - "\n", - "If you are using `PYTHONPATH` variable expansion, make sure it contains the following paths: `/onnxruntime/cmake/external/tvm_update/python` and `/onnxruntime/build/Linux/Release`." - ] - }, - { - "cell_type": "markdown", - "id": "da4ca21f", - "metadata": {}, - "source": [ - "### Common import\n", - "\n", - "These packages can be delivered from standard `pip`." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0f072875", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import onnx\n", - "import tempfile\n", - "import numpy as np\n", - "from typing import List, AnyStr\n", - "from onnx import ModelProto, helper, checker, mapping" - ] - }, - { - "cell_type": "markdown", - "id": "118670aa", - "metadata": {}, - "source": [ - "### Specialized import\n", - "\n", - "It is better to collect these packages from source code in order to clearly understand what is available to you right now." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a5502966", - "metadata": {}, - "outputs": [], - "source": [ - "import onnxruntime\n", - "\n", - "import tvm\n", - "import tvm.relay\n", - "import tvm.testing\n", - "import tvm.runtime\n", - "import tvm.runtime.vm\n", - "import tvm.relay.backend.vm\n", - "import tvm.contrib.download" - ] - }, - { - "cell_type": "markdown", - "id": "b7313183", - "metadata": {}, - "source": [ - "### Helper functions for working with ONNX ModelProto\n", - "\n", - "This set of helper functions allows you to recognize the meta information of the models. This information is needed for more versatile processing of ONNX models." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "7d0a36e8", - "metadata": {}, - "outputs": [], - "source": [ - "def get_onnx_input_names(model: ModelProto) -> List[AnyStr]:\n", - " inputs = [node.name for node in model.graph.input]\n", - " initializer = [node.name for node in model.graph.initializer]\n", - " inputs = list(set(inputs) - set(initializer))\n", - " return sorted(inputs)\n", - "\n", - "\n", - "def get_onnx_output_names(model: ModelProto) -> List[AnyStr]:\n", - " return [node.name for node in model.graph.output]\n", - "\n", - "\n", - "def get_onnx_input_types(model: ModelProto) -> List[np.dtype]:\n", - " input_names = get_onnx_input_names(model)\n", - " return [\n", - " mapping.TENSOR_TYPE_TO_NP_TYPE[node.type.tensor_type.elem_type]\n", - " for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names\n", - " ]\n", - "\n", - "\n", - "def get_onnx_input_shapes(model: ModelProto) -> List[List[int]]:\n", - " input_names = get_onnx_input_names(model)\n", - " return [\n", - " [dv.dim_value for dv in node.type.tensor_type.shape.dim]\n", - " for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names\n", - " ]\n", - "\n", - "\n", - "def get_random_model_inputs(model: ModelProto) -> List[np.ndarray]:\n", - " input_shapes = get_onnx_input_shapes(model)\n", - " input_types = get_onnx_input_types(model)\n", - " assert len(input_types) == len(input_shapes)\n", - " inputs = [np.random.uniform(size=shape).astype(dtype) for shape, dtype in zip(input_shapes, input_types)]\n", - " return inputs" - ] - }, - { - "cell_type": "markdown", - "id": "f0de1682", - "metadata": {}, - "source": [ - "### Wrapper helper functions for Inference\n", - "\n", - "Wrapper helper functions for running model inference using ONNX Runtime EP." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "258ce9e9", - "metadata": {}, - "outputs": [], - "source": [ - "def get_onnxruntime_output(model: ModelProto, inputs: List, provider_name: AnyStr) -> np.ndarray:\n", - " output_names = get_onnx_output_names(model)\n", - " input_names = get_onnx_input_names(model)\n", - " assert len(input_names) == len(inputs)\n", - " input_dict = {input_name: input_value for input_name, input_value in zip(input_names, inputs)}\n", - "\n", - " inference_session = onnxruntime.InferenceSession(model.SerializeToString(), providers=[provider_name])\n", - " output = inference_session.run(output_names, input_dict)\n", - "\n", - " # Unpack output if there's only a single value.\n", - " if len(output) == 1:\n", - " output = output[0]\n", - " return output\n", - "\n", - "\n", - "def get_cpu_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:\n", - " return get_onnxruntime_output(model, inputs, \"CPUExecutionProvider\")\n", - "\n", - "\n", - "def get_tvm_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:\n", - " return get_onnxruntime_output(model, inputs, \"TvmExecutionProvider\")" - ] - }, - { - "cell_type": "markdown", - "id": "cc17d3b2", - "metadata": {}, - "source": [ - "### Helper function for checking accuracy\n", - "\n", - "This function uses the TVM API to compare two output tensors. The tensor obtained using the `CPUExecutionProvider` is used as a reference.\n", - "\n", - "If a mismatch is found between tensors, an appropriate exception will be thrown." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "4e598907", - "metadata": {}, - "outputs": [], - "source": [ - "def verify_outputs(\n", - " lhs: List[np.ndarray],\n", - " rhs: List[np.ndarray],\n", - " rtol: float = 5e-5,\n", - " atol: float = 5e-5\n", - ") -> None:\n", - " for lhs_tensor, rhs_tensor in zip(lhs, rhs):\n", - " tvm.testing.assert_allclose(lhs_tensor, rhs_tensor, rtol=rtol, atol=atol)\n", - " assert lhs_tensor.dtype == rhs_tensor.dtype\n", - " print(\"Same output, congratulations!\")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f33a372b", - "metadata": {}, - "outputs": [], - "source": [ - "def verify_with_ort_with_inputs(\n", - " model,\n", - " inputs,\n", - " out_shape=None,\n", - " opset=None,\n", - " freeze_params=False,\n", - " dtype=\"float32\",\n", - " rtol=1e-5,\n", - " atol=1e-5,\n", - " opt_level=1,\n", - "):\n", - " if opset is not None:\n", - " model.opset_import[0].version = opset\n", - "\n", - " ort_out = get_cpu_onnxruntime_output(model, inputs)\n", - " tvm_out = get_tvm_onnxruntime_output(model, inputs)\n", - " verify_outputs(ort_out, tvm_out, rtol, atol)" - ] - }, - { - "cell_type": "markdown", - "id": "8c62b01a", - "metadata": {}, - "source": [ - "### Helper functions for download models\n", - "\n", - "These functions use the TVM API to download models from the ONNX Model Zoo." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "324c00e7", - "metadata": {}, - "outputs": [], - "source": [ - "BASE_MODEL_URL = \"https://github.com/onnx/models/raw/master/\"\n", - "MODEL_URL_COLLECTION = {\n", - " \"ResNet50-v1\": \"vision/classification/resnet/model/resnet50-v1-7.onnx\",\n", - " \"ResNet50-v2\": \"vision/classification/resnet/model/resnet50-v2-7.onnx\",\n", - " \"SqueezeNet-v1.1\": \"vision/classification/squeezenet/model/squeezenet1.1-7.onnx\",\n", - " \"SqueezeNet-v1.0\": \"vision/classification/squeezenet/model/squeezenet1.0-7.onnx\",\n", - " \"Inception-v1\": \"vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx\",\n", - " \"Inception-v2\": \"vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx\",\n", - "}\n", - "\n", - "\n", - "def get_model_url(model_name):\n", - " return BASE_MODEL_URL + MODEL_URL_COLLECTION[model_name]\n", - "\n", - "\n", - "def get_name_from_url(url):\n", - " return url[url.rfind(\"/\") + 1 :].strip()\n", - "\n", - "\n", - "def find_of_download(model_name):\n", - " model_url = get_model_url(model_name)\n", - " model_file_name = get_name_from_url(model_url)\n", - " return tvm.contrib.download.download_testdata(model_url, model_file_name, module=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "90fb7c5c", - "metadata": {}, - "source": [ - "## 2. Accuracy check for TVM EP \n", - "\n", - "This section will check the accuracy. The check will be to compare the output tensors for `CPUExecutionProvider` and `TvmExecutionProvider`. See the description of `verify_with_ort_with_inputs` function used above.\n", - "\n", - "\n", - "### Check for simple architectures" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "c739ed5c", - "metadata": {}, - "outputs": [], - "source": [ - "def get_two_input_model(op_name: AnyStr) -> ModelProto:\n", - " dtype = \"float32\"\n", - " in_shape = [1, 2, 3, 3]\n", - " in_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]\n", - " out_shape = in_shape\n", - " out_type = in_type\n", - "\n", - " layer = helper.make_node(op_name, [\"in1\", \"in2\"], [\"out\"])\n", - " graph = helper.make_graph(\n", - " [layer],\n", - " \"two_input_test\",\n", - " inputs=[\n", - " helper.make_tensor_value_info(\"in1\", in_type, in_shape),\n", - " helper.make_tensor_value_info(\"in2\", in_type, in_shape),\n", - " ],\n", - " outputs=[\n", - " helper.make_tensor_value_info(\n", - " \"out\", out_type, out_shape\n", - " )\n", - " ],\n", - " )\n", - " model = helper.make_model(graph, producer_name=\"two_input_test\")\n", - " checker.check_model(model, full_check=True)\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7048ee6d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Same output, congratulations!\n", - "****************** Success! ******************\n" - ] - } - ], - "source": [ - "onnx_model = get_two_input_model(\"Add\")\n", - "inputs = get_random_model_inputs(onnx_model)\n", - "verify_with_ort_with_inputs(onnx_model, inputs)\n", - "print(\"****************** Success! ******************\")" - ] - }, - { - "cell_type": "markdown", - "id": "52c880f4", - "metadata": {}, - "source": [ - "### Check for DNN architectures " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "f5d465dc", - "metadata": {}, - "outputs": [], - "source": [ - "def get_onnx_model(model_name):\n", - " model_path = find_of_download(model_name)\n", - " onnx_model = onnx.load(model_path)\n", - " return onnx_model" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "68daac7e", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Same output, congratulations!\n", - "****************** Success! ******************\n" - ] - } - ], - "source": [ - "model_name = \"ResNet50-v1\"\n", - "\n", - "onnx_model = get_onnx_model(model_name)\n", - "inputs = get_random_model_inputs(onnx_model)\n", - "verify_with_ort_with_inputs(onnx_model, inputs)\n", - "print(\"****************** Success! ******************\")" - ] - }, - { - "cell_type": "markdown", - "id": "e27f64a2", - "metadata": {}, - "source": [ - "## 3. Configuration options\n", - "\n", - "This section shows how you can configure TVM EP using custom options. For more details on the options used, see the corresponding section of the documentation." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "a053f59f", - "metadata": {}, - "outputs": [], - "source": [ - "provider_name = \"TvmExecutionProvider\"\n", - "provider_options = dict(\n", - " target=\"llvm -mtriple=x86_64-linux-gnu\",\n", - " target_host=\"llvm -mtriple=x86_64-linux-gnu\",\n", - " opt_level=3,\n", - " freeze_weights=True,\n", - " tuning_file_path=\"\",\n", - " tuning_type=\"Ansor\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "3f6e6f01", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"ResNet50-v1\"\n", - "onnx_model = get_onnx_model(model_name)\n", - "input_dict = {\n", - " input_name: input_value for input_name, input_value in zip(\n", - " get_onnx_input_names(onnx_model),\n", - " get_random_model_inputs(onnx_model),\n", - " )\n", - "}\n", - "output_names = get_onnx_output_names(onnx_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "85ab83f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "****************** Output shape: (1, 1000) ******************\n" - ] - } - ], - "source": [ - "tvm_session = onnxruntime.InferenceSession(\n", - " onnx_model.SerializeToString(),\n", - " providers=[provider_name],\n", - " provider_options=[provider_options],\n", - ")\n", - "output = tvm_session.run(output_names, input_dict)[0]\n", - "print(f\"****************** Output shape: {output.shape} ******************\")" - ] - }, - { - "cell_type": "markdown", - "id": "b704374b", - "metadata": {}, - "source": [ - "## 4. Support precompiled model\n", - "\n", - "Wrapper functions that allow you to compile the model and save it in the desired format." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "8150942b", - "metadata": {}, - "outputs": [], - "source": [ - "def compile_virtual_machine(model: onnx.ModelProto, target_str: AnyStr) -> tvm.runtime.vm.Executable:\n", - " ir_mod, params = tvm.relay.frontend.from_onnx(\n", - " model,\n", - " opset=model.opset_import[0].version,\n", - " freeze_params=True,\n", - " )\n", - " target = tvm.target.Target(target=target_str, host=target_str)\n", - " return tvm.relay.backend.vm.compile(ir_mod, target)\n", - "\n", - "\n", - "def serialize_virtual_machine(vm_exec: tvm.runtime.vm.Executable) -> AnyStr:\n", - " temp_directory = tempfile.mkdtemp()\n", - " path_consts = os.path.join(temp_directory, \"consts\")\n", - " vm_exec.move_late_bound_consts(path_consts, byte_limit=256)\n", - " lib_path = os.path.join(temp_directory, f\"model.so\")\n", - " code_path = os.path.join(temp_directory, f\"model.ro\")\n", - " code, lib = vm_exec.save()\n", - " lib.export_library(lib_path)\n", - " with open(code_path, \"wb\") as fo:\n", - " fo.write(code)\n", - " return temp_directory" - ] - }, - { - "cell_type": "markdown", - "id": "9cbb987e", - "metadata": {}, - "source": [ - "Preparation of the ONNX model." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "febb9d72", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = \"ResNet50-v1\"\n", - "onnx_model = get_onnx_model(model_name)\n", - "input_dict = {\n", - " input_name: input_value for input_name, input_value in zip(\n", - " get_onnx_input_names(onnx_model),\n", - " get_random_model_inputs(onnx_model),\n", - " )\n", - "}\n", - "output_names = get_onnx_output_names(onnx_model)" - ] - }, - { - "cell_type": "markdown", - "id": "b05b251a", - "metadata": {}, - "source": [ - "Compiling the ONNX model using `VirtualMachine` (TVM)." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b4b999ee", - "metadata": {}, - "outputs": [], - "source": [ - "compiled_vm_exec = compile_virtual_machine(onnx_model, target_str=\"llvm\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "e3408c15", - "metadata": {}, - "outputs": [], - "source": [ - "so_folder = serialize_virtual_machine(compiled_vm_exec)" - ] - }, - { - "cell_type": "markdown", - "id": "311405e8", - "metadata": {}, - "source": [ - "Preparing `ProviderOptions` and launching `TVM EP` inference.\n", - "\n", - "In order to use the precompiled model, you only need to pass two options:\n", - "* **executor** - `vm` (`VirtualMachine`) must be used as a value (this functionality is not supported for `GraphExecutor`);\n", - "* **so_folder** - as a value, you must pass the path to the directory where the files of the precompiled model are located." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "8927293c", - "metadata": {}, - "outputs": [], - "source": [ - "provider_name = \"TvmExecutionProvider\"\n", - "provider_options = dict(\n", - " executor=\"vm\",\n", - " so_folder=so_folder,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "d7532863", - "metadata": {}, - "outputs": [], - "source": [ - "tvm_session = onnxruntime.InferenceSession(\n", - " onnx_model.SerializeToString(),\n", - " providers=[provider_name],\n", - " provider_options=[provider_options],\n", - ")\n", - "tvm_output = tvm_session.run(output_names, input_dict)" - ] - }, - { - "cell_type": "markdown", - "id": "1c0b983e", - "metadata": {}, - "source": [ - "Let's make sure that the output values match those that can be obtained through `CPUExecutionProvider`:" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "c3de2299", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Same output, congratulations!\n" - ] - } - ], - "source": [ - "verify_outputs(\n", - " tvm_output[0],\n", - " get_cpu_onnxruntime_output(\n", - " onnx_model,\n", - " input_dict.values()\n", - " ),\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 7b3d04ee66d9e..aaf533135429c 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -8,6 +8,9 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { +namespace logging { +class Logger; +} using KernelCreateMap = std::multimap; using KernelDefHashes = std::vector>; @@ -33,6 +36,7 @@ class KernelRegistry { // Kernel matching uses the types from the node and the kernel_type_str_resolver. Status TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const; // map of type constraint name to required type @@ -42,6 +46,7 @@ class KernelRegistry { // Kernel matching uses the explicit type constraint name to required type map in type_constraints. Status TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; /** @@ -61,13 +66,15 @@ class KernelRegistry { std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; static bool HasImplementationOf(const KernelRegistry& r, const Node& node, ProviderType exec_provider, - const IKernelTypeStrResolver& kernel_type_str_resolver) { + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) { const KernelCreateInfo* info; - Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info); + Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info); return st.IsOK(); } @@ -83,6 +90,7 @@ class KernelRegistry { Status TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; // Check whether the types of inputs/outputs of the given node match the extra diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index a17da2a19bb99..07625c38d8474 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -79,7 +79,6 @@ class OpKernel { // the allocator tied to the session if the kernel owns the pre-packed buffer or an // allocator shared between sessions if the pre-packed buffer is to be shared across sessions // (i.e.) the kernel does not own the buffer. - // @param save_prepacked_initializers: Set it to true if intend to save prepacked initializers to external data file. // @param is_packed: Set it to true if the kernel packed the tensor or to false // The kernel is responsible for keeping the packed data and related metadata if is_packed is true, // and the original initialized constant tensor will be released and not accessible anymore in @@ -89,7 +88,6 @@ class OpKernel { virtual Status PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool, /*save_prepacked_initializers*/ /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; return Status::OK(); @@ -131,26 +129,6 @@ class OpKernel { return Status::OK(); } - // Override this function to get pre-packed tensors from this kernel. - // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from - // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. - // @param input_idx : The index of input we prepacked before and intend to get packed tensor back. - // Please refer to matmul_nbits kernel for a complete example. - virtual std::optional GetPrePackTensor(int /*input_idx*/) { - return std::nullopt; - } - - // Override this function to set pre-packed tensors to this kernel and restore prepacked weight buffer. - // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from - // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. - // Please refer to matmul_nbits kernel for a complete example. - // @param input_idx : The input index of the tensor in this kernel. - // @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor - // to restore prepacked weight buffer. - virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) { - return Status::OK(); - } - const OrtDevice GetDevice(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index f15543f22f21d..6f658ab65be20 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -17,6 +17,7 @@ struct OrtDevice { static const DeviceType GPU = 1; // Nvidia or AMD static const DeviceType FPGA = 2; static const DeviceType NPU = 3; // Ascend + static const DeviceType DML = 4; struct MemType { // Pre-defined memory types. diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 69af3c93d7a07..eb9581e8018d1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1148,11 +1148,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node); #endif - // Since one constant initializer could be used by different kernels - // and prepacked differently, use an unordered_map to store prepacked - // initializer in format of <[initializer_name], <[node_name], [prepacked_initializer]>> - typedef std::unordered_map> PrePackedTensorProtoToSave; - #if !defined(ORT_MINIMAL_BUILD) /** Gets the GraphProto representation of this Graph. */ const ONNX_NAMESPACE::GraphProto& ToGraphProto(); @@ -1187,26 +1182,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved in the external file. Initializer smaller than this threshold are included in the onnx file. @param align_info offset alignment info. - @param save_prepacked_constant_initializers whether to save prepacked initializer into external data file. - If set false to this boolean, prepacked initializer will not be saved into onnxruntime data file, - we keep constant initializer as it is. - @param pre_packed_initializers struct used to store all the prepacked initializers. @returns GraphProto serialization of the graph. */ ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - PrePackedTensorProtoToSave& pre_packed_initializers) const; + const OffsetAlignmentInfo& align_info) const; ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold) const { OffsetAlignmentInfo default_options; - PrePackedTensorProtoToSave pre_packed_initializers; - return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options, - false, pre_packed_initializers); + return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options); } /** Gets the ISchemaRegistry instances being used with this Graph. */ @@ -1521,18 +1508,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi private: void InitializeStateFromModelFileGraphProto(); - // Private method used to setup external initializer properly during model save, - // this external initializer could be oroginal initializer or prepacked initializer. - static void SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, - size_t tensor_bytes_size, - int64_t& external_offset, - std::ofstream& external_stream, - gsl::span raw_data, - ONNX_NAMESPACE::TensorProto& output_proto, - const std::filesystem::path& external_file_path, - const ONNX_NAMESPACE::TensorProto& initializer, - bool is_prepacked); - // Add node with specified . Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 6cff153c336f0..31b0f22340510 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -53,6 +53,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); @@ -84,6 +85,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 7a6ba3afddce7..d035fd34bd072 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -31,16 +31,37 @@ enum COREMLFlags { // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. COREML_FLAG_CREATE_MLPROGRAM = 0x010, - // Exclude ANE as sometimes this decrease performance // https://developer.apple.com/documentation/coreml/mlcomputeunits?language=objc // there are four compute units: // MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll + // different CU will have different performance and power consumption COREML_FLAG_USE_CPU_AND_GPU = 0x020, // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU, }; +// MLComputeUnits can be one of the following values: +// 'MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll' +// these values are intended to be used with Ort::SessionOptions::AppendExecutionProvider (C++ API) +// and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead. +static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits"; +static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat"; +// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES +static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes"; +static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs"; +// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property +// Core ML segments the model’s compute graph and specializes each segment for the target compute device. +// This process can affect the model loading time and the prediction latency. +// Use this option to tailor the specialization strategy for your model. +static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy"; +// Profile the Core ML MLComputePlan. +// This logs the hardware each operator is dispatched to and the estimated execution time. +// Intended for developer usage but provide useful diagnostic information if performance is not as expected. +static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan"; +// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu +static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU"; + #ifdef __cplusplus extern "C" { #endif diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b0c5d2329c428..a35d975ac8f1b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -626,8 +626,13 @@ typedef struct OrtMIGraphXProviderOptions { } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + * \brief This Struct is frozen since ORT 1.13.0. Its maintained part of Legacy API for compatibility. + * \brief For latest OpenVINO Provider Options update to the ProviderOptions map. + * \brief Latest OpenVINO Provider Options are listed in the + * \htmlonly + *
onnxruntime document. + * \endhtmlonly + * \see OrtApi::SessionOptionsAppendExecutionProvider() */ typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus @@ -645,7 +650,7 @@ typedef struct OrtOpenVINOProviderOptions { * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" */ const char* device_type; - unsigned char enable_npu_fast_compile; + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; size_t num_of_threads; ///< 0 = Use default number of threads const char* cache_dir; // path is set to empty by default @@ -3662,6 +3667,9 @@ struct OrtApi { * execution provider (typically CPU EP). * - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O. * - "1": Enabled. + * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary. + * - "0": Default. Disabled. + * - "1": Enabled. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", @@ -4607,6 +4615,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_ OrtSessionOptions* options, @@ -4624,6 +4634,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, @@ -4637,7 +4649,10 @@ struct OrtApi { * \param[in] mem_info OrtMemoryInfo instance * \param[in] count_or_bytes How many bytes is this scratch buffer * \param[out] out A pointer to the scrach buffer + * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); @@ -4648,6 +4663,8 @@ struct OrtApi { * \param[out] out A pointer to OrtAllocator * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); @@ -4669,6 +4686,8 @@ struct OrtApi { * \param[in] num_external_initializer_files Number of external files * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, @@ -4691,6 +4710,8 @@ struct OrtApi { * OrtApi::ReleaseLoraAdapter. * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, _Outptr_ OrtLoraAdapter** out); @@ -4709,6 +4730,8 @@ struct OrtApi { * OrtApi::ReleaseLoraAdapter. * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, _Outptr_ OrtLoraAdapter** out); @@ -4730,6 +4753,8 @@ struct OrtApi { * \param[in] adapter OrtLoraAdapter instance * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); @@ -4748,6 +4773,8 @@ struct OrtApi { * \param[in] kv_len Number of elements in the keys and values arrays * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 406ca3ea92559..f3e9758766d00 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -650,6 +650,9 @@ using AllocatedStringPtr = std::unique_ptr; * constructors to construct an instance of a Status object from exceptions. */ struct Status : detail::Base { + using Base = detail::Base; + using Base::Base; + explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception @@ -728,6 +731,9 @@ struct Env : detail::Base { * */ struct CustomOpDomain : detail::Base { + using Base = detail::Base; + using Base::Base; + explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateCustomOpDomain @@ -963,8 +969,10 @@ struct SessionOptions : detail::SessionOptionsImpl { * */ struct ModelMetadata : detail::Base { - explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used - explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API + using Base = detail::Base; + using Base::Base; + + explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used /** \brief Returns a copy of the producer name. * @@ -1237,6 +1245,9 @@ using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl { + using Base = detail::TensorTypeAndShapeInfoImpl; + using Base::Base; + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } @@ -1258,6 +1269,9 @@ using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl { + using Base = detail::SequenceTypeInfoImpl; + using Base::Base; + explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl{p} {} ///< Used for interop with the C API ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } @@ -1293,6 +1307,9 @@ using ConstMapTypeInfo = detail::MapTypeInfoImpl { + using Base = detail::MapTypeInfoImpl; + using Base::Base; + explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl{p} {} ///< Used for interop with the C API ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } @@ -1324,6 +1341,9 @@ using ConstTypeInfo = detail::TypeInfoImpl>; /// the information about contained sequence or map depending on the ONNXType. ///
struct TypeInfo : detail::TypeInfoImpl { + using Base = detail::TypeInfoImpl; + using Base::Base; + explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop @@ -1661,11 +1681,11 @@ using UnownedValue = detail::ValueImpl>; */ struct Value : detail::ValueImpl { using Base = detail::ValueImpl; + using Base::Base; using OrtSparseValuesParam = detail::OrtSparseValuesParam; using Shape = detail::Shape; - explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used - explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API + explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used Value(Value&&) = default; Value& operator=(Value&&) = default; @@ -1941,6 +1961,10 @@ struct ArenaCfg : detail::Base { /// This struct provides life time management for custom op attribute ///
struct OpAttr : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit OpAttr(std::nullptr_t) {} OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); }; @@ -2183,6 +2207,8 @@ using ConstKernelInfo = detail::KernelInfoImpl struct KernelInfo : detail::KernelInfoImpl { + using Base = detail::KernelInfoImpl; + using Base::Base; explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } @@ -2192,6 +2218,9 @@ struct KernelInfo : detail::KernelInfoImpl { /// Create and own custom defined operation. /// struct Op : detail::Base { + using Base = detail::Base; + using Base::Base; + explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used explicit Op(OrtOp*); ///< Take ownership of the OrtOp diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index c38da3e1c3e29..3aeb9412f350e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -51,7 +51,7 @@ inline void ThrowOnError(const Status& st) { } } -inline Status::Status(OrtStatus* status) noexcept : Base{status} { +inline Status::Status(OrtStatus* status) noexcept : detail::Base{status} { } inline Status::Status(const std::exception& e) noexcept { @@ -1908,7 +1908,7 @@ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std:: inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} -inline Op::Op(OrtOp* p) : Base(p) {} +inline Op::Op(OrtOp* p) : detail::Base(p) {} inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, const char** type_constraint_names, diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 086919913cbea..8f1bc98ce7b49 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -246,12 +246,6 @@ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disab static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = "session.optimized_model_external_initializers_file_name"; -// Use this config when save prepacked constant initializers to onnx external data file. -// Default is not save prepacked initializers to onnx data file. -// Sample usage: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', "1") -static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers = - "session.save_prepacked_constant_initializers"; - // Use this config to control the minimum size of the initializer when externalizing it during serialization static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; @@ -267,8 +261,8 @@ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; // Flag to specify whether to dump the EP context into the Onnx model. -// "0": dump the EP context into separate file, keep the file name in the Onnx model. -// "1": dump the EP context into the Onnx model. (default). +// "0": dump the EP context into separate file, keep the file name in the Onnx model. (default). +// "1": dump the EP context into the Onnx model. static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; // Specify the EPContext node name prefix to make it unique diff --git a/java/build-android.gradle b/java/build-android.gradle index d5839f9f27869..9c4275b74f626 100644 --- a/java/build-android.gradle +++ b/java/build-android.gradle @@ -82,7 +82,7 @@ allprojects { } android { - compileSdkVersion 32 + compileSdkVersion 34 defaultConfig { minSdkVersion minSdkVer @@ -108,8 +108,8 @@ android { } compileOptions { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } sourceSets { diff --git a/java/build.gradle b/java/build.gradle index 34ac93cce6f4e..845121dd17a48 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -50,8 +50,8 @@ mavenSettings { } java { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } // This jar tasks serves as a CMAKE signaling diff --git a/java/gradle/wrapper/gradle-wrapper.properties b/java/gradle/wrapper/gradle-wrapper.properties index 4baf5a11d45a3..381baa9cef1ec 100644 --- a/java/gradle/wrapper/gradle-wrapper.properties +++ b/java/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=9631d53cf3e74bfa726893aee1f8994fee4e060c401335946dba2156f440f24c -distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip +distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d +distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/java/gradlew.bat b/java/gradlew.bat index 93e3f59f135dd..25da30dbdeee9 100644 --- a/java/gradlew.bat +++ b/java/gradlew.bat @@ -43,11 +43,11 @@ set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -57,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 7280f3c88e2e8..32dc9d9f84aaa 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1323,6 +1323,18 @@ public void addQnn(Map providerOptions) throws OrtException { addExecutionProvider(qnnProviderName, providerOptions); } + /** + * Adds CoreML as an execution backend. + * + * @param providerOptions Configuration options for the CoreML backend. Refer to the CoreML + * execution provider's documentation. + * @throws OrtException If there was an error in native code. + */ + public void addCoreML(Map providerOptions) throws OrtException { + String CoreMLProviderName = "CoreML"; + addExecutionProvider(CoreMLProviderName, providerOptions); + } + private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) throws OrtException; diff --git a/java/src/test/android/README.md b/java/src/test/android/README.md index b84021669c9fe..b086be3dc904c 100644 --- a/java/src/test/android/README.md +++ b/java/src/test/android/README.md @@ -29,6 +29,11 @@ Use the android's [build instructions](https://onnxruntime.ai/docs/build/android Please note that you may need to set the `--android_abi=x86_64` (the default option is `arm64-v8a`). This is because android instrumentation test is run on an android emulator which requires an abi of `x86_64`. +#### QNN Builds +We use two AndroidManifest.xml files to manage different runtime requirements for QNN support. In the [build configuration](app/build.gradle), we specify which manifest file to use based on the qnnVersion. +In the [QNN manifest](app/src/main/AndroidManifestQnn.xml), we include the declaration for libcdsprpc.so, which is required for devices using QNN and Qualcomm DSP capabilities. +For QNN builds, it is also necessary to set the `ADSP_LIBRARY_PATH` environment variable to the [native library directory](https://developer.android.com/reference/android/content/pm/ApplicationInfo#nativeLibraryDir) depending on the device. This ensures that any native libraries downloaded as dependencies such as QNN libraries are found by the application. This is conditionally added by using the BuildConfig field IS_QNN_BUILD set in the build.gradle file. + #### Build Output The build will generate two apks which is required to run the test application in `$YOUR_BUILD_DIR/java/androidtest/android/app/build/outputs/apk`: diff --git a/java/src/test/android/app/build.gradle b/java/src/test/android/app/build.gradle index 381de06cc09de..baf18e714d25c 100644 --- a/java/src/test/android/app/build.gradle +++ b/java/src/test/android/app/build.gradle @@ -4,18 +4,27 @@ plugins { } def minSdkVer = System.properties.get("minSdkVer")?:24 +def qnnVersion = System.properties['qnnVersion'] android { - compileSdkVersion 32 + compileSdkVersion 34 defaultConfig { applicationId "ai.onnxruntime.example.javavalidator" minSdkVersion minSdkVer - targetSdkVersion 32 + targetSdkVersion 34 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + // Add BuildConfig field for qnnVersion + if (qnnVersion != null) { + buildConfigField "boolean", "IS_QNN_BUILD", "true" + } + else { + buildConfigField "boolean", "IS_QNN_BUILD", "false" + } } buildTypes { @@ -25,11 +34,29 @@ android { } } compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 + sourceCompatibility JavaVersion.VERSION_17 + targetCompatibility JavaVersion.VERSION_17 } kotlinOptions { - jvmTarget = '1.8' + jvmTarget = '17' + } + // Conditional packagingOptions for QNN builds only + if (qnnVersion != null) { + packagingOptions { + jniLibs { + useLegacyPackaging = true + } + // Dsp is used in older QC devices and not supported by ORT + // Gpu support isn't the target, we just want Npu support (Htp) + exclude 'lib/arm64-v8a/libQnnGpu.so' + exclude 'lib/arm64-v8a/libQnnDsp*.so' + } + + sourceSets { + main { + manifest.srcFile 'src/main/AndroidManifestQnn.xml' // Use QNN manifest + } + } } namespace 'ai.onnxruntime.example.javavalidator' } @@ -42,11 +69,20 @@ dependencies { implementation 'com.google.android.material:material:1.3.0' implementation 'androidx.constraintlayout:constraintlayout:2.0.4' testImplementation 'junit:junit:4.+' - androidTestImplementation 'androidx.test.ext:junit:1.1.3' - androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' - implementation(name: "onnxruntime-android", ext: "aar") + androidTestImplementation "androidx.test.ext:junit:1.1.5" + androidTestImplementation "androidx.test.espresso:espresso-core:3.5.0" - androidTestImplementation 'androidx.test:runner:1.4.0' - androidTestImplementation 'androidx.test:rules:1.4.0' + androidTestImplementation "androidx.test:runner:1.5.2" + androidTestImplementation "androidx.test:rules:1.5.0" androidTestImplementation 'com.microsoft.appcenter:espresso-test-extension:1.4' + + // dependencies for onnxruntime-android-qnn + if (qnnVersion != null) { + implementation(name: "onnxruntime-android-qnn", ext: "aar") + implementation "com.qualcomm.qti:qnn-runtime:$qnnVersion" + } + else { + implementation(name: "onnxruntime-android", ext: "aar") + } + } diff --git a/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt b/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt index 166803ae263a5..5e6bee6cac9f4 100644 --- a/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt +++ b/java/src/test/android/app/src/androidTest/java/ai/onnxruntime/example/javavalidator/SimpleTest.kt @@ -38,13 +38,18 @@ class SimpleTest { @Test fun runSigmoidModelTest() { for (intraOpNumThreads in 1..4) { - runSigmoidModelTestImpl(intraOpNumThreads) + runSigmoidModelTestImpl(intraOpNumThreads, OrtProvider.CPU) } } @Test fun runSigmoidModelTestNNAPI() { - runSigmoidModelTestImpl(1, true) + runSigmoidModelTestImpl(1, OrtProvider.NNAPI) + } + + @Test + fun runSigmoidModelTestQNN() { + runSigmoidModelTestImpl(1, OrtProvider.QNN) } @Throws(IOException::class) @@ -54,22 +59,49 @@ class SimpleTest { } @Throws(OrtException::class, IOException::class) - fun runSigmoidModelTestImpl(intraOpNumThreads: Int, useNNAPI: Boolean = false) { - reportHelper.label("Start Running Test with intraOpNumThreads=$intraOpNumThreads, useNNAPI=$useNNAPI") + fun runSigmoidModelTestImpl(intraOpNumThreads: Int, executionProvider: OrtProvider) { + reportHelper.label("Start Running Test with intraOpNumThreads=$intraOpNumThreads, executionProvider=$executionProvider") Log.println(Log.INFO, TAG, "Testing with intraOpNumThreads=$intraOpNumThreads") - Log.println(Log.INFO, TAG, "Testing with useNNAPI=$useNNAPI") + Log.println(Log.INFO, TAG, "Testing with executionProvider=$executionProvider") + val env = OrtEnvironment.getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE) env.use { val opts = SessionOptions() opts.setIntraOpNumThreads(intraOpNumThreads) - if (useNNAPI) { - if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) { - opts.addNnapi() - } else { - Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test") - return + + when (executionProvider) { + + OrtProvider.NNAPI -> { + if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) { + opts.addNnapi() + } else { + Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test") + return + } + } + + OrtProvider.QNN -> { + if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.QNN)) { + // Since this is running in an Android environment, we use the .so library + val qnnLibrary = "libQnnHtp.so" + val providerOptions = Collections.singletonMap("backend_path", qnnLibrary) + opts.addQnn(providerOptions) + } else { + Log.println(Log.INFO, TAG, "NO QNN EP available, skip the test") + return + } + } + + OrtProvider.CPU -> { + // No additional configuration is needed for CPU + } + + else -> { + // Non exhaustive when statements on enum will be prohibited in future Gradle versions + Log.println(Log.INFO, TAG, "Skipping test as OrtProvider is not implemented") } } + opts.use { val session = env.createSession(readModel("sigmoid.ort"), opts) session.use { @@ -92,13 +124,15 @@ class SimpleTest { output.use { @Suppress("UNCHECKED_CAST") val rawOutput = output[0].value as Array> + // QNN EP will run the Sigmoid float32 op with fp16 precision + val precision = if (executionProvider == OrtProvider.QNN) 1e-3 else 1e-6 for (i in 0..2) { for (j in 0..3) { for (k in 0..4) { Assert.assertEquals( rawOutput[i][j][k], expected[i][j][k], - 1e-6.toFloat() + precision.toFloat() ) } } diff --git a/java/src/test/android/app/src/main/AndroidManifest.xml b/java/src/test/android/app/src/main/AndroidManifest.xml index 2938b7e8bf409..08a612ed79fd6 100644 --- a/java/src/test/android/app/src/main/AndroidManifest.xml +++ b/java/src/test/android/app/src/main/AndroidManifest.xml @@ -17,4 +17,4 @@ - \ No newline at end of file + diff --git a/java/src/test/android/app/src/main/AndroidManifestQnn.xml b/java/src/test/android/app/src/main/AndroidManifestQnn.xml new file mode 100644 index 0000000000000..c9416523a9c91 --- /dev/null +++ b/java/src/test/android/app/src/main/AndroidManifestQnn.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + diff --git a/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt b/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt index 62e23c4b9b862..3b3a2d057b16e 100644 --- a/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt +++ b/java/src/test/android/app/src/main/java/ai/onnxruntime/example/javavalidator/MainActivity.kt @@ -1,11 +1,19 @@ package ai.onnxruntime.example.javavalidator import android.os.Bundle +import android.system.Os import androidx.appcompat.app.AppCompatActivity /*Empty activity app mainly used for testing*/ class MainActivity : AppCompatActivity() { override fun onCreate(savedInstanceState: Bundle?) { + if (BuildConfig.IS_QNN_BUILD) { + val adspLibraryPath = applicationContext.applicationInfo.nativeLibraryDir + // set the path variable to the native library directory + // so that any native libraries downloaded as dependencies + // (like qnn libs) are found + Os.setenv("ADSP_LIBRARY_PATH", adspLibraryPath, true) + } super.onCreate(savedInstanceState) } -} \ No newline at end of file +} diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index e11537492d3a7..15d89b536b39a 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -737,6 +737,7 @@ public void testCoreML() throws OrtException { runProvider(OrtProvider.CORE_ML); } + @Disabled("DirectML Java API hasn't been supported yet") @Test @EnabledIfSystemProperty(named = "USE_DML", matches = "1") public void testDirectML() throws OrtException { diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 57c4eb3577fd0..fa0b6fd0ef9d9 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -27,6 +27,7 @@ import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfSystemProperty; import org.junit.jupiter.api.condition.EnabledIfSystemProperty; public class ProviderOptionsTest { @@ -34,6 +35,7 @@ public class ProviderOptionsTest { @Test @EnabledIfSystemProperty(named = "USE_CUDA", matches = "1") + @DisabledIfSystemProperty(named = "NO_CUDA_TEST", matches = "1") public void testCUDAOptions() throws OrtException { // Test standard options OrtCUDAProviderOptions cudaOpts = new OrtCUDAProviderOptions(0); @@ -61,6 +63,7 @@ public void testCUDAOptions() throws OrtException { @Test @EnabledIfSystemProperty(named = "USE_TENSORRT", matches = "1") + @DisabledIfSystemProperty(named = "NO_CUDA_TEST", matches = "1") public void testTensorRT() throws OrtException { // Test standard options OrtTensorRTProviderOptions rtOpts = new OrtTensorRTProviderOptions(0); diff --git a/js/.eslintrc.js b/js/.eslintrc.js index bd1e9061355f5..462e417df1d66 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -198,19 +198,6 @@ module.exports = { '_OrtReleaseTensor', '_OrtRun', '_OrtRunWithBinding', - '_OrtTrainingCopyParametersFromBuffer', - '_OrtTrainingCopyParametersToBuffer', - '_OrtTrainingCreateSession', - '_OrtTrainingEvalStep', - '_OrtTrainingGetModelInputOutputCount', - '_OrtTrainingGetModelInputOutputName', - '_OrtTrainingGetParametersSize', - '_OrtTrainingLazyResetGrad', - '_OrtTrainingLoadCheckpoint', - '_OrtTrainingOptimizerStep', - '_OrtTrainingReleaseCheckpoint', - '_OrtTrainingReleaseSession', - '_OrtTrainingRunTrainStep', ], }, ], diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index e27e67622aa82..e63f9c6c9147f 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -3,7 +3,6 @@ import { InferenceSession } from './inference-session.js'; import { OnnxValue } from './onnx-value.js'; -import { TrainingSession } from './training-session.js'; /** * @ignore @@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler { ): Promise; } -/** - * Represent a handler instance of a training inference session. - * - * @ignore - */ -export interface TrainingSessionHandler extends SessionHandler { - readonly evalInputNames: readonly string[]; - readonly evalOutputNames: readonly string[]; - - lazyResetGrad(): Promise; - runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise; - runOptimizerStep(options: InferenceSession.RunOptions): Promise; - runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise; - - getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; -} - /** * Represent a backend that provides implementation of model inferencing. * @@ -84,14 +56,6 @@ export interface Backend { uriOrBuffer: string | Uint8Array, options?: InferenceSession.SessionOptions, ): Promise; - - createTrainingSessionHandler?( - checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, - trainModelUriOrBuffer: TrainingSession.UriOrBuffer, - evalModelUriOrBuffer: TrainingSession.UriOrBuffer, - optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, - options: InferenceSession.SessionOptions, - ): Promise; } export { registerBackend } from './backend-impl.js'; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 642a897a90d26..d6d9f7fa48790 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import { env as envImpl } from './env-impl.js'; +import { TryGetGlobalType } from './type-helper.js'; export declare namespace Env { export type WasmPathPrefix = string; @@ -14,7 +15,6 @@ export declare namespace Env { * If not modified, the filename of the .wasm file is: * - `ort-wasm-simd-threaded.wasm` for default build * - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN) - * - `ort-training-wasm-simd-threaded.wasm` for training build */ wasm?: URL | string; /** @@ -25,7 +25,6 @@ export declare namespace Env { * If not modified, the filename of the .mjs file is: * - `ort-wasm-simd-threaded.mjs` for default build * - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN) - * - `ort-training-wasm-simd-threaded.mjs` for training build */ mjs?: URL | string; } @@ -46,17 +45,19 @@ export declare namespace Env { * * This setting is available only when WebAssembly SIMD feature is available in current context. * + * @defaultValue `true` + * * @deprecated This property is deprecated. Since SIMD is supported by all major JavaScript engines, non-SIMD * build is no longer provided. This property will be removed in future release. - * @defaultValue `true` */ simd?: boolean; /** * set or get a boolean value indicating whether to enable trace. * - * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` + * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. */ trace?: boolean; @@ -154,7 +155,7 @@ export declare namespace Env { /** * Set or get the profiling configuration. */ - profiling?: { + profiling: { /** * Set or get the profiling mode. * @@ -177,6 +178,9 @@ export declare namespace Env { * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. * * @defaultValue `undefined` + * + * @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if + * you want to use a specific power preference. */ powerPreference?: 'low-power' | 'high-performance'; /** @@ -188,6 +192,9 @@ export declare namespace Env { * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. * * @defaultValue `undefined` + * + * @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if + * you want to use a specific fallback option. */ forceFallbackAdapter?: boolean; /** @@ -200,22 +207,25 @@ export declare namespace Env { * value will be the GPU adapter that created by the underlying WebGPU backend. * * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". - * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. * - * see comments on {@link Tensor.GpuBufferType} + * @deprecated It is no longer recommended to use this property. The latest WebGPU spec adds `GPUDevice.adapterInfo` + * (https://www.w3.org/TR/webgpu/#dom-gpudevice-adapterinfo), which allows to get the adapter information from the + * device. When it's available, there is no need to set/get the {@link adapter} property. */ - adapter: unknown; + adapter: TryGetGlobalType<'GPUAdapter'>; /** - * Get the device for WebGPU. - * - * This property is only available after the first WebGPU inference session is created. - * - * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types". - * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type. + * Set or get the GPU device for WebGPU. * - * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types". + * There are 3 valid scenarios of accessing this property: + * - Set a value before the first WebGPU inference session is created. The value will be used by the WebGPU backend + * to perform calculations. If the value is not a `GPUDevice` object, an error will be thrown. + * - Get the value before the first WebGPU inference session is created. This will try to create a new GPUDevice + * instance. Returns a `Promise` that resolves to a `GPUDevice` object. + * - Get the value after the first WebGPU inference session is created. Returns a resolved `Promise` to the + * `GPUDevice` object used by the WebGPU backend. */ - readonly device: unknown; + get device(): Promise>; + set device(value: TryGetGlobalType<'GPUDevice'>); /** * Set or get whether validate input content. * diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 3ed56b3c2e812..d75e6a477258d 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -26,4 +26,3 @@ export * from './tensor-factory.js'; export * from './trace.js'; export * from './onnx-model.js'; export * from './onnx-value.js'; -export * from './training-session.js'; diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 547db029471a2..e62c6579e8333 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -4,6 +4,7 @@ import { InferenceSession as InferenceSessionImpl } from './inference-session-impl.js'; import { OnnxModelOptions } from './onnx-model.js'; import { OnnxValue, OnnxValueDataLocation } from './onnx-value.js'; +import { TryGetGlobalType } from './type-helper.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -282,7 +283,7 @@ export declare namespace InferenceSession { extends WebNNExecutionProviderName, Omit, Required> { - context: unknown /* MLContext */; + context: TryGetGlobalType<'MLContext'>; } /** @@ -291,8 +292,8 @@ export declare namespace InferenceSession { * @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice */ export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName { - context: unknown /* MLContext */; - gpuDevice: unknown /* GPUDevice */; + context: TryGetGlobalType<'MLContext'>; + gpuDevice: TryGetGlobalType<'GPUDevice'>; } /** diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index af918705b97e3..05553bd96662b 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -4,6 +4,7 @@ import { TensorFactory } from './tensor-factory.js'; import { Tensor as TensorImpl } from './tensor-impl.js'; import { TypedTensorUtils } from './tensor-utils.js'; +import { TryGetGlobalType } from './type-helper.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -131,24 +132,19 @@ export declare namespace Tensor { */ export type TextureDataTypes = 'float32'; + type GpuBufferTypeFallback = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; /** * type alias for WebGPU buffer - * - * The reason why we don't use type "GPUBuffer" defined in webgpu.d.ts from @webgpu/types is because "@webgpu/types" - * requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its version need to be chosen - * carefully according to the TypeScript version being used. This means so far there is not a way to keep every - * TypeScript version happy. It turns out that we will easily broke users on some TypeScript version. - * - * for more info see https://github.com/gpuweb/types/issues/127 */ - export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; + export type GpuBufferType = TryGetGlobalType<'GPUBuffer', GpuBufferTypeFallback>; + type MLTensorTypeFallback = { destroy(): void }; /** * type alias for WebNN MLTensor * * The specification for WebNN's MLTensor is currently in flux. */ - export type MLTensorType = unknown; + export type MLTensorType = TryGetGlobalType<'MLTensor', MLTensorTypeFallback>; /** * supported data types for constructing a tensor from a WebGPU buffer diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts deleted file mode 100644 index 21dbe5fe51bb9..0000000000000 --- a/js/common/lib/training-session-impl.ts +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { resolveBackendAndExecutionProviders } from './backend-impl.js'; -import { SessionHandler, TrainingSessionHandler } from './backend.js'; -import { InferenceSession as InferenceSession } from './inference-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { Tensor } from './tensor.js'; -import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js'; - -type SessionOptions = InferenceSession.SessionOptions; -type FeedsType = InferenceSession.FeedsType; -type FetchesType = InferenceSession.FetchesType; -type ReturnType = InferenceSession.ReturnType; -type RunOptions = InferenceSession.RunOptions; - -const noBackendErrMsg: string = - 'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files."; - -export class TrainingSession implements TrainingSessionInterface { - private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) { - this.handler = handler; - this.hasOptimizerModel = hasOptimizerModel; - this.hasEvalModel = hasEvalModel; - } - private handler: TrainingSessionHandler; - private hasOptimizerModel: boolean; - private hasEvalModel: boolean; - - get trainingInputNames(): readonly string[] { - return this.handler.inputNames; - } - get trainingOutputNames(): readonly string[] { - return this.handler.outputNames; - } - - get evalInputNames(): readonly string[] { - if (this.hasEvalModel) { - return this.handler.evalInputNames; - } else { - throw new Error('This training session has no evalModel loaded.'); - } - } - get evalOutputNames(): readonly string[] { - if (this.hasEvalModel) { - return this.handler.evalOutputNames; - } else { - throw new Error('This training session has no evalModel loaded.'); - } - } - - static async create( - trainingOptions: TrainingSessionCreateOptions, - sessionOptions?: SessionOptions, - ): Promise { - const evalModel: string | Uint8Array = trainingOptions.evalModel || ''; - const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || ''; - const options: SessionOptions = sessionOptions || {}; - - // resolve backend, update session options with validated EPs, and create session handler - const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); - if (backend.createTrainingSessionHandler) { - const handler = await backend.createTrainingSessionHandler( - trainingOptions.checkpointState, - trainingOptions.trainModel, - evalModel, - optimizerModel, - optionsWithValidatedEPs, - ); - return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); - } else { - throw new Error(noBackendErrMsg); - } - } - - /** - * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from - * the given parameters to SessionHandler.FetchesType and RunOptions. - * - * @param inputNames the feeds object is checked that they contain all input names in the provided list of input - * names. - * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output - * names. - * @param feeds the required input - * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object - * @param arg2 optional RunOptions object. - * @returns - */ - typeNarrowingForRunStep( - inputNames: readonly string[], - outputNames: readonly string[], - feeds: FeedsType, - arg1?: FetchesType | RunOptions, - arg2?: RunOptions, - ): [SessionHandler.FetchesType, RunOptions] { - const fetches: { [name: string]: OnnxValue | null } = {}; - let options: RunOptions = {}; - // check inputs - if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { - throw new TypeError( - "'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.", - ); - } - - let isFetchesEmpty = true; - // determine which override is being used - if (typeof arg1 === 'object') { - if (arg1 === null) { - throw new TypeError('Unexpected argument[1]: cannot be null.'); - } - if (arg1 instanceof Tensor) { - throw new TypeError("'fetches' cannot be a Tensor"); - } - - if (Array.isArray(arg1)) { - if (arg1.length === 0) { - throw new TypeError("'fetches' cannot be an empty array."); - } - isFetchesEmpty = false; - // output names - for (const name of arg1) { - if (typeof name !== 'string') { - throw new TypeError("'fetches' must be a string array or an object."); - } - if (outputNames.indexOf(name) === -1) { - throw new RangeError(`'fetches' contains invalid output name: ${name}.`); - } - fetches[name] = null; - } - - if (typeof arg2 === 'object' && arg2 !== null) { - options = arg2; - } else if (typeof arg2 !== 'undefined') { - throw new TypeError("'options' must be an object."); - } - } else { - // decide whether arg1 is fetches or options - // if any output name is present and its value is valid OnnxValue, we consider it fetches - let isFetches = false; - const arg1Keys = Object.getOwnPropertyNames(arg1); - for (const name of outputNames) { - if (arg1Keys.indexOf(name) !== -1) { - const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; - if (v === null || v instanceof Tensor) { - isFetches = true; - isFetchesEmpty = false; - fetches[name] = v; - } - } - } - - if (isFetches) { - if (typeof arg2 === 'object' && arg2 !== null) { - options = arg2; - } else if (typeof arg2 !== 'undefined') { - throw new TypeError("'options' must be an object."); - } - } else { - options = arg1 as RunOptions; - } - } - } else if (typeof arg1 !== 'undefined') { - throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'."); - } - - // check if all inputs are in feed - for (const name of inputNames) { - if (typeof feeds[name] === 'undefined') { - throw new Error(`input '${name}' is missing in 'feeds'.`); - } - } - - // if no fetches is specified, we use the full output names list - if (isFetchesEmpty) { - for (const name of outputNames) { - fetches[name] = null; - } - } - - return [fetches, options]; - } - - /** - * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler - * and changes it into a map of Tensors. - * - * @param results - * @returns - */ - convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { - const returnValue: { [name: string]: OnnxValue } = {}; - for (const key in results) { - if (Object.hasOwnProperty.call(results, key)) { - const result = results[key]; - if (result instanceof Tensor) { - returnValue[key] = result; - } else { - returnValue[key] = new Tensor(result.type, result.data, result.dims); - } - } - } - return returnValue; - } - - async lazyResetGrad(): Promise { - await this.handler.lazyResetGrad(); - } - - runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; - runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; - async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { - const [fetches, options] = this.typeNarrowingForRunStep( - this.trainingInputNames, - this.trainingOutputNames, - feeds, - arg1, - arg2, - ); - const results = await this.handler.runTrainStep(feeds, fetches, options); - return this.convertHandlerReturnTypeToMapOfTensors(results); - } - - async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise { - if (this.hasOptimizerModel) { - await this.handler.runOptimizerStep(options || {}); - } else { - throw new Error('This TrainingSession has no OptimizerModel loaded.'); - } - } - - runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise; - runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise; - async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { - if (this.hasEvalModel) { - const [fetches, options] = this.typeNarrowingForRunStep( - this.evalInputNames, - this.evalOutputNames, - feeds, - arg1, - arg2, - ); - const results = await this.handler.runEvalStep(feeds, fetches, options); - return this.convertHandlerReturnTypeToMapOfTensors(results); - } else { - throw new Error('This TrainingSession has no EvalModel loaded.'); - } - } - - async getParametersSize(trainableOnly = true): Promise { - return this.handler.getParametersSize(trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise { - const paramsSize = await this.getParametersSize(trainableOnly); - // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number - // of parameters - if (array.length !== 4 * paramsSize) { - throw new Error( - 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + - 'the model. Please use getParametersSize method to check.', - ); - } - return this.handler.loadParametersBuffer(array, trainableOnly); - } - - async getContiguousParameters(trainableOnly = true): Promise { - return this.handler.getContiguousParameters(trainableOnly); - } - - async release(): Promise { - return this.handler.dispose(); - } -} diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts deleted file mode 100644 index 45dcafc46deb5..0000000000000 --- a/js/common/lib/training-session.ts +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession } from './inference-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js'; - -/* eslint-disable @typescript-eslint/no-redeclare */ - -export declare namespace TrainingSession { - /** - * Either URI file path (string) or Uint8Array containing model or checkpoint information. - */ - type UriOrBuffer = string | Uint8Array; -} - -/** - * Represent a runtime instance of an ONNX training session, - * which contains a model that can be trained, and, optionally, - * an eval and optimizer model. - */ -export interface TrainingSession { - // #region run() - - /** - * Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of - * runOptimizerStep. - */ - lazyResetGrad(): Promise; - - /** - * Run TrainStep asynchronously with the given feeds and options. - * - * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for - detail. - * @param options - Optional. A set of options that controls the behavior of model training. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. - */ - runTrainStep( - feeds: InferenceSession.FeedsType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Run a single train step with the given inputs and options. - * - * @param feeds - Representation of the model input. - * @param fetches - Representation of the model output. - * detail. - * @param options - Optional. A set of options that controls the behavior of model training. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runTrainStep( - feeds: InferenceSession.FeedsType, - fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model. - * - * @param options - Optional. A set of options that controls the behavior of model optimizing. - */ - runOptimizerStep(options?: InferenceSession.RunOptions): Promise; - - /** - * Run a single eval step with the given inputs and options using the eval model. - * - * @param feeds - Representation of the model input. - * @param options - Optional. A set of options that controls the behavior of model eval step. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runEvalStep( - feeds: InferenceSession.FeedsType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Run a single eval step with the given inputs and options using the eval model. - * - * @param feeds - Representation of the model input. - * @param fetches - Representation of the model output. - * detail. - * @param options - Optional. A set of options that controls the behavior of model eval step. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runEvalStep( - feeds: InferenceSession.FeedsType, - fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions, - ): Promise; - - // #endregion - - // #region copy parameters - - /** - * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of - * the parameters) elements of all the parameters in the training state. - * - * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true. - */ - getParametersSize(trainableOnly: boolean): Promise; - - /** - * Copies parameter values from the given buffer to the training state. Currently, only supporting models with - * parameters of type Float32. - * - * @param buffer - A Uint8Array representation of Float32 parameters. - * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. - */ - loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; - - /** - * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. - * Currently, only supporting models with parameters of type Float32. - * - * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters - * for which requires_grad is set to true. Default value is true. - * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters. - */ - getContiguousParameters(trainableOnly: boolean): Promise; - // #endregion - - // #region release() - - /** - * Release the inference session and the underlying resources. - */ - release(): Promise; - // #endregion - - // #region metadata - - /** - * Get input names of the loaded training model. - */ - readonly trainingInputNames: readonly string[]; - - /** - * Get output names of the loaded training model. - */ - readonly trainingOutputNames: readonly string[]; - - /** - * Get input names of the loaded eval model. Is an empty array if no eval model is loaded. - */ - readonly evalInputNames: readonly string[]; - - /** - * Get output names of the loaded eval model. Is an empty array if no eval model is loaded. - */ - readonly evalOutputNames: readonly string[]; - - // #endregion -} - -/** - * Represents the optional parameters that can be passed into the TrainingSessionFactory. - */ -export interface TrainingSessionCreateOptions { - /** - * URI or buffer for a .ckpt file that contains the checkpoint for the training model. - */ - checkpointState: TrainingSession.UriOrBuffer; - /** - * URI or buffer for the .onnx training file. - */ - trainModel: TrainingSession.UriOrBuffer; - /** - * Optional. URI or buffer for the .onnx optimizer model file. - */ - optimizerModel?: TrainingSession.UriOrBuffer; - /** - * Optional. URI or buffer for the .onnx eval model file. - */ - evalModel?: TrainingSession.UriOrBuffer; -} - -/** - * Defines method overload possibilities for creating a TrainingSession. - */ -export interface TrainingSessionFactory { - // #region create() - - /** - * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions - * - * @param trainingOptions specify models and checkpoints to load into the Training Session - * @param sessionOptions specify configuration for training session behavior - * - * @returns Promise that resolves to a TrainingSession object - */ - create( - trainingOptions: TrainingSessionCreateOptions, - sessionOptions?: InferenceSession.SessionOptions, - ): Promise; - - // #endregion -} - -// eslint-disable-next-line @typescript-eslint/naming-convention -export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl; diff --git a/js/common/lib/type-helper.ts b/js/common/lib/type-helper.ts new file mode 100644 index 0000000000000..845ba3018d443 --- /dev/null +++ b/js/common/lib/type-helper.ts @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/** + * A helper type to get certain types if they are declared in global scope. + * + * For example, if you installed "@webgpu/types" as a dev dependency, then `TryGetTypeIfDeclared<'GPUDevice'>` will + * be type `GPUDevice`, otherwise it will be type `unknown`. + * + * + * We don't want to introduce "@webgpu/types" as a dependency of this package because: + * + * (1) For JavaScript users, it's not needed. For TypeScript users, they can install it as dev dependency themselves. + * + * (2) because "@webgpu/types" requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its + * version need to be chosen carefully according to the TypeScript version being used. This means so far there is not a + * way to keep every TypeScript version happy. It turns out that we will easily broke users on some TypeScript version. + * + * for more info see https://github.com/gpuweb/types/issues/127 + * + * Update (2024-08-07): The reason (2) may be no longer valid. Most people should be using TypeScript >= 5.1 by now. + * However, we are still not sure whether introducing "@webgpu/types" as direct dependency is a good idea. We find this + * type helper is useful for TypeScript users. + * + * @ignore + */ +export type TryGetGlobalType = typeof globalThis extends { + [k in Name]: { prototype: infer T }; +} + ? T + : Fallback; diff --git a/js/common/typedoc.json b/js/common/typedoc.json index 088c7ba4053e6..f9c7e7b19db41 100644 --- a/js/common/typedoc.json +++ b/js/common/typedoc.json @@ -1,6 +1,7 @@ { "entryPoints": ["lib/index.ts"], "excludeInternal": true, + "intentionallyNotExported": ["TryGetGlobalType"], "name": "ONNX Runtime JavaScript API", "readme": "none", "cleanOutputDir": true diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index 1ce6d66881c3e..d79a82c572dc2 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.11) project (onnxruntime-node) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) add_compile_definitions(NAPI_VERSION=${napi_build_version}) add_compile_definitions(ORT_API_MANUAL_INIT) @@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api) # optional providers option(USE_DML "Build with DirectML support" OFF) +option(USE_WEBGPU "Build with WebGPU support" OFF) option(USE_CUDA "Build with CUDA support" OFF) option(USE_TENSORRT "Build with TensorRT support" OFF) option(USE_COREML "Build with CoreML support" OFF) @@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF) if(USE_DML) add_compile_definitions(USE_DML=1) endif() +if(USE_WEBGPU) + add_compile_definitions(USE_WEBGPU=1) +endif() if(USE_CUDA) add_compile_definitions(USE_CUDA=1) endif() diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index 46f8b83b0c5c2..004a3c890a7e4 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -3,12 +3,14 @@ import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common'; -import { Binding, binding } from './binding'; +import { Binding, binding, initOrt } from './binding'; class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) { + initOrt(); + this.#inferenceSession = new binding.InferenceSession(); if (typeof pathOrBuffer === 'string') { this.#inferenceSession.loadModel(pathOrBuffer, options); @@ -27,10 +29,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { readonly outputNames: string[]; startProfiling(): void { - // TODO: implement profiling + // startProfiling is a no-op. + // + // if sessionOptions.enableProfiling is true, profiling will be enabled when the model is loaded. } endProfiling(): void { - // TODO: implement profiling + this.#inferenceSession.endProfiling(); } async run( diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts index d6d592a1665b3..56203f5a5ca02 100644 --- a/js/node/lib/binding.ts +++ b/js/node/lib/binding.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { InferenceSession, OnnxValue } from 'onnxruntime-common'; +import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = { @@ -28,6 +28,8 @@ export declare namespace Binding { run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType; + endProfiling(): void; + dispose(): void; } @@ -48,4 +50,35 @@ export const binding = // eslint-disable-next-line @typescript-eslint/naming-convention InferenceSession: Binding.InferenceSessionConstructor; listSupportedBackends: () => Binding.SupportedBackend[]; + initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void; }; + +let ortInitialized = false; +export const initOrt = (): void => { + if (!ortInitialized) { + ortInitialized = true; + let logLevel = 2; + if (env.logLevel) { + switch (env.logLevel) { + case 'verbose': + logLevel = 0; + break; + case 'info': + logLevel = 1; + break; + case 'warning': + logLevel = 2; + break; + case 'error': + logLevel = 3; + break; + case 'fatal': + logLevel = 4; + break; + default: + throw new Error(`Unsupported log level: ${env.logLevel}`); + } + } + binding.initOrtOnce(logLevel, Tensor); + } +}; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 239c0b1ba557b..6d3c96e579a47 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -276,12 +276,12 @@ "dev": true }, "node_modules/axios": { - "version": "1.6.1", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", - "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", + "version": "1.7.9", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz", + "integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==", "dev": true, "dependencies": { - "follow-redirects": "^1.15.0", + "follow-redirects": "^1.15.6", "form-data": "^4.0.0", "proxy-from-env": "^1.1.0" } @@ -455,9 +455,9 @@ "dev": true }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", @@ -1581,12 +1581,12 @@ "dev": true }, "axios": { - "version": "1.6.1", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", - "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", + "version": "1.7.9", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz", + "integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==", "dev": true, "requires": { - "follow-redirects": "^1.15.0", + "follow-redirects": "^1.15.6", "form-data": "^4.0.0", "proxy-from-env": "^1.1.0" } @@ -1725,9 +1725,9 @@ "dev": true }, "cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "requires": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", diff --git a/js/node/script/build.ts b/js/node/script/build.ts index 133d1a0d981a0..dcdcb93377b4c 100644 --- a/js/node/script/build.ts +++ b/js/node/script/build.ts @@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator']; const REBUILD = !!buildArgs.rebuild; // --use_dml const USE_DML = !!buildArgs.use_dml; +// --use_webgpu +const USE_WEBGPU = !!buildArgs.use_webgpu; // --use_cuda const USE_CUDA = !!buildArgs.use_cuda; // --use_tensorrt @@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') { if (USE_DML) { args.push('--CDUSE_DML=ON'); } +if (USE_WEBGPU) { + args.push('--CDUSE_WEBGPU=ON'); +} if (USE_CUDA) { args.push('--CDUSE_CUDA=ON'); } diff --git a/js/node/script/install.js b/js/node/script/install.js index b15bc03840599..fef93f9169a2c 100644 --- a/js/node/script/install.js +++ b/js/node/script/install.js @@ -21,6 +21,7 @@ const os = require('os'); const fs = require('fs'); const path = require('path'); const tar = require('tar'); +const { execFileSync } = require('child_process'); const { Readable } = require('stream'); // commandline flag: @@ -58,10 +59,23 @@ if (NO_INSTALL || !shouldInstall) { // Step.2: Download the required binaries const artifactUrl = { - 11: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-${ - ORT_VERSION - }.tgz`, - 12: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-cuda12-${ + get 11() { + // TODO: support ORT Cuda v11 binaries + throw new Error(`CUDA 11 binaries are not supported by this script yet. + +To use ONNX Runtime Node.js binding with CUDA v11 support, please follow the manual steps: + +1. Use "--onnxruntime-node-install-cuda=skip" to skip the auto installation. +2. Navigate to https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/onnxruntime-cuda-11 +3. Download the binaries for your platform and architecture +4. Extract the following binaries to "node_modules/onnxruntime-node/bin/napi-v3/linux/x64: + - libonnxruntime_providers_tensorrt.so + - libonnxruntime_providers_shared.so + - libonnxruntime.so.${ORT_VERSION} + - libonnxruntime_providers_cuda.so +`); + }, + 12: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-${ ORT_VERSION }.tgz`, }[INSTALL_CUDA_FLAG || tryGetCudaVersion()]; @@ -108,9 +122,27 @@ Use "--onnxruntime-node-install-cuda=skip" to skip the installation. You will st function tryGetCudaVersion() { // Should only return 11 or 12. - // TODO: try to get the CUDA version from the system ( `nvcc --version` ) + // try to get the CUDA version from the system ( `nvcc --version` ) + let ver = 12; + try { + const nvccVersion = execFileSync('nvcc', ['--version'], { encoding: 'utf8' }); + const match = nvccVersion.match(/release (\d+)/); + if (match) { + ver = parseInt(match[1]); + if (ver !== 11 && ver !== 12) { + throw new Error(`Unsupported CUDA version: ${ver}`); + } + } + } catch (e) { + if (e?.code === 'ENOENT') { + console.warn('`nvcc` not found. Assuming CUDA 12.'); + } else { + console.warn('Failed to detect CUDA version from `nvcc --version`:', e.message); + } + } - return 11; + // assume CUDA 12 if failed to detect + return ver; } function parseInstallCudaFlag() { diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 057066507621b..23d859351f426 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -11,7 +11,12 @@ #include "tensor_helper.h" #include -Napi::FunctionReference InferenceSessionWrap::constructor; +Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor; +Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor; + +Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() { + return InferenceSessionWrap::ortTensorConstructor; +} Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { #if defined(USE_DML) && defined(_WIN32) @@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { Ort::Global::api_ == nullptr, env, "Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version " "ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)."); - auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"}; - env.SetInstanceData(ortEnv); + // initialize binding Napi::HandleScope scope(env); Napi::Function func = DefineClass( env, "InferenceSession", - {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run), + {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), + InstanceMethod("run", &InferenceSessionWrap::Run), InstanceMethod("dispose", &InferenceSessionWrap::Dispose), + InstanceMethod("endProfiling", &InferenceSessionWrap::EndProfiling), InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr), InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)}); - constructor = Napi::Persistent(func); - constructor.SuppressDestruct(); + wrappedSessionConstructor = Napi::Persistent(func); + wrappedSessionConstructor.SuppressDestruct(); exports.Set("InferenceSession", func); Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends); exports.Set("listSupportedBackends", listSupportedBackends); + Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce); + exports.Set("initOrtOnce", initOrtOnce); + return exports; } +Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + Napi::HandleScope scope(env); + + int log_level = info[0].As().Int32Value(); + + Ort::Env* ortEnv = env.GetInstanceData(); + if (ortEnv == nullptr) { + ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"}; + env.SetInstanceData(ortEnv); + } + + Napi::Function tensorConstructor = info[1].As(); + ortTensorConstructor = Napi::Persistent(tensorConstructor); + ortTensorConstructor.SuppressDestruct(); + + return env.Undefined(); +} + InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {} @@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) { ? typeInfo.GetTensorTypeAndShapeInfo().GetElementType() : ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); } + + // cache preferred output locations + ParsePreferredOutputLocations(info[argsLength - 1].As(), outputNames_, preferredOutputLocations_); + if (preferredOutputLocations_.size() > 0) { + ioBinding_ = std::make_unique(*session_); + } } catch (Napi::Error const& e) { throw e; } catch (std::exception const& e) { @@ -167,7 +201,8 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { std::vector reuseOutput; size_t inputIndex = 0; size_t outputIndex = 0; - OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release(); + Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault}; try { for (auto& name : inputNames_) { @@ -175,7 +210,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { inputIndex++; inputNames_cstr.push_back(name.c_str()); auto value = feed.Get(name); - inputValues.push_back(NapiValueToOrtValue(env, value, memory_info)); + inputValues.push_back(NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo)); } } for (auto& name : outputNames_) { @@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { outputNames_cstr.push_back(name.c_str()); auto value = fetch.Get(name); reuseOutput.push_back(!value.IsNull()); - outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info)); + outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo)); } } @@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { runOptions = Ort::RunOptions{}; ParseRunOptions(info[2].As(), runOptions); } + if (preferredOutputLocations_.size() == 0) { + session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, + inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0], + inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0], + outputIndex == 0 ? nullptr : &outputValues[0], outputIndex); - session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, - inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0], - inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0], - outputIndex == 0 ? nullptr : &outputValues[0], outputIndex); + Napi::Object result = Napi::Object::New(env); - Napi::Object result = Napi::Object::New(env); + for (size_t i = 0; i < outputIndex; i++) { + result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputValues[i]))); + } + return scope.Escape(result); + } else { + // IO binding + ORT_NAPI_THROW_ERROR_IF(preferredOutputLocations_.size() != outputNames_.size(), env, + "Preferred output locations must have the same size as output names."); - for (size_t i = 0; i < outputIndex; i++) { - result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i])); - } + for (size_t i = 0; i < inputIndex; i++) { + ioBinding_->BindInput(inputNames_cstr[i], inputValues[i]); + } + for (size_t i = 0; i < outputIndex; i++) { + // TODO: support preallocated output tensor (outputValues[i]) + + if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) { + ioBinding_->BindOutput(outputNames_cstr[i], gpuBufferMemoryInfo); + } else { + ioBinding_->BindOutput(outputNames_cstr[i], cpuMemoryInfo); + } + } + + session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, *ioBinding_); + + auto outputs = ioBinding_->GetOutputValues(); + ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch."); - return scope.Escape(result); + Napi::Object result = Napi::Object::New(env); + for (size_t i = 0; i < outputIndex; i++) { + result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputs[i]))); + } + return scope.Escape(result); + } } catch (Napi::Error const& e) { throw e; } catch (std::exception const& e) { @@ -218,6 +281,8 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) { ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + this->ioBinding_.reset(nullptr); + this->defaultRunOptions_.reset(nullptr); this->session_.reset(nullptr); @@ -225,6 +290,20 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) { return env.Undefined(); } +Napi::Value InferenceSessionWrap::EndProfiling(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + + Napi::EscapableHandleScope scope(env); + + Ort::AllocatorWithDefaultOptions allocator; + + auto filename = session_->EndProfilingAllocated(allocator); + Napi::String filenameValue = Napi::String::From(env, filename.get()); + return scope.Escape(filenameValue); +} + Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); Napi::EscapableHandleScope scope(env); @@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo #ifdef USE_DML result.Set(result.Length(), createObject("dml", true)); #endif +#ifdef USE_WEBGPU + result.Set(result.Length(), createObject("webgpu", true)); +#endif #ifdef USE_CUDA result.Set(result.Length(), createObject("cuda", false)); #endif diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index effdd83e3aa02..0b3dd1178c807 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -12,9 +12,22 @@ class InferenceSessionWrap : public Napi::ObjectWrap { public: static Napi::Object Init(Napi::Env env, Napi::Object exports); + static Napi::FunctionReference& GetTensorConstructor(); + InferenceSessionWrap(const Napi::CallbackInfo& info); private: + /** + * [sync] initialize ONNX Runtime once. + * + * This function must be called before any other functions. + * + * @param arg0 a number specifying the log level. + * + * @returns undefined + */ + static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info); + /** * [sync] list supported backend list * @returns array with objects { "name": "cpu", requirementsInstalled: true } @@ -63,10 +76,19 @@ class InferenceSessionWrap : public Napi::ObjectWrap { */ Napi::Value Dispose(const Napi::CallbackInfo& info); + /** + * [sync] end the profiling. + * @param nothing + * @returns nothing + * @throw nothing + */ + Napi::Value EndProfiling(const Napi::CallbackInfo& info); + // private members // persistent constructor - static Napi::FunctionReference constructor; + static Napi::FunctionReference wrappedSessionConstructor; + static Napi::FunctionReference ortTensorConstructor; // session objects bool initialized_; @@ -81,4 +103,8 @@ class InferenceSessionWrap : public Napi::ObjectWrap { std::vector outputNames_; std::vector outputTypes_; std::vector outputTensorElementDataTypes_; + + // preferred output locations + std::vector preferredOutputLocations_; + std::unique_ptr ioBinding_; }; diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 0ed1ba08e6bf7..8c1d7ca06b8c3 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -6,15 +6,20 @@ #include #include +#include #include "common.h" #include "session_options_helper.h" +#include "tensor_helper.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" #endif #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" #endif +#ifdef USE_WEBGPU +#include "core/providers/webgpu/webgpu_provider_factory.h" +#endif #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" #endif @@ -36,7 +41,12 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess Napi::Value epValue = epList[i]; std::string name; int deviceId = 0; +#ifdef USE_COREML int coreMlFlags = 0; +#endif +#ifdef USE_WEBGPU + std::unordered_map webgpu_options; +#endif if (epValue.IsString()) { name = epValue.As().Utf8Value(); } else if (!epValue.IsObject() || epValue.IsNull() || !epValue.As().Has("name") || @@ -49,9 +59,23 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess if (obj.Has("deviceId")) { deviceId = obj.Get("deviceId").As(); } +#ifdef USE_COREML if (obj.Has("coreMlFlags")) { coreMlFlags = obj.Get("coreMlFlags").As(); } +#endif +#ifdef USE_WEBGPU + for (const auto& nameIter : obj.GetPropertyNames()) { + Napi::Value nameVar = nameIter.second; + std::string name = nameVar.As().Utf8Value(); + if (name != "name") { + Napi::Value valueVar = obj.Get(nameVar); + ORT_NAPI_THROW_TYPEERROR_IF(!valueVar.IsString(), epList.Env(), "Invalid argument: sessionOptions.executionProviders must be a string or an object with property 'name'."); + std::string value = valueVar.As().Utf8Value(); + webgpu_options[name] = value; + } + } +#endif } // CPU execution provider @@ -77,6 +101,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } else if (name == "dml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId)); #endif +#ifdef USE_WEBGPU + } else if (name == "webgpu") { + sessionOptions.AppendExecutionProvider("WebGPU", webgpu_options); +#endif #ifdef USE_COREML } else if (name == "coreml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags)); @@ -95,6 +123,22 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } } +void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) { + for (const auto& kvp : obj) { + auto key = kvp.first.As().Utf8Value(); + Napi::Value value = kvp.second; + if (value.IsObject()) { + IterateExtraOptions(prefix + key + ".", value.As(), sessionOptions); + } else { + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(), + "Invalid argument: sessionOptions.extra value must be a string in Node.js binding."); + std::string entry = prefix + key; + auto val = value.As().Utf8Value(); + sessionOptions.AddConfigEntry(entry.c_str(), val.c_str()); + } + } +} + void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) { // Execution provider if (options.Has("executionProviders")) { @@ -162,6 +206,28 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio } } + // optimizedModelFilePath + if (options.Has("optimizedModelFilePath")) { + auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath"); + ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(), + "Invalid argument: sessionOptions.optimizedModelFilePath must be a string."); +#ifdef _WIN32 + auto str = optimizedModelFilePathValue.As().Utf16Value(); + std::filesystem::path optimizedModelFilePath{std::wstring{str.begin(), str.end()}}; +#else + std::filesystem::path optimizedModelFilePath{optimizedModelFilePathValue.As().Utf8Value()}; +#endif + sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str()); + } + + // extra + if (options.Has("extra")) { + auto extraValue = options.Get("extra"); + ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(), + "Invalid argument: sessionOptions.extra must be an object."); + IterateExtraOptions("", extraValue.As(), sessionOptions); + } + // execution mode if (options.Has("executionMode")) { auto executionModeValue = options.Get("executionMode"); @@ -195,4 +261,118 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio sessionOptions.SetLogSeverityLevel(static_cast(logLevelNumber)); } + + // Profiling + if (options.Has("enableProfiling")) { + auto enableProfilingValue = options.Get("enableProfiling"); + ORT_NAPI_THROW_TYPEERROR_IF(!enableProfilingValue.IsBoolean(), options.Env(), + "Invalid argument: sessionOptions.enableProfiling must be a boolean value."); + + if (enableProfilingValue.As().Value()) { + ORT_NAPI_THROW_TYPEERROR_IF(!options.Has("profileFilePrefix"), options.Env(), + "Invalid argument: sessionOptions.profileFilePrefix is required" + " when sessionOptions.enableProfiling is set to true."); + auto profileFilePrefixValue = options.Get("profileFilePrefix"); + ORT_NAPI_THROW_TYPEERROR_IF(!profileFilePrefixValue.IsString(), options.Env(), + "Invalid argument: sessionOptions.profileFilePrefix must be a string." + " when sessionOptions.enableProfiling is set to true."); +#ifdef _WIN32 + auto str = profileFilePrefixValue.As().Utf16Value(); + std::basic_string profileFilePrefix = std::wstring{str.begin(), str.end()}; +#else + std::basic_string profileFilePrefix = profileFilePrefixValue.As().Utf8Value(); +#endif + sessionOptions.EnableProfiling(profileFilePrefix.c_str()); + } else { + sessionOptions.DisableProfiling(); + } + } + + // external data + if (options.Has("externalData")) { + auto externalDataValue = options.Get("externalData"); + ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(), + "Invalid argument: sessionOptions.externalData must be an array."); + auto externalData = externalDataValue.As(); + std::vector> paths; + std::vector buffs; + std::vector sizes; + + for (const auto& kvp : externalData) { + Napi::Value value = kvp.second; + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(), + "Invalid argument: sessionOptions.externalData value must be an object in Node.js binding."); + Napi::Object obj = value.As(); + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(), + "Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding."); +#ifdef _WIN32 + auto path = obj.Get("path").As().Utf16Value(); + paths.push_back(std::wstring{path.begin(), path.end()}); +#else + auto path = obj.Get("path").As().Utf8Value(); + paths.push_back(path); +#endif + ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") || + !obj.Get("data").IsBuffer() || + !(obj.Get("data").IsTypedArray() && obj.Get("data").As().TypedArrayType() == napi_uint8_array), + options.Env(), + "Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding."); + + auto data = obj.Get("data"); + if (data.IsBuffer()) { + buffs.push_back(data.As>().Data()); + sizes.push_back(data.As>().Length()); + } else { + auto typedArray = data.As(); + buffs.push_back(reinterpret_cast(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset()); + sizes.push_back(typedArray.ByteLength()); + } + } + sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes); + } +} + +void ParsePreferredOutputLocations(const Napi::Object options, const std::vector& outputNames, std::vector& preferredOutputLocations) { + if (options.Has("preferredOutputLocation")) { + auto polValue = options.Get("preferredOutputLocation"); + if (polValue.IsNull() || polValue.IsUndefined()) { + return; + } + if (polValue.IsString()) { + DataLocation location = ParseDataLocation(polValue.As().Utf8Value()); + ORT_NAPI_THROW_TYPEERROR_IF(location == DATA_LOCATION_NONE, options.Env(), + "Invalid argument: preferredOutputLocation must be an array or a valid string."); + + if (location == DATA_LOCATION_GPU_BUFFER || location == DATA_LOCATION_ML_TENSOR) { + preferredOutputLocations.resize(outputNames.size(), location); + } + } else if (polValue.IsObject()) { + preferredOutputLocations.resize(outputNames.size(), DATA_LOCATION_CPU); + + auto pol = polValue.As(); + for (const auto& nameIter : pol.GetPropertyNames()) { + Napi::Value nameVar = nameIter.second; + std::string name = nameVar.As().Utf8Value(); + // find the name in outputNames + auto it = std::find(outputNames.begin(), outputNames.end(), name); + ORT_NAPI_THROW_TYPEERROR_IF(it == outputNames.end(), options.Env(), + "Invalid argument: \"", name, "\" is not a valid output name."); + + Napi::Value value = pol.Get(nameVar); + DataLocation location = DATA_LOCATION_NONE; + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString() || (location = ParseDataLocation(value.As().Utf8Value())) == DATA_LOCATION_NONE, + options.Env(), + "Invalid argument: preferredOutputLocation[\"", name, "\"] must be a valid string."); + + size_t index = it - outputNames.begin(); + preferredOutputLocations[index] = location; + } + + if (std::all_of(preferredOutputLocations.begin(), preferredOutputLocations.end(), [](int loc) { return loc == DATA_LOCATION_CPU; })) { + preferredOutputLocations.clear(); + } + } else { + ORT_NAPI_THROW_TYPEERROR(options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string."); + } + } } diff --git a/js/node/src/session_options_helper.h b/js/node/src/session_options_helper.h index c0a9ae0d683e9..c6338f28e03c6 100644 --- a/js/node/src/session_options_helper.h +++ b/js/node/src/session_options_helper.h @@ -11,3 +11,6 @@ struct SessionOptions; // parse a Javascript session options object and fill the native SessionOptions object. void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions); + +// parse a Javascript session options object and prepare the preferred output locations. +void ParsePreferredOutputLocations(const Napi::Object options, const std::vector& outputNames, std::vector& preferredOutputLocations); \ No newline at end of file diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 54f1c5a09906e..27eb9b65c62d3 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -8,6 +8,7 @@ #include "common.h" #include "tensor_helper.h" +#include "inference_session_wrap.h" // make sure consistent with origin definition static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == 0, "definition not consistent with OnnxRuntime"); @@ -100,7 +101,7 @@ const std::unordered_map DATA_TYPE_NAME_ {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"string", ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, {"bool", ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, {"float16", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16}, {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}}; // currently only support tensor -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info) { +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info) { ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), env, "Tensor must be an object."); // check 'dims' @@ -110,6 +111,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* auto dimsArray = dimsValue.As(); auto len = dimsArray.Length(); + size_t elementSize = 1; std::vector dims; if (len > 0) { dims.reserve(len); @@ -122,17 +124,26 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* "Tensor.dims[", i, "] is invalid: ", dimDouble); int64_t dim = static_cast(dimDouble); dims.push_back(dim); + elementSize *= dim; } } + // check 'location' + auto tensorLocationValue = tensorObject.Get("location"); + ORT_NAPI_THROW_TYPEERROR_IF(!tensorLocationValue.IsString(), env, "Tensor.location must be a string."); + DataLocation tensorLocation = ParseDataLocation(tensorLocationValue.As().Utf8Value()); + ORT_NAPI_THROW_RANGEERROR_IF(tensorLocation == DATA_LOCATION_NONE, env, "Tensor.location is not supported."); + // check 'data' and 'type' - auto tensorDataValue = tensorObject.Get("data"); auto tensorTypeValue = tensorObject.Get("type"); ORT_NAPI_THROW_TYPEERROR_IF(!tensorTypeValue.IsString(), env, "Tensor.type must be a string."); auto tensorTypeString = tensorTypeValue.As().Utf8Value(); if (tensorTypeString == "string") { + auto tensorDataValue = tensorObject.Get("data"); + + ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_CPU, env, "Tensor.location must be 'cpu' for string tensors."); ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsArray(), env, "Tensor.data must be an array for string tensors."); auto tensorDataArray = tensorDataValue.As(); @@ -162,29 +173,42 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* auto v = DATA_TYPE_NAME_TO_ID_MAP.find(tensorTypeString); ORT_NAPI_THROW_TYPEERROR_IF(v == DATA_TYPE_NAME_TO_ID_MAP.end(), env, "Tensor.type is not supported: ", tensorTypeString); - ONNXTensorElementDataType elemType = v->second; - ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env, - "Tensor.data must be a typed array for numeric tensor."); + if (tensorLocation == DATA_LOCATION_CPU) { + auto tensorDataValue = tensorObject.Get("data"); + ORT_NAPI_THROW_TYPEERROR_IF(!tensorDataValue.IsTypedArray(), env, + "Tensor.data must be a typed array for numeric tensor."); + + auto tensorDataTypedArray = tensorDataValue.As(); + auto typedArrayType = tensorDataValue.As().TypedArrayType(); + ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env, + "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", + tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); - auto tensorDataTypedArray = tensorDataValue.As(); - auto typedArrayType = tensorDataValue.As().TypedArrayType(); - ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env, - "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", - tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); + char* buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); + size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); + size_t bufferByteLength = tensorDataTypedArray.ByteLength(); + return Ort::Value::CreateTensor(cpu_memory_info, buffer + bufferByteOffset, bufferByteLength, + dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + } else { + ORT_NAPI_THROW_TYPEERROR_IF(tensorLocation != DATA_LOCATION_GPU_BUFFER, env, "Tensor.location must be 'gpu-buffer' for IO binding."); - char* buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); - size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); - size_t bufferByteLength = tensorDataTypedArray.ByteLength(); - return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength, - dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + auto gpuBufferValue = tensorObject.Get("gpuBuffer"); + // nodejs: tensor.gpuBuffer is no longer a GPUBuffer in nodejs. we assume it is an external object (bind the OrtValue pointer). + ORT_NAPI_THROW_TYPEERROR_IF(!gpuBufferValue.IsExternal(), env, "Tensor.gpuBuffer must be an external object."); + Ort::Value dataValue(gpuBufferValue.As>().Data()); + void* gpuBuffer = dataValue.GetTensorMutableRawData(); + dataValue.release(); + + size_t dataByteLength = DATA_TYPE_ELEMENT_SIZE_MAP[elemType] * elementSize; + return Ort::Value::CreateTensor(webgpu_memory_info, gpuBuffer, dataByteLength, dims.empty() ? nullptr : &dims[0], dims.size(), elemType); + } } } -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value) { Napi::EscapableHandleScope scope(env); - auto returnValue = Napi::Object::New(env); auto typeInfo = value.GetTypeInfo(); auto onnxType = typeInfo.GetONNXType(); @@ -197,24 +221,26 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { // type auto typeCstr = DATA_TYPE_ID_TO_NAME_MAP[elemType]; ORT_NAPI_THROW_ERROR_IF(typeCstr == nullptr, env, "Tensor type (", elemType, ") is not supported."); - - returnValue.Set("type", Napi::String::New(env, typeCstr)); + auto type = Napi::String::New(env, typeCstr); // dims const size_t dimsCount = tensorTypeAndShapeInfo.GetDimensionsCount(); - std::vector dims; + std::vector dimsVector; if (dimsCount > 0) { - dims = tensorTypeAndShapeInfo.GetShape(); + dimsVector = tensorTypeAndShapeInfo.GetShape(); } - auto dimsArray = Napi::Array::New(env, dimsCount); + auto dims = Napi::Array::New(env, dimsCount); for (uint32_t i = 0; i < dimsCount; i++) { - dimsArray[i] = dims[i]; + dims[i] = dimsVector[i]; } - returnValue.Set("dims", dimsArray); + + // location + auto memoryInfo = value.GetTensorMemoryInfo(); + bool isGpuBuffer = memoryInfo.GetDeviceType() == OrtMemoryInfoDeviceType_GPU && + memoryInfo.GetAllocatorName() == "WebGPU_Buffer"; // size auto size = tensorTypeAndShapeInfo.GetElementCount(); - returnValue.Set("size", Napi::Number::From(env, size)); // data if (elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { @@ -234,20 +260,48 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { i == size - 1 ? tempBufferLength - tempOffsets[i] : tempOffsets[i + 1] - tempOffsets[i]); } } - returnValue.Set("data", Napi::Value(env, stringArray)); + + // new Tensor("string", stringArray /* string[] */, dims /* number[] */) + return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New({Napi::String::New(env, "string"), stringArray, dims})); } else { // number data - // TODO: optimize memory - auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); - if (size > 0) { - memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + if (isGpuBuffer) { + // Tensor.fromGpuBuffer(buffer, options) + Napi::Function tensorFromGpuBuffer = InferenceSessionWrap::GetTensorConstructor().Value().Get("fromGpuBuffer").As(); + OrtValue* underlyingOrtValue = value.release(); + + auto options = Napi::Object::New(env); + options.Set("dataType", type); + options.Set("dims", dims); + options.Set("dispose", Napi::Function::New( + env, [](const Napi::CallbackInfo& info) { + Ort::GetApi().ReleaseValue(reinterpret_cast(info.Data())); + return info.Env().Undefined(); + }, + "dispose", underlyingOrtValue)); + options.Set("download", Napi::Function::New( + env, [](const Napi::CallbackInfo& info) { + NAPI_THROW("not implemented"); + }, + "download", underlyingOrtValue)); + + return scope.Escape(tensorFromGpuBuffer.Call({Napi::External::New(env, underlyingOrtValue), options})); + } else { + // TODO: optimize memory + auto arrayBuffer = Napi::ArrayBuffer::New(env, size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + if (size > 0) { + memcpy(arrayBuffer.Data(), value.GetTensorRawData(), size * DATA_TYPE_ELEMENT_SIZE_MAP[elemType]); + } + napi_value typedArrayData; + napi_status status = + napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData); + NAPI_THROW_IF_FAILED(env, status, Napi::Value); + + // new Tensor(type, typedArrayData, dims) + return scope.Escape(InferenceSessionWrap::GetTensorConstructor().New( + {type, + Napi::Value(env, typedArrayData), + dims})); } - napi_value typedArrayData; - napi_status status = - napi_create_typedarray(env, DATA_TYPE_TYPEDARRAY_MAP[elemType], size, arrayBuffer, 0, &typedArrayData); - NAPI_THROW_IF_FAILED(env, status, Napi::Value); - returnValue.Set("data", Napi::Value(env, typedArrayData)); } - - return scope.Escape(returnValue); } diff --git a/js/node/src/tensor_helper.h b/js/node/src/tensor_helper.h index 56b399ccc24ee..4a51e5240602a 100644 --- a/js/node/src/tensor_helper.h +++ b/js/node/src/tensor_helper.h @@ -9,7 +9,32 @@ #include "onnxruntime_cxx_api.h" // convert a Javascript OnnxValue object to an OrtValue object -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info); +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* cpu_memory_info, OrtMemoryInfo* webgpu_memory_info); // convert an OrtValue object to a Javascript OnnxValue object -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value); +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value); + +enum DataLocation { + DATA_LOCATION_NONE = 0, + DATA_LOCATION_CPU = 1, + DATA_LOCATION_CPU_PINNED = 2, + DATA_LOCATION_TEXTURE = 3, + DATA_LOCATION_GPU_BUFFER = 4, + DATA_LOCATION_ML_TENSOR = 5 +}; + +inline DataLocation ParseDataLocation(const std::string& location) { + if (location == "cpu") { + return DATA_LOCATION_CPU; + } else if (location == "cpu-pinned") { + return DATA_LOCATION_CPU_PINNED; + } else if (location == "texture") { + return DATA_LOCATION_TEXTURE; + } else if (location == "gpu-buffer") { + return DATA_LOCATION_GPU_BUFFER; + } else if (location == "ml-tensor") { + return DATA_LOCATION_ML_TENSOR; + } else { + return DATA_LOCATION_NONE; + } +} diff --git a/js/node/tsconfig.json b/js/node/tsconfig.json index c154c3e148ed0..0401fb9609ad6 100644 --- a/js/node/tsconfig.json +++ b/js/node/tsconfig.json @@ -1,7 +1,8 @@ { "extends": "../tsconfig.json", "compilerOptions": { - "outDir": "dist" + "outDir": "dist", + "declaration": true }, "include": ["lib"] } diff --git a/js/package-lock.json b/js/package-lock.json index 594d0584ad80e..f4401c6e98c75 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -1573,9 +1573,9 @@ "dev": true }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dev": true, "dependencies": { "path-key": "^3.1.0", @@ -5922,9 +5922,9 @@ "dev": true }, "cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", "dev": true, "requires": { "path-key": "^3.1.0", diff --git a/js/react_native/android/build.gradle b/js/react_native/android/build.gradle index 825990eba0fb8..521866ff0f3e2 100644 --- a/js/react_native/android/build.gradle +++ b/js/react_native/android/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { - classpath 'com.android.tools.build:gradle:4.1.2' + classpath 'com.android.tools.build:gradle:7.4.2' // noinspection DifferentKotlinGradleVersion } } @@ -221,9 +221,8 @@ dependencies { api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION api "org.mockito:mockito-core:2.28.2" - androidTestImplementation "androidx.test:runner:1.1.0" - androidTestImplementation "androidx.test:rules:1.1.0" - + androidTestImplementation "androidx.test:runner:1.5.2" + androidTestImplementation "androidx.test:rules:1.5.0" implementation "junit:junit:4.12" androidTestImplementation "com.linkedin.dexmaker:dexmaker-mockito-inline-extended:2.28.1" diff --git a/js/react_native/android/gradle.properties b/js/react_native/android/gradle.properties index 465b04d1f5813..8fe6e40d76911 100644 --- a/js/react_native/android/gradle.properties +++ b/js/react_native/android/gradle.properties @@ -4,7 +4,7 @@ # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. # Default value: -Xmx1024m -XX:MaxPermSize=256m -# org.gradle.jvmargs=-Xmx2048m -XX:MaxPermSize=512m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 +org.gradle.jvmargs=-Xmx4096m -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8 # # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar index 62d4c053550b9..e6441136f3d4b 100644 Binary files a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar and b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties index 51d930a381f3a..381baa9cef1ec 100644 --- a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties +++ b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,8 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=7faa7198769f872826c8ef4f1450f839ec27f0b4d5d1e51bade63667cbccd205 -distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip +distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d +distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip +networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/js/react_native/android/gradlew b/js/react_native/android/gradlew index fbd7c515832da..1aa94a4269074 100755 --- a/js/react_native/android/gradlew +++ b/js/react_native/android/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,99 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +119,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,88 +130,120 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=`expr $i + 1` + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/js/react_native/android/gradlew.bat b/js/react_native/android/gradlew.bat index 5093609d512a9..25da30dbdeee9 100644 --- a/js/react_native/android/gradlew.bat +++ b/js/react_native/android/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,8 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,13 +41,13 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init +if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -54,31 +55,16 @@ goto fail set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe -if exist "%JAVA_EXE%" goto init +if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - :execute @rem Setup the command line @@ -86,17 +72,19 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/js/react_native/e2e/android/app/build.gradle b/js/react_native/e2e/android/app/build.gradle index 8a84b0d5065a8..526259e3f8d8f 100644 --- a/js/react_native/e2e/android/app/build.gradle +++ b/js/react_native/e2e/android/app/build.gradle @@ -193,7 +193,7 @@ dependencies { implementation "com.facebook.react:react-native:+" // From node_modules implementation "androidx.swiperefreshlayout:swiperefreshlayout:1.0.0" - implementation 'androidx.test.ext:junit:1.1.3' + implementation 'androidx.test.ext:junit:1.1.5' debugImplementation("com.facebook.flipper:flipper:${FLIPPER_VERSION}") { exclude group:'com.facebook.fbjni' } @@ -213,9 +213,9 @@ dependencies { implementation jscFlavor } - androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' - androidTestImplementation 'androidx.test:runner:1.4.0' - androidTestImplementation 'androidx.test:rules:1.4.0' + androidTestImplementation "androidx.test.espresso:espresso-core:3.5.0" + androidTestImplementation "androidx.test:runner:1.5.2" + androidTestImplementation "androidx.test:rules:1.5.0" implementation project(':onnxruntime-react-native') // specify ORT dependency here so it can be found in libs flatDir repository diff --git a/js/scripts/prepare-onnx-node-tests.ts b/js/scripts/prepare-onnx-node-tests.ts index 91aa63e9e6aff..02c33892d57d5 100644 --- a/js/scripts/prepare-onnx-node-tests.ts +++ b/js/scripts/prepare-onnx-node-tests.ts @@ -10,6 +10,8 @@ import * as path from 'path'; import { downloadZip, extractFile } from './utils'; const TEST_DATA_OPSET_VERSIONS = [ + ['opset21', '1.16.2'], + ['opset20', '1.15.0'], ['opset19', '1.14.0'], ['opset18', '1.13.1'], ['opset17', '1.12.1'], diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 1c140de448430..5c8748d75c2bc 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -21,11 +21,11 @@ Do not modify directly.* | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | | Attention | com.microsoft(1+) | need implementing mask and past/present | -| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | +| AveragePool | ai.onnx(7-9,10,11-18,19+); com.ms.internal.nhwc(7-9,10,11-18,19+) | need perf optimization; need implementing activation | | BatchNormalization | ai.onnx(7-8,9-13,14,15+); com.ms.internal.nhwc(7-8,9-13,14,15+) | | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | -| Cast | ai.onnx(6-8,9-12,13-18,19+) | | +| Cast | ai.onnx(6-8,9-12,13-18,19-20,21+) | | | Ceil | ai.onnx(6-12,13+) | | | Clip | ai.onnx(6-10,11,12,13+) | | | Concat | ai.onnx(1-3,4-10,11-12,13+) | | @@ -44,21 +44,23 @@ Do not modify directly.* | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | | FastGelu | com.microsoft(1+) | | -| Flatten | ai.onnx(1-8,9-10,11-12,13+) | | +| Flatten | ai.onnx(1-8,9-10,11-12,13-20,21+) | | | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherBlockQuantized | com.microsoft(1+) | | | GatherElements | ai.onnx(11-12,13+) | | +| GatherND | ai.onnx(11,12,13+) | | | Gelu | ai.onnx(20+); com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| GridSample | ai.onnx(16-19); com.ms.internal.nhwc(16-19) | | | GroupQueryAttention | com.microsoft(1+) | | | HardSigmoid | ai.onnx(6+) | | -| If | ai.onnx(1-10,11-12,13-18,19+) | | +| If | ai.onnx(1-10,11-12,13-18,19-20,21+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(1-16,17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | @@ -74,7 +76,7 @@ Do not modify directly.* | MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | -| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | +| Pad | ai.onnx(2-10,11-12,13-17,18,19-20,21+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | | QuickGelu | com.microsoft(1+) | | | Range | ai.onnx(11+) | | @@ -83,9 +85,9 @@ Do not modify directly.* | ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceLogSum | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceLogSumExp | ai.onnx(1-10,11-12,13-17,18+) | | -| ReduceMax | ai.onnx(1-10,11,12,13-17,18+) | | +| ReduceMax | ai.onnx(1-10,11,12,13-17,18-19,20+) | | | ReduceMean | ai.onnx(1-10,11-12,13-17,18+) | | -| ReduceMin | ai.onnx(1-10,11,12,13-17,18+) | | +| ReduceMin | ai.onnx(1-10,11,12,13-17,18-19,20+) | | | ReduceProd | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceSum | ai.onnx(1-10,11-12,13+) | | | ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | | @@ -93,6 +95,7 @@ Do not modify directly.* | Reshape | ai.onnx(5-12,13,14-18,19-20,21+) | no GPU kernel | | Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | RotaryEmbedding | com.microsoft(1+) | | +| ScatterND | ai.onnx(11-12,13-15,16-17,18+) | | | Shape | ai.onnx(1-12,13-14,15-18,19-20,21+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | SimplifiedLayerNormalization | ai.onnx(1+) | | @@ -104,12 +107,12 @@ Do not modify directly.* | Softmax | ai.onnx(1-10,11-12,13+) | | | Split | ai.onnx(1,2-10,11-12,13-17,18+) | | | Sqrt | ai.onnx(6-12,13+) | | -| Squeeze | ai.onnx(1-10,11-12,13+) | | +| Squeeze | ai.onnx(1-10,11-12,13-20,21+) | | | Sub | ai.onnx(7-12,13,14+) | | | Tan | ai.onnx(7+) | | | Tanh | ai.onnx(6-12,13+) | | | ThresholdedRelu | ai.onnx(10+) | | | Tile | ai.onnx(6-12,13+) | | -| Transpose | ai.onnx(1-12,13+) | need perf optimization | -| Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Transpose | ai.onnx(1-12,13-20,21+) | need perf optimization | +| Unsqueeze | ai.onnx(1-10,11-12,13-20,21+) | | | Where | ai.onnx(9-15,16+) | | diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index bf0f1dffb83ee..e0012e70a7dec 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -13,6 +13,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim |:------:|:------:|:------:|:-:|:-:|:------| | Abs | ai.onnx(7-12, 13+) | abs | ✓ | ✓ | | | Add | ai.onnx(7-12, 13, 14+) | add | ✓ | ✓ | | +| And | ai.onnx(7+) | logicalAnd | ✗ | ✓ | | | ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | | | ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | | | AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 | @@ -24,9 +25,11 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight) | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | +| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | 'axis' input should be a constant | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | | DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✗ | ✓ | | | Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode | +| Einsum | ai.onnx(12+) | reshape, transpose, matmul, reduceSum, mul, triangular | ✓ | ✓ | | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | ✓ | ✓ | | | Erf | ai.onnx(7-9, 10-12, 13+) | erf | ✓ | ✓ | | @@ -35,6 +38,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | | | Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | | +| GatherElements | ai.onnx(11-12, 13+) | gatherElements | ✗ | ✓ | | +| GatherND | ai.onnx(11, 12, 13+) | gatherND | ✓ | ✓ | Only supports 'batch_dims' == 0 | | Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | | | Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input | | GlobalAveragePool | ai.onnx(7+) | averagePool2d | ✓ | ✓ | Only supports 4-D input | @@ -53,6 +58,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | | Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | | | LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 | +| LRN | ai.onnx(7-12, 13+) | pad, averagePool2d, transpose, add, mul, pow, div | ✓ | ✓ | | | LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | | | Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | | @@ -60,7 +66,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Min | ai.onnx(7, 8-11, 12, 13+) | min | ✓ | ✓ | | | Mul | ai.onnx(7-12, 13, 14+) | mul | ✓ | ✓ | | | Neg | ai.onnx(7-12, 13+) | neg | ✓ | ✓ | | -| Not | ai.onnx(7+) | logicalnot | ✓ | ✓ | | +| Not | ai.onnx(7+) | logicalNot | ✓ | ✓ | | +| Or | ai.onnx(7+) | logicalOr | ✗ | ✓ | | | Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad | ✓ | ✓ | modes == 'wrap' is not supported | | Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | | | PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) | @@ -78,13 +85,17 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✓ | ✓ | Input 'axes' if present should be a constant | | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | -| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | +| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | +| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements | ✗ | ✓ | Only supports 'reduction' == 'none' | +| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | +| SimplifiedLayerNormalization | ai.onnx(1+) | pow + reduceMean + add + sqrt + div + mul | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | +| Sign | ai.onnx(9-12, 13+) | sign | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | | Softsign | ai.onnx(7+) | softsign | ✓ | ✓ | | | Sin | ai.onnx(7+) | sin | ✓ | ✓ | | -| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 | +| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice, reverse | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant | | Softmax | ai.onnx(7-10, 11-12, 13+) | softmax | ✓ | ✓ | | | Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split | ✓ | ✓ | Input 'split' if present should be a constant | | Sqrt | ai.onnx(7-12, 13+) | sqrt | ✓ | ✓ | | @@ -97,3 +108,4 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Trilu | ai.onnx(14+) | triangular | ✓ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant | | Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Where | ai.onnx(7-8, 9-15, 16+) | where | ✓ | ✓ | | +| Xor | ai.onnx(7+) | logicalXor | ✗ | ✓ | | diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 50d83f5af26e0..a0010df4643a4 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -13,6 +13,7 @@ import { ProgramManager } from './webgpu/program-manager'; import { AdapterInfo, ComputeContext, + DeviceInfo, GpuArchitecture, GpuData, GpuVendor, @@ -134,6 +135,26 @@ class AdapterInfoImpl implements AdapterInfo { } } +class DeviceInfoImpl implements DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; + + constructor(device: GPUDevice) { + this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName); + this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName); + // Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to + // workaround the IDL type checks. + // TODO: clean this after subgroups feature is settled in IDL. + const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number }; + if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) { + this.subgroupSizeRange = undefined; + } else { + this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize]; + } + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. @@ -141,6 +162,7 @@ class AdapterInfoImpl implements AdapterInfo { export class WebGpuBackend { adapterInfo: AdapterInfoImpl; device: GPUDevice; + deviceInfo: DeviceInfoImpl; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping */ @@ -243,16 +265,22 @@ export class WebGpuBackend { requiredFeatures, }; - if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) { - requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName); - } else if (adapter.features.has('timestamp-query')) { - requiredFeatures.push('timestamp-query'); + // Try requiring WebGPU features + const requireFeatureIfAvailable = (feature: GPUFeatureName) => + adapter.features.has(feature) && requiredFeatures.push(feature) && true; + // Try chromium-experimental-timestamp-query-inside-passes and fallback to timestamp-query + if (!requireFeatureIfAvailable('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName)) { + requireFeatureIfAvailable('timestamp-query'); } - if (adapter.features.has('shader-f16')) { - requiredFeatures.push('shader-f16'); + requireFeatureIfAvailable('shader-f16'); + // Try subgroups + if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) { + // If subgroups feature is available, also try subgroups-f16 + requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName); } this.device = await adapter.requestDevice(deviceDescriptor); + this.deviceInfo = new DeviceInfoImpl(this.device); this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo())); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 47304fdc64ae4..b302354c46eeb 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -32,6 +32,24 @@ const onnxDataTypeToWebnnDataType = new Map([ [DataType.bool, 'uint8'], ]); +type MLContextEntry = { + gpuDevice?: GPUDevice; + options?: MLContextOptions; + mlContext: MLContext; +}; + +const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => { + if (a === b) { + return true; + } + if (a === undefined || b === undefined) { + return false; + } + const aKeys = Object.keys(a).sort() as Array; + const bKeys = Object.keys(b).sort() as Array; + return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]); +}; + /** * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track * of the current MLContext being used by the sessions. @@ -49,6 +67,10 @@ export class WebNNBackend { * Maps from MLContext to session ids. */ private sessionIdsByMLContext = new Map>(); + /** + * Cache of MLContexts. + */ + private mlContextCache: MLContextEntry[] = []; /** * Current session id. */ @@ -69,6 +91,41 @@ export class WebNNBackend { this.activeSessionId = sessionId; } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { + if (optionsOrDevice instanceof GPUDevice) { + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext }); + return mlContext; + } + } else if (optionsOrDevice === undefined) { + const mlContextIndex = this.mlContextCache.findIndex( + (entry) => entry.options === undefined && entry.gpuDevice === undefined, + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(); + this.mlContextCache.push({ mlContext }); + return mlContext; + } + } + + const mlContextIndex = this.mlContextCache.findIndex((entry) => + compareMLContextOptions(entry.options, optionsOrDevice), + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ options: optionsOrDevice, mlContext }); + return mlContext; + } + } + public get currentContext(): MLContext { const mlContext = this.getMLContext(this.currentSessionId); if (!mlContext) { @@ -99,6 +156,10 @@ export class WebNNBackend { sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext); + if (mlContextIndex !== -1) { + this.mlContextCache.splice(mlContextIndex, 1); + } } } @@ -165,7 +226,7 @@ export class WebNNBackend { return id; } - // Register WebNN Constant operands from external data. + // Register a WebNN Constant operand from external data. public registerMLConstant( externalFilePath: string, dataOffset: number, diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index fddc061cd775a..48bd3ef2bc36f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -11,7 +11,13 @@ import { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; import { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; -import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; +import { + AdapterInfo, + ComputeContext, + ComputeContextInputsOutputsMapping, + DeviceInfo, + ProgramInfo, +} from './webgpu/types'; import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -70,6 +76,7 @@ class TensorViewImpl implements TensorView { class ComputeContextImpl implements ComputeContext { readonly adapterInfo: AdapterInfo; + readonly deviceInfo: DeviceInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -87,6 +94,7 @@ class ComputeContextImpl implements ComputeContext { contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; + this.deviceInfo = backend.deviceInfo; // extract context data const ptrSize = module.PTR_SIZE; @@ -112,18 +120,6 @@ class ComputeContextImpl implements ComputeContext { this.inputs = inputs; } - getMaxComputeWorkgroupSizes(): [number, number, number] { - return [ - this.backend.device.limits.maxComputeWorkgroupSizeX, - this.backend.device.limits.maxComputeWorkgroupSizeY, - this.backend.device.limits.maxComputeWorkgroupSizeZ, - ]; - } - - getMaxComputeWorkgroupStoragesize(): number { - return this.backend.device.limits.maxComputeWorkgroupStorageSize; - } - compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 1860870a1130b..1c6016500e7d3 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -191,8 +191,6 @@ class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) private storageCache: Map; - // pending buffers for uploading ( data is unmapped ) - private buffersForUploadingPending: GPUBuffer[]; // pending buffers for computing private buffersPending: GPUBuffer[]; @@ -212,7 +210,6 @@ class GpuDataManagerImpl implements GpuDataManager { this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); - this.buffersForUploadingPending = []; this.buffersPending = []; this.capturedPendingBuffers = new Map(); @@ -252,13 +249,12 @@ class GpuDataManagerImpl implements GpuDataManager { gpuBufferForUploading.unmap(); // GPU copy - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); + const commandEncoder = this.backend.device.createCommandEncoder(); commandEncoder.copyBufferToBuffer(gpuBufferForUploading, 0, gpuDataCache.gpuData.buffer, 0, size); + this.backend.device.queue.submit([commandEncoder.finish()]); + gpuBufferForUploading.destroy(); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.upload(id=${id})`); - - this.buffersForUploadingPending.push(gpuBufferForUploading); } memcpy(sourceId: GpuDataId, destinationId: GpuDataId): void { @@ -395,12 +391,6 @@ class GpuDataManagerImpl implements GpuDataManager { } refreshPendingBuffers(): void { - for (const buffer of this.buffersForUploadingPending) { - // upload buffer is only useful in the session creation time. So we don't need to reuse them in session running. - buffer.destroy(); - } - this.buffersForUploadingPending = []; - if (this.buffersPending.length === 0) { return; } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 09c786daa3fcd..6c7afbc7365bb 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,9 +16,11 @@ import { einsum, parseEinsumAttributes } from './ops/einsum'; import { expand } from './ops/expand'; import { fastGelu } from './ops/fast-gelu'; import { gather, parseGatherAttributes } from './ops/gather'; +import { gatherND, parseGatherNDAttributes } from './ops/gather-nd'; import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; +import { gridSample, parseGridSampleAttributes } from './ops/grid-sample'; import { groupQueryAttention } from './ops/group-query-attention'; import { instanceNorm } from './ops/instance-norm'; import { layerNorm } from './ops/layer-norm'; @@ -29,6 +31,7 @@ import { pad } from './ops/pad'; import * as pool from './ops/pool'; import { dequantizeLinear, parseDequantizeLinearAttributes } from './ops/quantize-linear'; import { range } from './ops/range'; +import { scatterND, parseScatterNDAttributes } from './ops/scatter-nd'; import { reduceL1, reduceL2, @@ -98,12 +101,14 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]], + ['GatherND', [gatherND, parseGatherNDAttributes]], ['Gelu', [unaryOps.gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['GridSample', [gridSample, parseGridSampleAttributes]], ['GroupQueryAttention', [groupQueryAttention]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], @@ -138,6 +143,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], ['RotaryEmbedding', [rotaryEmbedding]], + ['ScatterND', [scatterND, parseScatterNDAttributes]], ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3ef5c943d5624..9e21a552b8466 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -143,7 +143,21 @@ const conv2dCommonSnippet = ( } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`; - const sampleW = `${getWSnippet(innerElementSizeW)}`; + const sampleW = isChannelsLast + ? fitInner && fitBOuter + ? getWSnippet(innerElementSizeW) + : ` + let col = colIn * ${innerElementSizeW}; + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { + ${getWSnippet(innerElementSizeW)} + } + return ${typeSnippet(innerElementSizeW, dataType)}(0.0);` + : ` + let col = colIn * ${innerElementSizeW}; + if (row < uniforms.dim_inner && col < uniforms.dim_a_outer) { + ${getWSnippet(innerElementSizeW)} + } + return ${typeSnippet(innerElementSizeW, dataType)}(0.0);`; const resType = typeSnippet(innerElementSize, dataType); const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 2a8756e435b8e..cb1f30ecdd1f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -29,229 +29,27 @@ import { ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType, + getMaxComponents, } from '../common'; import { ConvTransposeAttributes } from '../conv-transpose'; -const createConvTranspose2DOpProgramShaderSource = ( - shaderHelper: ShaderHelper, - inputs: readonly TensorView[], - outputShape: readonly number[], - hasBias: boolean, - is1DimensionDispatch: boolean, - isVec4 = false, - dataType: string, - uniforms: UniformsArrayType, - isChannelsLast = false, -): string => { - const rowDim = isChannelsLast ? 1 : 2; - const colDim = isChannelsLast ? 2 : 3; - const channelDim = isChannelsLast ? 3 : 1; - const workPerThread = isVec4 ? 2 : 1; - - let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { - result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); - }`; - if (hasBias) { - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; - } - const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); - const inputVariables = [dy, w]; - if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - - const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; - let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; - let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - - let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); - - // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). - // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; - for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4<${dataType}>(0.0); - } - for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); - let wRPerm = uniforms.filter_dims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || - fract(dyR) > 0.0 || wRPerm < 0) { - continue; - } - let idyR: u32 = u32(dyR); - - for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let wCPerm = uniforms.filter_dims[1] - 1 - wC; - if (wCPerm < 0) { - continue; - } - var bDyCVal = true; - var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || - fract(dyC) > 0.0) { - bDyCVal = false; - } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || - fract(dyC2) > 0.0) { - bDyCVal2 = false; - } - - let idyC: u32 = u32(dyC); - let idyC2: u32 = u32(dyC2); - if (bDyCVal && bDyCVal2) { - let d2Length = uniforms.Dy_shape[3]; - for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[0] = dotProd[0] + tmpval; - - xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - - dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - } - } else if (bDyCVal) { - let d2Length = uniforms.Dy_shape[${channelDim}]; - for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[0] = dotProd[0] + tmpval; - } - } else if (bDyCVal2) { - let d2Length = uniforms.Dy_shape[3]; - for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[1] = dotProd[1] + tmpval; - } - } - } - } - - for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) { - let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`}; - ${output.set('batch', 'r', 'c + i', 'd1', 'value')}; - } - }`; - const codeSnippet = ` - let outputIndices = ${output.offsetToIndices('global_idx')}; - let batch = ${output.indicesGet('outputIndices', 0)}; - let d1 = ${output.indicesGet('outputIndices', channelDim)}; - let r = ${output.indicesGet('outputIndices', rowDim)}; - let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; - let dyRCorner = dyCorner.x; - let dyCCorner = dyCorner.y; - let groupId = d1 / uniforms.output_channels_per_group; - let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; - // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). - // ? = to be determined. : = across all values in that axis. - var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { - if (wR % uniforms.dilations.x != 0) { - continue; - } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); - let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || - wRPerm < 0) { - continue; - } - let idyR: u32 = u32(dyR); - - for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { - if (wC % uniforms.dilations.y != 0) { - continue; - } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || - fract(dyC) > 0.0 || wCPerm < 0) { - continue; - } - let idyC: u32 = u32(dyC); - var inputChannel = groupId * uniforms.input_channels_per_group; - for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { - let xValue = ${ - isChannelsLast - ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') - : dy.get('batch', 'inputChannel', 'idyR', 'idyC') - }; - let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; - dotProd = dotProd + xValue * wValue; - inputChannel = inputChannel + 1; - } - } - } - let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`}; - ${output.setByOffset('global_idx', 'value')}; - `; - - return ` - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} - ${declareFunctions} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; - ${isVec4 ? codeSnippet4 : codeSnippet}}`; -}; - export const createConvTranspose2DProgramInfo = ( inputs: readonly TensorView[], attributes: ConvTransposeAttributes, squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): ProgramInfo => { const hasBias = inputs.length > 2; - // const isChannelsLast = attributes.format === 'NHWC'; const outputShape = attributes.outputShape; - const outputSize = ShapeUtil.size(outputShape); - - // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // TODO Enable isVec4 for performance - // Disabled due to weight matrix layout issue - // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0; + const isChannelsLast = attributes.format === 'NHWC'; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[2] / group; + const outputChannelsPerGroup = wShape[3]; + const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1; + const outputSize = ShapeUtil.size(outputShape) / components; const dispatch = [Math.ceil(outputSize / 64), 1, 1]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const isChannelsLast = attributes.format === 'NHWC'; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const strides = [attributes.strides[0], attributes.strides[1]]; const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; @@ -268,15 +66,9 @@ export const createConvTranspose2DProgramInfo = ( ]; const pads = [ effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), - effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2, + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2), ]; - const isVec4 = false; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; - const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, { type: DataType.uint32, data: strides }, @@ -294,7 +86,6 @@ export const createConvTranspose2DProgramInfo = ( } programUniforms.push(...createTensorShapeVariables(outputShape)); - const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniforms: UniformsArrayType = [ { name: 'output_size', type: 'u32' }, @@ -307,21 +98,83 @@ export const createConvTranspose2DProgramInfo = ( { name: 'output_channels_per_group', type: 'u32' }, ]; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - return `${createConvTranspose2DOpProgramShaderSource( - shaderHelper, - inputs, - outputShape, - hasBias, - is1DimensionDispatch, - isVec4, - dataType, - uniforms, - isChannelsLast, - )}`; + const rowDim = isChannelsLast ? 1 : 2; + const colDim = isChannelsLast ? 2 : 3; + const channelDim = isChannelsLast ? 3 : 1; + + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length); + const inputVariables = [dy, w]; + if (hasBias) { + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); + } + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + + const codeSnippet = ` + let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; + let batch = ${output.indicesGet('outputIndices', 0)}; + let d1 = ${output.indicesGet('outputIndices', channelDim)}; + let r = ${output.indicesGet('outputIndices', rowDim)}; + let c = ${output.indicesGet('outputIndices', colDim)}; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; + let dyRCorner = dyCorner.x; + let dyCCorner = dyCorner.y; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; + // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). + // ? = to be determined. : = across all values in that axis. + var dotProd = ${output.type.value}(0.0); + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { + continue; + } + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || + wRPerm < 0) { + continue; + } + let idyR: u32 = u32(dyR); + + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { + continue; + } + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || + fract(dyC) > 0.0 || wCPerm < 0) { + continue; + } + let idyC: u32 = u32(dyC); + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { + let xValue = ${ + isChannelsLast + ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') + : dy.get('batch', 'inputChannel', 'idyR', 'idyC') + }; + let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; + let wValue = ${w.getByOffset(`w_offset / ${components}`)}; + dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; + } + } + } + let value = dotProd${hasBias ? ` + bias[d1 / ${components}]` : ''}; + ${output.setByOffset('global_idx', 'value')}; + `; + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; + ${codeSnippet}}`; }; + return { name: 'ConvTranspose2D', - shaderCache: { hint: `${attributes.cacheKey};`, inputDependencies }, + shaderCache: { hint: `${attributes.cacheKey};${components}`, inputDependencies }, getRunData: () => ({ dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, outputs: [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index f0287529ca08b..c6341f94cf191 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -25,7 +25,6 @@ import { ShapeUtil } from '../../../util'; import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types'; import { createTensorShapeVariables, - getBroadcastDims, IndicesHelper, inputVariable, internalVariable, @@ -40,6 +39,7 @@ import { getActivationSnippet, InternalActivationAttributes, } from '../fuse-utils'; +import { convertOutputBatchIndicesToInputBatchIndices } from '../matmul-shaders'; import { typeSnippet } from './activation_util'; @@ -373,42 +373,11 @@ const matMulReadWriteFnSource = ( hasBias: boolean, applyActivation: string, variables: IndicesHelper[], - batchShapes: Array, isChannelsLast = false, ): string => { - const [batchAShape, batchBShape, batchShape] = batchShapes; const [batchVariable, aVariable, bVariable, outputVariable] = variables; - const broadCastADims = getBroadcastDims(batchAShape, batchShape); - const broadCastBDims = getBroadcastDims(batchBShape, batchShape); const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); - const getAIndices = () => { - const aRank = aVariable.rank; - const batchRank = batchVariable.rank; - let resStr = `var aIndices: ${aVariable.type.indices};`; - for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastADims.forEach((i) => { - resStr += `\naIndices[${i}] = 0;`; - }); - resStr += `\naIndices[${aRank - 2}] = u32(row); - aIndices[${aRank - 1}] = u32(colIn);`; - return resStr; - }; - const getBIndices = () => { - const bRank = bVariable.rank; - const batchRank = batchVariable.rank; - let resStr = `var bIndices: ${bVariable.type.indices};`; - for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastBDims.forEach((i) => { - resStr += `\nbIndices[${i}] = 0;`; - }); - resStr += `\nbIndices[${bRank - 2}] = u32(row); - bIndices[${bRank - 1}] = u32(colIn);`; - return resStr; - }; + const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet( component, @@ -418,7 +387,16 @@ const matMulReadWriteFnSource = ( let col = colIn * ${component}; if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { - ${getAIndices()} + var aIndices: ${aVariable.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices( + 'aIndices', + aVariable, + aVariable.rank - 2, + batchVariable.rank, + 'batchIndices', + )} + ${aVariable.indicesSet('aIndices', aVariable.rank - 2, 'u32(row)')} + ${aVariable.indicesSet('aIndices', aVariable.rank - 1, 'u32(colIn)')} value = ${aVariable.getByIndices('aIndices')}; } return value; @@ -432,7 +410,16 @@ const matMulReadWriteFnSource = ( let col = colIn * ${component}; if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { - ${getBIndices()} + var bIndices: ${bVariable.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices( + 'bIndices', + bVariable, + bVariable.rank - 2, + batchVariable.rank, + 'batchIndices', + )} + ${bVariable.indicesSet('bIndices', bVariable.rank - 2, 'u32(row)')} + ${bVariable.indicesSet('bIndices', bVariable.rank - 1, 'u32(colIn)')} value = ${bVariable.getByIndices('bIndices')}; } return value; @@ -532,7 +519,6 @@ export const createMatmulProgramInfo = ( hasBias, applyActivation, [batchDims, A, B, output], - [outerDimsA, outerDimsB, outerDims], isChannelsLast, ); return ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 793f26fe901e3..0b9173403cd7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -195,7 +195,7 @@ export interface IndicesHelper { /** * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input' | 'output' | 'internal'; + readonly usage: 'input' | 'output' | 'atomicOutput' | 'internal'; /** * the rank of the input or output. @@ -733,6 +733,20 @@ export const outputVariable = ( components: 1 | 2 | 3 | 4 = 1, ): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'output', components); +/** + * Create a IndicesHelper for an atomic output. + * + * @param name - the name of the output. + * @param type - the tensor type of the output. + * @param shapeOrRank - the tensor shape or the rank of the output. + * @returns an IndicesHelper for the output. + */ +export const atomicOutputVariable = ( + name: string, + type: number, + shapeOrRank: number | readonly number[], +): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'atomicOutput', 1); + /** * Create a IndicesHelper for an internal variable. * @@ -905,9 +919,8 @@ class ShaderHelperImpl implements ShaderHelper { } this.variables.push(variable); this.appendVariableUniforms(variable); - const access = variable.usage === 'input' ? 'read' : 'read_write'; - const storageType = variable.type.storage; + const storageType = variable.usage === 'atomicOutput' ? `atomic` : variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } @@ -996,27 +1009,3 @@ class ShaderHelperImpl implements ShaderHelper { export const createShaderHelper = (dispatchGroup: [number, number, number], limits: GPUSupportedLimits) => new ShaderHelperImpl(dispatchGroup, limits); - -/** - * This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40 - * Returns the dimensions in the input shape that are broadcasted to - * produce the provided output shape. - * - * The returned dimensions are 0-indexed and sorted. An example: - * inShape = [4, 1, 3] - * outShape = [5, 4, 3, 3] - * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3. - */ -export const getBroadcastDims = (inShape: readonly number[], outShape: readonly number[]): number[] => { - const inRank = inShape.length; - const dims: number[] = []; - for (let i = 0; i < inRank; i++) { - const dim = inRank - 1 - i; - const a = inShape[dim] || 1; - const b = outShape[outShape.length - 1 - i] || 1; - if (b > 1 && a === 1) { - dims.unshift(dim); - } - } - return dims; -}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 236f1b09a6c93..3e168ddedac86 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -4,7 +4,6 @@ import { TensorView } from '../../tensor-view'; import { ComputeContext } from '../types'; -import { createConv2DTransposeMatMulProgramInfo } from './3rd-party/conv_backprop_mm_webgpu'; import { createConvTranspose2DProgramInfo } from './3rd-party/conv_backprop_webgpu'; import { ConvAttributes } from './conv'; import { parseInternalActivationAttributes } from './fuse-utils'; @@ -227,41 +226,16 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose } }; -// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] -const weightTransposePerm = [2, 3, 1, 0]; - const convTranspose2d = ( context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): void => { - const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); - const isChannelsLast = attributes.format === 'NHWC'; - const outputShape = adjustedAttributes.outputShape; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's - // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit - // utilization rate is very low. - if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { - context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); - return; - } - const outHeight = outputShape[isChannelsLast ? 1 : 2]; - const outWidth = outputShape[isChannelsLast ? 2 : 3]; - const weightHeight = inputs[1].dims[2]; - const weightWidth = inputs[1].dims[3]; - - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; - const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; - const dimInner = weightHeight * weightWidth * inputChannels; - - const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; - // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute(createTransposeProgramInfo(inputs[1], weightTransposePerm), { + context.compute(createTransposeProgramInfo(inputs[1], [2, 3, 0, 1]), { inputs: [1], outputs: [attributes.wIsConst ? -2 : -1], })[0]; @@ -271,29 +245,12 @@ const convTranspose2d = ( // STEP.2: prepare reshaped inputs const convTransposeInputs = [inputs[0], transposedWeight]; - const hasBias = inputs.length === 3; - if (hasBias) { - if (!isChannelsLast && inputs[2].dims.length === 1) { - convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); - } else { - convTransposeInputs.push(inputs[2]); - } + if (inputs.length === 3) { + convTransposeInputs.push(inputs[2]); } - - // STEP.3: compute matmul - context.compute( - createConv2DTransposeMatMulProgramInfo( - convTransposeInputs, - adjustedAttributes, - outputShape, - dimAOuter, - dimBOuter, - dimInner, - hasBias, - sequentialAccessByThreads, - ), - { inputs: convTransposeInputs }, - ); + context.compute(createConvTranspose2DProgramInfo(convTransposeInputs, attributes, squeezeOutputShapeFunction), { + inputs: convTransposeInputs, + }); }; const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { @@ -338,12 +295,9 @@ const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttri { ...attributes, pads, strides, dilations, kernelShape }, inputs, ); - context.compute( - createConvTranspose2DProgramInfo(inputs, adjustedAttributes, (outputShape) => - isChannelLast - ? [outputShape[0], outputShape[2], outputShape[3]] - : [outputShape[0], outputShape[1], outputShape[3]], - ), + + convTranspose2d(context, inputs, adjustedAttributes, (outputShape) => + isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]], ); }; @@ -352,6 +306,7 @@ export const convTranspose = (context: ComputeContext, attributes: ConvTranspose if (context.inputs[0].dims.length === 3) { convTranspose1d(context, attributes); } else { - convTranspose2d(context, context.inputs, attributes); + const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, context.inputs); + convTranspose2d(context, context.inputs, adjustedAttributes); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index de9f7bc8885ab..f9225baf66eea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -11,7 +11,7 @@ import { computeConv3DInfo, createConv3DNaiveProgramInfo } from './3rd-party/con import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu'; import { createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo } from './conv-grouped'; import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils'; -import { createNaiveMatmulProgramInfo } from './matmul'; +import { createNaiveMatmulProgramInfo } from './matmul-shaders'; import { createTransposeProgramInfo } from './transpose'; export const calculateOutputShape = ( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 4e2bfa9d89924..3691b5ecb602b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -48,11 +48,18 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const shape = Array.from(inputs[1].getBigInt64Array(), Number); const outputShape: number[] = calculateOutputShape(inputShape, shape); const dataType = inputs[0].dataType; - const components = dataType === DataType.bool ? 4 : 1; + const isBoolOrScalar = dataType === DataType.bool || ShapeUtil.size(inputShape) === 1; + const iComponents = + dataType === DataType.bool ? 4 : inputShape.length > 0 && inputShape[inputShape.length - 1] % 4 === 0 ? 4 : 1; + const components = isBoolOrScalar + ? 4 + : outputShape.length > 0 && outputShape[outputShape.length - 1] % 4 === 0 + ? 4 + : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const getShaderSource = (shaderHelper: ShaderHelper) => { - const input = inputVariable('input', dataType, inputShape.length, components); + const input = inputVariable('input', dataType, inputShape.length, iComponents); const output = outputVariable('output', dataType, outputShape.length, components); let assignment: string; if (dataType === DataType.bool) { @@ -74,9 +81,10 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => }`; } else { assignment = ` - let outputIndices = ${output.offsetToIndices('global_idx')}; + let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; let inputOffset = ${input.broadcastedIndicesToOffset('outputIndices', output)}; - ${output.setByOffset('global_idx', input.getByOffset('inputOffset'))} + let data = ${output.type.value}(${input.getByOffset(`inputOffset / ${iComponents}`)}); + ${output.setByOffset('global_idx', 'data')} }`; } return ` @@ -92,7 +100,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ]; return { name: 'Expand', - shaderCache: { hint: `${outputShape.length}`, inputDependencies: ['rank'] }, + shaderCache: { hint: `${outputShape.length};${iComponents}${components}`, inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts new file mode 100644 index 0000000000000..43b51f6e94a66 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramUniform } from '../types'; + +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; + +export interface GatherNDAttributes extends AttributeWithCacheKey { + readonly batchDims: number; +} + +const computeSliceOffsets = ( + context: ComputeContext, + indicesData: TensorView, + sizesFromSliceDimsData: number[], + batchDims: number, + inputDims: readonly number[], + numSlices: number, + numSlicesPerBatch: number, + inputBatchStride: number, + numSliceDims: number, +) => { + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: numSlices }, + { type: DataType.uint32, data: batchDims }, + { type: DataType.uint32, data: inputDims }, + { type: DataType.uint32, data: sizesFromSliceDimsData }, + { type: DataType.uint32, data: numSlicesPerBatch }, + { type: DataType.uint32, data: inputBatchStride }, + { type: DataType.uint32, data: numSliceDims }, + ]; + + const outputShape = [numSlices]; + programUniforms.push(...createTensorShapeVariables(indicesData.dims, outputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const indices = inputVariable('indices_data', indicesData.dataType, indicesData.dims.length); + const output = outputVariable('input_slice_offsets_data', DataType.uint32, 1, 1); + const variables = [indices, output]; + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'batch_dims', type: 'u32' }, + { name: 'input_dims', type: 'u32', length: inputDims.length }, + { name: 'sizes_from_slice_dims_data', type: 'u32', length: sizesFromSliceDimsData.length }, + { name: 'num_slices_per_batch', type: 'u32' }, + { name: 'input_batch_stride', type: 'u32' }, + { name: 'num_slice_dims', type: 'u32' }, + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let batch_idx = global_idx / uniforms.num_slices_per_batch; + let base_offset = batch_idx * uniforms.input_batch_stride; + + let slice_indices_base_offset = global_idx * uniforms.num_slice_dims; + var relative_slice_offset = 0; + for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) { + var index = i32(indices_data[dim_idx + slice_indices_base_offset].x); + let input_dim_idx = uniforms.batch_dims + dim_idx; + if (index < 0) { + ${ + inputDims.length === 1 + ? 'index += i32(uniforms.input_dims);' + : 'index += i32(uniforms.input_dims[input_dim_idx]);' + } + } + ${ + sizesFromSliceDimsData.length === 1 + ? 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);' + : 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);' + } + } + + input_slice_offsets_data[global_idx] = base_offset + u32(relative_slice_offset); + }`; + }; + + return context.compute( + { + name: 'computeSliceOffsets', + shaderCache: { hint: `${inputDims.length}_${sizesFromSliceDimsData.length}`, inputDependencies: ['rank'] }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: context.inputs[1].dataType }], + dispatchGroup: { x: Math.ceil(numSlices / 64) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [indicesData], outputs: [-1] }, + )[0]; +}; + +export const gatherND = (context: ComputeContext, attributes: GatherNDAttributes) => { + const inputs = context.inputs; + const inputShape = inputs[0].dims; + const inputType = inputs[0].dataType; + const indicesShape = inputs[1].dims; + const numSliceDims = indicesShape[indicesShape.length - 1]; + const numSlices = ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1); + const sliceSize = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims + numSliceDims); + const numBatches = ShapeUtil.sizeToDimension(inputShape, attributes.batchDims); + const inputBatchStride = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims); + const numSlicesPerBatch = numSlices / numBatches; + const sizesFromSliceDims = new Array(numSliceDims); + let runningProduct = sliceSize; + for (let i = 0; i < numSliceDims; ++i) { + sizesFromSliceDims[numSliceDims - 1 - i] = runningProduct; + runningProduct *= inputShape[attributes.batchDims + numSliceDims - 1 - i]; + } + + const inputSliceOffsets = computeSliceOffsets( + context, + inputs[1], + sizesFromSliceDims, + attributes.batchDims, + inputShape, + numSlices, + numSlicesPerBatch, + inputBatchStride, + numSliceDims, + ); + + const lastIndicesDimension = attributes.batchDims + numSliceDims; + if (lastIndicesDimension > inputShape.length) { + throw new Error('last dimension of indices must not be larger than rank of input tensor'); + } + + const outputShape = indicesShape.slice(0, -1).concat(inputShape.slice(lastIndicesDimension)); + const outputSize = ShapeUtil.size(outputShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: sliceSize }, + ...createTensorShapeVariables(inputs[0].dims, inputSliceOffsets.dims, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('data', inputs[0].dataType, inputs[0].dims.length); + const indices = inputVariable('slice_offsets', DataType.uint32, inputSliceOffsets.dims.length); + + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + return ` + ${shaderHelper + .registerUniform('output_size', 'u32') + .registerUniform('slice_size', 'u32') + .declareVariables(input, indices, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let slice_offset = slice_offsets[global_idx / uniforms.slice_size]; + output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size]; + }`; + }; + context.compute( + { + name: 'GatherND', + shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [inputs[0], inputSliceOffsets] }, + ); +}; + +export const parseGatherNDAttributes = (attributes: Record): GatherNDAttributes => { + const batchDims = attributes.batch_dims as number; + return { + batchDims, + cacheKey: '', + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 7f2469d95e1c1..09365f3b984b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -55,9 +55,15 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt if (!outputShape) { throw new Error("Can't use gemm on the given tensors"); } + const tileSize = 16; + const numTileN = Math.ceil(N / tileSize); + const numTileM = Math.ceil(M / tileSize); + // TODO: Find the condition when to use the naive one. + const useShared = true; + const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: useShared ? numTileN : outputSize }, { type: DataType.uint32, data: M }, { type: DataType.uint32, data: N }, { type: DataType.uint32, data: K }, @@ -130,6 +136,159 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt }`; }; + const getShaderSourceShared = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims); + const b = inputVariable('b', inputs[1].dataType, inputs[1].dims); + let c: IndicesHelper | null = null; + const variables = [a, b]; + if (inputs.length === 3) { + c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length); + variables.push(c); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + variables.push(output); + const uniforms: UniformsArrayType = [ + { name: 'num_tile_n', type: 'u32' }, + { name: 'M', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'K', type: 'u32' }, + { name: 'alpha', type: 'f32' }, + { name: 'beta', type: 'f32' }, + ]; + + let calcResult = ''; + let fillWorkgroupMemory = ''; + if (attributes.transA && attributes.transB) { + fillWorkgroupMemory = ` + var col = tile_row_start + local_id.x; + var row = k_start + local_id.y; + if (col < uniforms.M && row < uniforms.K) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = k_start + local_id.x; + row = tile_col_start + local_id.y; + if (col < uniforms.K && row < uniforms.N) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[k][local_id.y] * tile_b[local_id.x][k];`; + } else if (attributes.transA && !attributes.transB) { + fillWorkgroupMemory = ` + var col = tile_row_start + local_id.x; + var row = k_start + local_id.y; + if (col < uniforms.M && row < uniforms.K) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = tile_col_start + local_id.x; + row = k_start + local_id.y; + if (col < uniforms.N && row < uniforms.K) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[k][local_id.y] * tile_b[k][local_id.x];`; + } else if (!attributes.transA && attributes.transB) { + fillWorkgroupMemory = ` + var col = k_start + local_id.x; + var row = tile_row_start + local_id.y; + if (col < uniforms.K && row < uniforms.M) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = k_start + local_id.x; + row = tile_col_start + local_id.y; + if (col < uniforms.K && row < uniforms.N) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[local_id.y][k] * tile_b[local_id.x][k];`; + } else if (!attributes.transA && !attributes.transB) { + fillWorkgroupMemory = ` + var col = k_start + local_id.x; + var row = tile_row_start + local_id.y; + if (col < uniforms.K && row < uniforms.M) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = tile_col_start + local_id.x; + row = k_start + local_id.y; + if (col < uniforms.N && row < uniforms.K) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[local_id.y][k] * tile_b[k][local_id.x];`; + } + + const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;'; + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} + var tile_a: array, ${tileSize}>; + var tile_b: array, ${tileSize}>; + ${shaderHelper.mainStart([tileSize, tileSize, 1])} + let tile_col_start = (workgroup_index % uniforms.num_tile_n) * ${tileSize}; + let tile_row_start = (workgroup_index / uniforms.num_tile_n) * ${tileSize}; + let num_tiles = (uniforms.K - 1) / ${tileSize} + 1; + var k_start = 0u; + var value = ${output.type.value}(0); + for (var t: u32 = 0u; t < num_tiles; t++) { + ${fillWorkgroupMemory} + k_start = k_start + ${tileSize}; + workgroupBarrier(); + + for (var k: u32 = 0u; k < ${tileSize}; k++) { + ${calcResult} + } + workgroupBarrier(); + } + + ${calculateAlpha} + let m = tile_row_start + local_id.y; + let n = tile_col_start + local_id.x; + ${(() => { + if (c != null) { + return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${ + output.type.value + }(uniforms.beta) * ${c.getByOffset('cOffset')};`; + } + return ''; + })()} + if (m < uniforms.M && n < uniforms.N) { + output[m * uniforms.N + n] = value; + } + }`; + }; + + if (useShared) { + return { + name: 'GemmShared', + shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: numTileN * numTileM }, + programUniforms, + }), + getShaderSource: getShaderSourceShared, + }; + } + return { name: 'Gemm', shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts b/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts new file mode 100644 index 0000000000000..50c71472434ad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; + +let [idxN, idxC, idxH, idxW] = [0, 1, 2, 3]; // NCHW +type Mode = 'bilinear' | 'nearest' | 'bicubic'; +type PaddingMode = 'zeros' | 'border' | 'reflection'; +type Format = 'NHWC' | 'NCHW'; +export interface GridSampeAttributes extends AttributeWithCacheKey { + alignCorners: number; + mode: Mode; + paddingMode: PaddingMode; + format: Format; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 4) { + throw new Error('only 4-D tensor is supported.'); + } + if (inputs[0].dims.length !== inputs[1].dims.length) { + throw new Error('input dimensions must be equal to grid dimensions'); + } + + if (inputs[0].dims.length - 2 !== inputs[1].dims[inputs[1].dims.length - 1]) { + throw new Error(`last dimension of grid must be equal to ${inputs[0].dims.length - 2}`); + } + + if (inputs[0].dims[0] !== inputs[1].dims[0]) { + throw new Error('grid batch size must match input batch size'); + } +}; + +const gsGetCubicCoeffs = ` + fn gs_get_cubic_coeffs(x: f32) -> vec4 { + let cubic_alpha = -0.75f; + let x_abs = abs(x); + var coeffs: vec4; + coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha); + coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1); + coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1); + coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha); + return coeffs; + } +`; + +const gsBicubicInterpolate = (dataType: string): string => ` + fn gs_bicubic_interpolate(p: mat4x4<${dataType}>, x: f32, y: f32) -> ${dataType} { + var v: vec4; + var coeffs = gs_get_cubic_coeffs(x); + for (var i = 0; i < 4; i++) { + v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; + } + coeffs = gs_get_cubic_coeffs(y); + let pixel = ${dataType}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]); + return pixel; + } +`; + +const gsDenormalize = (attributes: GridSampeAttributes): string => ` + fn gs_denormalize(n: f32, length: i32) -> f32 { + ${ + attributes.alignCorners === 0 + ? ` + // alignCorners: false => [-1, 1] to [-0.5, length - 0.5] + return ((n + 1.0) * f32(length) - 1.0) / 2.0; + ` + : ` + // alignCorners: true => [-1, 1] to [0, length - 1] + return (n + 1.0) / 2.0 * (f32(length - 1)); + ` + } + } +`; + +const gsReflect = (attributes: GridSampeAttributes): string => ` + ${ + attributes.paddingMode === 'reflection' + ? ` + fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 { + var dx = 0.0; + var fx = f32(x); + let range = x_max - x_min; + if (fx < x_min) { + dx = x_min - fx; + let n = u32(dx / range); + let r = dx - f32(n) * range; + if (n % 2 == 0) { + fx = x_min + r; + } else { + fx = x_max - r; + } + } else if (fx > x_max) { + dx = fx - x_max; + let n = u32(dx / range); + let r = dx - f32(n) * range; + if (n % 2 == 0) { + fx = x_max - r; + } else { + fx = x_min + r; + } + } + return u32(fx); + }` + : '' + } +`; + +const pixelAtGrid = (input: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => + ` + fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4) -> ${dataType} { + var pixel = ${dataType}(0); + var indices = vec4(0); + indices[${idxN}] = batch; + indices[${idxC}] = channel;` + + (() => { + switch (attributes.paddingMode) { + case 'zeros': + return ` + if (r >= 0 && r < H && c >=0 && c < W) { + indices[${idxH}] = u32(r); + indices[${idxW}] = u32(c); + } + `; + case 'border': + return ` + indices[${idxH}] = u32(clamp(r, 0, H - 1)); + indices[${idxW}] = u32(clamp(c, 0, W - 1)); + `; + case 'reflection': + return ` + indices[${idxH}] = gs_reflect(r, border[1], border[3]); + indices[${idxW}] = gs_reflect(c, border[0], border[2]); + `; + default: + throw new Error(`padding mode ${attributes.paddingMode} is not supported`); + } + })() + + ` + return ${input.getByIndices('indices')}; + } +`; + +const computePixel = (output: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => + (() => { + switch (attributes.mode) { + case 'nearest': + return ` + let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${idxN}], indices[${idxC}], border); + `; + case 'bilinear': + return ` + let x1 = i32(floor(x)); + let y1 = i32(floor(y)); + let x2 = x1 + 1; + let y2 = y1 + 1; + + let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + + let dx2 = ${dataType}(f32(x2) - x); + let dx1 = ${dataType}(x - f32(x1)); + let dy2 = ${dataType}(f32(y2) - y); + let dy1 = ${dataType}(y - f32(y1)); + let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + `; + case 'bicubic': + return ` + let x0 = i32(floor(x)) - 1; + let y0 = i32(floor(y)) - 1; + var p: mat4x4<${dataType}>; + for (var h = 0; h < 4; h++) { + for (var w = 0; w < 4; w++) { + p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + } + } + + let dx = x - f32(x0 + 1); + let dy = y - f32(y0 + 1); + let result = gs_bicubic_interpolate(p, dx, dy); + `; + default: + throw new Error(`mode ${attributes.mode} is not supported`); + } + })() + `${output.setByOffset('global_idx', 'result')}`; + +const createGridSampleProgramInfo = (inputs: readonly TensorView[], attributes: GridSampeAttributes): ProgramInfo => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length); + // discard last dimension for using vec2 to access grid data + const gridShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2]]; + const grid = inputVariable('grid', inputs[1].dataType, gridShape.length, 2); + let outputShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[1].dims[1], inputs[1].dims[2]]; + if (attributes.format === 'NHWC') { + outputShape = [inputs[0].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[0].dims[3]]; + [idxN, idxC, idxH, idxW] = [0, 3, 1, 2]; + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const dataType = x.type.value; + const outputSize = ShapeUtil.size(outputShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputs[0].dims, gridShape, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(x, grid, output)} + ${gsGetCubicCoeffs} + ${gsBicubicInterpolate(dataType)} + ${gsDenormalize(attributes)} + ${gsReflect(attributes)} + ${pixelAtGrid(x, dataType, attributes)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let H_in = i32(uniforms.x_shape[${idxH}]); + let W_in = i32(uniforms.x_shape[${idxW}]); + + ${ + attributes.alignCorners === 0 + ? ` + let x_min = -0.5; + let x_max = f32(W_in) - 0.5; + let y_min = -0.5; + let y_max = f32(H_in) - 0.5; + ` + : ` + let x_min = 0.0; + let x_max = f32(W_in) - 1.0; + let y_min = 0.0; + let y_max = f32(H_in) - 1.0; + ` + }; + let border = vec4(x_min, y_min, x_max, y_max); + + let indices = ${output.offsetToIndices('global_idx')}; + var grid_indices = vec3(indices[${idxN}], indices[${idxH}], indices[${idxW}]); + let nxy = ${grid.getByIndices('grid_indices')}; + var x = gs_denormalize(f32(nxy[0]), W_in); + var y = gs_denormalize(f32(nxy[1]), H_in); + + ${computePixel(output, dataType, attributes)} + }`; + + return { + name: 'GridSample', + shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies: ['type', 'type'] }, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }; + }, + getShaderSource, + }; +}; + +export const gridSample = (context: ComputeContext, attributes: GridSampeAttributes): void => { + validateInputs(context.inputs); + context.compute(createGridSampleProgramInfo(context.inputs, attributes)); +}; + +export const parseGridSampleAttributes = (attributes: Record): GridSampeAttributes => + createAttributeWithCacheKey({ + alignCorners: attributes.align_corners as number, + mode: attributes.mode as Mode, + paddingMode: attributes.padding_mode as PaddingMode, + format: attributes.format as Format, + }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 859bd850862aa..a357d29667319 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -36,7 +36,10 @@ const computeChannelScaleShift = ( const f32Type = components === 1 ? 'f32' : `vec${components}f`; const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`; const unitsOfWork = n * c; - + let workgroupSize = 64; + if (unitsOfWork === 1) { + workgroupSize = 256; + } const inputShape = [n, c, h / components]; const outputShape = [n, c, 2]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; @@ -49,7 +52,6 @@ const computeChannelScaleShift = ( const b = inputVariable('bias', bias.dataType, bias.dims); const output = outputVariable('output', DataType.float, 3, 2); const variables = [x, s, b, output]; - const workgroupSize = 64; return ` var workgroup_shared : array<${wgType}, ${workgroupSize}>; const workgroup_size = ${workgroupSize}u; @@ -91,7 +93,7 @@ const computeChannelScaleShift = ( { name: 'InstanceNormComputeChannelScaleShift', // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, + shaderCache: { hint: `${components};${epsilon};${workgroupSize}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: DataType.float }], dispatchGroup: { x: unitsOfWork }, @@ -187,14 +189,21 @@ const createInstanceNormNHWCProgramInfo = ( const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // 1. transpose x from NHWC to NCHW + let needTranspose = false; const transposedXPerm = [0, xShape.length - 1]; for (let i = 0; i < xShape.length - 2; i++) { + needTranspose = needTranspose || xShape[i + 1] !== 1; transposedXPerm.push(i + 1); } - const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { - inputs: [context.inputs[0]], - outputs: [-1], - })[0]; + + needTranspose = needTranspose && xShape[xShape.length - 1] !== 1; + + const transposedX = needTranspose + ? context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { + inputs: [context.inputs[0]], + outputs: [-1], + })[0] + : context.inputs[0].reshape(Array.from({ length: xShape.length }, (_, i) => xShape[transposedXPerm[i]])); // 2. compute channel scale and channel shift. const channelScaleShift = computeChannelScaleShift( context, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts new file mode 100644 index 0000000000000..e1f73f137e43e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ProgramInfo, ProgramUniform } from '../types'; + +import { + createTensorShapeVariables, + getElementAt, + getMaxComponents, + IndicesHelper, + inputVariable, + internalVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; +import { + appendActivationUniforms, + appendActivationUniformsData, + getActivationSnippet, + InternalActivationAttributes, +} from './fuse-utils'; + +// Helper that convert output batch indices to input batch indices using only the rank and +// the shape information in uniform +export const convertOutputBatchIndicesToInputBatchIndices = ( + targetIndicesName: string, + inputVariable: IndicesHelper, + inputBatchRank: number, + outputBatchRank: number, + batchIndicesName: string, +) => { + // Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of + // outputBatchRank should be ignored. + const extendingInputRank = outputBatchRank - inputBatchRank; + return ` + ${Array.from({ length: inputBatchRank }) + .map( + (_, i) => ` + if (${getElementAt(inputVariable.shape, i, inputVariable.rank)} != 1) { + ${inputVariable.indicesSet(targetIndicesName, i, getElementAt(batchIndicesName, i + extendingInputRank, outputBatchRank))} + } else { + ${inputVariable.indicesSet(targetIndicesName, i, 0)} + }`, + ) + .join('')} +`; +}; + +export const createNaiveMatmulProgramInfo = ( + inputs: readonly TensorView[], + activationAttributes: InternalActivationAttributes, + outputShape: readonly number[], + reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], +): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + + const M = aShape[aShape.length - 2]; + const N = bShape[bShape.length - 1]; + const K = aShape[aShape.length - 1]; + const components = getMaxComponents(N); + const aComponents = getMaxComponents(K); + const outputNumber = getMaxComponents(M); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const hasBias = inputs.length > 2; + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchSize = ShapeUtil.size(outerDims); + const outputShapeInShader = [batchSize, M, N]; + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: M }, + { type: DataType.uint32, data: N }, + { type: DataType.uint32, data: K }, + ]; + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } + programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', inputs[1].dataType, bShape.length, components); + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); + const inputVariables = [a, b]; + let processBias = ''; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + processBias = `${ + isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);` + }`; + } + + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'M', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'K', type: 'u32' }, + ]; + appendActivationUniforms(activationAttributes, uniforms); + + const calcResult = (): string => { + let calcStr = `var a_data: ${a.type.value};`; + for (let i = 0; i < aComponents; i++) { + calcStr += ` + let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`; + } + for (let i = 0; i < outputNumber; i++) { + calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; + + for (let j = 0; j < aComponents; j++) { + calcStr += ` + values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`; + } + } + return calcStr; + }; + + return ` + ${shaderHelper + .registerUniforms(uniforms) + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let col = (global_idx % (uniforms.N / ${components})) * ${components}; + var index1 = global_idx / (uniforms.N / ${components}); + let stride1 = uniforms.M / ${outputNumber}; + let row = (index1 % stride1) * ${outputNumber}; + let batch = index1 / stride1; + + ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`} + + var a_indices: ${a.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices('a_indices', a, a.rank - 2, batchDims.rank, 'batch_indices')} + ${a.indicesSet('a_indices', a.rank - 2, 0)} + ${a.indicesSet('a_indices', a.rank - 1, 0)} + let a_offset = ${a.indicesToOffset('a_indices')}; + + var b_indices: ${b.type.indices}; + ${convertOutputBatchIndicesToInputBatchIndices('b_indices', b, b.rank - 2, batchDims.rank, 'batch_indices')} + ${b.indicesSet('b_indices', b.rank - 2, 0)} + ${b.indicesSet('b_indices', b.rank - 1, 0)} + let b_offset = ${b.indicesToOffset('b_indices')}; + var values: array<${output.type.value}, ${outputNumber}>; + for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) { + ${calcResult()} + } + for (var i = 0u; i < ${outputNumber}u; i++) { + var value = values[i]; + ${processBias} + ${applyActivation} + let cur_indices = ${output.type.indices}(batch, row + i, col); + let offset = ${output.indicesToOffset('cur_indices')}; + ${output.setByOffset(`offset / ${components}`, 'value')}; + } + } + `; + }; + return { + name: 'MatMulNaive', + shaderCache: { + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, + inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], + }, + getRunData: () => ({ + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 7605e67c972b9..46a358aacdad4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,184 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { BroadcastUtil, ShapeUtil } from '../../util'; -import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; +import { ComputeContext } from '../types'; +import { createNaiveMatmulProgramInfo } from './matmul-shaders'; import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu'; -import { - createTensorShapeVariables, - getBroadcastDims, - getMaxComponents, - IndicesHelper, - inputVariable, - internalVariable, - outputVariable, - ShaderHelper, - tensorTypeToWsglStorageType, - UniformsArrayType, -} from './common'; -import { - appendActivationUniforms, - appendActivationUniformsData, - getActivationSnippet, - InternalActivationAttributes, -} from './fuse-utils'; - -export const createNaiveMatmulProgramInfo = ( - inputs: readonly TensorView[], - activationAttributes: InternalActivationAttributes, - outputShape: readonly number[], - reshapedOutputShape?: readonly number[], - isChannelsLast = false /* only used for conv2dByMatMul*/, - squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], -): ProgramInfo => { - const aShape = inputs[0].dims; - const bShape = inputs[1].dims; - - const M = aShape[aShape.length - 2]; - const N = bShape[bShape.length - 1]; - const K = aShape[aShape.length - 1]; - const components = getMaxComponents(N); - const aComponents = getMaxComponents(K); - const outputNumber = getMaxComponents(M); - const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; - const hasBias = inputs.length > 2; - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchSize = ShapeUtil.size(outerDims); - const outputShapeInShader = [batchSize, M, N]; - - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, - { type: DataType.uint32, data: M }, - { type: DataType.uint32, data: N }, - { type: DataType.uint32, data: K }, - ]; - appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - } - programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); - const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); - const b = inputVariable('b', inputs[1].dataType, bShape.length, components); - const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const baseType = tensorTypeToWsglStorageType(output.type.tensor); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); - const inputVariables = [a, b]; - let processBias = ''; - if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); - processBias = `${ - isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);` - }`; - } - - const outerDimsA = aShape.slice(0, -2); - const outerDimsB = bShape.slice(0, -2); - const broadCastADims = getBroadcastDims(outerDimsA, outerDims); - const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); - const uniforms: UniformsArrayType = [ - { name: 'output_size', type: 'u32' }, - { name: 'M', type: 'u32' }, - { name: 'N', type: 'u32' }, - { name: 'K', type: 'u32' }, - ]; - appendActivationUniforms(activationAttributes, uniforms); - - const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { - const rank = variable.rank; - const name = variable.name; - if (rank === 2) { - return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`; - } - const batchRank = batchDims.rank; - let resStr = `var ${name}_indices: ${variable.type.indices};`; - for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`; - } - broadCastDims.forEach((i) => { - resStr += `\n${name}_indices[${i}] = 0;`; - }); - resStr += `${name}_indices[${rank - 2}] = 0u; - ${name}_indices[${rank - 1}] = 0u;`; - return resStr; - }; - - const calcResult = (): string => { - let calcStr = `var a_data: ${a.type.value};`; - for (let i = 0; i < aComponents; i++) { - calcStr += ` - let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`; - } - for (let i = 0; i < outputNumber; i++) { - calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; - - for (let j = 0; j < aComponents; j++) { - calcStr += ` - values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`; - } - } - return calcStr; - }; - - return ` - ${shaderHelper - .registerUniforms(uniforms) - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - let col = (global_idx % (uniforms.N / ${components})) * ${components}; - var index1 = global_idx / (uniforms.N / ${components}); - let stride1 = uniforms.M / ${outputNumber}; - let row = (index1 % stride1) * ${outputNumber}; - let batch = index1 / stride1; - - ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`} - ${getIndices(a, broadCastADims)} - let a_offset = ${a.indicesToOffset('a_indices')}; - ${getIndices(b, broadCastBDims)} - let b_offset = ${b.indicesToOffset('b_indices')}; - var values: array<${output.type.value}, ${outputNumber}>; - for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) { - ${calcResult()} - } - for (var i = 0u; i < ${outputNumber}u; i++) { - var value = values[i]; - ${processBias} - ${applyActivation} - let cur_indices = ${output.type.indices}(batch, row + i, col); - let offset = ${output.indicesToOffset('cur_indices')}; - ${output.setByOffset(`offset / ${components}`, 'value')}; - } - } - `; - }; - return { - name: 'MatMulNaive', - shaderCache: { - hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, - inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], - }, - getRunData: () => ({ - outputs: [ - { - dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType, - }, - ], - dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, - programUniforms, - }), - getShaderSource, - }; -}; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -201,6 +29,20 @@ export const matMul = (context: ComputeContext): void => { if (N < 8 && K < 8) { context.compute(createNaiveMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); + const M = outputShape[outputShape.length - 2]; + const batchA = ShapeUtil.size(context.inputs[0].dims.slice(0, -2)); + const batchB = ShapeUtil.size(context.inputs[1].dims.slice(0, -2)); + if (batchA !== 1 && M === 1 && batchB === 1) { + // Optimization for batched vec-mat-mul + const reshapedA = context.inputs[0].reshape([1, batchA, K]); + const reshapedB = context.inputs[1].reshape([1, K, N]); + const matmulOutputShape = [1, batchA, N]; + const matmulInputs = [reshapedA, reshapedB]; + context.compute(createMatmulProgramInfo(matmulInputs, { activation: '' }, outputShape, matmulOutputShape), { + inputs: matmulInputs, + }); + } else { + context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); + } } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index bf64b04dde1e8..fe0c3712197c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -4,7 +4,7 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; -import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types'; +import { ComputeContext, ProgramInfo } from '../types'; import { inputVariable, outputVariable, ShaderHelper } from './common'; import { createReduceAttributesFromInputs, ReduceAttributes } from './reduce'; @@ -119,7 +119,7 @@ const getAxesPermutation = (axes: number[], rank: number): number[] => { export const createReduceSharedProgramInfo = ( name: string, - shaderCache: ProgramShaderCacheInfo, + cacheKey: string, inputs: readonly TensorView[], reduceType: string, outputDataType: DataType, @@ -134,7 +134,11 @@ export const createReduceSharedProgramInfo = ( const input = inputVariable('_A', inputs[0].dataType, inputShape); const output = outputVariable('output', outputDataType, outputShape); - const workgroupSize = 32; + let workgroupSize = 64; + // If only one workgroup is dispatched, increase workgroupSize to improve parallelism. + if (outputSize === 1) { + workgroupSize = 256; + } const sharedMemorySnippet = ` var aBestValues : array; @@ -188,7 +192,8 @@ export const createReduceSharedProgramInfo = ( // One work group is responsible for only one element of output. return { name, - shaderCache, + // Note that in JSEP, WG size is not included in cache by default, but WebGPU EP it is. + shaderCache: { hint: `${cacheKey};${workgroupSize}`, inputDependencies: ['type'] }, getShaderSource, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: outputDataType }], @@ -233,7 +238,7 @@ const reduceCommon = ( context.compute( createReduceSharedProgramInfo( name, - { hint: updatedAttributes.cacheKey, inputDependencies: ['type'] }, + updatedAttributes.cacheKey, [input], reduceType, context.inputs[0].dataType, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts new file mode 100644 index 0000000000000..8c24232d63c0c --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { + atomicOutputVariable, + createTensorShapeVariables, + inputVariable, + outputVariable, + ShaderHelper, +} from './common'; + +export interface ScatterNDAttributes extends AttributeWithCacheKey { + reduction: string; +} + +type ReductionType = 'i32' | 'u32' | 'f32'; + +const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: ReductionType) => { + if (reduction !== 'none' && type !== 'i32' && type !== 'u32' && type !== 'f32') { + throw new Error(`Input ${type} is not supported with reduction ${reduction}.`); + } + + const floatStart = `{ + var oldValue = 0; + loop { + let newValueF32 =`; + const floatEnd = `; + let newValue = bitcast(newValueF32); + let res = atomicCompareExchangeWeak(&${ptr}, oldValue, newValue); + if res.exchanged { + break; + } + oldValue = res.old_value; + } + }`; + + switch (reduction) { + case 'none': + return `${ptr}=${v};`; + case 'add': + if (type === 'i32' || type === 'u32') { + return `atomicAdd(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicAdd only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return ` + ${floatStart}bitcast<${type}>(oldValue) + (${v})${floatEnd}`; + } + case 'max': + if (type === 'i32' || type === 'u32') { + return `atomicMax(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicMax only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return ` + ${floatStart}max(bitcast(oldValue), (${v}))${floatEnd}`; + } + case 'min': + if (type === 'i32' || type === 'u32') { + return `atomicMin(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicMin only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return `${floatStart}min(bitcast<${type}>(oldValue), (${v}))${floatEnd}`; + } + case 'mul': + // atomicMul is not supported, we use atomicCompareExchangeWeak to simulate. + return `${floatStart}(bitcast<${type}>(oldValue) * (${v}))${floatEnd}`; + + default: + throw new Error(`Reduction ${reduction} is not supported.`); + } +}; + +const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const indicesShape = inputs[1].dims; + const outputShape = inputShape; + // TODO: support bool with components 4. + const components = 1; + const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components); + const lastIndexDimension = indicesShape[indicesShape.length - 1]; + const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: lastIndexDimension }, + { type: DataType.uint32, data: numUpdatesElements }, + ...createTensorShapeVariables(inputs[1].dims, inputs[2].dims, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const indices = inputVariable('indices', inputs[1].dataType, inputs[1].dims.length); + const updates = inputVariable('updates', inputs[2].dataType, inputs[2].dims.length, components); + const output = + attributes.reduction !== 'none' && attributes.reduction !== '' + ? atomicOutputVariable('output', inputs[0].dataType, outputShape.length) + : outputVariable('output', inputs[0].dataType, outputShape.length, components); + + return ` + ${shaderHelper + .registerUniform('output_size', 'u32') + .registerUniform('last_index_dimension', 'u32') + .registerUniform('num_updates_elements', 'u32') + .declareVariables(indices, updates, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var data_offset = 0u; + let indices_start = uniforms.last_index_dimension * global_idx; + let indices_end = indices_start + uniforms.last_index_dimension; + for (var i = indices_start; i < indices_end; i++) { + var index = i32(indices[i].x); + ${ + inputs[0].dims.length === 1 + ? ` + let element_count_dim = uniforms.output_strides; + let dim_value = uniforms.output_shape;` + : ` + let element_count_dim = uniforms.output_strides[i - indices_start]; + let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];` + } + if (index >= 0) { + if (index >= i32(dim_value)) { + index = i32(dim_value - 1); + } + } else { + if (index < -i32(dim_value)) { + index = 0; + } else { + index += i32(dim_value); + } + } + data_offset += u32((u32(index) * element_count_dim)); + } + + for (var i = 0u; i < uniforms.num_updates_elements; i++) { + let value = updates[uniforms.num_updates_elements * global_idx + i]; + ${atomicReductionSnippet( + attributes.reduction, + 'output[data_offset + i]', + 'value', + output.type.value as ReductionType, + )} + } + + }`; + }; + return { + name: 'ScatterND', + shaderCache: { + hint: `${attributes.cacheKey}_${attributes.reduction}`, + inputDependencies: ['rank', 'rank'], + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; + +export const parseScatterNDAttributes = (attributes: Record): ScatterNDAttributes => + createAttributeWithCacheKey({ reduction: attributes.reduction as string }); + +export const scatterND = (context: ComputeContext, attributes: ScatterNDAttributes): void => { + context.compute(createScatterNDProgramInfo(context.inputs, attributes), { + inputs: [context.inputs[1], context.inputs[2]], + outputs: [], + }); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index fbab44e211946..7c62d1f7182a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -35,7 +35,6 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt const input = context.inputs[0]; const inputShape = input.dims; const outputSize = ShapeUtil.size(inputShape); - const WG = 64; const inputRank = inputShape.length; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); const isTransposeRequired = axis < inputShape.length - 1; @@ -60,7 +59,11 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt const rows = outputSize / cols; const components = getMaxComponents(cols); const packedCols = cols / components; - + let WG = 64; + // If only one workgroup is dispatched, increase workgroupSize to improve parallelism. + if (rows === 1) { + WG = 256; + } const maxVector = (name: string, components: number) => { if (components === 4) { return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`; @@ -95,7 +98,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt result[index] = value; } ${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)} - ${shaderHelper.mainStart()} + ${shaderHelper.mainStart(WG)} let gindex = i32(global_idx); let lindex = i32(local_idx); const wg = ${WG}; @@ -156,7 +159,8 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt const result = context.compute( { name: 'Softmax', - shaderCache: { hint: `${components}`, inputDependencies: ['type'] }, + // Note that in JSEP, WG size is not included in cache by default, but WebGPU EP it is. + shaderCache: { hint: `${components};${WG}`, inputDependencies: ['type'] }, getRunData: () => ({ outputs: [{ dims: transposedInputShape, dataType: transposedInput.dataType }], dispatchGroup: { x: rows }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 1fd99d085e0ed..5059645211aea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -29,7 +29,9 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} { var a: ${input.type.indices};`; for (let i = 0; i < rank; ++i) { - reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`); + // input indices and output indices should always be larger or equal to 2, + // so indexer is always valid to be used on `a` and `i`. + reverseFunc += `a[${perm[i]}]=i[${i}];`; } return (reverseFunc += 'return a;}'); }; @@ -48,17 +50,61 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh return { newShape, newPerm }; }; +const isTransposeReshape = (perm: number[], shape: readonly number[]) => { + // As long as the dims with values > 1 stay in the same order, it's a reshape. + // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + let lastPermutedAxis = 0; + for (let i = 0; i < perm.length; ++i) { + if (shape[perm[i]] === 1) { + continue; + } + if (perm[i] < lastPermutedAxis) { + return false; + } + lastPermutedAxis = perm[i]; + } + return true; +}; + export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); const outputShape = getOutputShape(inputTensor.dims, perm); + let newInputShape = inputTensor.dims; + let newOutputShape = outputShape; + const transposeAsReshape = inputRank < 2 || isTransposeReshape(perm, inputTensor.dims); + let getShaderSource; + if (transposeAsReshape) { + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('input', inputDataType, newInputShape, 4); + const output = outputVariable('output', inputDataType, newOutputShape, 4); + return ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + output[global_idx] = input[global_idx]; + }`; + }; + + return { + name: 'TransposeCopy', + shaderCache: { inputDependencies: ['type'] }, + getRunData: () => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* components */) }, + programUniforms: [{ type: DataType.uint32, data: Math.ceil(outputSize / 4) }], + }; + }, + getShaderSource, + }; + } const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm); const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]); const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]); - const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst; - let newInputShape = useShared ? newShape : inputTensor.dims; - let newOutputShape = outputShape; + const useShared = newShape.length === 2 || channelsLast || channelsFirst; if (useShared) { newInputShape = channelsLast ? [newShape[0], newShape[1] * newShape[2]] @@ -66,13 +112,11 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ? [newShape[0] * newShape[1], newShape[2]] : newShape; newOutputShape = [newInputShape[1], newInputShape[0]]; - } - const input = inputVariable('a', inputDataType, newInputShape.length); - const output = outputVariable('output', inputDataType, newOutputShape.length); - const tileSize = 16; - let getShaderSource; - if (useShared) { - getShaderSource = (shaderHelper: ShaderHelper) => ` + const tileSize = 16; + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + return ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} var tile : array, ${tileSize}>; ${shaderHelper.mainStart([tileSize, tileSize, 1])} @@ -92,8 +136,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')} } }`; - } else { - getShaderSource = (shaderHelper: ShaderHelper) => ` + }; + return { + name: 'TransposeShared', + shaderCache: { inputDependencies: ['type'] }, + getRunData: () => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], + dispatchGroup: { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(newInputShape, newOutputShape), + ], + }; + }, + getShaderSource, + }; + } + + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + return ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -106,17 +171,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; - } + }; return { - name: useShared ? 'TransposeShared' : 'Transpose', + name: 'Transpose', shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] }, getRunData: () => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], - dispatchGroup: useShared - ? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) } - : { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: [ { type: DataType.uint32, data: outputSize }, ...createTensorShapeVariables(newInputShape, newOutputShape), diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index c5b8f579c3aae..2c5180c5db3ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -93,13 +93,23 @@ export class ProgramManager { build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { TRACE_FUNC_BEGIN(programInfo.name); const device = this.backend.device; - const extensions: string[] = []; - if (device.features.has('shader-f16')) { - extensions.push('enable f16;'); - } + const enableDirectives: string[] = []; + + // Enable WGSL extensions based on available WebGPU features + const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [ + { feature: 'shader-f16', extension: 'f16' }, + { feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' }, + { feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' }, + ]; + extensionsInfo.forEach((info) => { + if (device.features.has(info.feature)) { + enableDirectives.push(`enable ${info.extension};`); + } + }); + const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits); const userCode = programInfo.getShaderSource(shaderHelper); - const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; + const code = `${enableDirectives.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({ code, label: programInfo.name }); LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`); diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 3b3c55733c973..9321ac170d036 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -21,6 +21,11 @@ export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean; isVendor: (vendor: GpuVendor) => boolean; } +export interface DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; +} export interface GpuData { type: GpuDataType; @@ -160,6 +165,11 @@ export interface ComputeContext { */ readonly adapterInfo: AdapterInfo; + /** + * gpu device info + */ + readonly deviceInfo: DeviceInfo; + /** * stores the pointer to OpKernelContext */ @@ -187,8 +197,6 @@ export interface ComputeContext { compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[]; output(index: number, dims: readonly number[]): number; - getMaxComputeWorkgroupSizes(): [number, number, number]; - getMaxComputeWorkgroupStoragesize(): number; } export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes'; diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 916dec4545af3..4932691bda65b 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -54,6 +54,33 @@ export interface TensorManager { let tensorGuid = 1; const createNewTensorId = (): TensorId => tensorGuid++; +/** + * Map from MLOperandDataType to size in bits. Using bits instead of bytes to avoid possible precision loss on int4 and uint4. + */ +const webnnDataTypeToSize = new Map([ + ['float32', 32], + ['float16', 16], + ['int32', 32], + ['uint32', 32], + ['int64', 64], + ['uint64', 64], + ['int8', 8], + ['uint8', 8], + ['int4', 4], + ['uint4', 4], +]); + +/** + * Calculate the byte length of a tensor with the given data type and shape. + */ +const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number[]): number => { + const size = webnnDataTypeToSize.get(dataType); + if (!size) { + throw new Error('Unsupported data type.'); + } + return shape.length > 0 ? Math.ceil((shape.reduce((a, b) => a * b) * size) / 8) : 0; +}; + /** * TensorWrapper wraps an MLTensor and provides a way to track the last session that used it. */ @@ -92,6 +119,10 @@ class TensorWrapper { return this.tensorShape; } + public get byteLength(): number { + return calculateByteLength(this.dataType, this.tensorShape); + } + public destroy(): void { LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy'); this.mlTensor.destroy(); @@ -111,7 +142,11 @@ class TensorWrapper { } public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean { - return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]); + return ( + this.dataType === dataType && + this.tensorShape.length === shape.length && + this.tensorShape.every((v, i) => v === shape[i]) + ); } } @@ -136,6 +171,7 @@ class TensorIdTracker { public releaseTensor(): void { if (this.tensorWrapper) { this.tensorManager.releaseTensor(this.tensorWrapper); + this.wrapper = undefined; } } @@ -149,6 +185,9 @@ class TensorIdTracker { return this.wrapper.tensor; } else { if (copyOld) { + if (this.wrapper.byteLength !== calculateByteLength(dataType, shape)) { + throw new Error('Unable to copy data to tensor with different size.'); + } this.activeUpload = new Uint8Array(await this.wrapper.read()); } this.tensorManager.releaseTensor(this.wrapper); @@ -156,7 +195,7 @@ class TensorIdTracker { } // eslint-disable-next-line no-bitwise - const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; + const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE; this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true); if (copyOld && this.activeUpload) { @@ -169,8 +208,13 @@ class TensorIdTracker { public upload(data: Uint8Array): void { if (this.wrapper) { - this.wrapper.write(data); - return; + if (data.byteLength === this.wrapper.byteLength) { + this.wrapper.write(data); + return; + } else { + LOG_DEBUG('verbose', () => 'Data size does not match tensor size. Releasing tensor.'); + this.releaseTensor(); + } } if (this.activeUpload) { @@ -305,13 +349,14 @@ class TensorManagerImpl implements TensorManager { public async getCachedTensor( dataType: MLOperandDataType, shape: readonly number[], - usage: MLTensorUsageFlags, + usage: MLTensorUsageFlags | undefined, writable: boolean, readable: boolean, ): Promise { const sessionId = this.backend.currentSessionId; for (const [index, tensor] of this.freeTensors.entries()) { if (tensor.sameTypeAndShape(dataType, shape)) { + LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`); const wrapper = this.freeTensors.splice(index, 1)[0]; wrapper.sessionId = sessionId; return wrapper; diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 2620168738dac..c513b2ec2ed8b 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -36,8 +36,8 @@ interface MLOperandDescriptor { dimensions?: readonly number[]; } interface MLOperand { - dataType(): MLOperandDataType; - shape(): number[]; + dataType: MLOperandDataType; + shape: readonly number[]; } interface MLActivation {} type MLNamedOperands = Record; @@ -400,7 +400,8 @@ declare const MLTensorUsage: { }; interface MLTensorDescriptor extends MLOperandDescriptor { - usage: MLTensorUsageFlags; + /** @deprecated Use readable/writeable instead of usage */ + usage: MLTensorUsageFlags | undefined; importableToWebGPU?: boolean; readable?: boolean; writable?: boolean; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index eb74aa44b3a72..da8939cd0263a 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -291,9 +291,6 @@ export const createSession = async ( const providerName = typeof provider === 'string' ? provider : provider.name; if (providerName === 'webnn') { wasm.shouldTransferToMLTensor = false; - if (wasm.currentContext) { - throw new Error('WebNN execution provider is already set.'); - } if (typeof provider !== 'string') { const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption; const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; @@ -303,12 +300,12 @@ export const createSession = async ( if (context) { wasm.currentContext = context as MLContext; } else if (gpuDevice) { - wasm.currentContext = await navigator.ml.createContext(gpuDevice); + wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice); } else { - wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference }); + wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference }); } } else { - wasm.currentContext = await navigator.ml.createContext(); + wasm.currentContext = await wasm.jsepCreateMLContext!(); } break; } @@ -490,7 +487,7 @@ export const prepareInputOutputTensor = ( } if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const gpuBuffer = tensor[2].gpuBuffer; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; const registerBuffer = wasm.jsepRegisterBuffer; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index dff3ca74de5a4..ebeac5dc9e587 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -225,6 +225,30 @@ export declare namespace JSEP { * @returns the MLTensor ID for the external MLTensor. */ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + + /** + * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. + * @param optionsOrGpuDevice - specify the options or GPUDevice. + * @returns + */ + jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise; + + /** + * [exported from pre-jsep.js] Register a WebNN Constant operand from external data. + * @param externalFilePath - specify the external file path. + * @param dataOffset - specify the external data offset. + * @param dataLength - specify the external data length. + * @param builder - specify the MLGraphBuilder used for constructing the Constant. + * @param desc - specify the MLOperandDescriptor of the Constant. + * @returns the WebNN Constant operand for the specified external data. + */ + jsepRegisterMLConstant( + externalFilePath: string, + dataOffset: number, + dataLength: number, + builder: MLGraphBuilder, + desc: MLOperandDescriptor, + ): MLOperand; } } diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 894667ad58933..07c8f0bf3b940 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -861,9 +861,9 @@ } }, "node_modules/cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dev": true, "dependencies": { "nice-try": "^1.0.4", @@ -4312,9 +4312,9 @@ } }, "cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dev": true, "requires": { "nice-try": "^1.0.4", diff --git a/js/web/package.json b/js/web/package.json index 1ba06b3953748..181d6127f5455 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -78,25 +78,21 @@ "types": "./types.d.ts" }, "./all": { - "node": null, "import": "./dist/ort.all.bundle.min.mjs", "require": "./dist/ort.all.min.js", "types": "./types.d.ts" }, "./wasm": { - "node": null, - "import": "./dist/ort.wasm.min.mjs", + "import": "./dist/ort.wasm.bundle.min.mjs", "require": "./dist/ort.wasm.min.js", "types": "./types.d.ts" }, "./webgl": { - "node": null, "import": "./dist/ort.webgl.min.mjs", "require": "./dist/ort.webgl.min.js", "types": "./types.d.ts" }, "./webgpu": { - "node": null, "import": "./dist/ort.webgpu.bundle.min.mjs", "require": "./dist/ort.webgpu.min.js", "types": "./types.d.ts" diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 408f9e00a5cbd..529e9d1065e69 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -591,14 +591,14 @@ async function main() { // ort[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true' }, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, }); // ort.bundle.min.mjs await buildOrt({ isProduction: true, outputName: 'ort.bundle', format: 'esm', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true' }, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true' }, }); // ort.webgpu[.min].[m]js @@ -619,6 +619,13 @@ async function main() { outputName: 'ort.wasm', define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, }); + // ort.wasm.bundle.min.mjs + await buildOrt({ + isProduction: true, + outputName: 'ort.wasm.bundle', + format: 'esm', + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, + }); // ort.webgl[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort.webgl', diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 613b4507b2b15..8fbe9339feb9b 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -134,6 +134,56 @@ "type": "float32" } ] + }, + { + "name": "Expand in components = 1, out components = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [3, 2, 1], + "type": "float32" + }, + { + "data": [3, 1, 8], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [ + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, + 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6 + ], + "dims": [3, 2, 8], + "type": "float32" + } + ] + }, + { + "name": "Expand in components = 4, out components = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + "data": [2, 1, 8], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [1, 2, 2, 8], + "type": "float32" + } + ] } ] }, diff --git a/js/web/test/data/ops/gather-nd.jsonc b/js/web/test/data/ops/gather-nd.jsonc new file mode 100644 index 0000000000000..209c7d1f74087 --- /dev/null +++ b/js/web/test/data/ops/gather-nd.jsonc @@ -0,0 +1,147 @@ +[ + { + "name": "GatherND int32", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100, 101, 102, 777, 778, 779, 1000, 1001, 1002], + "dims": [9], + "type": "int32" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100, 778, 1002], + "dims": [3], + "type": "int32" + } + ] + } + ] + }, + { + "name": "GatherND float32", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9], + "dims": [9], + "type": "float32" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100.0999984741211, 778.5, 1002.9000244140625], + "dims": [3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherND int32 [2 2 2], batch_dims", + "operator": "GatherND", + "attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 2, 2], + "type": "int32" + }, + { + "data": [1, 0], + "dims": [2, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [2, 3, 4, 5], + "dims": [2, 2], + "type": "int32" + } + ] + } + ] + }, + { + "name": "GatherND float16", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9], + "dims": [9], + "type": "float16" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100.0999984741211, 778.5, 1002.9000244140625], + "dims": [3], + "type": "float16" + } + ] + } + ] + }, + { + "name": "GatherND uint32 [2 2 2], batch_dims", + "operator": "GatherND", + "attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 2, 2], + "type": "uint32" + }, + { + "data": [1, 0], + "dims": [2, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [2, 3, 4, 5], + "dims": [2, 2], + "type": "uint32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/matmul.jsonc b/js/web/test/data/ops/matmul.jsonc index 2c2cf509d7e3e..f5996db1aecb6 100644 --- a/js/web/test/data/ops/matmul.jsonc +++ b/js/web/test/data/ops/matmul.jsonc @@ -95,6 +95,56 @@ } ] }, + { + "name": "multiplies 3D tensors with M = 1", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [6, 1, 8], + "type": "float32" + }, + { + "data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [1, 8, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550], + "dims": [6, 1, 3], + "type": "float32" + } + ] + }, + { + "name": "multiplies 4D tensors with M = 1", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [2, 3, 1, 8], + "type": "float32" + }, + { + "data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [1, 1, 8, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550], + "dims": [2, 3, 1, 3], + "type": "float32" + } + ] + }, { "name": "multiplies 4D tensors", "inputs": [ @@ -313,6 +363,100 @@ "type": "float32" } ] + }, + { + "name": "same ranks different broadcast small 0", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + "data": [8, 9, 10, 11], + "dims": [2, 1, 2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 43, 77, 111, 11, 53, 95, 137], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "same ranks different broadcast small 1", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 1, 2, 2], + "type": "float32" + }, + { + "data": [8, 9, 10, 11], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 43, 11, 53, 77, 111, 95, 137], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "same ranks different broadcast larger 0", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [1, 2, 2, 8], + "type": "float32" + }, + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [2, 1, 8, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1036, 3308, 5580, 7852, 1260, 4044, 6828, 9612], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "same ranks different broadcast larger 1", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [2, 1, 2, 8], + "type": "float32" + }, + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1036, 3308, 1260, 4044, 5580, 7852, 6828, 9612], + "dims": [2, 2, 2, 1], + "type": "float32" + } + ] } ] } diff --git a/js/web/test/data/ops/scatternd.jsonc b/js/web/test/data/ops/scatternd.jsonc new file mode 100644 index 0000000000000..5135bb9e4d3a5 --- /dev/null +++ b/js/web/test/data/ops/scatternd.jsonc @@ -0,0 +1,472 @@ +[ + { + "name": "ScatterND int32", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 13 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9, 10, 11, 12], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 11, 3, 10, 9, 6, 7, 12], + "dims": [8], + "type": "int32" + } + ] + }, + { + "name": "int32", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [4, 4, 4], + "type": "int32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "int64" + }, + { + "data": [ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131 + ], + "dims": [2, 4, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [4, 4, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND float32", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 13 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.1, 11.3, 3.1, 10.2, 9.1, 6.1, 7.8, 12.5], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND add int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "add", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9, 10, 11, 12], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 13, 3, 14, 14, 6, 7, 20], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND add float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "add", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 13.5, 3.0999999046325684, 14.699999809265137, 14.40000057220459, 6.099999904632568, + 7.800000190734863, 21.399999618530273 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND mul int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 2486, 31, 4590, 4823, 61, 78, 11125], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND min int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "min", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND max int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "max", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 113, 31, 102, 91, 61, 78, 125], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND mul float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 24.860000610351562, 3.0999999046325684, 45.89999771118164, 48.230003356933594, + 6.099999904632568, 7.800000190734863, 111.24999237060547 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND min float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "min", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 2.200000047683716, 3.0999999046325684, 4.5, 5.300000190734863, 6.099999904632568, + 7.800000190734863, 8.899999618530273 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND max float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "max", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 11.300000190734863, 3.0999999046325684, 10.199999809265137, 9.100000381469727, + 6.099999904632568, 7.800000190734863, 12.5 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND float16", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 11 }, + "cases": [ + { + "name": "float16", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float16" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float16" + } + ], + "outputs": [ + { + "data": [1.1, 11.3, 3.1, 10.2, 9.1, 6.1, 7.8, 12.5], + "dims": [8], + "type": "float16" + } + ] + } + ] + }, + { + "name": "ScatterND mul uint32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "uint32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "uint32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [11, 2486, 31, 4590, 4823, 61, 78, 11125], + "dims": [8], + "type": "uint32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index a7265d6444118..d431ceb1712a5 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -263,6 +263,30 @@ } ] }, + { + "name": "Transpose as reshape - perms:[1, 0, 2, 4, 3]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [1, 0, 2, 4, 3], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [1, 3, 2, 4, 1], + "type": "float32" + } + ] + } + ] + }, { "name": "Transpose - perms:[1, 0]", "operator": "Transpose", diff --git a/js/web/test/e2e/browser-test-wasm-binary-override.js b/js/web/test/e2e/browser-test-wasm-binary-override.js index 471c26f6990b5..27cce2ca06236 100644 --- a/js/web/test/e2e/browser-test-wasm-binary-override.js +++ b/js/web/test/e2e/browser-test-wasm-binary-override.js @@ -7,7 +7,7 @@ const documentUrl = document.currentScript.src; it('Browser E2E testing - WebAssembly backend', async function () { // preload .wasm file binary - const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.wasm', documentUrl).href; + const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.jsep.wasm', documentUrl).href; const response = await fetch(wasmUrl); // make sure the .wasm file is loaded successfully diff --git a/js/web/test/e2e/browser-test-wasm-path-override-filename-jsep.js b/js/web/test/e2e/browser-test-wasm-path-override-filename-jsep.js new file mode 100644 index 0000000000000..d325a5ca7187d --- /dev/null +++ b/js/web/test/e2e/browser-test-wasm-path-override-filename-jsep.js @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +it('Browser E2E testing - WebAssembly backend (path override filename)', async function () { + // check base URL port from test args + if (typeof __ort_arg_port === 'undefined') { + throw new Error('test flag --port= is required'); + } + const base = `http://localhost:${__ort_arg_port}/`; + + ort.env.wasm.wasmPaths = {}; + + if (typeof __ort_arg_files === 'string' && __ort_arg_files.includes('wasm')) { + const overrideWasmUrl = new URL('./test-wasm-path-override/jsep-renamed.wasm', base).href; + console.log(`ort.env.wasm.wasmPaths['wasm'] = ${JSON.stringify(overrideWasmUrl)};`); + ort.env.wasm.wasmPaths.wasm = overrideWasmUrl; + } + + if (typeof __ort_arg_files === 'string' && __ort_arg_files.includes('mjs')) { + const overrideMjsUrl = new URL('./test-wasm-path-override/jsep-renamed.mjs', base).href; + console.log(`ort.env.wasm.wasmPaths['mjs'] = ${JSON.stringify(overrideMjsUrl)};`); + ort.env.wasm.wasmPaths.mjs = overrideMjsUrl; + } + + await testFunction(ort, { executionProviders: ['wasm'] }); +}); diff --git a/js/web/test/e2e/run-data.js b/js/web/test/e2e/run-data.js index 04079b042bc23..dbc3ca0bd2460 100644 --- a/js/web/test/e2e/run-data.js +++ b/js/web/test/e2e/run-data.js @@ -14,7 +14,7 @@ const NODEJS_TEST_CASES = [ // [test_for_same_origin, test_for_cross_origin, main_js, ort_main_js, [test_args]] const BROWSER_TEST_CASES = [ // IIFE - [true, true, './browser-test-webgl.js', 'ort.min.js'], // webgl + [true, true, './browser-test-webgl.js', 'ort.all.min.js'], // webgl [true, true, './browser-test-webgl.js', 'ort.webgl.min.js'], // webgl [true, true, './browser-test-wasm.js', 'ort.wasm.min.js'], // wasm, ort.wasm [true, true, './browser-test-wasm-multi-session-create.js', 'ort.min.js'], // wasm, multi-session create @@ -24,7 +24,7 @@ const BROWSER_TEST_CASES = [ [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=1', 'proxy=1']], // wasm, 1 thread, proxy // ort.min.mjs - [true, true, './browser-test-webgl.js', 'ort.min.mjs'], // webgl + [true, true, './browser-test-webgl.js', 'ort.webgl.min.mjs'], // webgl [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=1']], // wasm, 1 thread [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2']], // wasm, 2 threads [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2', 'proxy=1']], // wasm, 2 threads, proxy @@ -41,22 +41,22 @@ const BROWSER_TEST_CASES = [ // path override: // wasm, path override filenames for both mjs and wasm, same origin - [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=mjs,wasm']], + [true, false, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=9876', 'files=mjs,wasm']], [true, false, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=9876', 'files=mjs,wasm']], // wasm, path override filenames for both mjs and wasm, cross origin - [false, true, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=8081', 'files=mjs,wasm']], + [false, true, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=8081', 'files=mjs,wasm']], [false, true, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=8081', 'files=mjs,wasm']], // wasm, path override filename for wasm, same origin - [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=wasm']], + [true, false, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=9876', 'files=wasm']], [true, false, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=9876', 'files=wasm']], // wasm, path override filename for wasm, cross origin - [false, true, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=8081', 'files=wasm']], + [false, true, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=8081', 'files=wasm']], [false, true, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=8081', 'files=wasm']], // wasm, path override filename for mjs, same origin - [true, false, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=9876', 'files=mjs']], + [true, false, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=9876', 'files=mjs']], [true, false, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=9876', 'files=mjs']], // wasm, path override filename for mjs, cross origin - [false, true, './browser-test-wasm-path-override-filename.js', 'ort.min.js', ['port=8081', 'files=mjs']], + [false, true, './browser-test-wasm-path-override-filename-jsep.js', 'ort.min.js', ['port=8081', 'files=mjs']], [false, true, './browser-test-wasm-path-override-filename.js', 'ort.wasm.min.js', ['port=8081', 'files=mjs']], // wasm, path override prefix, same origin [true, false, './browser-test-wasm-path-override-prefix.js', 'ort.min.js', ['port=9876']], diff --git a/js/web/test/e2e/run.js b/js/web/test/e2e/run.js index 93f9d4a144bf2..3361bbece64ed 100644 --- a/js/web/test/e2e/run.js +++ b/js/web/test/e2e/run.js @@ -146,6 +146,10 @@ function prepareWasmPathOverrideFiles() { fs.copyFileSync(`${sourceFile}.wasm`, path.join(folder, 'ort-wasm-simd-threaded.wasm')); fs.copyFileSync(`${sourceFile}.mjs`, path.join(folder, 'renamed.mjs')); fs.copyFileSync(`${sourceFile}.wasm`, path.join(folder, 'renamed.wasm')); + fs.copyFileSync(`${sourceFile}.jsep.mjs`, path.join(folder, 'ort-wasm-simd-threaded.jsep.mjs')); + fs.copyFileSync(`${sourceFile}.jsep.wasm`, path.join(folder, 'ort-wasm-simd-threaded.jsep.wasm')); + fs.copyFileSync(`${sourceFile}.jsep.mjs`, path.join(folder, 'jsep-renamed.mjs')); + fs.copyFileSync(`${sourceFile}.jsep.wasm`, path.join(folder, 'jsep-renamed.wasm')); } async function testAllNodejsCases() { diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index dcfc8ccc3928f..f179756967d49 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -570,14 +570,14 @@ "test_greater_equal_expanded", "test_greater_equal", "test_greater", - // // "test_gridsample_aligncorners_true", - // // "test_gridsample_bicubic", - // // "test_gridsample_bilinear", - // // "test_gridsample_border_padding", - // // "test_gridsample_nearest", - // // "test_gridsample_reflection_padding", - // // "test_gridsample_zeros_padding", - // // "test_gridsample", + "test_gridsample_aligncorners_true", + "test_gridsample_bicubic", + "test_gridsample_bilinear", + "test_gridsample_border_padding", + "test_gridsample_nearest", + "test_gridsample_reflection_padding", + "test_gridsample_zeros_padding", + "test_gridsample", // // "test_gru_batchwise", // // "test_gru_defaults", // // "test_gru_seq_length", @@ -1365,6 +1365,7 @@ "gather.jsonc", "gather-block-quantized.jsonc", "gather-elements.jsonc", + "gather-nd.jsonc", "gemm.jsonc", "global-average-pool.jsonc", "greater.jsonc", @@ -1396,6 +1397,7 @@ "pow-big-number.jsonc", "reshape.jsonc", "rotary-embedding.jsonc", + "scatternd.jsonc", "simplified-layer-norm.jsonc", "skip-layer-norm.jsonc", "skip-simplified-layer-norm.jsonc", @@ -1532,14 +1534,14 @@ "test_add_bcast", // "test_add_uint8", "test_add", - // "test_and_bcast3v1d", - // "test_and_bcast3v2d", - // "test_and_bcast4v2d", - // "test_and_bcast4v3d", - // "test_and_bcast4v4d", - // "test_and2d", - // "test_and3d", - // "test_and4d", + "test_and_bcast3v1d", + "test_and_bcast3v2d", + "test_and_bcast4v2d", + "test_and_bcast4v3d", + "test_and_bcast4v4d", + "test_and2d", + "test_and3d", + "test_and4d", "test_argmax_default_axis_example_select_last_index", "test_argmax_default_axis_example", "test_argmax_default_axis_random_select_last_index", @@ -1699,13 +1701,13 @@ "test_cos", // "test_cosh_example", // "test_cosh", - // "test_cumsum_1d_exclusive", - // "test_cumsum_1d_reverse_exclusive", - // "test_cumsum_1d_reverse", - // "test_cumsum_1d", - // "test_cumsum_2d_axis_0", - // "test_cumsum_2d_axis_1", - // "test_cumsum_2d_negative_axis", + "test_cumsum_1d_exclusive", + "test_cumsum_1d_reverse_exclusive", + "test_cumsum_1d_reverse", + "test_cumsum_1d", + "test_cumsum_2d_axis_0", + "test_cumsum_2d_axis_1", + "test_cumsum_2d_negative_axis", // "test_depthtospace_crd_mode_example", // "test_depthtospace_crd_mode", // "test_depthtospace_dcr_mode", @@ -1777,9 +1779,9 @@ "test_gather_elements_0", "test_gather_elements_1", "test_gather_elements_negative_indices", - // "test_gathernd_example_float32", - // "test_gathernd_example_int32_batch_dim1", - // "test_gathernd_example_int32", + "test_gathernd_example_float32", + "test_gathernd_example_int32_batch_dim1", + "test_gathernd_example_int32", "test_gemm_all_attributes", "test_gemm_alpha", "test_gemm_beta", @@ -2089,14 +2091,14 @@ // // "test_optional_get_element", // // "test_optional_has_element_empty", // // "test_optional_has_element", - // "test_or_bcast3v1d", - // "test_or_bcast3v2d", - // "test_or_bcast4v2d", - // "test_or_bcast4v3d", - // "test_or_bcast4v4d", - // "test_or2d", - // "test_or3d", - // "test_or4d", + "test_or_bcast3v1d", + "test_or_bcast3v2d", + "test_or_bcast4v2d", + "test_or_bcast4v3d", + "test_or_bcast4v4d", + "test_or2d", + "test_or3d", + "test_or4d", "test_pow_bcast_array", "test_pow_bcast_scalar", "test_pow_example", @@ -2254,15 +2256,15 @@ // // "test_round", // // "test_scan_sum", // // "test_scan9_sum", - // // "test_scatter_elements_with_axis", - // // "test_scatter_elements_with_duplicate_indices", - // // "test_scatter_elements_with_negative_indices", - // // "test_scatter_elements_without_axis", + "test_scatter_elements_with_axis", + "test_scatter_elements_with_duplicate_indices", + "test_scatter_elements_with_negative_indices", + "test_scatter_elements_without_axis", // // "test_scatter_with_axis", // // "test_scatter_without_axis", - // // "test_scatternd_add", - // // "test_scatternd_multiply", - // // "test_scatternd", + "test_scatternd_add", + "test_scatternd_multiply", + "test_scatternd", // // "test_sce_mean_3d_expanded", // // "test_sce_mean_3d_log_prob_expanded", // // "test_sce_mean_3d_log_prob", @@ -2352,7 +2354,7 @@ // "test_shrink_soft", "test_sigmoid_example", "test_sigmoid", - // "test_sign", + "test_sign", // "test_simple_rnn_batchwise", // "test_simple_rnn_defaults", // "test_simple_rnn_with_initial_bias", @@ -2362,14 +2364,14 @@ // "test_sinh", // // "test_size_example", // // "test_size", - // "test_slice_default_axes", - // "test_slice_default_steps", - // "test_slice_end_out_of_bounds", - // "test_slice_neg_steps", - // "test_slice_neg", - // "test_slice_negative_axes", - // "test_slice_start_out_of_bounds", - // "test_slice", + "test_slice_default_axes", + "test_slice_default_steps", + "test_slice_end_out_of_bounds", + "test_slice_neg_steps", + "test_slice_neg", + "test_slice_negative_axes", + "test_slice_start_out_of_bounds", + "test_slice", // "test_softmax_axis_0_expanded", "test_softmax_axis_0", // "test_softmax_axis_1_expanded", @@ -2550,16 +2552,16 @@ "test_unsqueeze", // "test_wrap_pad" // "test_upsample_nearest", - "test_where_example" + "test_where_example", // "test_where_long_example", - // "test_xor_bcast3v1d", - // "test_xor_bcast3v2d", - // "test_xor_bcast4v2d", - // "test_xor_bcast4v3d", - // "test_xor_bcast4v4d", - // "test_xor2d", - // "test_xor3d", - // "test_xor4d" + "test_xor_bcast3v1d", + "test_xor_bcast3v2d", + "test_xor_bcast4v2d", + "test_xor_bcast4v3d", + "test_xor_bcast4v4d", + "test_xor2d", + "test_xor3d", + "test_xor4d" ], "ops": [] } diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index c37c10c781400..5de39535a5c07 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -586,11 +586,11 @@ export class TensorResultValidator { } } -function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { +async function createGpuTensorForInput(cpuTensor: ort.Tensor): Promise { if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); } - const device = ort.env.webgpu.device as GPUDevice; + const device = await ort.env.webgpu.device; const gpuBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, @@ -612,14 +612,14 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { }); } -function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { +async function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { if (!isGpuBufferSupportedType(type)) { throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); } const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!; - const device = ort.env.webgpu.device as GPUDevice; + const device = await ort.env.webgpu.device; const gpuBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, @@ -661,7 +661,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty shape: dims as number[], // Assign both shape and dimensions while transitioning to new API. dimensions: dims as number[], - usage: MLTensorUsage.READ, + usage: typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ, readable: true, }); @@ -686,7 +686,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso shape: cpuTensor.dims as number[], // Assign both shape and dimensions while transitioning to new API. dimensions: cpuTensor.dims as number[], - usage: MLTensorUsage.WRITE, + usage: typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.WRITE, writable: true, }); mlContext.writeTensor(mlTensor, cpuTensor.data); @@ -725,7 +725,7 @@ export async function sessionRun(options: { if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') { feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]); } else { - feeds[name] = createGpuTensorForInput(feeds[name]); + feeds[name] = await createGpuTensorForInput(feeds[name]); } } } @@ -742,7 +742,7 @@ export async function sessionRun(options: { if (options.ioBinding === 'ml-tensor') { fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims); } else { - fetches[name] = createGpuTensorForOutput(type, dims); + fetches[name] = await createGpuTensorForOutput(type, dims); } } } diff --git a/objectivec/error_utils.mm b/objectivec/error_utils.mm index 335cf8894d549..e8d4d5bb365c9 100644 --- a/objectivec/error_utils.mm +++ b/objectivec/error_utils.mm @@ -11,7 +11,7 @@ void ORTSaveCodeAndDescriptionToError(int code, const char* descriptionCstr, NSE if (!error) return; NSString* description = [NSString stringWithCString:descriptionCstr - encoding:NSASCIIStringEncoding]; + encoding:NSUTF8StringEncoding]; *error = [NSError errorWithDomain:kOrtErrorDomain code:code diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h index d7d873f5eb0e0..41d15aa39453a 100644 --- a/objectivec/include/ort_coreml_execution_provider.h +++ b/objectivec/include/ort_coreml_execution_provider.h @@ -70,7 +70,22 @@ NS_ASSUME_NONNULL_BEGIN */ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOptions*)options error:(NSError**)error; - +/** + * Enables the CoreML execution provider in the session configuration options. + * It is appended to the execution provider list which is ordered by + * decreasing priority. + * + * @param provider_options The CoreML execution provider options in dict. + * available keys-values: more detail in core/providers/coreml/coreml_execution_provider.h + * kCoremlProviderOption_MLComputeUnits: one of "CPUAndNeuralEngine", "CPUAndGPU", "CPUOnly", "All" + * kCoremlProviderOption_ModelFormat: one of "MLProgram", "NeuralNetwork" + * kCoremlProviderOption_RequireStaticInputShapes: "1" or "0" + * kCoremlProviderOption_EnableOnSubgraphs: "1" or "0" + * @param error Optional error information set if an error occurs. + * @return Whether the provider was enabled successfully. + */ +- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options + error:(NSError**)error; @end NS_ASSUME_NONNULL_END diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm index 6cb5026b93521..0c790a91fb8b9 100644 --- a/objectivec/ort_coreml_execution_provider.mm +++ b/objectivec/ort_coreml_execution_provider.mm @@ -43,6 +43,21 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti #endif } +- (BOOL)appendCoreMLExecutionProviderWithOptionsV2:(NSDictionary*)provider_options + error:(NSError**)error { +#if ORT_OBJC_API_COREML_EP_AVAILABLE + try { + return [self appendExecutionProvider:@"CoreML" providerOptions:provider_options error:error]; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error); + +#else // !ORT_OBJC_API_COREML_EP_AVAILABLE + static_cast(provider_options); + ORTSaveCodeAndDescriptionToError(ORT_FAIL, "CoreML execution provider is not enabled.", error); + return NO; +#endif +} + @end NS_ASSUME_NONNULL_END diff --git a/objectivec/test/ort_session_test.mm b/objectivec/test/ort_session_test.mm index 508289f7bc748..409ee7e1584e2 100644 --- a/objectivec/test/ort_session_test.mm +++ b/objectivec/test/ort_session_test.mm @@ -223,6 +223,28 @@ - (void)testAppendCoreMLEP { ORTAssertNullableResultSuccessful(session, err); } +- (void)testAppendCoreMLEP_v2 { + NSError* err = nil; + ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions]; + NSDictionary* provider_options = @{@"EnableOnSubgraphs" : @"1"}; // set an arbitrary option + + BOOL appendResult = [sessionOptions appendCoreMLExecutionProviderWithOptionsV2:provider_options + error:&err]; + + if (!ORTIsCoreMLExecutionProviderAvailable()) { + ORTAssertBoolResultUnsuccessful(appendResult, err); + return; + } + + ORTAssertBoolResultSuccessful(appendResult, err); + + ORTSession* session = [[ORTSession alloc] initWithEnv:self.ortEnv + modelPath:[ORTSessionTest getAddModelPath] + sessionOptions:sessionOptions + error:&err]; + ORTAssertNullableResultSuccessful(session, err); +} + - (void)testAppendXnnpackEP { NSError* err = nil; ORTSessionOptions* sessionOptions = [ORTSessionTest makeSessionOptions]; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index b15e865aa423c..ad14fb8258656 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -30,7 +30,6 @@ class Attention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -102,7 +101,6 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, template Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { /* The PrePack() massages the weights to speed up Compute(), there is an option to diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index dc9ba80af5ba4..87938f3728750 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -77,7 +77,7 @@ class AttentionCPUBase : public AttentionBase { // Convert mask from boolean (0/1) to float (mask_filter_value/0.0f). // Merge padding mask with causal mask, and broadcast to 3D (BxSxT). PrepareMask(mask_index_data, mask_index_dims, static_cast(mask_data), - causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); + causal, batch_size, sequence_length, kv_sequence_length, past_sequence_length, mask_filter_value_); DUMP_CPU_TENSOR("Mask3D", static_cast(mask_data), batch_size, sequence_length, total_sequence_length); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 4d435f71cc195..37bb5664393c9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -120,9 +120,10 @@ void PrepareMask(const int32_t* mask_index, bool causal, int batch_size, int sequence_length, + int kv_sequence_length, int past_sequence_length, float mask_filter_value) { - const int all_sequence_length = past_sequence_length + sequence_length; + const int all_sequence_length = past_sequence_length + kv_sequence_length; // mask_data has been filled with 0, and its shape is BxSxT T* p_mask = mask_data; diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc index b2aaa9cb11beb..e6f65f92e14f4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc @@ -339,6 +339,7 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( T* attention_probs_ptr = reinterpret_cast(attention_probs) + last_offset; math::Dot(head_size, q_vec, K + i * head_size, attention_probs_ptr, nullptr); + *attention_probs_ptr *= scale; // Apply the attention bias and mask if (attn_bias_data != nullptr) { *attention_probs_ptr += attn_bias_data[attn_bias_base_offset + past_sequence_length]; @@ -348,7 +349,6 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( if (is_masked) { *attention_probs_ptr += mask_filter_value_; } - *attention_probs_ptr *= scale; } { @@ -362,6 +362,8 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size; T* output = reinterpret_cast(attention_probs) + j + i * probs_matrix_size; math::Dot(head_size, q_vec, past_k_vec, output, nullptr); + + *output *= scale; // Apply the attention bias and mask if (attn_bias_data != nullptr) { *output += attn_bias_data[attn_bias_base_offset + j]; @@ -371,11 +373,11 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( if (is_masked) { *output += mask_filter_value_; } - *output *= scale; } } // Append current key to present key (past_present_share_buffer_ is true) - memcpy(present_key_data + i * max_sequence_length * head_size, K + i * head_size, head_size * sizeof(T)); + memcpy(present_key_data + (i * max_sequence_length + past_sequence_length) * head_size, + K + i * head_size, head_size * sizeof(T)); } }); @@ -460,7 +462,7 @@ void DecoderMaskedMultiHeadAttention::ComputeVxAttentionScoreWithBeams( } } // Append current value to present value (past_present_share_buffer_ is true) - memcpy(present_value_data + i * max_sequence_length * v_head_size, + memcpy(present_value_data + (i * max_sequence_length + past_sequence_length) * v_head_size, V + i * v_head_size, v_head_size * sizeof(T)); } diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h index 68d1b9751301d..d5167e8989669 100644 --- a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h @@ -33,7 +33,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC const Tensor* cache_indir, OpKernelContext* context, int beam_width, - Tensor* scaled_qk = nullptr) const; + Tensor* output_qk = nullptr) const; void ComputeAttentionProbsWithBeams(T* attention_probs, const T* Q, const T* K, @@ -50,7 +50,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC bool broadcast_attn_bias_dim_1, const int32_t* cache_indir_data, int beam_width, - T* scaled_qk_data = nullptr) const; + T* output_qk_data = nullptr) const; void ComputeVxAttentionScoreWithBeams(T* output, T* tmp_buffer, const T* attention_probs, diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0bdee151d2173..4cc5a4228dc8c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -11,18 +11,19 @@ namespace onnxruntime { namespace contrib { namespace group_query_attention_helper { -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache @@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query, return Status::OK(); } -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + const T* cos_cache, + const T* sin_cache, void* parameters, int num_heads, int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, + const T* seqlens_k, + const T* total_seqlen, float scale, float softcap, int max_threads_per_block) { diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index cbfd2f0949363..9a6c2af022c91 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/bert/rotary_embedding.h" #include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" using onnxruntime::concurrency::ThreadPool; @@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; - int cache_idx = 0; - bool sign = false; - int j = 0; - for (int i = 0; i < rotary_emb_dim; i++) { - if (interleaved) { - cache_idx = (i / 2) % half_rotary_emb_dim; - sign = i & 1; - j = sign ? i - 1 : i + 1; // i - sign - } else { - cache_idx = i % half_rotary_emb_dim; - sign = (i >= half_rotary_emb_dim); - j = (i + half_rotary_emb_dim) % rotary_emb_dim; - } - float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); - float input_data_j = static_cast(input_data[j]); - float sin_data_cache_idx = static_cast(sin_data[cache_idx]); - if (sign) { - output_data_i += input_data_j * sin_data_cache_idx; - } else { - output_data_i -= input_data_j * sin_data_cache_idx; - } - output_data[i] = static_cast(output_data_i); - } - for (int i = rotary_emb_dim; i < head_size; i++) { - output_data[i] = input_data[i]; + MlasRotaryEmbedOneRow(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); + + if (rotary_emb_dim < head_size) { + std::memcpy(output_data + rotary_emb_dim, + input_data + rotary_emb_dim, + (head_size - rotary_emb_dim) * sizeof(T)); } } }); diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 71a66ea368943..2c897f183164f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -24,7 +24,6 @@ class QAttention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& /*out*/ is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -59,7 +58,6 @@ QAttention::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionC template Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc index 4148aae4b9a35..aa47f365c0005 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc @@ -13,7 +13,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed, + AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -91,7 +91,6 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index cee3dfc6b3f28..c3e43f897c509 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -32,24 +32,47 @@ constexpr size_t A = 0, bias = 5; }; -int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { - const auto accuracy_level = std::clamp(accuracy_level_attr, - static_cast(CompMostAccurate), - static_cast(CompLeastAccurate)); - - // Find a supported accuracy level that is not less accurate than the one given. - // CompMostAccurate is always supported with the fallback implementation. - // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. - int64_t effective_accuracy_level = accuracy_level; - for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { - const auto compute_type = static_cast(effective_accuracy_level); - if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { - break; - } - } +typedef enum { + Level0, /*!< input fp32, accumulator fp32 */ + Level1, /*!< input fp32, accumulator fp32 */ + Level2, /*!< input fp16, accumulator fp16 */ + Level3, /*!< input bf16, accumulator fp32 */ + Level4, /*!< input int8, accumulator int32 */ +} ACCURACY_LEVEL; + +// T: A data type. +template +MLAS_QNBIT_GEMM_COMPUTE_TYPE +GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + // For Fp32, only accuracy level 1 or 4 makes sense. + // non-ARM CPU converts Fp16 to Fp32. + // By converting Fp32 to Fp16, precision becomes worse. And due to the casting, + // there is no performance gain. + if (accuracy_level_attr == static_cast(Level4) && + MlasIsQNBitGemmAvailable(nbits, block_size, SQNBIT_CompInt8)) { + return SQNBIT_CompInt8; + } + + return SQNBIT_CompFp32; +} - return effective_accuracy_level; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +template <> +MLAS_QNBIT_GEMM_COMPUTE_TYPE +GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + // For Fp16, only accuracy level 2 or 4 makes sense. + // By converting Fp16 to Fp32, there is not precision increase, and the performance + // becomes worse. + if (accuracy_level_attr == static_cast(Level4) && + MlasIsQNBitGemmAvailable(nbits, block_size, HQNBIT_CompInt8)) { + return HQNBIT_CompInt8; + } + + // if HQNBIT_CompFp16 is not supported, will fallback to unpacked computation. + return HQNBIT_CompFp16; } +#endif // !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 + } // namespace bool GetType(const NodeArg& node_arg, int32_t& type) { @@ -74,10 +97,9 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))}, has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, - compute_type_{static_cast(accuracy_level_)} { + compute_type_{GetComputeType(nbits_, block_size_, info.GetAttr("accuracy_level"))} { const auto& node = info.node(); auto input_defs = node.InputDefs(); const NodeArg* zero_point_arg = @@ -98,36 +120,26 @@ class MatMulNBits final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - void ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx); - Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) override; - std::optional GetPrePackTensor(int /*input_idx*/) override; - - Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override; - private: const size_t K_; const size_t N_; const size_t block_size_; const size_t nbits_; - const int64_t accuracy_level_; const bool has_g_idx_; const bool has_bias_; - const MLAS_SQNBIT_GEMM_COMPUTE_TYPE compute_type_; + const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; bool has_unquantized_zero_point_{false}; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; IAllocatorUniquePtr scales_fp32_{}; IAllocatorUniquePtr bias_fp32_{}; - std::optional packed_tensor_{std::nullopt}; - MLDataType prepack_tensor_data_type_; bool has_zp_input_{false}; @@ -152,27 +164,11 @@ class MatMulNBits final : public OpKernel { Tensor* y, AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, - const MatMulComputeHelper& helper) const { - ORT_THROW("ComputeBPacked is not supported for T1 type."); - } + const MatMulComputeHelper& helper) const; }; -template -void MatMulNBits::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) { - if (input_idx == InputIndex::B) { - prepack_tensor_data_type_ = tensor.DataType(); - } - - TensorShapeVector weights_dims = {static_cast((packed_b_size_ - 1) / prepack_tensor_data_type_->Size()) + 1}; - packed_tensor_ = Tensor(prepack_tensor_data_type_, - TensorShape(weights_dims), - packed_b_.get(), - OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); -} - template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -181,43 +177,40 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } if (input_idx == InputIndex::B) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); if (packed_b_size_ == 0) { return Status::OK(); } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; - } else if (compute_type_ == CompInt8) { + } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } #endif // MLAS_TARGET_AMD64_IX86 } - if (save_prepacked_initializers) { - ConvertPrepackWeightIntoTensor(tensor, input_idx); - } - return Status::OK(); } +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || !defined(MLAS_TARGET_ARM64) +// Non-ARM-with-fp16-intrinsics fall back fp16 to fp32. template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -239,64 +232,37 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou return Status::OK(); } - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return Status::OK(); } if (input_idx == InputIndex::B) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); if (packed_b_size_ == 0) { return Status::OK(); } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), - nullptr, has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), + nullptr, has_zp_input_, nullptr, nullptr); is_packed = true; - } else if (compute_type_ == CompInt8) { + } else if (compute_type_ == SQNBIT_CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - scales_fp32_.get(), has_zp_input_, nullptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + scales_fp32_.get(), has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - nullptr, has_zp_input_, zptr, nullptr); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), + nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } #endif // MLAS_TARGET_AMD64_IX86 } - if (save_prepacked_initializers) { - ConvertPrepackWeightIntoTensor(tensor, input_idx); - } - - return Status::OK(); -} - -template -std::optional MatMulNBits::GetPrePackTensor(int input_idx) { - // For this kernel, prepack is performed on input_B, and possibly scales, zeros_points. - // During compute process, scales and zeros_points will keep as it is and only use prepacked - // buffer to replace input_B. - // Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize - // the latest one. So, we need to always return packed_tensor_ here not only for input_B. - ORT_UNUSED_PARAMETER(input_idx); - return std::move(packed_tensor_); -} - -template -Status MatMulNBits::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) { - if (input_idx == 1) { - // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state, - // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so - // pass empty/default buffer deleter here. - // const_cast here is temporary, will fix in follow up PR. - packed_b_ = BufferUniquePtr(const_cast(pre_packed_tensor.DataRaw()), BufferDeleter()); - } - return Status::OK(); } +#endif // end !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 template Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, @@ -311,20 +277,20 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& return Status::OK(); } -template <> -Status MatMulNBits::ComputeBPacked(const Tensor* a, - const Tensor* scales, - const Tensor* zero_points, - const Tensor* bias, - Tensor* y, - AllocatorPtr& allocator, - concurrency::ThreadPool* thread_pool, - const MatMulComputeHelper& helper) const { - const auto* a_data = a->Data(); - const auto* scales_data = scales->Data(); +template +Status MatMulNBits::ComputeBPacked(const Tensor* a, + const Tensor* scales, + const Tensor* zero_points, + const Tensor* bias, + Tensor* y, + AllocatorPtr& allocator, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const { + const auto* a_data = a->Data(); + const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); - auto* y_data = y->MutableData(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + auto* y_data = y->MutableData(); const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); @@ -333,19 +299,19 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; - const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { // Use reserve since no caching is needed workspace = IAllocator::MakeUniquePtr(allocator, workspace_size, true); } - InlinedVector data(batch_count); + InlinedVector> data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type_ == CompInt8) { + if (compute_type_ == SQNBIT_CompInt8) { data[i].QuantBDataWorkspace = packed_b_.get(); } #endif @@ -356,11 +322,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), - thread_pool); + MlasQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), + thread_pool); return Status::OK(); } +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || !defined(MLAS_TARGET_ARM64) template <> Status MatMulNBits::ComputeBPacked(const Tensor* a, const Tensor* scales, @@ -383,7 +350,7 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, const size_t lda = helper.Lda(false); IAllocatorUniquePtr workspace{}; - const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( M, N, K, batch_count, nbits_, block_size_, compute_type_); if (workspace_size > 0) { // Use reserve since no caching is needed @@ -417,12 +384,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, size_t c_size = static_cast(y->Shape().Size()); std::vector c_v(c_size); - InlinedVector data(batch_count); + InlinedVector> data(batch_count); for (size_t i = 0; i < batch_count; ++i) { data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type_ == CompInt8) { + if (compute_type_ == SQNBIT_CompInt8) { data[i].QuantBDataWorkspace = packed_b_.get(); } #endif @@ -433,11 +400,12 @@ Status MatMulNBits::ComputeBPacked(const Tensor* a, data[i].C = c_v.data() + helper.OutputOffsets()[i]; data[i].ldc = N; } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), - thread_pool); + MlasQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type_, data.data(), workspace.get(), + thread_pool); MlasConvertFloatToHalfBuffer(c_v.data(), y_data, c_size); return Status::OK(); } +#endif // end of !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_AMD64 template <> Status MatMulNBits::ComputeBUnpacked(const Tensor* a, @@ -573,9 +541,10 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, const size_t ldb = helper.Ldb(true); float* scales_ptr = nullptr; + IAllocatorUniquePtr temp_scales; if (!scales_fp32_) { auto scales_size = static_cast(scales->Shape().Size()); - auto temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); + temp_scales = IAllocator::MakeUniquePtr(allocator, scales_size, true); MlasConvertHalfToFloatBuffer(scales_data, temp_scales.get(), scales_size); scales_ptr = temp_scales.get(); } else { @@ -656,8 +625,9 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, if (bias) { float* bias_ptr = nullptr; const size_t bias_size = static_cast(bias->Shape().Size()); + IAllocatorUniquePtr bias_temp; if (!bias_fp32_) { - auto bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); + bias_temp = IAllocator::MakeUniquePtr(allocator, bias_size, true); MlasConvertHalfToFloatBuffer(bias->Data(), bias_temp.get(), bias_size); bias_ptr = bias_temp.get(); } else { @@ -710,11 +680,11 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // clang-format on if (has_single_b_matrix && - packed_b_) { // Assume that MlasSQNBitGemmBatch() always requires packed B. - // If this changes, i.e., if MlasIsSQNBitGemmAvailable() can return true while - // MlasSQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasSQNBitGemmBatch() + packed_b_) { // Assume that MlasQNBitGemmBatch() always requires packed B. + // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while + // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. - if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } } diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index c9ee9e2cb760d..d5b8961cf8c5a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -46,24 +46,13 @@ void ComputeJob( const T* gamma_data, const T* beta_data, const T* bias_data, - IAllocatorUniquePtr& skip_float_uptr, - IAllocatorUniquePtr& gamma_float_uptr, - IAllocatorUniquePtr& beta_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, ptrdiff_t task_idx, int hidden_size, int64_t skip_size, float epsilon, bool simplified, T* output_data, - T* skip_input_bias_add_output_data, - AllocatorPtr alloc) { - ORT_UNUSED_PARAMETER(skip_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(gamma_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(beta_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(alloc); - + T* skip_input_bias_add_output_data) { auto offset = task_idx * hidden_size; const T* p_input = input_data + offset; const T* p_skip = skip_data + (offset % skip_size); @@ -107,101 +96,6 @@ void ComputeJob( } } -void ComputeJob( - const MLFloat16* input_data, - const MLFloat16* skip_data, - const MLFloat16* gamma_data, - const MLFloat16* beta_data, - const MLFloat16* bias_data, - IAllocatorUniquePtr& skip_float_uptr, - IAllocatorUniquePtr& gamma_float_uptr, - IAllocatorUniquePtr& beta_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, - ptrdiff_t task_idx, - int hidden_size, - int64_t skip_size, - float epsilon, - bool simplified, - MLFloat16* output_data, - MLFloat16* skip_input_bias_add_output_data, - AllocatorPtr alloc) { - auto offset = task_idx * hidden_size; - const MLFloat16* p_input = input_data + offset; - const MLFloat16* p_skip = skip_data + (offset % skip_size); - MLFloat16* p_output = output_data + offset; - MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset; - - float mean(0.0f); - float mean_square(0.0f); - const size_t num_elems = static_cast(hidden_size); - - IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); - - if (!skip_float_uptr) { - skip_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems); - } - - if (bias_data && !bias_float_uptr) { - bias_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems); - } - - IAllocatorUniquePtr output_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - float* output_float_ptr = output_float_uptr.get(); - - const float* input_float_ptr = input_float_uptr.get(); - const float* skip_float_ptr = skip_float_uptr.get(); - const float* bias_float_ptr = bias_float_uptr.get(); - for (size_t h = 0; h < num_elems; h++) { - float val = input_float_ptr[h] + skip_float_ptr[h]; - - if (bias_float_uptr) { - val += bias_float_ptr[h]; - } - - output_float_ptr[h] = val; - mean += val; - mean_square += val * val; - } - - if (nullptr != p_skip_input_bias_add_output) { - MlasConvertFloatToHalfBuffer(output_float_ptr, p_skip_input_bias_add_output, num_elems); - } - - mean = mean / hidden_size; - if (simplified) { - mean_square = sqrt(mean_square / hidden_size + epsilon); - } else { - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); - } - - if (!gamma_float_uptr) { - gamma_float_uptr = std::move(input_float_uptr); // overwrite input with gamma values, since they have the same size - MlasConvertHalfToFloatBuffer(gamma_data, gamma_float_uptr.get(), num_elems); - } - - if (beta_data && !beta_float_uptr) { - beta_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(beta_data, beta_float_uptr.get(), num_elems); - } - - const float* gamma_float_ptr = gamma_float_uptr.get(); - const float* beta_float_ptr = beta_float_uptr.get(); - for (size_t h = 0; h < num_elems; h++) { - if (simplified) { - output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h]; - } else if (nullptr == beta_float_uptr) { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h]; - } else { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h]; - } - } - - MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems); -} - void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { auto tensor_data_ptr = tensor.Data(); @@ -218,7 +112,12 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I template SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) { + : OpKernel(op_kernel_info), + prepacked_skip_fp32_size_(0), + prepacked_skip_fp32_data_(nullptr), + prepacked_gamma_fp32_data_(nullptr), + prepacked_beta_fp32_data_(nullptr), + prepacked_bias_fp32_data_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } @@ -226,10 +125,10 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) template Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); - const Tensor* skip = p_ctx->Input(1); - const Tensor* gamma = p_ctx->Input(2); - const Tensor* beta = p_ctx->Input(3); - const Tensor* bias = p_ctx->Input(4); + const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input(1); + const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input(2); + const Tensor* beta = simplified ? nullptr : (prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3)); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(simplified ? 3 : 4); Tensor* output = p_ctx->Output(0, input->Shape()); // For inferencing, we support one more optional output which is the sum of the input and skip tensors Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape()); @@ -238,19 +137,21 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { size_t input_dims_size = input_dims.size(); int hidden_size = static_cast(input_dims[input_dims_size - 1]); - ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, - skip, - gamma, - beta, - bias, - hidden_size, - input_dims_size)); + ORT_RETURN_IF_ERROR(skip_layer_norm_helper::CheckPotentiallyPrepackedInputs(input, + skip, + gamma, + beta, + bias, + hidden_size, + input_dims_size, + prepacked_skip_fp32_data_ != nullptr, + prepacked_gamma_fp32_data_ != nullptr)); int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1); const T* input_data = input->Data(); - const T* skip_data = skip->Data(); - const T* gamma_data = gamma->Data(); + const T* skip_data = skip == nullptr ? nullptr : skip->Data(); + const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data(); const T* beta_data = beta == nullptr ? nullptr : beta->Data(); const T* bias_data = bias == nullptr ? nullptr : bias->Data(); @@ -258,39 +159,118 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { // For inferencing, we support one more optional output which is the sum of the input and skip tensors T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData(); + const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_; + + if constexpr (std::is_same_v) { + const size_t total_data_size = static_cast(input->Shape().Size()); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + + IAllocatorUniquePtr input_fp32; + IAllocatorUniquePtr output_fp32; + IAllocatorUniquePtr skip_input_bias_add_output_fp32; + IAllocatorUniquePtr skip_fp32; + IAllocatorUniquePtr gamma_fp32; + IAllocatorUniquePtr beta_fp32; + IAllocatorUniquePtr bias_fp32; + + const float* input_data_f = nullptr; + const float* skip_data_f = nullptr; + const float* gamma_data_f = nullptr; + const float* beta_data_f = nullptr; + const float* bias_data_f = nullptr; + float* output_data_f = nullptr; + float* skip_input_bias_add_output_data_f = nullptr; + + const size_t num_elems = static_cast(hidden_size); + + input_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size); + input_data_f = input_fp32.get(); + + output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + output_data_f = output_fp32.get(); + + skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get(); + + if (skip_data) { + skip_fp32 = IAllocator::MakeUniquePtr(alloc, static_cast(skip_size)); + MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), static_cast(skip_size)); + skip_data_f = skip_fp32.get(); + } else if (prepacked_skip_fp32_data_) { + skip_data_f = prepacked_skip_fp32_data_.get(); + } - const int64_t& skip_size = skip->Shape().Size(); + if (gamma_data) { + gamma_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems); + gamma_data_f = gamma_fp32.get(); + } else if (prepacked_gamma_fp32_data_) { + gamma_data_f = prepacked_gamma_fp32_data_.get(); + } - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + if (beta_data) { + beta_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems); + beta_data_f = beta_fp32.get(); + } else if (prepacked_beta_fp32_data_) { + beta_data_f = prepacked_beta_fp32_data_.get(); + } - concurrency::ThreadPool::TryBatchParallelFor( - p_ctx->GetOperatorThreadPool(), static_cast(task_count), - [&](ptrdiff_t task_idx) { - ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_, - bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, - skip_input_bias_add_output_data, alloc); - }, - 0); + if (bias_data) { + bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + bias_data_f = bias_fp32.get(); + } else if (prepacked_bias_fp32_data_) { + bias_data_f = prepacked_bias_fp32_data_.get(); + } + + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { + ComputeJob(input_data_f, skip_data_f, gamma_data_f, beta_data_f, bias_data_f, task_idx, hidden_size, skip_size, + epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f); + }, + 0); + MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size); + if (skip_input_bias_add_output_data != nullptr) + MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size); + } else { + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { + ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, + epsilon_, simplified, output_data, skip_input_bias_add_output_data); + }, + 0); + } return Status::OK(); } template Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); - is_packed = false; if (input_idx == 1) { // skip - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed); + prepacked_skip_fp32_size_ = tensor.Shape().Size(); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed); } else if (input_idx == 2) { // gamma - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed); - } else if (input_idx == 3) { // beta - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed); + } else if (input_idx == 3) { + if constexpr (simplified) { + // bias + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); + } else { + // beta + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); + } } else if (input_idx == 4) { // bias - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed); + ORT_ENFORCE(!simplified, "SkipSimplifiedLayerNormalization should only has 4 inputs (input, skip, gamma, and beta). Got 5."); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index d904c14857437..4a350fdcc2220 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -16,15 +16,16 @@ class SkipLayerNorm final : public OpKernel { SkipLayerNorm(const OpKernelInfo& op_kernel_info); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: float epsilon_; - mutable IAllocatorUniquePtr skip_fp32_; - mutable IAllocatorUniquePtr gamma_fp32_; - mutable IAllocatorUniquePtr beta_fp32_; - mutable IAllocatorUniquePtr bias_fp32_; + int64_t prepacked_skip_fp32_size_; + IAllocatorUniquePtr prepacked_skip_fp32_data_; + IAllocatorUniquePtr prepacked_gamma_fp32_data_; + IAllocatorUniquePtr prepacked_beta_fp32_data_; + IAllocatorUniquePtr prepacked_bias_fp32_data_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h index 6271f822287e6..4c901f5650dbd 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h @@ -11,14 +11,10 @@ namespace onnxruntime { namespace contrib { namespace skip_layer_norm_helper { +namespace { + template -Status CheckInputs(const T* input, - const T* skip, - const T* gamma, - const T* beta, - const T* bias, - int hidden_size_check, - size_t input_dims_size_check) { +Status CheckSkip(const T* input, const T* skip, size_t input_dims_size_check) { const auto& input_dims_check = input->Shape().GetDims(); const auto& skip_dims_check = skip->Shape().GetDims(); size_t skip_dims_size_check = skip_dims_check.size(); @@ -33,49 +29,150 @@ Status CheckInputs(const T* input, "skip is expected to have same shape as input or, a batch size of 1 or no batch size when input has 3 dimensions"); } - if (input_dims_size_check != 3 && input_dims_size_check != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); - } - if (skip_dims_check[skip_dims_size_check - 1] != input_dims_check[input_dims_size_check - 1] || skip_dims_check[skip_dims_size_check - 2] != input_dims_check[input_dims_size_check - 2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "last two dimensions of skip needs to be same as input"); } + return Status::OK(); +} + +template +Status CheckGamma(const T* gamma, int hidden_size_check) { const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } + if (gamma_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of gamma and input does not match"); } + return Status::OK(); +} + +template +Status CheckBeta(const T* beta, int hidden_size_check) { if (nullptr != beta) { const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } + if (beta_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of beta and input does not match"); } } + return Status::OK(); +} + +template +Status CheckBias(const T* bias, int hidden_size_check) { if (nullptr != bias) { const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bias is expected to have 1 dimension, got ", bias_dims.size()); } + if (bias_dims[0] != hidden_size_check) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Last dimension of bias and input does not match"); } } + + return Status::OK(); +} + +} // anonymous namespace + +template +Status CheckInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check) { + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + auto status = CheckSkip(input, skip, input_dims_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckGamma(gamma, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBeta(beta, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBias(bias, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + return Status::OK(); +} + +template +Status CheckPotentiallyPrepackedInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check, + bool prepacked_skip, + bool prepacked_gamma) { + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + if (nullptr != skip) { + auto status = CheckSkip(input, skip, input_dims_size_check); + if (status != Status::OK()) { + return status; + } + } else if (!prepacked_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "skip is expected but not provided"); + } + + if (nullptr != gamma) { + auto status = CheckGamma(gamma, hidden_size_check); + if (status != Status::OK()) { + return status; + } + } else if (!prepacked_gamma) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected but not provided"); + } + + auto status = CheckBeta(beta, hidden_size_check); + if (status != Status::OK()) { + return status; + } + + status = CheckBias(bias, hidden_size_check); + if (status != Status::OK()) { + return status; + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 8f5cdc97f27e5..b67d003eaceeb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -258,7 +258,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches current_length, cpu_state.sequences, parameters->max_length, - decoder_subgraph_.has_decoder_masked_attention_)); + decoder_subgraph_.has_decoder_masked_attention_, + this->cuda_device_prop_ != nullptr)); if (decoder_subgraph_.past_present_share_buffer_) { decoder_fetches.reserve(static_cast(decoder_subgraph_.GetFirstPresentOutputIndex()) + @@ -302,17 +303,24 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cur_len = std::to_string(current_length); dumper->Print("***CurrentLength", cur_len, true); - for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) { + for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { dumper->Print("decoder_feeds", i, true); dumper->Print("", decoder_feeds[i]); } - auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers; - dumper->Print("past_sequence_length", offset, true); - dumper->Print("", decoder_feeds[offset]); - dumper->Print("beam_width", offset + 1, true); - dumper->Print("", decoder_feeds[offset + 1]); - dumper->Print("cache_redir", offset + 2, true); - dumper->Print("", decoder_feeds[offset + 2]); + for (int i = 0; i < decoder_subgraph_.num_layers; i++) { + int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; + int self_value_idx = self_key_idx + 1; + dumper->Print("past_key_self", i, true); + dumper->Print("", decoder_feeds[self_key_idx]); + dumper->Print("past_value_self", i + 1, true); + dumper->Print("", decoder_feeds[self_value_idx]); + int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i; + int cross_value_idx = cross_key_idx + 1; + dumper->Print("past_key_cross", i, true); + dumper->Print("", decoder_feeds[cross_key_idx]); + dumper->Print("past_value_cross", i, true); + dumper->Print("", decoder_feeds[cross_value_idx]); + } #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 30bf3aa0a1212..8145fbd4a4123 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -100,6 +100,7 @@ struct ISequences { virtual gsl::span GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA) virtual gsl::span GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA) virtual int GetSequenceLength() const = 0; + virtual int GetMaxLength() const = 0; }; struct ILogitsProcessorList { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 723c271897a78..ecad146da6777 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const { return current_length_; } +int Sequences::GetMaxLength() const { + return max_length_; +} + #ifdef DEBUG_GENERATION void Sequences::PrintSequences(const IConsoleDumper* dumper) const { for (int i = 0; i < batch_beam_size_; i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 440a07e14a6cc..7dd1f28d270c7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -25,6 +25,9 @@ class Sequences : public ISequences { // Returns current sequence length. int GetSequenceLength() const override; + // Returns max sequence length. + int GetMaxLength() const override; + #ifdef DEBUG_GENERATION // Print the sequences to StdOut in debug mode void PrintSequences(const IConsoleDumper* dumper) const; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index d675ba742e03b..7757435990a65 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -31,6 +31,7 @@ Subgraph::Subgraph( allocator_(nullptr), is_output_float16_(false) { num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + used_implicit_inputs = std::vector(num_implicit_inputs, true); auto& subgraph_inputs = subgraph.GetInputs(); auto& subgraph_outputs = subgraph.GetOutputs(); @@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state, // The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); - for (auto& entry : node.ImplicitInputDefs()) { - feed_names.push_back(entry->Name()); + const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); + + const auto& implicit_input_defs = node.ImplicitInputDefs(); + for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) { + const auto* entry = implicit_input_defs[i]; + int idx; + if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) { + feed_names.push_back(entry->Name()); + } else { + --num_implicit_inputs; + used_implicit_inputs[i] = false; + } } InlinedVector feed_locations; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index bde591626bb83..8ec9c9cbdc20f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -31,6 +31,7 @@ class Subgraph { const GraphViewer& subgraph; // The subgraph int num_implicit_inputs; + std::vector used_implicit_inputs; int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience. int num_subgraph_outputs; // Same as subgraph_output_names.size() diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 9037e58aaf31f..f4e7173c917c1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len, - bool need_cache_indir) { + bool need_cache_indir, + bool use_cuda) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // Allocate subgraph inputs from same device as inputs of encoder subgraph. @@ -171,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); + size_t total_size_bytes = total_size * sizeof(int); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { @@ -182,19 +184,35 @@ Status T5DecoderSubgraph::CreateInitialFeeds( stream, DeviceCopyDirection::hostToDevice)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + if (use_cuda) { + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + for (int i = 0; i < batch_beam_size; i++) { + size_t batch_beam_stride = static_cast(i) * static_cast(sequences.GetMaxLength()); + int seq_size = sequences.GetSequenceLength(); + gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); + gsl::span temp_input(input_ids_data + static_cast(i) * seq_size, seq_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + sequence, + stream, + DeviceCopyDirection::deviceToDevice)); + } + } else { + const size_t cur_len_bytes = cur_len * sizeof(int); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span sequence = sequences.GetSequence(i); + const int32_t* sequence_data = sequence.data(); + ptrdiff_t seq_index = static_cast(i) * cur_len; + memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); + } + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + temp_sequence, + stream, + DeviceCopyDirection::hostToDevice)); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); - ORT_RETURN_IF_ERROR(device_copy_int32_func( - temp_input, - temp_sequence, - stream, - DeviceCopyDirection::hostToDevice)); } // The ordering is the same as used in Setup. @@ -230,7 +248,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } else { ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, @@ -238,7 +256,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } decoder_feeds.push_back(expanded_hidden_states); @@ -281,8 +299,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( } // Pass through implicit inputs. - for (const auto* entry : implicit_inputs) { - decoder_feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + decoder_feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 83dae49c7dcbd..a72ce37a93aba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph { int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len = -1, - bool need_cache_indir = false); + bool need_cache_indir = false, + bool use_cuda = false); Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 51473c0c931b9..d59db4afac2c2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds( pinned_allocator, location)); - for (const auto* entry : implicit_inputs) { - feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 8edae863ff44e..e4c1659c0fb2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -298,6 +298,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (params.attention_bias != nullptr) { qk = add_vec(qk, reinterpret_cast(params.attention_bias)[attn_bias_offset + tlength]); } + if (params.mask != nullptr && params.mask[bi_total_seq_length + params.past_sequence_length] == 0) { + qk += params.mask_filter_value; + } qk_max = qk; qk_smem[tlength] = qk; } @@ -534,7 +537,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (params.out_qk != nullptr) { // store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length] - float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength); + float* target = (reinterpret_cast(params.out_qk)) + (static_cast(bhi) * (sum_tlength + 1)); for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { target[ti] = (float)(qk_smem[ti]); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index e961bab399326..d46d9597a758f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -98,7 +98,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi for (int m = 0; m < size<1>(tOgO); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSE(row) = INFINITY; + gLSE(row) = std::numeric_limits::infinity(); } } return; @@ -499,7 +499,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons for (int m = 0; m < size<1>(tOgOaccum); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSEaccum(row) = Split ? -INFINITY : INFINITY; + gLSEaccum(row) = Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); } } return; @@ -1061,7 +1061,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -std::numeric_limits::infinity(); if (row < kMaxSplits) { sLSE[row][col] = lse; } @@ -1082,7 +1082,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -std::numeric_limits::infinity(); // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } } @@ -1094,7 +1094,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { @@ -1104,7 +1104,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { lse_sum = Allreduce::run(lse_sum, sum_op); // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? std::numeric_limits::infinity() : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h index 0998155eba635..71434002f8df1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/mask.h @@ -4,6 +4,7 @@ #pragma once +#include #include namespace onnxruntime { @@ -28,7 +29,7 @@ __forceinline__ __device__ void apply_mask(Tensor& tensor, const // Without the "make_coord" we get wrong results #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -59,7 +60,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor& tensor, for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -96,7 +97,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; + tensor(mi, ni) = -std::numeric_limits::infinity(); } } // if (cute::thread0()) { @@ -151,7 +152,7 @@ struct Mask { } if constexpr (!Is_even_MN) { if (col_idx >= max_seqlen_k) { - tensor(mi, make_coord(j, nj)) = -INFINITY; + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -181,18 +182,18 @@ struct Mask { } if constexpr (Causal_mask) { if (col_idx >= col_idx_limit_right) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } if constexpr (Is_local) { if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { // Causal and Local already handles MN masking if (col_idx >= max_seqlen_k) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 7e0095cb39bd9..7fe506e01a9b9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include @@ -71,7 +72,9 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + const float max_scaled = max(mi) == -std::numeric_limits::infinity() + ? 0.f + : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - @@ -99,7 +102,7 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor& ten max(mi) = Allreduce<4>::run(max(mi), max_op); // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + const float max_scaled = max(mi) == -std::numeric_limits::infinity() ? 0.f : max(mi) * scale; sum(mi) = 0; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { @@ -143,7 +146,7 @@ struct Softmax { for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + : (row_max(mi) == -std::numeric_limits::infinity() ? 0.0f : row_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale; #pragma unroll @@ -169,7 +172,9 @@ struct Softmax { for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + lse(mi) = (sum == 0.f || sum != sum) + ? (Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity()) + : row_max(mi) * softmax_scale + __logf(sum); float scale = inv_sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h index 5be69ea0af55c..bd54b404420e5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h @@ -825,7 +825,7 @@ inline __device__ void lean_compute_attn_impl_ver3(const Params& params, const i const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; // We skip the first row = 0, as we already populated it in shared memory. - ElementAccum lse = (row > 0 && row < total_splits && col < params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum) ? gLSEaccumRead(row, col) : -INFINITY; + ElementAccum lse = (row > 0 && row < total_splits && col < params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum) ? gLSEaccumRead(row, col) : -std::numeric_limits::infinity(); if (row > 0 && row < kMaxSplits) { sLSE(row, col) = lse; @@ -857,7 +857,7 @@ inline __device__ void lean_compute_attn_impl_ver3(const Params& params, const i for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row, col) : -INFINITY; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row, col) : -std::numeric_limits::infinity(); #if defined(DEBUG_LEAN_ATTENTION) if (threadIdx.x == 0 && blockIdx.z == tracing_block) { @@ -874,7 +874,7 @@ inline __device__ void lean_compute_attn_impl_ver3(const Params& params, const i } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { @@ -884,7 +884,9 @@ inline __device__ void lean_compute_attn_impl_ver3(const Params& params, const i lse_sum = Allreduce::run(lse_sum, sum_op); // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) + ? std::numeric_limits::infinity() + : logf(lse_sum) + lse_max; // if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } // Store the scales exp(lse - lse_logsum) in shared memory. #pragma unroll diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h index d63c80b012de6..2d33418d69667 100644 --- a/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h @@ -3,7 +3,7 @@ ******************************************************************************/ #pragma once - +#include #include namespace onnxruntime { @@ -28,7 +28,7 @@ __forceinline__ __device__ void apply_mask(Tensor& tensor, const // Without the "make_coord" we get wrong results #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -59,7 +59,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor& tensor, for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -96,7 +96,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; + tensor(mi, ni) = -std::numeric_limits::infinity(); } } // if (cute::thread0()) { @@ -152,7 +152,7 @@ struct Mask { } if constexpr (!Is_even_MN) { if (col_idx >= max_seqlen_k) { - tensor(mi, make_coord(j, nj)) = -INFINITY; + tensor(mi, make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } @@ -182,18 +182,18 @@ struct Mask { } if constexpr (Causal_mask) { if (col_idx >= col_idx_limit_right) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } if constexpr (Is_local) { if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { // Causal and Local already handles MN masking if (col_idx >= max_seqlen_k) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + tensor(make_coord(i, mi), make_coord(j, nj)) = -std::numeric_limits::infinity(); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h index ad66389848e6e..0b6ffb3f1985a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h @@ -3,7 +3,7 @@ ******************************************************************************/ #pragma once - +#include #include #include @@ -72,7 +72,9 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + const float max_scaled = max(mi) == -std::numeric_limits::infinity() + ? 0.f + : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - @@ -107,7 +109,7 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor& ten max(mi) = Allreduce<4>::run(max(mi), max_op); // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + const float max_scaled = max(mi) == -std::numeric_limits::infinity() ? 0.f : max(mi) * scale; sum(mi) = 0; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { @@ -151,7 +153,7 @@ struct Softmax { for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + : (row_max(mi) == -std::numeric_limits::infinity() ? 0.0f : row_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale; #pragma unroll @@ -181,7 +183,9 @@ struct Softmax { // printf("sum: %f, inv_sum: %f\n", sum, inv_sum); // printf("mi %d row_max %f softmax_scale %f\n", mi, row_max(mi), softmax_scale); // } - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + lse(mi) = (sum == 0.f || sum != sum) + ? (Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity()) + : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { diff --git a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu index 8a04ede231a27..ab809d12a89ad 100644 --- a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu @@ -6,7 +6,7 @@ Licensed under the MIT License. /* Kernel implementation for blocking repeated n-grams. */ - +#include #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/ngram_repeat_block_impl.h" @@ -48,7 +48,7 @@ __global__ void banRepeatedTokens(const int64_t* __restrict__ tokens, } if (is_banned == true) { auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; - lprobs[lprob_start + token_to_be_banned] = -INFINITY; + lprobs[lprob_start + token_to_be_banned] = -std::numeric_limits::infinity(); } } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index d190ed389f3e9..dea5391c7629b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -95,7 +95,6 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { } Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 4505c066baedb..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -17,7 +17,6 @@ class GroupNorm final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index aa2c8755f6536..3e93a527877c5 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -99,7 +99,6 @@ Status QOrderedAttention::PutIntoMergedBias(const Tensor& tensor, AllocatorPtr a } Status QOrderedAttention::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h index 529fd00307d66..9d4e563c1feab 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h @@ -20,7 +20,6 @@ class QOrderedAttention final : public CudaKernel, public AttentionBase { public: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc index 351e36b884540..a64f628f245e6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc @@ -51,7 +51,6 @@ QOrderedMatMul::QOrderedMatMul(const OpKernelInfo& info) : CudaKernel(info) { } Status QOrderedMatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h index d1cef99779e09..dcb6cc6374be1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h @@ -18,7 +18,6 @@ class QOrderedMatMul final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu index 5ac10f6321e63..44be2ef2375ee 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu @@ -60,7 +60,7 @@ struct TopK { __device__ __forceinline__ void Init() { for (int i = 0; i < max_k; i++) { key[i] = -1; - value[i] = NumericLimits::Min(); + value[i] = NumericLimits::Lowest(); } } }; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e047bd948434d..4e65336665bf7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds( CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, cuda_stream)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - CUDA_RETURN_IF_ERROR( - cudaMemcpyAsync(input_ids_data + static_cast(i) * current_length, - sequence_data, - current_length * sizeof(int32_t), - cudaMemcpyHostToDevice, - cuda_stream)); - } + // We expect sequences to point directly to device memory + int max_length = sequences.GetMaxLength(); + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + CUDA_RETURN_IF_ERROR( + cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t), + sequences_buffer.data(), max_length * sizeof(int32_t), + current_length * sizeof(int32_t), batch_beam_size, + cudaMemcpyDeviceToDevice, cuda_stream)); } next_inputs[0] = input_ids; diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu index 68a2e16482af9..b2969194ff400 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search_top_one.cu @@ -5,6 +5,7 @@ #include + #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/cu_inc/common.cuh" @@ -19,7 +20,10 @@ struct TopOne { int32_t key; T value; - __device__ __host__ __forceinline__ TopOne(int32_t key = -1, T value = NumericLimits::Min()) : key(key), value(value) { + __device__ __host__ __forceinline__ TopOne() : key(-1), value(NumericLimits::Lowest()) { + } + + __device__ __host__ __forceinline__ TopOne(int32_t key, T value) : key(key), value(value) { } __device__ __forceinline__ void Reduce(int32_t k, T v) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc new file mode 100644 index 0000000000000..86dc959cf2e83 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -0,0 +1,459 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/attention.h" + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderUsage::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" + << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" + << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; + if (has_bias_) { + shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; + } + shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; + if (has_bias_) { + shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; + } else { + shader.MainFunctionBody() << ";\n"; + } + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + ORT_ENFORCE(input_tensor->Shape().GetDims().size() == 3); + ORT_ENFORCE(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)}}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) { + if (seqlen_k != nullptr) { + ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; + ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n"; + } else { + ss << "let past_sequence_length = uniforms.past_sequence_length;\n"; + } +} + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderUsage::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + if (seqlen_k_ != nullptr) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" + << "let m = workgroup_id.y * TILE_SIZE;\n" + << "let n = workgroup_id.x * TILE_SIZE;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.N;\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str(); + shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; + if (has_present_key_) { + shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; + } + + shader.MainFunctionBody() << "var value = f32_val_t(0);\n" + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" + << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n" + " tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " }\n"; + } + + if (has_present_key_) { + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" + << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" + << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " + attention_bias[outputIdx]"; + } + shader.MainFunctionBody() << ";\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length, + const Tensor* seqlen_k) { + const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) + : parameters.scale_; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + constexpr int tile_size = 12; + const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + if (seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AdditionalImplementation() << "var thread_max: array;\n" + << "var thread_sum: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let sequence_length = uniforms.sequence_length;\n" + << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str() + << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = f32_val_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" + << " }\n" + << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" + << " var f32input = f32_val_t(x[offset + i]);\n" + << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << "}\n"; + if (seqlen_k_) { + shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" + << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" + << "}\n"; + } + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, + const Tensor* seqlen_k, bool is_first_prompt) { + const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); + int work_group_size = 64; + const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; + if (total_sequence_length_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k}; + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .CacheHint(work_group_size, is_first_prompt) + .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(batch_size)}, + {static_cast(num_heads)}, + {static_cast(past_sequence_length)}, + {static_cast(sequence_length)}, + {static_cast(total_sequence_length_comp)}, + {static_cast(elementsPerThread)}}); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderUsage::UseUniform); + } + if (seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderUsage::UseUniform); + } + + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n"; + shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" + << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" + << "let sequence_length = uniforms.M;\n" + << "var total_sequence_length = uniforms.K;\n"; + std::ostringstream oss; + InitVarStub(oss, seqlen_k_, is_first_prompt_); + shader.MainFunctionBody() << oss.str(); + shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; + if (has_present_value_) { + shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; + } + + shader.MainFunctionBody() << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { + shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" + << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n" + << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n" + << " }\n"; + } + + if (has_present_value_) { + if (past_present_share_buffer_) { + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" + << " }\n"; + } + + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {\n" + << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n" + << " output[outputIdx] = value;\n" + << "}\n"; + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + WebgpuAttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length, + const Tensor* seqlen_k) { + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; + const bool has_present_value = output_count > 1 && past_value != nullptr; + constexpr int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (seqlen_k != nullptr) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size, + (parameters.sequence_length_ + tile_size - 1) / tile_size, + parameters.batch_size_ * parameters.num_heads_) + .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({{static_cast(parameters.sequence_length_)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size_)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(parameters.v_hidden_size_ * parameters.n_reps)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, + {static_cast(parameters.n_reps)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; + + const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length, seqlen_k)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length, seqlen_k)); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h new file mode 100644 index 0000000000000..03279fffbc3ef --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; + int n_reps_; + const Tensor* seqlen_k_; + bool past_present_share_buffer_; + bool is_first_prompt_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + + private: + int work_group_size_; + int components_; + const Tensor* seqlen_k_; + bool is_first_prompt_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"n_reps", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; + int n_reps_; + const Tensor* seqlen_k_; + bool past_present_share_buffer_; + bool is_first_prompt_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h new file mode 100644 index 0000000000000..be80ade8b87d0 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention_common.h" + +#include "contrib_ops/cpu/bert/attention_common.h" +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +struct WebgpuAttentionParameters { + explicit WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_(false), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.kv_sequence_length), + past_sequence_length_(parameters.past_sequence_length), + total_sequence_length_(parameters.total_sequence_length), + max_sequence_length_(parameters.max_sequence_length), + input_hidden_size_(parameters.input_hidden_size), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.v_hidden_size), + v_head_size_(parameters.v_head_size), + num_heads_(parameters.num_heads), + is_unidirectional_(parameters.is_unidirectional), + past_present_share_buffer_(parameters.past_present_share_buffer), + do_rotary_(parameters.do_rotary), + broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), + broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), + mask_filter_value_(parameters.mask_filter_value), + scale_(parameters.scale), + mask_type_(parameters.mask_type), + qkv_format_(parameters.qkv_format) { + } + + explicit WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_(true), + batch_size_(parameters.batch_size), + sequence_length_(parameters.sequence_length), + kv_sequence_length_(parameters.sequence_length), + past_sequence_length_(parameters.seqlen_past_kv_cache), + total_sequence_length_(parameters.total_sequence_length), + hidden_size_(parameters.hidden_size), + head_size_(parameters.head_size), + v_hidden_size_(parameters.kv_hidden_size), + v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), + num_heads_(parameters.num_heads), + do_rotary_(parameters.do_rotary), + scale_(parameters.scale), + seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), + seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), + kv_hidden_size_(parameters.kv_hidden_size), + kv_num_heads_(parameters.kv_num_heads), + num_splits_(parameters.num_splits), + rotary_dim_(parameters.rotary_dim), + is_packed_qkv_(parameters.is_packed_qkv), + is_subsequent_prompt_(parameters.is_subsequent_prompt), + is_first_prompt_(parameters.is_first_prompt), + rotary_interleaved_(parameters.rotary_interleaved), + use_smooth_softmax_(parameters.use_smooth_softmax), + softcap_(parameters.softcap), + zeros_count_(parameters.zeros_count), + zero_ptr_(parameters.zero_ptr), + n_reps(parameters.num_heads / parameters.kv_num_heads), + qkv_format_(parameters.qkv_format) { + } + + bool is_gqa_; + int batch_size_ = 0; + int sequence_length_ = 0; + int kv_sequence_length_ = 0; // input sequence length of K or V + int past_sequence_length_ = 0; // sequence length in past state of K or V + int total_sequence_length_ = 0; // total sequence length of K or V + int max_sequence_length_ = 0; // max sequence length from 4D mask + int input_hidden_size_ = 0; // first dimension of weights for input projection + int hidden_size_ = 0; // hidden size of Q or K + int head_size_ = 0; // hidden size per head of Q or K + int v_hidden_size_ = 0; // hidden size of V + int v_head_size_ = 0; // hidden size per head of V + int num_heads_ = 0; + int rotary_embedding_ = 0; + bool is_unidirectional_ = false; + bool past_present_share_buffer_ = false; + bool do_rotary_ = false; + bool broadcast_attn_bias_dim_0_ = false; + bool broadcast_attn_bias_dim_1_ = false; + float mask_filter_value_ = -10000.0f; + float scale_ = 0.0f; + bool use_tf32_ = false; + ; + // The following members are in onnxruntime::contrib::GroupQueryAttentionParameters + // and not in onnxruntime::contrib::AttentionParameters + int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor + int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor + int kv_hidden_size_ = 0; + int kv_num_heads_ = 0; + int num_splits_ = 0; // number of splits for splitkv + int rotary_dim_ = 0; // rotary embedding dimension + int local_window_size_ = 0; + bool kv_share_buffer_ = false; + bool is_packed_qkv_ = false; + bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt_ = false; // indicates whether this is first decoding step + bool rotary_interleaved_ = false; + bool use_smooth_softmax_ = false; + float softcap_ = 0.0; + int zeros_count_ = 0; + ; + int* zero_ptr_ = nullptr; + // Computed values + int n_reps = 1; + AttentionMaskType mask_type_ = MASK_NONE; + AttentionQkvFormat qkv_format_ = UNKNOWN; +}; + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc new file mode 100644 index 0000000000000..a5cae7e7f6747 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "contrib_ops/webgpu/bert/fast_gelu.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + FastGelu); + +Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform); + + shader.AdditionalImplementation() << TanhImpl; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " var a = " << x.GetByOffset("global_idx") << ";\n"; + if (Inputs().size() > 1) { + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); + if (bias_components_ == 1) { + shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n" + " a += x_value_t(" + << bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n"; + } else { + shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } + } + shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr); + + return Status::OK(); +} + +Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + auto* output = context.Output(0, input->Shape()); + + uint32_t data_size = gsl::narrow(output->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const auto vec_size = (data_size + 3) / 4; + uint32_t bias_size = 0; + int bias_components = 1; + + if (bias != nullptr) { + bias_size = gsl::narrow(bias->Shape().Size()); + if (bias_size % 4 == 0) { + bias_components = 4; + bias_size = bias_size / 4; + } + } + + FastGeluProgram program{bias_components}; + program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({vec_size}); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h new file mode 100644 index 0000000000000..fa40d52bf301f --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class FastGeluProgram final : public Program { + public: + FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int bias_components_; +}; + +class FastGelu final : public WebGpuKernel { + public: + FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc new file mode 100644 index 0000000000000..31c8af9b4f922 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/webgpu/bert/group_query_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::group_query_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + GroupQueryAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .MayInplace(3, 1) + .MayInplace(4, 2) + .InputMemoryType(OrtMemTypeCPUInput, 6), + GroupQueryAttention); + +Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* past_key = context.Input(3); + const Tensor* past_value = context.Input(4); + const Tensor* seqlen_k = context.Input(5); + const Tensor* total_seqlen_tensor = context.Input(6); + const Tensor* cos_cache = context.Input(7); + const Tensor* sin_cache = context.Input(8); + + GroupQueryAttentionParameters params; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶ms, + num_heads_, + kv_num_heads_, + seqlen_k, + total_seqlen_tensor, + scale_, + softcap_)); + WebgpuAttentionParameters parameters(params); + if (parameters.is_packed_qkv_) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep."); + } + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.hidden_size_); + Tensor* output = context.Output(0, output_shape); + std::vector present_dims{ + parameters.batch_size_, + kv_num_heads_, + parameters.seqlen_present_kv_cache_, + parameters.head_size_}; + std::vector present_kv_shape(present_dims); + Tensor* present_key = context.Output(1, present_kv_shape); + Tensor* present_value = context.Output(2, present_kv_shape); + parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); + } + + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, nullptr, 0, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, nullptr, 0, &V)); + return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, + present_value, parameters, context, seqlen_k); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h new file mode 100644 index 0000000000000..04969dc778927 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class GroupQueryAttention final : public WebGpuKernel { + public: + GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + } + + int num_heads_; // number of attention heads of Q + int kv_num_heads_; // number of attention heads of K or V + float scale_; // the scaling factor applied before softmax + float softcap_; + bool do_rotary_; // whether or not to use rotary embeddings + bool rotary_interleaved_; + int local_window_size_; + + bool use_smooth_softmax_; + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc new file mode 100644 index 0000000000000..8997e8698d96d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc @@ -0,0 +1,36 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 1, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc new file mode 100644 index 0000000000000..424556c66bd9d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MultiHeadAttention); + +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : WebGpuKernel(info), AttentionBase(info, false) { + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); +} + +Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* bias = context.Input(3); + const Tensor* key_padding_mask = context.Input(4); + const Tensor* attention_bias = context.Input(5); + const Tensor* past_key = context.Input(6); + const Tensor* past_value = context.Input(7); + + if (query->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu"); + } + if (key != nullptr && key->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed KV not implemented for webgpu"); + } + if (key_padding_mask) { + ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); + } + + AttentionParameters params; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶ms, + num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + WebgpuAttentionParameters parameters(params); + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size_); + output_shape[1] = static_cast(parameters.sequence_length_); + output_shape[2] = static_cast(parameters.v_hidden_size_); + Tensor* output = context.Output(0, output_shape); + + // If optional outputs aren't needed, present_key and present_value will be null + std::vector present_dims{ + parameters.batch_size_, + parameters.num_heads_, + parameters.total_sequence_length_, + parameters.head_size_, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context.Output(1, present_shape); + Tensor* present_value = context.Output(2, present_shape); + + TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q)); + + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); + } + + TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, bias, parameters.hidden_size_, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V)); + + // Compute the attention score and apply the score to V + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h new file mode 100644 index 0000000000000..d983236422c9e --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "contrib_ops/webgpu/bert/attention.h" + +#include "contrib_ops/cpu/bert/attention_base.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MultiHeadAttention final : public WebGpuKernel, public AttentionBase { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..bc8b7493fc916 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform); + const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + // TODO: remove output_indices. + const auto& output_indices = shader.AddIndices("output_indices", false); + const auto interleaved_str = interleaved_ ? "true" : "false"; + shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" + " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n" + " if (global_idx >= size) { return; }\n" + " if (bsnh[3] < half_rotary_emb_dim) {\n" + << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n" + << " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n" + << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n" + << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n" + << " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("i", "re") << "\n" + << " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("j", "im") << "\n" + << " } else { \n" + " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + << " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n" + << " }"; + + return Status::OK(); +} + +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { + scale_ = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads_ = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved_ = (info.GetAttrOrDefault("interleaved", 0) == 1); + is_packed_batching_ = (info.GetAttrOrDefault("is_packed_batching", 0) == 1); +} + +Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto input_shape = input->Shape(); + const auto* position_ids = context.Input(1); + const auto* cos_cache = context.Input(2); + const auto* sin_cache = context.Input(3); + auto* output = context.Output(0, input_shape); + + const auto batch_size = gsl::narrow(input->Shape()[0]); + const auto batch_stride = gsl::narrow(input_shape.SizeFromDimension(1)); + const auto sequence_length = gsl::narrow(input_shape[input_shape.NumDimensions() - 2]); + const auto hidden_size = batch_stride / sequence_length; + const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); + const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const TensorShape global_shape({batch_size, + sequence_length, + hidden_size / head_size, + head_size - half_rotary_embedding_dim}); + + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = gsl::narrow(global_shape[j]); + global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); + } + + const auto output_size = gsl::narrow(global_shape.Size()); + RotaryEmbeddingProgram program{interleaved_}; + const auto input_output_strides = + input_shape.NumDimensions() == 3 + ? std::vector({batch_stride, hidden_size, head_size, 1}) + : (input_shape.NumDimensions() == 4 + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + : std::vector({})); + + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::Rank}, + {position_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{scale_}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..0d73b89fb62df --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class RotaryEmbeddingProgram final : public Program { + public: + RotaryEmbeddingProgram(bool interleaved) : Program{"RotaryEmbedding"}, interleaved_{interleaved} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32}, + {"global_shape", ProgramUniformVariableDataType::Uint32}, + {"global_stride", ProgramUniformVariableDataType::Uint32}, + {"input_output_stride", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; +}; + +class RotaryEmbedding final : public WebGpuKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + float scale_; + int num_heads_; + int rotary_embedding_dim_; + bool interleaved_; + bool is_packed_batching_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc new file mode 100644 index 0000000000000..fe541f58d34ec --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +static uint32_t GetMaxComponents(int size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("skip", ShaderUsage::UseUniform); + shader.AddInput("gamma", ShaderUsage::UseUniform); + if (hasBeta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + if (hasBias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + if (has_input_skip_bias_sum_) { + shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseUniform); + } + + int components = x.NumComponents(); + + std::string bias = (hasBias_) ? " + bias[offset1d + i] " : ""; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; + std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; + std::string input_skip_bias_sum = (has_input_skip_bias_sum_) ? "input_skip_bias_sum[offset + i] = value;\n" : ""; + + shader.AdditionalImplementation() + << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n" + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n" + << "var stride = hidden_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * hidden_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = hidden_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let skip_value = skip[offset + i];\n" + << " let input_value = x[offset + i];\n" + << " let value = input_value + skip_value" << bias << ";\n" + << " output[offset + i] = value;\n" + << input_skip_bias_sum + << " let f32_value = f32_val_t(value);\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" + << "};\n"; + + return Status::OK(); +} + +template +Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* x = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* gamma = context.Input(2); + // optional + const Tensor* beta = context.Input(3); + const Tensor* bias = context.Input(4); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + auto* input_skip_bias_sum = context.Output(3, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + const uint32_t hidden_size = gsl::narrow(x_shape[x_shape.NumDimensions() - 1]); + const int components = GetMaxComponents(hidden_size); + const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr; + + SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, is_fp16, simplified}; + program + .CacheHint(simplified, has_input_skip_bias_sum) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize(gsl::narrow(ceil(1.0 * data_size / hidden_size))) + .AddUniformVariables({ + {static_cast(components)}, + }) + .AddUniformVariables({ + {static_cast(hidden_size)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (beta != nullptr) { + program.AddInput({beta, ProgramTensorMetadataDependency::Type, components}); + } + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + if (has_input_skip_bias_sum) { + program.AddOutputs({{input_skip_bias_sum, ProgramTensorMetadataDependency::None, components}}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + SkipLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SkipSimplifiedLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h new file mode 100644 index 0000000000000..03de1a4b568b9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class SkipLayerNormProgram final : public Program { + public: + SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool has_input_skip_bias_sum, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} { + epsilon_ = epsilon; + hasBeta_ = hasBeta; + hasBias_ = hasBias; + epsilon_ = epsilon; + hidden_size_ = hidden_size; + has_input_skip_bias_sum_ = has_input_skip_bias_sum; + simplified_ = simplified; + is_fp16_ = is_fp16; + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"components", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool hasBeta_; + bool hasBias_; + float epsilon_; + uint32_t hidden_size_; + bool has_input_skip_bias_sum_; + bool is_fp16_; + bool simplified_; +}; + +template +class SkipLayerNorm final : public WebGpuKernel { + public: + SkipLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + float epsilon_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..be18f820e2747 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -0,0 +1,542 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { +// Put it to a common place? +uint32_t GetMaxComponents(uint32_t size) { + // we cannot use vec3 type since it has alignment of 16 bytes + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + + return 1; +} + +std::string QuantizedDataType(int components) { + switch (components) { + case 1: + return "array"; + case 2: + return "mat4x2"; + case 4: + return "mat2x4"; + default: + return "array"; + } +} + +constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16; +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + MatMulNBits); + +Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + if (use_block32_) { + const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); + const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. + const uint32_t a_length_per_tile = tile_size / a.NumComponents(); + constexpr uint32_t block_size = 32; + const uint32_t blocks_per_tile = tile_size / block_size; + shader.AdditionalImplementation() << "var sub_a: array;\n" + << "var inter_results: array, " << WorkgroupSizeY() << ">;\n"; + std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY()); + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" + // Loop over shared dimension. + << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" + << " let a_col_start = tile * " << a_length_per_tile << ";\n" + << " // load one tile A data into shared memory.\n" + << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" + << " let a_col = a_col_start + a_offset;\n" + " if (a_col < uniforms.input_a_shape[2]) {\n" + << " sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n" + << " } else {\n" + " sub_a[a_offset] = input_a_value_t(0);\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + // Each thread processes one block. + " let b_row = col + local_id.y;\n" + << " let block = tile * " << blocks_per_tile << " + local_id.x;\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n" + " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; + } else { + // The default zero point is 8 for unsigned 4-bit quantization. + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; + } + shader.MainFunctionBody() << " var scale = output_element_t(0);\n" + " var b_data = input_b_value_t(0);\n" + << " if (block < n_blocks_per_col) {\n" + << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" + << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" + << " }\n" + << " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n" + << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; + switch (a.NumComponents()) { + case 1: + shader.MainFunctionBody() << " let a_data0 = vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);\n" + " let a_data1 = vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);\n"; + break; + case 2: + shader.MainFunctionBody() << " let a_data0 = vec4(sub_a[word_offset], sub_a[word_offset + 1]);\n" + " let a_data1 = vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]);\n"; + break; + case 4: + shader.MainFunctionBody() << " let a_data0 = sub_a[word_offset];\n" + " let a_data1 = sub_a[word_offset + 1];\n"; + break; + default: + break; + } + shader.MainFunctionBody() << " let b_value = b_data"; + if (components_b_ > 1) { + shader.MainFunctionBody() << "[i]"; + } + shader.MainFunctionBody() << ";\n" + " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" + " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" + " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + " let b_dequantized_values = (b_quantized_values - mat2x4("; + for (int i = 0; i < 8; i++) { + shader.MainFunctionBody() << "zero_point"; + if (i < 7) { + shader.MainFunctionBody() << ", "; + } + } + shader.MainFunctionBody() << ")) * scale;\n" + " inter_results[local_id.y][local_id.x] += dot(a_data0, b_dequantized_values[0]) + dot(a_data1, b_dequantized_values[1]);\n" + << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " workgroupBarrier();\n" + " }\n" + << " if (local_idx < " << WorkgroupSizeY() << ") {\n" + << " var output_value = output_value_t(0);\n" + << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" + << " output_value += inter_results[local_idx][b];\n" + " }\n" + " if (col + local_idx < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n" + << " }\n" + " }\n"; + } else { + const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); + const int output_element_number = y.NumComponents() * gsl::narrow(output_number_); + + const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; + std::string offset = "workgroup_idx * " + std::to_string(output_number_); + shader.AdditionalImplementation() << "var workgroup_shared : array;\n"; + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + " let blob_size = uniforms.input_b_shape[2];\n" + " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" + << " var word_offset = block * uniforms.block_size / " << a.NumComponents() << ";\n"; + + // prepare scale and zero point + shader.MainFunctionBody() << " var col_index = col * " << y.NumComponents() << ";\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " var zero_point_byte_count: u32;\n" + " var zero_point_word_index: u32;\n" + " var zero_point_byte_offset: u32;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " var zero_point_bits_offset: u32;\n" + " var zero_point_word: u32;\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" + " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n" + << " col_index += 1;\n"; + } + } else { + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " col_index += 1;\n"; + } + } + + shader.MainFunctionBody() << " for (var word: u32 = 0; word < blob_size; word += 1) {\n"; + + // prepare b data + shader.MainFunctionBody() << " col_index = col * " << y.NumComponents() << ";\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" + << " col_index += 1;\n"; + } + shader.MainFunctionBody() << " var b_value : u32;\n" + " let b_mask : u32 = 0x0F0F0F0Fu;\n" + " var b_value_lower : vec4;\n" + " var b_value_upper : vec4;\n" + << " var b_quantized_values : " << quantized_data_type << ";\n" + << " var b_dequantized_values : " << quantized_data_type << ";\n"; + + shader.MainFunctionBody() << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; + + // process one word + shader.MainFunctionBody() << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" + << " var a_data: " << quantized_data_type << ";\n" + << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" + << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" + << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" + << " input_offset++;\n" + " } else {\n" + " a_data[j] = input_a_value_t(0);\n" + " }\n" + " }\n"; + for (int c = 0; c < output_element_number; c++) { + shader.MainFunctionBody() << " b_value = b" << c << "_data"; + if (components_b_ > 1) { + shader.MainFunctionBody() << "[i]"; + } + shader.MainFunctionBody() << ";\n" + " b_value_lower = unpack4xU8(b_value & b_mask);\n" + " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" + << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + << " b_dequantized_values = "; + if (a.NumComponents() == 1) { + if (has_zero_points_) { + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; + } else { + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " + << "(b_quantized_values[1] - zero_point) * scale" << c << "," + << "(b_quantized_values[2] - zero_point) * scale" << c << "," + << "(b_quantized_values[3] - zero_point) * scale" << c << "," + << "(b_quantized_values[4] - zero_point) * scale" << c << "," + << "(b_quantized_values[5] - zero_point) * scale" << c << "," + << "(b_quantized_values[6] - zero_point) * scale" << c << "," + << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; + } + } else { + shader.MainFunctionBody() << "(b_quantized_values - " << quantized_data_type << "("; + for (int i = 0; i < 8; i++) { + if (has_zero_points_) { + shader.MainFunctionBody() << "zero_point" << c; + } else { + shader.MainFunctionBody() << "zero_point"; + } + if (i < 7) { + shader.MainFunctionBody() << ", "; + } + } + shader.MainFunctionBody() << ")) * scale" << c << ";\n"; + } + + shader.MainFunctionBody() << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; + if (y.NumComponents() > 1) { + shader.MainFunctionBody() << "[" << c % y.NumComponents() << "]"; + } + shader.MainFunctionBody() << " += "; + if (a.NumComponents() == 1) { + shader.MainFunctionBody() << "a_data[0] * b_dequantized_values[0] + " + "a_data[1] * b_dequantized_values[1] + " + "a_data[2] * b_dequantized_values[2] + " + "a_data[3] * b_dequantized_values[3] + " + "a_data[4] * b_dequantized_values[4] + " + "a_data[5] * b_dequantized_values[5] + " + "a_data[6] * b_dequantized_values[6] + " + "a_data[7] * b_dequantized_values[7];\n"; + } else if (a.NumComponents() == 2) { + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]) + " + "dot(a_data[2], b_dequantized_values[2]) + " + "dot(a_data[3], b_dequantized_values[3]);\n"; + } else if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]);\n"; + } + } + + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + << " if (local_id.x < " << output_number_ << ") {\n" + << " var output_value = output_value_t(0);\n" + " var workgroup_shared_offset = local_id.x;\n" + << " let blocks_num = min(" << shared_memory_size << ", n_blocks_per_col);\n" + << " for (var b = 0u; b < blocks_num; b++) {\n" + " output_value += workgroup_shared[workgroup_shared_offset];\n" + << " workgroup_shared_offset += " << output_number_ << ";\n" + << " }\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value") << "\n" + << " }\n"; + } + + return Status::OK(); +} + +Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + // This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion + // M is the number of rows in A and M rows in the output. + // N is the number of columns in B and N columns in the output. + // K is the hidden/shared dimension number of columns in A and K rows in B. + // Note in matmulnbits, B matrix is already transposed, however the following remains true + // for the shader below M describes A, N describes B and K is the hidden/shared dimension. + // K4/K8 are simply K divided by 4 or 8 respectively. + shader.AdditionalImplementation() << R"INIT_SECTION( +// Matrix dimensions and quantization parameters +const TILE_SIZE : u32 = 16u; +const VALUES_PER_VEC4 : u32 = 4u; +const QUANTIZATION_BLOCK_SIZE : u32 = 32; +// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM, +// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time. +// This uses all 16 lanes on 12th gen intel chips. +const BLOCKS_PER_CYCLE : u32 = 2u; +const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE +const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4; + +//Shared memory +var tile_A : array, TILE_SIZE>; +var tile_B : array, TILE_SIZE>; +var tile_O : array, TILE_SIZE>; + +fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32) +{ + if (a_global >= uniforms.M) { + return; + } + let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id]; + tile_A[slot][parallel_id] = local_A; +} + +fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t +{ + // Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at + // a time or 2 scales at every step. + let scale_offset = vec_step_idx*2; + let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset); + return scales[idx+scale_idx]; +} + +fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32) +{ + if (b_global >= uniforms.N) { + return; + } + let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE)); + let idx:u32 = parallel_id; + if (idx % 2 == 0) + { + // Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time. + // Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight + // elements were holding one element each. For the case of each element holding 8 values, begin + // would become vec_step_idx * 64/8 or vec_step_idx * 8. + var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2); + let b_value = input_b[b_global*uniforms.K8+weight_offset]; + let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu); + tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale; + tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale; + tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale; + tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale; + tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale; + tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale; + tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale; + tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale; + } +} + +fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t +{ + var sum:output_value_t = 0; + for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++) + { + sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]); + } + return sum; +} +)INIT_SECTION"; + + shader.MainFunctionBody() << R"MAIN_FN( + // Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE + // appears to give a performance win on Intel Gen12LP architecture. + // This is likley because of locality of memory access, idy below in this approach + // is the same as subgroup_id or lane id, while idx is the wave_id. + // The work distribution therefore keeps memory accesses close together in + // a single wave in this approach of indexing. + let idx = u32(local_idx / TILE_SIZE); + let idy = u32(local_idx % TILE_SIZE); + let a_global_base = workgroup_id.x * TILE_SIZE; + let b_global_base = workgroup_id.y * TILE_SIZE; + let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE)); + for (var vec_step:u32 = 0; vec_step < step_count; vec_step++) + { + workgroupBarrier(); + loadA(idx, a_global_base+idx, vec_step, idy); + loadB(idx, b_global_base+idx, vec_step, idy); + workgroupBarrier(); + let result = computeDotProduct(idx, idy); + tile_O[idx][idy]+=result; + } + workgroupBarrier(); + if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) { + output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy]; + } +)MAIN_FN"; + return Status::OK(); +} + +Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* b = context.Input(1); + const Tensor* scales = context.Input(2); + const Tensor* zero_points = context.Input(3); + const Tensor* g_idx = context.Input(4); + const Tensor* bias = context.Input(5); + + ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet."); + ORT_ENFORCE(bias == nullptr, "bias as input is not supported yet."); + + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + auto* y = context.Output(0, helper.OutputShape()); + const uint32_t data_size = gsl::narrow(y->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const uint32_t batch_count = gsl::narrow(helper.OutputOffsets().size()); + const uint32_t M = gsl::narrow(helper.M()); + const uint32_t N = gsl::narrow(helper.N()); + const uint32_t K = gsl::narrow(helper.K()); + const uint32_t block_size = gsl::narrow(block_size_); + constexpr uint32_t nbits = 4; + + const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; + const uint32_t blob_size = (block_size / 8) * nbits; + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + uint32_t components = GetMaxComponents(N); + + // Use block32 for Intel Gen12LP architecture. + const bool use_block32 = context.AdapterInfo().vendor == std::string_view{"intel"} && + context.AdapterInfo().architecture == std::string_view{"gen-12lp"} && + block_size == 32; + const bool has_zero_points = zero_points != nullptr; + + if (use_block32 && batch_count == 1 && + components_a == 4 && components_b == 4 && + !has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) { + MatMulNBitsProgramPrefill program; + constexpr int32_t tile_size = 16; + // subgroup_size here controls how many elements of the hidden dimension we load in a cycle. + // MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup + // size just helps with optimal lane usage in the shader. + constexpr int32_t subgroup_size = 16; + program.SetWorkgroupSize(tile_size * subgroup_size); + program.SetDispatchGroupSize((M + tile_size - 1) / tile_size, + (N + tile_size - 1) / tile_size, + 1); + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 4)}, + {static_cast(K / 8)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}); + return context.RunProgram(program); + } else { + // TODO: Support output_number > 1. Some cases are failed when output_number > 1. + // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; + constexpr uint32_t output_number = 1; + MatMulNBitsProgram program{output_number, gsl::narrow(components_b), has_zero_points, use_block32}; + + if (use_block32) { + components = 1; + constexpr uint32_t workgroup_size = 128; + const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 + : 1; + const uint32_t workgroup_x = workgroup_size / workgroup_y; + program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); + program.SetDispatchGroupSize(data_size / components / workgroup_y); + } else { + program.SetDispatchGroupSize(data_size / components / output_number); + } + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddUniformVariable({block_size}) + .CacheHint(std::to_string(output_number)); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + return context.RunProgram(program); + } +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..5f785c03f6a5e --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MatMulNBitsProgram final : public Program { + public: + MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points, bool use_block32) : Program{"MatMulNBits"}, + output_number_{output_number}, + components_b_{components_b}, + has_zero_points_{has_zero_points}, + use_block32_{use_block32} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t output_number_; + int components_b_; + bool has_zero_points_; + bool use_block32_; +}; + +class MatMulNBitsProgramPrefill final : public Program { + public: + MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K4", ProgramUniformVariableDataType::Uint32}, + {"K8", ProgramUniformVariableDataType::Uint32}); +}; + +class MatMulNBits final : public WebGpuKernel { + public: + MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + N_ = info.GetAttr("N"); + block_size_ = info.GetAttr("block_size"); + int64_t bits = info.GetAttr("bits"); + ORT_ENFORCE(bits == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 8ed1372cd0e62..2e7ed5a16a2f0 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -9,6 +9,24 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); +// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -18,7 +36,22 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - }; + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/codegen/common/common.cc b/onnxruntime/core/codegen/common/common.cc deleted file mode 100644 index 818b919e99ef2..0000000000000 --- a/onnxruntime/core/codegen/common/common.cc +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/common.h" - -#include "core/framework/tensorprotoutils.h" -#include "core/common/inlined_containers.h" -#include "core/graph/graph.h" -#include "core/graph/schema_registry.h" -#include -#include - -namespace onnxruntime { - -NodeKey GetKey(const onnxruntime::Node* node) { - ORT_ENFORCE(nullptr != node); - ORT_ENFORCE(node->OutputDefs().size() > 0); - return node->OutputDefs()[0]->Name(); -} - -NodeKey GetKey(const onnxruntime::Node& node) { - ORT_ENFORCE(node.OutputDefs().size() > 0); - return node.OutputDefs()[0]->Name(); -} - -NodeKey GetKey(const onnxruntime::NodeArg* def) { - // NodeArg's name is unique. - ORT_ENFORCE(nullptr != def); - return def->Name(); -} - -bool IsRecurrentNode(const onnxruntime::Node& node) { - auto op_type = node.OpType(); - return (op_type == "LSTM" || op_type == "RNN" || op_type == "GRU" || - op_type == "Scan" || op_type == "Loop"); -} - -bool IsAliasNode(const onnxruntime::Node& node) { - auto op_type = node.OpType(); - if (op_type == "Transpose") { - // Treat Transpose (1,N) -> (N,1) as Alias - const auto shape = node.OutputDefs()[0]->Shape(); - if (shape != nullptr && shape->dim_size() == 2) { - for (int i = 0; i < 2; ++i) { - if (shape->dim(i).has_dim_value() && shape->dim(i).dim_value() == 1) { - return true; - } - } - } - return false; - } - - return (op_type == "Flatten" || op_type == "Identity" || op_type == "Reshape" || - op_type == "Squeeze" || op_type == "Unsqueeze"); -} - -std::string NormalizeCppName(const std::string& name) { - std::string normalized_name = name; - for (char c : {'.', ' ', '+', '-', '*', '/', '\\', '='}) - std::replace(normalized_name.begin(), normalized_name.end(), c, '_'); - return normalized_name; -} - -std::string NormalizeNodeArgName(const NodeArg* def) { - return NormalizeCppName(def->Name()); -} - -bool IsFusedNode(const Node& node) { - if (node.NodeType() == Node::Type::Fused) { - return true; - } - return false; -} - -// A unified API to get Subgraph -const Graph* GetSubgraph(const Node& node) { - if (node.NodeType() == Node::Type::Fused) { - return &(node.GetFunctionBody()->Body()); - } else if (node.OpType() == "Scan") { - return node.GetGraphAttribute("body"); - } - // return nullptr implying no subgraph - return nullptr; -} - -bool HasLoop(const Node& node) { - auto op_type = node.OpType(); - if (op_type == "LSTM" || - op_type == "GRU" || - op_type == "RNN" || - op_type == "Scan") { - return true; - } - return false; -} - -// Return the corresponding input node for the NodeArg of the given node -const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def) { - const auto& input_name = def->Name(); - const onnxruntime::Node* input_node = nullptr; - // search input node set to see if input_name is in their outputs (weights are not from node) - for (auto iter = node.InputNodesBegin(); iter != node.InputNodesEnd(); ++iter) { - const onnxruntime::Node& p = *iter; - bool found = false; - ORT_THROW_IF_ERROR(p.ForEachWithIndex( - p.OutputDefs(), - [&found, &input_name](const onnxruntime::NodeArg& out_def, size_t) { - if (input_name == out_def.Name()) { - found = true; - } - return Status::OK(); - })); - if (found) - input_node = &p; - } - return input_node; -} - -// create capacity from subgraph -std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& graph, - int fused_count, - std::unique_ptr& subgraph) { - auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "Fuse" + std::to_string(fused_count); - meta_def->domain = "Fuse"; - - std::set node_indices(subgraph->nodes.begin(), subgraph->nodes.end()); - - const auto& start_node_index = subgraph->nodes.front(); - const auto& start_node = *graph.GetNode(start_node_index); - const auto& end_node_index = subgraph->nodes.back(); - const auto& end_node = *graph.GetNode(end_node_index); - meta_def->name += start_node.OpType() + std::to_string(start_node_index); - meta_def->name += "_With" + std::to_string(subgraph->nodes.size()) + "Nodes_"; - meta_def->name += end_node.OpType() + std::to_string(end_node_index); - - InlinedHashSet real_output_names; - real_output_names.reserve(graph.GetOutputs().size()); - for (const auto* def : graph.GetOutputs()) { - real_output_names.insert(def->Name()); - } - - for (const auto& node_index : subgraph->nodes) { - const auto& node = *graph.GetNode(node_index); - auto process_input_fn = - [&meta_def, &node, &node_indices](const onnxruntime::NodeArg& def, size_t) { - const onnxruntime::Node* input_node = GetInputNode(node, &def); - bool input_from_subgraph = (input_node && node_indices.count(input_node->Index())); - if (!input_from_subgraph) { - // input is from weights or outside of graph - meta_def->inputs.push_back(def.Name()); - } - return Status::OK(); - }; - // handle current graph's inputs - ORT_THROW_IF_ERROR(node.ForEachWithIndex(node.InputDefs(), process_input_fn)); - // nodes' implicit inputs also need to be collected. They need to - // be promoted to being explicit inputs for everything to work. - ORT_THROW_IF_ERROR(node.ForEachWithIndex(node.ImplicitInputDefs(), process_input_fn)); - - // Handle outouts - // two cases are considered as outputs - // 1. Output NodeArg is not used by any Node - // 2. Output NodeArg is used by at least one Node out of this subgraph. - // Note a NodeArg can be used by Nodes in and out of the subgraph at the same time. - // 3. Output NodeArg is one of real outputs of an Ort graph. - - auto InsertOutputToSubgraph = [&meta_def](const NodeArg* def) { - if (std::find(meta_def->outputs.begin(), meta_def->outputs.end(), def->Name()) == - meta_def->outputs.end()) { - meta_def->outputs.push_back(def->Name()); - } - }; - - InlinedHashSet input_names_from_the_output_node; - - for (auto o_iter = node.OutputEdgesBegin(); o_iter != node.OutputEdgesEnd(); ++o_iter) { - const auto& p = *o_iter; - const Node& out_node = p.GetNode(); - - // preprocess for the case 1 - ORT_THROW_IF_ERROR(out_node.ForEachWithIndex( - out_node.InputDefs(), - [&input_names_from_the_output_node](const onnxruntime::NodeArg& in_def, size_t) { - input_names_from_the_output_node.insert(in_def.Name()); - return Status::OK(); - })); - - // handle the case 2 - if (node_indices.count(out_node.Index()) == 0) { - const NodeArg* def = node.OutputDefs()[p.GetSrcArgIndex()]; - InsertOutputToSubgraph(def); - } - } - - // handle case 1 and 3 - ORT_THROW_IF_ERROR(node.ForEachWithIndex( - node.OutputDefs(), - [&](const onnxruntime::NodeArg& def, size_t) { - if (input_names_from_the_output_node.count(def.Name()) == 0 || - real_output_names.count(def.Name()) > 0) { - InsertOutputToSubgraph(&def); - } - return Status::OK(); - })); - } - - // Handle subgraph's initializers - const auto& all_initializers = graph.GetAllInitializedTensors(); - for (const auto& node_index : subgraph->nodes) { - const auto& node = *graph.GetNode(node_index); - // check whether it is an immediate nested subgraph - auto immediate_nested_subgraph = GetSubgraph(node); - // If so, copy the immediate nested subgraph's initializers to meta_def->inputs. - // Note we don't need recursion here, since Ort did recursion for us by handling subgraph early than the current graph. - // Therefore, the all inner nested subgraph's initializers should be already in the immediate nested subgraph's inputs. - if (nullptr != immediate_nested_subgraph) { - for (auto& n : immediate_nested_subgraph->Nodes()) { - auto add_input_fn = - [&meta_def, &all_initializers](const onnxruntime::NodeArg& def, size_t) { - auto iter = all_initializers.find(def.Name()); - if (iter != all_initializers.end()) { - meta_def->inputs.push_back(def.Name()); - } - return Status::OK(); - }; - ORT_THROW_IF_ERROR(n.ForEachWithIndex(n.InputDefs(), add_input_fn)); - ORT_THROW_IF_ERROR(n.ForEachWithIndex(n.ImplicitInputDefs(), add_input_fn)); - } - } - } - - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - std::unique_ptr finished_subgraph(subgraph.release()); - finished_subgraph->SetMetaDef(std::move(meta_def)); - return std::make_unique(std::move(finished_subgraph)); -} - -int64_t ShapeRank(const NodeArg* def) { - ORT_ENFORCE_DEBUG(nullptr != def); - return gsl::narrow_cast(def->Shape()->dim_size()); -} - -bool ShapeHasValue(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(nullptr != def); - ORT_ENFORCE_DEBUG(i >= 0); - ORT_ENFORCE_DEBUG(i < def->Shape()->dim_size()); - return utils::HasDimValue(def->Shape()->dim(i)); -} - -bool ShapeHasSymbol(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(nullptr != def); - ORT_ENFORCE_DEBUG(i >= 0); - ORT_ENFORCE_DEBUG(i < def->Shape()->dim_size()); - return utils::HasDimParam(def->Shape()->dim(i)); -} - -int64_t ShapeValue(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(ShapeHasValue(def, i)); - return def->Shape()->dim(i).dim_value(); -} - -const std::string& ShapeSymbol(const NodeArg* def, int i) { - ORT_ENFORCE_DEBUG(ShapeHasSymbol(def, i)); - return def->Shape()->dim(i).dim_param(); -} - -ONNX_NAMESPACE::TensorProto_DataType TensorProtoDataType(const NodeArg* def) { - ORT_ENFORCE_DEBUG(nullptr != def); - return static_cast(def->TypeAsProto()->tensor_type().elem_type()); -} - -// Convert GraphNodes to internal NodePtrs without check lifetime. -// Please use it only locally when GraphNodes still exist -InlinedVector ConvertGraphNodesToNodePtrs(const ConstGraphNodes& graph_nodes) { - InlinedVector nodes; - for (auto& node : graph_nodes) { - nodes.push_back(&node); - } - return nodes; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/common.h b/onnxruntime/core/codegen/common/common.h deleted file mode 100644 index 81b74daf6f711..0000000000000 --- a/onnxruntime/core/codegen/common/common.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/inlined_containers.h" -#include "core/framework/compute_capability.h" -#include "core/framework/tensor.h" -#include "core/graph/graph_nodes.h" -#include "core/graph/graph_viewer.h" - -#ifndef NDEBUG -#define ORT_ENFORCE_DEBUG(...) ORT_ENFORCE(__VA_ARGS__) -#else -#define ORT_ENFORCE_DEBUG(...) -#endif // !NDEBUG - -// DYN_PROMOTE is a simplified llvm::dyn_cast, which does not need RTTI -// DYN_PROMOTE is faster than dynamic_cast and also has smaller binary size -// Please use DYN_PROMOTE in a critical path. -#define DYN_PROMOTE(BASE) \ - template \ - inline const ToType* Promote(const BASE* base) { \ - if (ToType::IsType(base)) \ - return static_cast(base); \ - return nullptr; \ - } \ - \ - template \ - inline ToType* Promote(BASE* base) { \ - if (ToType::IsType(base)) \ - return static_cast(base); \ - return nullptr; \ - } \ - \ - template \ - inline ToType* Promote(const std::unique_ptr& base) { \ - if (ToType::IsType(base.get())) \ - return static_cast(base); \ - return nullptr; \ - } \ - \ - template \ - inline ToType* Promote(const std::shared_ptr& base) { \ - if (ToType::IsType(base.get())) \ - return static_cast(base); \ - return nullptr; \ - } - -// DYN_PROMOTE_BASE is a macro inserted in the base class to support DYN_PROMOTE -// TYPE_ID is required for DYN_PROMOTE and TYPE_ID is a enum class -// TYPE_ID_VAR is a corresponding variable name for in the base class -#define DYN_PROMOTE_BASE(BASE, TYPE_ID, TYPE_ID_VAR) \ - inline const TYPE_ID TypeID() const { \ - return TYPE_ID_VAR; \ - } \ - \ - static inline bool IsType(const BASE*) { \ - return true; \ - } - -// DYN_PROMOTE_DERIVED is a macro inserted in a derived class to support DYN_PROMOTE -// TYPE_ID is required for DYN_PROMOTE and TYPE_ID is a enum class -// TYPE_ID_VALUE is corresponding TYPE_ID::value of a derived class. -#define DYN_PROMOTE_DERIVED(BASE, TYPE_ID, TYPE_ID_VALUE) \ - static inline bool IsType(const BASE* base) { \ - ORT_ENFORCE_DEBUG(nullptr != base); \ - return base->TypeID() == TYPE_ID::TYPE_ID_VALUE; \ - } - -// DYNAMIC_PROMOTE is a dynamic_cast needing RTTI -// DYNAMIC_PROMOTE is usually slower than than DYN_PROMOTE. -// Please use DYNAMIC_PROMOTE in a non-critical path. -#define DYNAMIC_PROMOTE(BASE) \ - template \ - inline const X* Promote(const BASE* base) { \ - auto derived = dynamic_cast(base); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } \ - \ - template \ - inline X* Promote(BASE* base) { \ - auto derived = dynamic_cast(base); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } \ - \ - template \ - inline X* Promote(const std::unique_ptr& base) { \ - auto derived = dynamic_cast(base.get()); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } \ - \ - template \ - inline X* Promote(const std::shared_ptr& base) { \ - auto derived = dynamic_cast(base.get()); \ - ORT_ENFORCE(nullptr != derived); \ - return derived; \ - } - -namespace onnxruntime { - -// Nodekey is used as a key for maps -using NodeKey = std::string; - -NodeKey GetKey(const onnxruntime::Node* node); -NodeKey GetKey(const onnxruntime::Node& node); -NodeKey GetKey(const onnxruntime::NodeArg* def); - -bool IsRecurrentNode(const onnxruntime::Node& node); - -bool IsAliasNode(const onnxruntime::Node& node); - -// Helper function that creates ComputeCapability for subgraphs -std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& graph, - int fused_count, - std::unique_ptr& subgraph); - -bool IsFusedNode(const Node& node); - -bool HasLoop(const Node& node); - -const Graph* GetSubgraph(const Node& node); - -std::string NormalizeCppName(const std::string& name); - -std::string NormalizeNodeArgName(const NodeArg* def); - -// Return the corresponding input node for the NodeArg of the given node -const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def); - -int64_t ShapeRank(const NodeArg* def); - -bool ShapeHasValue(const NodeArg* def, int i); - -bool ShapeHasSymbol(const NodeArg* def, int i); - -int64_t ShapeValue(const NodeArg* def, int i); - -const std::string& ShapeSymbol(const NodeArg* def, int i); - -ONNX_NAMESPACE::TensorProto_DataType TensorProtoDataType(const NodeArg* def); - -// Convert ConstGraphNodes to internal NodePtrs without check lifetime. -// Please use it only locally when GraphNodes still exist -InlinedVector ConvertGraphNodesToNodePtrs(const ConstGraphNodes& graph_nodes); - -enum : int { - Dimension_Unknown = -1, -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/creator.h b/onnxruntime/core/codegen/common/creator.h deleted file mode 100644 index b31a12db4875b..0000000000000 --- a/onnxruntime/core/codegen/common/creator.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/dispatcher.h" - -// TODO rename this file to creator_base -namespace onnxruntime { -namespace codegen { - -// It is a base class for TVM Op IR builder, weight layout builder, TVM scheduler -// CreatorBase is a template class of compiler pass -// for 1) TVM IR builder -// 2) Weight layout transformer -// 3) TVM Scheduler, etc. -// CreatorBase is similor to OpXXCreate in llvm IR builder - -template -class CreatorBase { - public: - CreatorBase(const std::string& name) - : name_(name) {} - - virtual ~CreatorBase() = default; - - virtual RETURN_TYPE Evaluate(INPUT_TYPE, - NODE_TYPE, - CONTEXT_TYPE, - OUTPUT_TYPE) = 0; - - const std::string& Name() const { - return name_; - } - - protected: - std::string name_; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CreatorBase); -}; - -// macro to stringize -#define STRINGIZE_NX(OP) #OP -#define STRINGIZE(OP) STRINGIZE_NX(OP) - -// macro returns class name -#define CREATOR_CLASS(OP, POSTFIX) \ - OP##POSTFIX - -// macro returns class name as string -#define CREATOR_STRING(OP, POSTFIX) \ - STRINGIZE(CREATOR_CLASS(OP, POSTFIX)) - -// macro returns class constructor name -#define CREATOR_CLASS_FUNC(OP, POSTFIX) \ - OP##POSTFIX() - -// macro declares a creator class inheriting the template class CreatorBase -// with corresponding template parameters -#define DECLARE_CREATOR_CLASS(OP, POSTFIX, INPUT, NODE, CONTEXT, OUTPUT, RETURN) \ - class CREATOR_CLASS(OP, POSTFIX) : public onnxruntime::codegen::CreatorBase { \ - public: \ - CREATOR_CLASS_FUNC(OP, POSTFIX) : CreatorBase(CREATOR_STRING(OP, POSTFIX)) {} \ - RETURN Evaluate(INPUT, \ - NODE, \ - CONTEXT, \ - OUTPUT) override; \ - \ - private: \ - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CREATOR_CLASS(OP, POSTFIX)); \ - }; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/dispatcher.h b/onnxruntime/core/codegen/common/dispatcher.h deleted file mode 100644 index 80a854a06977c..0000000000000 --- a/onnxruntime/core/codegen/common/dispatcher.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include -#include -#include - -namespace onnxruntime { -namespace codegen { - -// DispatcherBase is a customized unordered_map -// that provides all codegen-related functionality -// including 1) dispatching a pass -// 2) dump corresponding name -// DispatcherBase may or may not keep ownership, -// depending on the template parameter, CONTENT_TYPE. -// Note DispatcherBase has a protected destructor - -template -class DispatcherBase { - public: - DispatcherBase(const std::string& name) - : name_(name) {} - - const std::string& Name() const { - return name_; - } - - bool Contains(const std::string& name) const { - return contents_.count(name) > 0; - } - - void ForEach(std::function - func) { - for (auto& p : contents_) { - func(p.first, p.second); - } - } - - bool Register(const std::string& name, - CONTENT_TYPE op) { - if (!Contains(name)) { - contents_.emplace(name, op); - return true; - } - return false; - } - - CONTENT_TYPE Get(const std::string& key) const { - auto iter = contents_.find(key); - if (iter != contents_.end()) { - return iter->second; - } - return nullptr; - } - - const std::unordered_map GetContents() const { - return contents_; - } - - std::unordered_map GetMutableContents() { - return contents_; - } - - protected: - std::string name_; - std::unordered_map contents_; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DispatcherBase); - ~DispatcherBase() = default; -}; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/dump_array.h b/onnxruntime/core/codegen/common/dump_array.h deleted file mode 100644 index 8e51cd36d0087..0000000000000 --- a/onnxruntime/core/codegen/common/dump_array.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include -#include -#include -#include - -namespace onnxruntime { - -template -void DumpArrayRecursive(const T1* data, int64_t& data_offset, const std::vector& shape, int idx) { - int dim = static_cast(shape.size()); - if (dim == 0) { - std::cout << "[]\n"; - return; - } - - assert(idx < dim); - int sz = shape[idx]; - - std::cout << "["; - if (idx < dim - 1) { - for (auto i = 0; i < sz; ++i) { - DumpArrayRecursive(data, data_offset, shape, idx + 1); - if (i < sz - 1) { - std::cout << ","; - // print multiple newlines after ',' when necessary - for (int j = idx + 1; j < dim; j++) - std::cout << "\n"; - // print leading spaces before "[" when necessary - for (int j = 0; j < idx + 1; ++j) - std::cout << " "; - } - } - } else { - for (auto i = 0; i < sz; ++i) { - if (std::is_same::value || std::is_same::value) - std::cout << std::setw(3) << static_cast(*(data + data_offset)); - else - std::cout << std::setw(12) << std::setprecision(8) << *(data + data_offset); - data_offset++; - if (i < sz - 1) - std::cout << ","; - } - } - std::cout << "]"; -} - -// A helper function to dump multidimensional arrays in a way similar to numpy -template -void DumpArray(const std::string& tag, const T1* data, const std::vector& shape) { - std::cout << tag << "\n"; - int64_t data_offset = 0; - DumpArrayRecursive(data, data_offset, shape, 0); - assert(data_offset == TotalSize(shape)); - std::cout << std::endl; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/handle.h b/onnxruntime/core/codegen/common/handle.h deleted file mode 100644 index 7caad27dcbe01..0000000000000 --- a/onnxruntime/core/codegen/common/handle.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/target_info.h" -#include -#include - -namespace onnxruntime { -namespace codegen { - -using DomainVersionLookupFunc = std::function; - -struct CodeGenHandle { - CodeGenTarget* codegen_target; - DomainVersionLookupFunc domain_version_lookup_func = - // by default, always uses the latest opset implemented - [](const std::string&) { return INT_MAX; }; -}; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/op_macro.h b/onnxruntime/core/codegen/common/op_macro.h deleted file mode 100644 index 04305c4aa47b0..0000000000000 --- a/onnxruntime/core/codegen/common/op_macro.h +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -namespace onnxruntime { - -#define LIST_BINARY_OPS() \ - BINARY_OP(Add) \ - BINARY_OP(Div) \ - BINARY_OP(Mul) \ - BINARY_OP(PRelu) \ - BINARY_OP(Sub) - -#define LIST_BINARY_CMP_OPS() \ - BINARY_CMP_OP(Equal) \ - BINARY_CMP_OP(Greater) \ - BINARY_CMP_OP(Less) - -#define LIST_POOL_OPS() \ - POOL_OP(MaxPool) \ - POOL_OP(AveragePool) \ - POOL_OP(GlobalMaxPool) \ - POOL_OP(GlobalAveragePool) - -#define LIST_REDUCE_OPS() \ - REDUCE_INDEXED_OP(ArgMax) \ - REDUCE_INDEXED_OP(ArgMin) \ - REDUCE_OP(ReduceL1) \ - REDUCE_OP(ReduceL2) \ - REDUCE_OP(ReduceLogSum) \ - REDUCE_OP(ReduceLogSumExp) \ - REDUCE_OP(ReduceMax) \ - REDUCE_OP(ReduceMean) \ - REDUCE_OP(ReduceMin) \ - REDUCE_OP(ReduceProd) \ - REDUCE_OP(ReduceSum) \ - REDUCE_OP(ReduceSumSquare) - -#define LIST_UNARY_OPS() \ - UNARY_OP(Abs) \ - UNARY_OP(Affine) \ - UNARY_OP(Ceil) \ - UNARY_OP(Elu) \ - UNARY_OP(Exp) \ - UNARY_OP(Floor) \ - UNARY_OP(HardSigmoid) \ - UNARY_OP(LeakyRelu) \ - UNARY_OP(Log) \ - UNARY_OP(Neg) \ - UNARY_OP(ParametricSoftplus) \ - UNARY_OP(Reciprocal) \ - UNARY_OP(Relu) \ - UNARY_OP(ScaledTanh) \ - UNARY_OP(Selu) \ - UNARY_OP(Sigmoid) \ - UNARY_OP(Softplus) \ - UNARY_OP(Softsign) \ - UNARY_OP(Sqrt) \ - UNARY_OP(Tanh) \ - UNARY_OP(ThresholdedRelu) - -#define LIST_VARIADIC_OPS() \ - VARIADIC_OP(Max) \ - VARIADIC_OP(Min) \ - VARIADIC_OP(Sum) - -#define LIST_ALL_GENERIC_OPS() \ - LIST_BINARY_OPS() \ - LIST_BINARY_CMP_OPS() \ - LIST_REDUCE_OPS() \ - LIST_POOL_OPS() \ - LIST_UNARY_OPS() \ - LIST_VARIADIC_OPS() \ - ADD_OP_ITEM(Cast) \ - ADD_OP_ITEM(Clip) \ - ADD_OP_ITEM(Concat) \ - ADD_OP_ITEM(Conv) \ - ADD_OP_ITEM(Crop) \ - ADD_OP_ITEM(Dropout) \ - ADD_OP_ITEM(Expand) \ - ADD_OP_ITEM(Flatten) \ - ADD_OP_ITEM(Gather) \ - ADD_OP_ITEM(GatherElements) \ - ADD_OP_ITEM(Gemm) \ - ADD_OP_ITEM(Identity) \ - ADD_OP_ITEM(LogSoftmax) \ - ADD_OP_ITEM(LSTM) \ - ADD_OP_ITEM(MatMul) \ - ADD_OP_ITEM(MatMulInteger) \ - ADD_OP_ITEM(Pad) \ - ADD_OP_ITEM(Reshape) \ - ADD_OP_ITEM(Shape) \ - ADD_OP_ITEM(Slice) \ - ADD_OP_ITEM(Softmax) \ - ADD_OP_ITEM(Split) \ - ADD_OP_ITEM(Squeeze) \ - ADD_OP_ITEM(Transpose) \ - ADD_OP_ITEM(Unsqueeze) \ - ADD_OP_ITEM(Where) - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/profile.h b/onnxruntime/core/codegen/common/profile.h deleted file mode 100644 index 31c9e764320d0..0000000000000 --- a/onnxruntime/core/codegen/common/profile.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -// uncomment this line or use -DCODEGEN_ENABLE_PROFILER in compiler options to enable profiler events in codegen -// #define CODEGEN_ENABLE_PROFILER - -#ifdef CODEGEN_ENABLE_PROFILER -#include "core/common/profiler.h" - -namespace onnxruntime { - -class ProfilerEvent { - public: - ProfilerEvent(const std::string& name) : name_(name) { - ts_ = profiling::Profiler::Instance().StartTime(); - } - - ~ProfilerEvent() { - profiling::Profiler::Instance().EndTimeAndRecordEvent(profiling::EventCategory::NODE_EVENT, name_, ts_); - } - - private: - TimePoint ts_; - const std::string name_; -}; - -} // namespace onnxruntime - -#define CODEGEN_PROFILER_EVENT(name) onnxruntime::ProfilerEvent profiler_event(name) - -#else - -#define CODEGEN_PROFILER_EVENT(name) - -#endif diff --git a/onnxruntime/core/codegen/common/registry.h b/onnxruntime/core/codegen/common/registry.h deleted file mode 100644 index c1642e76e2120..0000000000000 --- a/onnxruntime/core/codegen/common/registry.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include -#include -#include - -namespace onnxruntime { -namespace codegen { - -// RegistryBase is a customized unordered_map -// that keep ownership of passes, -// including 1) IR builder passes -// 2) Weight layout transformer passes -// 3) Scheduler passses, etc. - -template -class RegistryBase { - public: - RegistryBase() = default; - - virtual ~RegistryBase() = default; - - bool Contains(const std::string& name) const { - return contents_.count(name) > 0; - } - - CONTENT_TYPE* Get(const std::string& name) const { - if (contents_.find(name) != contents_.end()) - return contents_.at(name).get(); - return nullptr; - } - - CONTENT_TYPE* RegisterOrGet( - const std::string& name, - std::unique_ptr&& ptr) { - if (!Contains(name)) - contents_.emplace(name, std::move(ptr)); - return Get(name); - } - - CONTENT_TYPE* RegisterOrGet( - std::unique_ptr&& ptr) { - return RegisterOrGet(ptr->Name(), std::move(ptr)); - } - - bool Register( - const std::string& name, - std::unique_ptr&& ptr) { - if (!Contains(name)) { - contents_.emplace(name, std::move(ptr)); - return true; - } - return false; - } - - bool Register( - std::unique_ptr&& ptr) { - return Register(ptr->Name(), std::move(ptr)); - } - - protected: - std::unordered_map> contents_; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RegistryBase); -}; - -// Put common Registry Management utilities if these is any - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/settings.cc b/onnxruntime/core/codegen/common/settings.cc deleted file mode 100644 index 529cb654f922c..0000000000000 --- a/onnxruntime/core/codegen/common/settings.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/settings.h" - -#include "core/common/logging/logging.h" -#include -#include - -namespace onnxruntime { -namespace codegen { - -CodeGenSettings& CodeGenSettings::Instance() { - static CodeGenSettings settings; - return settings; -} - -CodeGenSettings::CodeGenSettings() {} - -void CodeGenSettings::InsertOptions(const std::map& options) { - for (const auto& option : options) { - const auto& key = option.first; - const auto& value = option.second; - - auto iter = options_.find(key); - // found existing ones - if (iter != options_.end()) { - if (iter->second != value) { - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "CodeGenSettings: option" - << key << " is overridded from: " - << iter->second << " to: " << value; - iter->second = value; - } - } else { - options_.insert(std::make_pair(key, value)); - } - } -} - -void CodeGenSettings::DumpOptions() const { - std::ostringstream stream; - stream << "CodeGenSettings: dump all options" << std::endl; - for (const auto& option : options_) { - stream << " " << option.first << " = " << option.second << std::endl; - } - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << stream.str(); -} - -std::string CodeGenSettings::GetOptionValue(const std::string& key) const { - const auto& iter = options_.find(key); - if (iter == options_.end()) { - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << "CodeGenSettings::GetOptionValue: unrecognized option" << key; - return ""; - } - return iter->second; -} - -bool CodeGenSettings::HasOption(const std::string& key) const { - return options_.count(key) > 0; -} - -bool CodeGenSettings::OptionMatches(const std::string& key, const std::string& value) const { - if (!HasOption(key)) - return false; - -#ifdef _WIN32 - return 0 == _stricmp(options_.at(key).c_str(), value.c_str()); -#else - return 0 == strcasecmp(options_.at(key).c_str(), value.c_str()); -#endif -} - -void CodeGenSettings::Clear() { - options_.clear(); -} - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/settings.h b/onnxruntime/core/codegen/common/settings.h deleted file mode 100644 index e327b0e207cc2..0000000000000 --- a/onnxruntime/core/codegen/common/settings.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace codegen { - -// use log level warning as default to make sure logs are outputted -#define CODEGEN_SETTINGS_LOG_LEVEL WARNING - -// This stores codegen settings to control dumps, execution preference, etc. -// CodeGenSettings could come from command line options or environment variables -// Or could come from a static variables in source code -class CodeGenSettings { - public: - // generic built-in options - constexpr static const char* kDumpAllOptions = "dump_all_options"; - constexpr static const char* kCodeGenDumpModule = "codegen_dump_module"; // dump tvm module - constexpr static const char* kCodeGenDumpLower = "codegen_dump_lower"; // dump lowered func - constexpr static const char* kCodeGenDumpSchedule = "codegen_dump_schedule"; // dump scheduler - - void InsertOptions(const std::map& options); - void DumpOptions() const; - std::string GetOptionValue(const std::string& key) const; - bool HasOption(const std::string& key) const; - bool OptionMatches(const std::string& key, const std::string& value) const; - void Clear(); - static CodeGenSettings& Instance(); - - private: - CodeGenSettings(); - - std::map options_; -}; - -} // namespace codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/target_info.h b/onnxruntime/core/codegen/common/target_info.h deleted file mode 100644 index da063545f0a1e..0000000000000 --- a/onnxruntime/core/codegen/common/target_info.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { - -// CodeGenTarget holds meta info for backend code generation -// and will be lowered to a target of corresponding backend -// code generation, e.g. TVM's Target. -class CodeGenTarget { - public: - CodeGenTarget() {} - CodeGenTarget(const std::string& target_name) - : target_name_(target_name) {} - - virtual int NaturalVectorWidth(int /*bits*/) const { - return 1; - } - - const std::string& GetTargetName() const { - return target_name_; - } - - virtual ~CodeGenTarget() = default; - - private: - std::string target_name_{"unknown"}; // default name is unknown -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/utils.cc b/onnxruntime/core/codegen/common/utils.cc deleted file mode 100644 index f4140a411bddf..0000000000000 --- a/onnxruntime/core/codegen/common/utils.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/utils.h" -#include "core/common/cpuid_info.h" -#include "core/common/safeint.h" - -#include -#include - -namespace onnxruntime { - -std::unique_ptr GetEnv(const char* var) { - char* val = nullptr; -#if _MSC_VER - size_t len; - - if (_dupenv_s(&val, &len, var)) { - // Something went wrong, just return nullptr. - return nullptr; - } -#else - val = getenv(var); -#endif // _MSC_VER - - if (val == nullptr) { - return nullptr; - } - - // On windows, we will have to explicitly free val. Instead of returning val - // to its caller and make distinguish between windows and linux, we return - // a unique_ptr, and it will be destroyed automatically after the caller - // completes. - size_t len_val = strnlen(val, onnxruntime::kMaxStrLen) + 1; - auto p = std::make_unique(len_val); - // use explicit loop to get ride of VC's warning on unsafe copy - for (size_t i = 0; i < len_val; ++i) { - p[i] = val[i]; - } - return p; -} - -bool IsEnvVarDefined(const char* var) { - auto val = GetEnv(var); - return val != nullptr; -} - -int64_t TotalSize(const std::vector& shape) { - SafeInt total = 1; - for (auto s : shape) { - total *= s; - } - return total; -} - -// Return the strides for the input shape, i.e. the number of -// elements contained by a single element of current dimension. -// For example, for shape[3][4][5][6], strides will be -// [4*5*6, 5*6, 6, 1], i.e. [120, 30, 6, 1] -void GetStrides(const int64_t* shape, int ndim, std::vector& strides) { - strides.resize(ndim); - strides[ndim - 1] = 1; - for (int64_t i = ndim - 2; i >= 0; i--) { - strides[i] = strides[i + 1] * shape[i + 1]; - } -} - -// Common utils to get target option -TargetFeature GetTargetInfo(const codegen::CodeGenSettings& settings) { - TargetFeature feature; - - std::string target_str = ""; - - bool isAVX = false; - bool isAVX2 = false; - bool isAVX512 = false; - if (target_str == "avx") { - isAVX = true; - } else if (target_str == "avx2") { - isAVX = true; - isAVX2 = true; - } else if (target_str == "avx512") { - isAVX = true; - isAVX2 = true; - isAVX512 = true; - } else { - isAVX = CPUIDInfo::GetCPUIDInfo().HasAVX(); - isAVX2 = CPUIDInfo::GetCPUIDInfo().HasAVX2(); - isAVX512 = CPUIDInfo::GetCPUIDInfo().HasAVX512Skylake(); - } - - feature.hasAVX = isAVX; - feature.hasAVX2 = isAVX2; - feature.hasAVX512 = isAVX512; - - return feature; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/common/utils.h b/onnxruntime/core/codegen/common/utils.h deleted file mode 100644 index ef06b5b72dc2c..0000000000000 --- a/onnxruntime/core/codegen/common/utils.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include -#include -#include - -namespace onnxruntime { - -// Holding utility functions that are not tied to TVM and ORT - -std::unique_ptr GetEnv(const char* var); - -// Check if an environment variable is set -bool IsEnvVarDefined(const char* var); - -int64_t TotalSize(const std::vector& shape); - -void GetStrides(const int64_t* shape, int ndim, std::vector& strides); - -struct TargetFeature { - bool hasAVX; - bool hasAVX2; - bool hasAVX512; -}; - -TargetFeature GetTargetInfo(const codegen::CodeGenSettings& setttings); - -// GCD (Greatest Common Divisor) -template -T GCD(T a, T b) { - ORT_ENFORCE(a >= 0); - ORT_ENFORCE(b >= 0); - if (a < b) std::swap(a, b); - if (b == 0) return a; - while (a % b != 0) { - a = a % b; - std::swap(a, b); - } - return b; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/common.h b/onnxruntime/core/codegen/mti/common.h deleted file mode 100644 index d71e740b9284a..0000000000000 --- a/onnxruntime/core/codegen/mti/common.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#define MTI_ASSERT(condition) \ - if (!(condition)) { \ - std::string error_msg = "Not satisfied: " #condition \ - ": line " + \ - std::to_string(__LINE__) + \ - " in file " + std::string(__FILE__) + "\n"; \ - throw std::runtime_error(error_msg); \ - } diff --git a/onnxruntime/core/codegen/mti/debug/tvm_print.cc b/onnxruntime/core/codegen/mti/debug/tvm_print.cc deleted file mode 100644 index 0491636032b47..0000000000000 --- a/onnxruntime/core/codegen/mti/debug/tvm_print.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/debug/tvm_print.h" - -#include "core/codegen/common/utils.h" -#include "core/codegen/common/dump_array.h" -#include "core/codegen/mti/common.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -TVM_REGISTER_GLOBAL("tvm.contrib.onnxruntime.print") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* /*ret*/) { - DLTensor* X = args[0]; - DLTensor* Y = args[1]; - - DLDataType dtype = X->dtype; - std::vector shape; - int64_t total_size = 1; - for (int i = 0; i < X->ndim; ++i) { - shape.push_back(X->shape[i]); - total_size *= X->shape[i]; - } - - // pass X to Y - memcpy(static_cast(Y->data) + Y->byte_offset, - static_cast(X->data) + X->byte_offset, - total_size * dtype.bits / 8); - - if (tvm::runtime::TypeMatch(dtype, kDLFloat, 32)) { - float* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("float tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLInt, 8)) { - int8_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("int8 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLInt, 16)) { - int16_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("int16 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLInt, 32)) { - int32_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("int32 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLUInt, 8)) { - uint8_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("uint8 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLUInt, 16)) { - uint16_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("uint16 tensor:", data, shape); - } else if (tvm::runtime::TypeMatch(dtype, kDLUInt, 32)) { - uint32_t* data = reinterpret_cast(static_cast(X->data) + X->byte_offset); - DumpArray("uint32 tensor:", data, shape); - } else { - MTI_ASSERT(0 && "not implemented!"); - } - }); - -tvm::Array -PrintTVMTensorExtern(const tvm::Tensor& X, - const std::string& name) { - return topi::detail::make_extern( - {X->shape}, - {X->dtype}, - {X}, - [&](tvm::Array ins, tvm::Array outs) { - return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.print"), - topi::detail::pack_buffer(ins[0]), - topi::detail::pack_buffer(outs[0])}); - }, - name + "_print", "", {}); -} - -tvm::Tensor PrintImmutable(const tvm::Tensor& X) { - auto outputs = PrintTVMTensorExtern(X, X->op->name + "_print"); - return outputs[0]; -} - -void Print(tvm::Tensor& X) { - X = PrintImmutable(X); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/debug/tvm_print.h b/onnxruntime/core/codegen/mti/debug/tvm_print.h deleted file mode 100644 index 91a334785a2a4..0000000000000 --- a/onnxruntime/core/codegen/mti/debug/tvm_print.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array PrintTVMTensorExtern( - const tvm::Tensor& X, - const std::string& name = "PrintTVM2DTensorExtern"); - -tvm::Tensor PrintImmutable(const tvm::Tensor& X); - -void Print(tvm::Tensor& X); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/binary_ops.cc b/onnxruntime/core/codegen/mti/math/binary_ops.cc deleted file mode 100644 index f3048799458f4..0000000000000 --- a/onnxruntime/core/codegen/mti/math/binary_ops.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/binary_ops.h" - -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/cast_ops.h" -#include - -// Using namespace topi for override operator +-*/ -using namespace topi; - -namespace onnxruntime { -namespace tvm_codegen { - -#define TVM_BINARY_OP1(op, expr) \ - tvm::Tensor op(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { \ - return Rename(expr, name); \ - } \ - tvm::Tensor op(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { \ - return Rename(expr, name); \ - } - -#define TVM_BINARY_OP(op, expr) \ - TVM_BINARY_OP1(op, expr) \ - tvm::Tensor op(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { \ - return Rename(expr, name); \ - } - -TVM_BINARY_OP(Add, lhs + rhs); -TVM_BINARY_OP(Div, lhs / rhs); -TVM_BINARY_OP(Max, maximum(lhs, rhs)); -TVM_BINARY_OP(Min, minimum(lhs, rhs)); -TVM_BINARY_OP(Mul, lhs* rhs); -TVM_BINARY_OP1(PRelu, Relu(lhs) - rhs * Relu(0 - lhs)); -TVM_BINARY_OP(Sub, lhs - rhs); - -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::equal(lhs, rhs, name); -} -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { - return topi::equal(lhs, rhs, name); -} -tvm::Tensor Equal(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::equal(lhs, rhs, name); -} - -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::greater(lhs, rhs, name); -} -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { - return topi::greater(lhs, rhs, name); -} -tvm::Tensor Greater(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::greater(lhs, rhs, name); -} - -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::less(lhs, rhs, name); -} -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name) { - return topi::less(lhs, rhs, name); -} -tvm::Tensor Less(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name) { - return topi::less(lhs, rhs, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/binary_ops.h b/onnxruntime/core/codegen/mti/math/binary_ops.h deleted file mode 100644 index dd51ce5e7917d..0000000000000 --- a/onnxruntime/core/codegen/mti/math/binary_ops.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Add(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "add"); -tvm::Tensor Add(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "add"); -tvm::Tensor Add(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "add"); -tvm::Tensor Div(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "div"); -tvm::Tensor Div(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "div"); -tvm::Tensor Div(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "div"); -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "equal"); -tvm::Tensor Equal(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "equal"); -tvm::Tensor Equal(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "equal"); -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "greater"); -tvm::Tensor Greater(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "greater"); -tvm::Tensor Greater(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "greater"); -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "less"); -tvm::Tensor Less(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "less"); -tvm::Tensor Less(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "less"); -tvm::Tensor Max(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "max"); -tvm::Tensor Max(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "max"); -tvm::Tensor Max(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "max"); -tvm::Tensor Min(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "min"); -tvm::Tensor Min(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "min"); -tvm::Tensor Min(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "min"); -tvm::Tensor Mul(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "mul"); -tvm::Tensor Mul(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "mul"); -tvm::Tensor Mul(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "mul"); -tvm::Tensor PRelu(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "prelu"); -tvm::Tensor PRelu(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "prelu"); -tvm::Tensor Sub(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name = "sub"); -tvm::Tensor Sub(const tvm::Tensor& lhs, const tvm::Expr& rhs, const std::string& name = "sub"); -tvm::Tensor Sub(const tvm::Expr& lhs, const tvm::Tensor& rhs, const std::string& name = "sub"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/gemm.cc b/onnxruntime/core/codegen/mti/math/gemm.cc deleted file mode 100644 index 7a79513ccaa97..0000000000000 --- a/onnxruntime/core/codegen/mti/math/gemm.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/gemm.h" - -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -// Using namespace topi for override operator +-*/ -using namespace topi; - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gemm(const tvm::Tensor& A, const tvm::Tensor& B, const tvm::Tensor& C, - bool trans_A, bool trans_B, float alpha, float beta, - const std::string& name) { - auto A_dot_B = MatMul2D(A, B, trans_A, trans_B, name + "_matmul2d"); - tvm::Expr alphaExpr = tvm::make_const(A->dtype, alpha); - if (beta != 0) { - tvm::Expr betaExpr = tvm::make_const(A->dtype, beta); - return Rename(alphaExpr * A_dot_B + (betaExpr * C), name); - } else { - return Rename(alphaExpr * A_dot_B, name); - } -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/gemm.h b/onnxruntime/core/codegen/mti/math/gemm.h deleted file mode 100644 index 3bb205c13fdc9..0000000000000 --- a/onnxruntime/core/codegen/mti/math/gemm.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gemm(const tvm::Tensor& p_A, const tvm::Tensor& p_B, const tvm::Tensor& p_C, - bool trans_A, bool trans_B, float alpha, float beta, - const std::string& name = "gemm"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/logsoftmax.cc b/onnxruntime/core/codegen/mti/math/logsoftmax.cc deleted file mode 100644 index cd8c2edae6959..0000000000000 --- a/onnxruntime/core/codegen/mti/math/logsoftmax.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/logsoftmax.h" - -#include "core/codegen/mti/tensor/reshape_ops.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor LogSoftmax(const tvm::Tensor& input, int64_t axis, const std::string& name) { - tvm::Tensor flatten_t = Flatten(input, axis, "logsoftmax_flatten"); - return Reshape(topi::nn::log_softmax(flatten_t, name), input->shape, "logsoftmax_reshape"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/logsoftmax.h b/onnxruntime/core/codegen/mti/math/logsoftmax.h deleted file mode 100644 index 606a32806434b..0000000000000 --- a/onnxruntime/core/codegen/mti/math/logsoftmax.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor LogSoftmax(const tvm::Tensor& input, int64_t axis, const std::string& name = "logsoftmax"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/matmul_ops.cc b/onnxruntime/core/codegen/mti/math/matmul_ops.cc deleted file mode 100644 index 6ecf2f69a9c25..0000000000000 --- a/onnxruntime/core/codegen/mti/math/matmul_ops.cc +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/matmul_ops.h" - -#include "core/codegen/mti/common.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor MatMul2D(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a, bool trans_b, const std::string& name) { - return topi::matmul(A, B, trans_a, trans_b, name); -} - -/* - * Generic Matrix Multiplication - * - * If both arguments are 2-D, they are multiplied like conventional matrices. - * - * If either argument is N-D and N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly. - * - * If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. - * After matrix multiplication the prepended 1 is removed. - * - * If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. - * After matrix multiplication the appended 1 is removed. - */ -tvm::Tensor MatMul(const tvm::Tensor& A, const tvm::Tensor& B, const std::string& name) { - int64_t a_rank = static_cast(A->shape.size()); - int64_t b_rank = static_cast(B->shape.size()); - const auto& A_shape = A->shape; - const auto& B_shape = B->shape; - if (a_rank == 2 && b_rank == 2) { - // 2-D X 2-D - return MatMul2D(A, B); - } else if (a_rank == 1 && b_rank == 1) { - // 1-D X 1-D - auto k = tvm::reduce_axis(tvm::Range(0, A_shape[0]), "k"); - - return tvm::compute( - {}, - [&](const tvm::Array& /*indices*/) { - return tvm::sum(A[k] * B[k], {k}); - }, - name); - } else if (a_rank == 1) { - // 1-D X n-D - auto k = tvm::reduce_axis(tvm::Range(0, A_shape[0]), "k"); - - auto l = [&](const tvm::Array& indices) { - auto ndims = indices.size(); - MTI_ASSERT(ndims >= 1); - tvm::Array b_indices; - for (size_t bi = 0; bi < ndims - 1; ++bi) { - b_indices.push_back(indices[bi]); - } - b_indices.push_back(k); - b_indices.push_back(indices[ndims - 1]); - return tvm::sum(A({k}) * B(b_indices), {k}); - }; - return tvm::compute(ConcatShapes(SliceShapeToDimension(B_shape, -2), SliceShapeFromDimension(B_shape, -1)), l, name); - } else if (b_rank == 1) { - // n-D X 1-D - auto k = tvm::reduce_axis(tvm::Range(0, B_shape[0]), "k"); - - auto l = [&](const tvm::Array& indices) { - tvm::Array a_indices(indices.begin(), indices.end()); - a_indices.push_back(k); - return tvm::sum(A(a_indices) * B({k}), {k}); - }; - return tvm::compute(SliceShapeToDimension(A->shape, -1), l, name); - } else { - // n-D X m-D - MTI_ASSERT(a_rank >= 2 && b_rank >= 2); - auto k = tvm::reduce_axis(tvm::Range(0, A_shape[a_rank - 1]), "k"); - - auto l = [&](const tvm::Array& indices) { - auto ndims = static_cast(indices.size()); - MTI_ASSERT(ndims > 2); - tvm::Array a_indices, b_indices; - - // handle broadcasting - int i = 0, a_idx = 0, b_idx = 0; - bool a_greater = a_rank > b_rank; - for (; i < std::abs(a_rank - b_rank); ++i) { - if (a_greater) { - a_indices.push_back(indices[i]); - a_idx++; - } else { - b_indices.push_back(indices[i]); - b_idx++; - } - } - for (; i < ndims - 2; ++i, ++a_idx, ++b_idx) { - auto tp = indices[i].type(); - if (IsOne(A_shape, a_idx)) { - a_indices.push_back(tvm::make_zero(tp)); - b_indices.push_back(indices[i]); - } else if (IsOne(B_shape, b_idx)) { - b_indices.push_back(tvm::make_zero(tp)); - a_indices.push_back(indices[i]); - } else { - a_indices.push_back(indices[i]); - b_indices.push_back(indices[i]); - } - } - - MTI_ASSERT(a_idx == a_rank - 2 && b_idx == b_rank - 2); - a_indices.push_back(indices[ndims - 2]); - a_indices.push_back(k); - - b_indices.push_back(k); - b_indices.push_back(indices[ndims - 1]); - - return tvm::sum(A(a_indices) * B(b_indices), {k}); - }; - - return tvm::compute(ComputeMatMulShape(A_shape, B_shape), l, name); - } -} - -tvm::Array -ComputeMatMulShape( - const tvm::Array& A_shape, - const tvm::Array& B_shape, - bool trans_a, - bool trans_b) { - auto a_rank = A_shape.size(); - auto b_rank = B_shape.size(); - tvm::Array output_shape; - int64_t output_rank = std::max(a_rank, b_rank); - MTI_ASSERT(a_rank > 0 && b_rank > 0); - if (a_rank == 1 && b_rank == 1) { - MTI_ASSERT(!trans_a && !trans_b); - // reduction, output shape is empty - } else if (a_rank == 1) { - MTI_ASSERT(!trans_a && !trans_b); - output_shape = SliceShapeToDimension(B_shape, b_rank - 2); - output_shape.push_back(B_shape[b_rank - 1]); - } else if (b_rank == 1) { - MTI_ASSERT(!trans_a && !trans_b); - output_shape = SliceShapeToDimension(A_shape, a_rank - 1); - } else { - for (int64_t i = 0; i < output_rank - 2; i++) { - tvm::Expr broadcasted_dim = tvm::make_const(HalideIR::Int(32), 1); - bool broadcasted = - BroadcastDim(A_shape, i, output_rank, broadcasted_dim) && - BroadcastDim(B_shape, i, output_rank, broadcasted_dim); - MTI_ASSERT(broadcasted); - output_shape.push_back(broadcasted_dim); - } - output_shape.push_back(A_shape[a_rank - (trans_a ? 1 : 2)]); - output_shape.push_back(B_shape[b_rank - (trans_b ? 2 : 1)]); - } - return output_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/matmul_ops.h b/onnxruntime/core/codegen/mti/math/matmul_ops.h deleted file mode 100644 index ab9986132d34a..0000000000000 --- a/onnxruntime/core/codegen/mti/math/matmul_ops.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array -ComputeMatMulShape( - const tvm::Array& A_shape, - const tvm::Array& B_shape, - bool trans_a = false, - bool trans_b = false); - -tvm::Tensor MatMul2D(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a = false, bool trans_b = false, const std::string& name = "matmul2d"); - -tvm::Tensor MatMul(const tvm::Tensor& A, const tvm::Tensor& B, const std::string& name = "matmul"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/reduce_ops.cc b/onnxruntime/core/codegen/mti/math/reduce_ops.cc deleted file mode 100644 index 7d179e2b04316..0000000000000 --- a/onnxruntime/core/codegen/mti/math/reduce_ops.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/reduce_ops.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor ArgMax(const tvm::Tensor& X, int64_t axis, bool keep_dims, const std::string& name) { - return Rename(topi::argmax(X, ToTvmArrayInt({axis}), keep_dims), name); -} - -tvm::Tensor ArgMin(const tvm::Tensor& X, int64_t axis, bool keep_dims, const std::string& name) { - return Rename(topi::argmin(X, ToTvmArrayInt({axis}), keep_dims), name); -} - -tvm::Tensor ReduceL1(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return ReduceSum(Abs(X), axes, keep_dims, name); -} - -tvm::Tensor ReduceL2(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Sqrt(ReduceSumSquare(X, axes, keep_dims), name); -} - -tvm::Tensor ReduceLogSum(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Log(ReduceSum(X, axes, keep_dims), name); -} - -tvm::Tensor ReduceLogSumExp(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - tvm::Tensor reduce_max = ReduceMax(X, axes, true); - tvm::Tensor exp_delta = Exp(Sub(X, reduce_max)); - tvm::Tensor reduce_max_keep_dims = ReduceMax(X, axes, keep_dims); - return Add(ReduceLogSum(exp_delta, axes, keep_dims), reduce_max_keep_dims, name); -} - -tvm::Tensor ReduceMax(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::max(X, ToTvmArrayInt(axes), keep_dims), name); -} - -tvm::Tensor ReduceMean(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - tvm::Tensor reduce_sum = ReduceSum(X, axes, keep_dims); - tvm::Expr count = tvm::make_const(reduce_sum->dtype, 1.0f); - if (axes.empty()) { - for (const auto& dim : X->shape) - count = count * dim; - } else { - for (int64_t axis : axes) { - int64_t i = HandleNegativeAxis(axis, X->shape.size()); - count = count * X->shape[i]; - } - } - return tvm::compute( - reduce_sum->shape, - [&](const tvm::Array& i) { - return reduce_sum(i) / count; - }, - name); -} - -tvm::Tensor ReduceMin(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::min(X, ToTvmArrayInt(axes), keep_dims), name); -} - -tvm::Tensor ReduceProd(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - auto prod = [](tvm::Expr source, tvm::Array rdom) { - tvm::Var x("x", source.type()), y("y", source.type()); - tvm::Expr Rename_element = tvm::make_const(source.type(), 1.0f); - tvm::ir::CommReducer combiner = - tvm::ir::CommReducerNode::make({x}, {y}, {x * y}, {Rename_element}); - return tvm::ir::Reduce::make(combiner, {source}, rdom, tvm::make_const(tvm::Bool(1), true), 0); - }; - - return Rename(topi::CommReduce(X, ToTvmArrayInt(axes), prod, keep_dims, true), name); -} - -tvm::Tensor ReduceSum(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::sum(X, ToTvmArrayInt(axes), keep_dims), name); -} - -tvm::Tensor ReduceSumSquare(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name) { - return Rename(topi::sum(Mul(X, X), ToTvmArrayInt(axes), keep_dims), name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/reduce_ops.h b/onnxruntime/core/codegen/mti/math/reduce_ops.h deleted file mode 100644 index f782df5e6515f..0000000000000 --- a/onnxruntime/core/codegen/mti/math/reduce_ops.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor ArgMax(const tvm::Tensor& X, - int64_t axis, - bool keep_dims, - const std::string& name = "argmax"); - -tvm::Tensor ArgMin(const tvm::Tensor& X, - int64_t axis, - bool keep_dims, - const std::string& name = "argmin"); - -tvm::Tensor ReduceL1(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_l1"); - -tvm::Tensor ReduceL2(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_l2"); - -tvm::Tensor ReduceLogSum(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_log_sum"); - -tvm::Tensor ReduceLogSumExp(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "argmareduce_log_sum_exp"); - -tvm::Tensor ReduceMax(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_max"); - -tvm::Tensor ReduceMean(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_mean"); - -tvm::Tensor ReduceMin(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_min"); - -tvm::Tensor ReduceProd(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_prod"); - -tvm::Tensor ReduceSum(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_sum"); - -tvm::Tensor ReduceSumSquare(const tvm::Tensor& X, - const std::vector& axes, - bool keep_dims, - const std::string& name = "reduce_sum_square"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/softmax.cc b/onnxruntime/core/codegen/mti/math/softmax.cc deleted file mode 100644 index d7404137bb873..0000000000000 --- a/onnxruntime/core/codegen/mti/math/softmax.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/softmax.h" - -#include "core/codegen/mti/tensor/reshape_ops.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Softmax(const tvm::Tensor& input, int64_t axis, const std::string& name) { - tvm::Tensor flatten_t = Flatten(input, axis, "softmax_flatten"); - return Reshape(topi::nn::softmax(flatten_t, 1, name), input->shape, "softmax_reshape"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/softmax.h b/onnxruntime/core/codegen/mti/math/softmax.h deleted file mode 100644 index fb16fbaeb56a2..0000000000000 --- a/onnxruntime/core/codegen/mti/math/softmax.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Softmax(const tvm::Tensor& input, int64_t axis, const std::string& name = "softmax"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/unary_ops.cc b/onnxruntime/core/codegen/mti/math/unary_ops.cc deleted file mode 100644 index ae732ea33e670..0000000000000 --- a/onnxruntime/core/codegen/mti/math/unary_ops.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/math/unary_ops.h" - -#include "core/codegen/common/settings.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include -#include -#include - -// Using namespace topi for override operator +-*/ -using namespace topi; - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Abs(const tvm::Tensor& X, const std::string& name) { - return abs(X, name); -} - -tvm::Tensor Affine(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return Rename(alphaExpr * X + betaExpr, name); -} - -tvm::Tensor Ceil(const tvm::Tensor& X, const std::string& name) { - return topi::ceil(X, name); -} - -tvm::Tensor Clip(const tvm::Tensor& X, tvm::Expr min_value, tvm::Expr max_value, const std::string& name) { - auto Y = tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::min(tvm::max(X(indices), min_value), max_value); - }, - name); - return Y; -} - -tvm::Tensor Elu(const tvm::Tensor& X, float alpha, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - return Rename(Relu(X) - alphaExpr * Relu(1 - Exp(X)), name); -} - -tvm::Tensor Exp(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::exp(X(indices)); - }, - name); -} - -tvm::Tensor Floor(const tvm::Tensor& X, const std::string& name) { - return topi::floor(X, name); -} - -tvm::Tensor HardSigmoid(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return maximum(0, minimum(1, alphaExpr * X + betaExpr), name); -} - -tvm::Tensor LeakyRelu(const tvm::Tensor& X, float alpha, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - return Rename(Relu(X) - alphaExpr * Relu(0 - X), name); -} - -tvm::Tensor Log(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::log(X(indices)); - }, - name); -} - -tvm::Tensor Neg(const tvm::Tensor& X, const std::string& name) { - return negative(X, name); -} - -tvm::Tensor ParametricSoftplus(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return Rename(alphaExpr * Softplus(betaExpr * X), name); -} - -tvm::Tensor Reciprocal(const tvm::Tensor& X, const std::string& name) { - return Rename(1 / X, name); -} - -tvm::Tensor Relu(const tvm::Tensor& X, const std::string& name) { - return maximum(X, 0, name); -} - -tvm::Tensor ScaledTanh(const tvm::Tensor& X, float alpha, float beta, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr betaExpr = tvm::make_const(X->dtype, beta); - return Rename(alphaExpr * Tanh(betaExpr * X), name); -} - -tvm::Tensor Selu(const tvm::Tensor& X, float alpha, float gamma, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - tvm::Expr gammaExpr = tvm::make_const(X->dtype, gamma); - return Rename(gammaExpr * (-alphaExpr * Relu(1 - Exp(X)) + Relu(X)), name); -} - -tvm::Tensor Sigmoid(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::ir::Select::make(X(indices) > 0, - 1 / (1 + tvm::exp(-X(indices))), - tvm::exp(X(indices)) / (tvm::exp(X(indices)) + 1)); - }, - name); -} - -tvm::Tensor SignNoZero(const tvm::Tensor& X, const std::string& name) { - return Rename(greater_equal(X, 0) * 2 - 1, name); -} - -tvm::Tensor Softplus(const tvm::Tensor& X, const std::string& name) { - return Rename(Log(1 + Exp(Neg(Abs(X)))) + Relu(X), name); -} - -tvm::Tensor Softsign(const tvm::Tensor& X, const std::string& name) { - return Rename(X / (1 + Abs(X)), name); -} - -tvm::Tensor Sqrt(const tvm::Tensor& X, const std::string& name) { - return sqrt(X, name); -} - -tvm::Tensor Tanh(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - return tvm::ir::Select::make(X(indices) < 0, - (tvm::exp(2 * X(indices)) - 1) / (tvm::exp(2 * X(indices)) + 1), - (1 - tvm::exp(-2 * X(indices))) / (1 + tvm::exp(-2 * X(indices)))); - }, - name); -} - -tvm::Tensor ThresholdedRelu(const tvm::Tensor& X, float alpha, const std::string& name) { - tvm::Expr alphaExpr = tvm::make_const(X->dtype, alpha); - return topi::where(greater(X, alphaExpr), X, topi::full_like(X, tvm::make_zero(X->dtype)), name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/math/unary_ops.h b/onnxruntime/core/codegen/mti/math/unary_ops.h deleted file mode 100644 index aeb336262e547..0000000000000 --- a/onnxruntime/core/codegen/mti/math/unary_ops.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Abs(const tvm::Tensor& X, const std::string& name = "abs"); -tvm::Tensor Affine(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "affine"); -tvm::Tensor Ceil(const tvm::Tensor& X, const std::string& name = "ceil"); -tvm::Tensor Clip(const tvm::Tensor& X, tvm::Expr min_value, tvm::Expr max_value, const std::string& name = "clip"); -tvm::Tensor Elu(const tvm::Tensor& X, float alpha, const std::string& name = "elu"); -tvm::Tensor Exp(const tvm::Tensor& X, const std::string& name = "exp"); -tvm::Tensor Floor(const tvm::Tensor& X, const std::string& name = "floor"); -tvm::Tensor HardSigmoid(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "hard_sigmoid"); -tvm::Tensor LeakyRelu(const tvm::Tensor& X, float alpha, const std::string& name = "leaky_relu"); -tvm::Tensor Log(const tvm::Tensor& X, const std::string& name = "log"); -tvm::Tensor Neg(const tvm::Tensor& X, const std::string& name = "neg"); -tvm::Tensor ParametricSoftplus(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "parametric_softplus"); -tvm::Tensor Reciprocal(const tvm::Tensor& X, const std::string& name = "reciprocal"); -tvm::Tensor Relu(const tvm::Tensor& X, const std::string& name = "relu"); -tvm::Tensor ScaledTanh(const tvm::Tensor& X, float alpha, float beta, const std::string& name = "scaled_tanh"); -tvm::Tensor Selu(const tvm::Tensor& X, float alpha, float gamma, const std::string& name = "selu"); -tvm::Tensor Sigmoid(const tvm::Tensor& X, const std::string& name = "sigmoid"); -tvm::Tensor SignNoZero(const tvm::Tensor& X, const std::string& name = "sign_no_zero"); -tvm::Tensor Softplus(const tvm::Tensor& X, const std::string& name = "softplus"); -tvm::Tensor Softsign(const tvm::Tensor& X, const std::string& name = "softsign"); -tvm::Tensor Sqrt(const tvm::Tensor& X, const std::string& name = "sqrt"); -tvm::Tensor Tanh(const tvm::Tensor& X, const std::string& name = "tanh"); -tvm::Tensor ThresholdedRelu(const tvm::Tensor& X, float alpha, const std::string& name = "thresholded_relu"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc b/onnxruntime/core/codegen/mti/mti_tvm_utils.cc deleted file mode 100644 index 8e73629c05614..0000000000000 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.cc +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/mti_tvm_utils.h" - -#include "core/codegen/common/settings.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array ToTvmArray(gsl::span shape) { - tvm::Array arr; - for (size_t i = 0; i < shape.size(); ++i) { - arr.push_back(tvm::Expr(static_cast(shape[i]))); - } - return arr; -} - -tvm::Array ToTvmArrayInt(gsl::span shape) { - tvm::Array arr; - for (size_t i = 0; i < shape.size(); ++i) { - arr.push_back(shape[i]); - } - return arr; -} - -tvm::Expr SizeToDimension(const tvm::Array& shape, int64_t axis) { - tvm::Expr size(1); - auto rank = shape.size(); - if (static_cast(axis) != rank) { - axis = HandleNegativeAxis(axis, rank); - } - for (size_t d = 0; d < std::min(rank, static_cast(axis)); ++d) - size = tvm::ir::Simplify(size * shape[d]); - return size; -} - -tvm::Expr SizeFromDimension(const tvm::Array& shape, int64_t axis) { - tvm::Expr size(1); - auto rank = shape.size(); - if (static_cast(axis) != rank) { - axis = HandleNegativeAxis(axis, rank); - } - for (size_t d = static_cast(axis); d < rank; ++d) - size = tvm::ir::Simplify(size * shape[d]); - return size; -} - -tvm::Expr RoundUp(tvm::Expr value, tvm::Expr alignment) { - return tvm::ir::Simplify((value + alignment - 1) / alignment * alignment); -} - -tvm::Array ConcatShapes( - const tvm::Array& shape1, - const tvm::Array& shape2) { - tvm::Array result; - for (size_t i = 0; i < shape1.size(); i++) - result.push_back(shape1[i]); - for (size_t i = 0; i < shape2.size(); i++) - result.push_back(shape2[i]); - return result; -} - -tvm::Tensor Rename(tvm::Tensor X, const std::string& name) { - const_cast(X->op->name) = name; - return X; -} - -tvm::Array SliceShape(const tvm::Array& shape, const std::vector& axes) { - tvm::Array new_shape; - for (auto axis : axes) { - CHECK(axis < static_cast(shape.size())); - new_shape.push_back(shape[axis]); - } - return new_shape; -} - -tvm::Array SliceShapeFromDimension(const tvm::Array& shape, int64_t axis) { - int64_t rank = static_cast(shape.size()); - axis = HandleNegativeAxis(axis, rank); - std::vector axes; - for (auto i = axis; i < rank; ++i) - axes.push_back(i); - return SliceShape(shape, axes); -} - -tvm::Array SliceShapeToDimension(const tvm::Array& shape, int64_t axis) { - int64_t rank = static_cast(shape.size()); - axis = HandleNegativeAxis(axis, rank); - std::vector axes; - for (auto i = 0; i < axis; ++i) - axes.push_back(i); - return SliceShape(shape, axes); -} - -bool IsOne(const tvm::Array& shape, int64_t axis) { - int64_t rank = static_cast(shape.size()); - axis = HandleNegativeAxis(axis, rank); - const auto& dim = shape[axis]; - auto* p = tvm::as_const_int(dim); - return p != nullptr && *p == 1; -} - -tvm::Tensor Promote(const tvm::Expr& expr, const tvm::Array& shape, const std::string& name) { - return tvm::compute( - shape, - [&](const tvm::Array&) { - return expr; - }, - name); -} - -void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module) { - const codegen::CodeGenSettings& settings = codegen::CodeGenSettings::Instance(); - if (!settings.HasOption(codegen::CodeGenSettings::kCodeGenDumpModule)) - return; - - // ISSUE: note that all option values are converted to lower case. It doesn't cause - // any issue currently, because all supported formats (i.e. file exts) are of lower case. - // Just keep in mind that we might have issue if somehow we started to support dump - // formats with upper case, although it's quite unlikely. - std::string format = settings.GetOptionValue(codegen::CodeGenSettings::kCodeGenDumpModule); - std::string module_filename = filename + "." + format; - module->SaveToFile(module_filename, format); -} - -tvm::Tensor MakeZeroTensor(const tvm::Array& shape, - HalideIR::Type type, - const std::string& name) { - auto l = [&](const tvm::Array& /*indices*/) { - return tvm::make_zero(type); - }; - return tvm::compute(shape, l, name); -} - -bool BroadcastDim(const tvm::Array& shape, size_t i, size_t output_rank, tvm::Expr& dim) { - if (i >= output_rank - shape.size()) { - auto new_dim = shape[shape.size() - output_rank + i]; - if (tvm::ir::Equal(new_dim, dim)) - return true; - - const int64_t* p_new = tvm::as_const_int(new_dim); - if (p_new != nullptr && *p_new == 1) { - return true; - } else { - const int64_t* p_old = tvm::as_const_int(dim); - if (p_old != nullptr && *p_old == 1) { - dim = new_dim; - return true; - } - } - return false; - } - // auto broadcast to outer dims - return true; -} - -tvm::Array MakeInputsForExtern(const tvm::Array& inputs, const std::string& name) { - // note that currently TVM StorageFlatten creates strides like max(symbolic_dim, 1) - // which is not zero when checking symbolic_dim - max(symbolic_dim, 1) - // then triggers error like: Trying to bind compact buffer to strided one - // here's a workaround to reshape inputs to avoid that - tvm::Array fixed_inputs; - for (size_t idx_input = 0; idx_input < inputs.size(); ++idx_input) { - const auto& input = inputs[idx_input]; - tvm::Array fixed_shape; - if (input->shape.size() > 0) { - // stride compute does not use dim 0, so directly push to fixed_shape - fixed_shape.push_back(input->shape[0]); - bool need_fix = false; - for (size_t idx_dim = 1; idx_dim < input->shape.size(); ++idx_dim) { - const auto& dim = input->shape[idx_dim]; - if (tvm::as_const_int(dim) == nullptr) { - fixed_shape.push_back(tvm::max(dim, tvm::make_const(HalideIR::Int(32), 1))); - need_fix = true; - } else { - fixed_shape.push_back(dim); - } - } - if (need_fix) { - fixed_inputs.push_back(tvm_codegen::Reshape(input, fixed_shape, name + "_" + std::to_string(idx_input))); - continue; - } - } - // no fix needed - fixed_inputs.push_back(input); - } - return fixed_inputs; -} - -// Make sure idx is clamped in the range of [-bound, bound - 1] -tvm::Expr ClampIndex(const tvm::Expr& idx, const tvm::Expr& bound) { - // when idx >= 0, we take tvm::max(..., 0), because (idx < 0) is 0 - // when idx < 0, we take bound + tvm::max(...), because tvm::max(idx, 0) is 0 - return tvm::max(tvm::min(idx, bound - 1), 0) + - (idx < 0) * (bound + tvm::max(idx, -bound)); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/mti_tvm_utils.h b/onnxruntime/core/codegen/mti/mti_tvm_utils.h deleted file mode 100644 index c2a14106c1686..0000000000000 --- a/onnxruntime/core/codegen/mti/mti_tvm_utils.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include "core/codegen/mti/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Array ToTvmArray(gsl::span shape); - -tvm::Array ToTvmArrayInt(gsl::span shape); - -// Helper function to compute sub shape size to axis (not included) -tvm::Expr SizeToDimension(const tvm::Array& shape, int64_t axis); - -// Helper function to compute sub shape size from axis (included) -tvm::Expr SizeFromDimension(const tvm::Array& shape, int64_t axis); - -// Helper function to align -tvm::Expr RoundUp(tvm::Expr value, tvm::Expr alignment); - -tvm::Array ConcatShapes( - const tvm::Array& shape1, - const tvm::Array& shape2); - -// Helper function to rename tvm::Tensor -tvm::Tensor Rename(tvm::Tensor X, const std::string& name); - -// Helper function to slice TVM shape -tvm::Array SliceShape(const tvm::Array& shape, const std::vector& axes); - -// Helper function to slice TVM shape from axis (inclusive). -// Basically, this function returns the shape of [axis, shape.size()-1] -tvm::Array SliceShapeFromDimension(const tvm::Array& shape, int64_t axis); - -// this function returns the shape of [0, axis-1] -tvm::Array SliceShapeToDimension(const tvm::Array& shape, int64_t axis); - -// check if dimension is 1 -bool IsOne(const tvm::Array& shape, int64_t axis); - -// Helper function to convert tvm::Expr to tvm::Tensor -tvm::Tensor Promote(const tvm::Expr& expr, - const tvm::Array& shape, - const std::string& name = "PromoteExpr"); - -tvm::Tensor MakeZeroTensor(const tvm::Array& shape, HalideIR::Type type, const std::string& name); - -void DumpTVMModuleToFile(const std::string& filename, tvm::runtime::Module& module); - -bool BroadcastDim(const tvm::Array& shape, size_t i, size_t output_rank, tvm::Expr& dim); - -inline int64_t HandleNegativeAxis(int64_t axis, int64_t rank) { - MTI_ASSERT(axis >= -rank && axis <= rank - 1); - return axis = axis < 0 ? (axis + rank) : axis; -} - -// Make sure idx is clamped in the range of [-bound, bound - 1] -tvm::Expr ClampIndex(const tvm::Expr& idx, const tvm::Expr& bound); - -// Helper function to workaround tvm ExternOp issue when input has symbolic dimensions -tvm::Array MakeInputsForExtern(const tvm::Array& inputs, const std::string& name = "make_inputs_for_extern"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/conv_ops.cc b/onnxruntime/core/codegen/mti/nn/conv_ops.cc deleted file mode 100644 index e2d4acc8843ad..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/conv_ops.cc +++ /dev/null @@ -1,193 +0,0 @@ -#include "core/codegen/mti/nn/conv_ops.h" - -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/tensor/pad_ops.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/codegen/mti/tensor/transpose.h" - -namespace onnxruntime { -namespace tvm_codegen { - -static tvm::Tensor PadTensor1D(const tvm::Tensor& input, - const tvm::Array& padding, - size_t width_axis, - const std::string& name) { - auto pad_left = padding[0]; - auto pad_right = padding[1]; - - tvm::Array pad_before(std::vector(input->shape.size(), 0)); - pad_before.Set(width_axis, pad_left); - tvm::Array pad_after(std::vector(input->shape.size(), 0)); - pad_after.Set(width_axis, pad_right); - - const int64_t* padding_w0 = tvm::as_const_int(pad_left); - const int64_t* padding_w1 = tvm::as_const_int(pad_right); - - const bool do_pad = ((padding_w0 != nullptr && *padding_w0) || - (padding_w1 != nullptr && *padding_w1)); - - return do_pad ? Pad(input, pad_before, pad_after, - 0, "constant", name + "_input_padded") - : input; -} - -tvm::Tensor Conv1D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& out_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - size_t channel_axis = 1; - size_t width_axis = 2; - - auto stride_width = stride[width_axis - 2]; - - auto input_padded = PadTensor1D(input, padding, width_axis, name); - auto rc = tvm::reduce_axis((tvm::Range(0, filter->shape[1])), "rc"); - auto rx = tvm::reduce_axis((tvm::Range(0, filter->shape[2])), "rx"); - - return tvm::compute( - out_shape, - [&](const tvm::Array& output) { - tvm::Array indices; - for (const tvm::Var& var : output) { - indices.push_back(var); - } - indices.Set(channel_axis, rc); - indices.Set(width_axis, output[width_axis] * stride_width + rx); - - return tvm::sum(input_padded(indices) * filter({output[1], rc, rx}), - {rc, rx}); - }, - name); -} - -tvm::Tensor Conv2D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - return Conv2D_native(input, filter, output_shape, stride, padding); -} - -static tvm::Tensor PadTensor2D(const tvm::Tensor& input, - const tvm::Array& padding, - size_t height_axis, - size_t width_axis, - const std::string& name) { - auto pad_top = padding[0]; - auto pad_left = padding[1]; - auto pad_bottom = padding[2]; - auto pad_right = padding[3]; - - tvm::Array pad_before(std::vector(input->shape.size(), 0)); - pad_before.Set(height_axis, pad_top); - pad_before.Set(width_axis, pad_left); - - tvm::Array pad_after(std::vector(input->shape.size(), 0)); - pad_after.Set(height_axis, pad_bottom); - pad_after.Set(width_axis, pad_right); - - const int64_t* padding_h0 = tvm::as_const_int(pad_top); - const int64_t* padding_w0 = tvm::as_const_int(pad_left); - const int64_t* padding_h1 = tvm::as_const_int(pad_bottom); - const int64_t* padding_w1 = tvm::as_const_int(pad_right); - - const bool do_pad = ((padding_h0 != nullptr && *padding_h0) || - (padding_w0 != nullptr && *padding_w0)) || - ((padding_h1 != nullptr && *padding_h1) || - (padding_w1 != nullptr && *padding_w1)); - - return do_pad ? Pad(input, pad_before, pad_after, - 0, "constant", name + "_input_padded") - : input; -} - -tvm::Tensor Conv2D_native(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& out_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - size_t channel_axis = 1; - size_t height_axis = 2; - size_t width_axis = 3; - - auto stride_height = stride[height_axis - 2]; - auto stride_width = stride[width_axis - 2]; - - auto input_padded = PadTensor2D(input, padding, height_axis, width_axis, name); - - auto rc = tvm::reduce_axis((tvm::Range(0, filter->shape[1])), "rc"); - auto ry = tvm::reduce_axis((tvm::Range(0, filter->shape[2])), "ry"); - auto rx = tvm::reduce_axis((tvm::Range(0, filter->shape[3])), "rx"); - - return tvm::compute( - out_shape, - [&](const tvm::Array& output) { - tvm::Array indices; - for (const tvm::Var& var : output) { - indices.push_back(var); - } - indices.Set(channel_axis, rc); - indices.Set(height_axis, output[height_axis] * stride_height + ry); - indices.Set(width_axis, output[width_axis] * stride_width + rx); - - return tvm::sum(input_padded(indices) * filter({output[1], rc, ry, rx}), - {rc, ry, rx}); - }, - name); -} - -tvm::Tensor Conv2D_gemm(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& out_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name) { - size_t height_axis = 2; - size_t width_axis = 3; - - auto stride_height = stride[height_axis - 2]; - auto stride_width = stride[width_axis - 2]; - - auto input_padded = PadTensor2D(input, padding, height_axis, width_axis, name); - - tvm::Array img_col_tmp(std::vector(6, 0)); - img_col_tmp.Set(0, out_shape[0]); - img_col_tmp.Set(1, out_shape[2]); - img_col_tmp.Set(2, out_shape[3]); - img_col_tmp.Set(3, filter->shape[1]); - img_col_tmp.Set(4, filter->shape[2]); - img_col_tmp.Set(5, filter->shape[3]); - - auto img_col = tvm::compute( - img_col_tmp, - [&](const tvm::Array& output) { - tvm::Array indices; - indices.push_back(output[0]); - indices.push_back(output[3]); - indices.push_back(output[1] * stride_height + output[4]); - indices.push_back(output[2] * stride_width + output[5]); - return input_padded(indices); - }, - name); - - tvm::Array input_col_shape(std::vector(2, 0)); - input_col_shape.Set(0, img_col_tmp[1] * img_col_tmp[2]); - input_col_shape.Set(1, img_col_tmp[3] * img_col_tmp[4] * img_col_tmp[5]); - auto input_col = Reshape(img_col, input_col_shape); - - tvm::Array filter_row_shape(std::vector(2, 0)); - filter_row_shape.Set(0, filter->shape[0]); - filter_row_shape.Set(1, filter->shape[1] * filter->shape[2] * filter->shape[3]); - auto filter_row = Reshape(filter, filter_row_shape, name); - - auto Y = MatMul2D(input_col, filter_row, false, true, name); - auto Y_T = Transpose(Y, /*axes=*/{}, name); - return Reshape(Y_T, out_shape, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/conv_ops.h b/onnxruntime/core/codegen/mti/nn/conv_ops.h deleted file mode 100644 index 1396c216865a7..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/conv_ops.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Conv1D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv1d"); - -tvm::Tensor Conv2D(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv2d"); - -tvm::Tensor Conv2D_native(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv2d_native"); - -tvm::Tensor Conv2D_gemm(const tvm::Tensor& input, - const tvm::Tensor& filter, - const tvm::Array& output_shape, - const tvm::Array& stride, - const tvm::Array& padding, - const std::string& name = "conv2d_gemm"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/lstm.cc b/onnxruntime/core/codegen/mti/nn/lstm.cc deleted file mode 100644 index 1148b0924e869..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/lstm.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/nn/lstm.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/math/reduce_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/codegen/mti/tensor/split.h" - -namespace onnxruntime { -namespace tvm_codegen { - -/* -`X` - input tensor -`i` - input gate -`o` - output gate -`f` - forget gate -`c` - cell gate -`t` - time step (t-1 means previous time step) - -`W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates -`R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates -`Wb[iofc]` - W bias vectors for input, output, forget, and cell gates -`Rb[iofc]` - R bias vectors for input, output, forget, and cell gates -`P[iof]` - P peephole weight vector for input, output, and forget gates -`WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates -`RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates -`WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates -`RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates -`PB[iof]` - P peephole weight vector for backward input, output, and forget gates - -`H` - Hidden state -`num_directions` - 2 if direction == bidirectional else 1 - -Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - Ct = ft (.) Ct-1 + it (.) ct - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - Ht = ot (.) h(Ct) -*/ - -void LSTM_cell( - const LSTMAttributes& lstm_attrs, - const tvm::Tensor& X, - const tvm::Tensor& W, - const tvm::Tensor& R, - const tvm::Tensor& B, - bool has_B, - const tvm::Tensor& prev_H, - const tvm::Tensor& prev_C, - const tvm::Tensor& P, - bool has_P, - tvm::Tensor& Y_h, - tvm::Tensor& Y_c) { - // Input projection: Xt*(W[iofc]^T) for forward direction or Xt*(WB[iofc]^T) for reverse direction - // (batch_size, input_size) * trans(4 * hidden_size, input_size) => (batch_size, 4 * hidden_size) - tvm::Tensor input_proj = MatMul2D(X, W, /*trans_a*/ false, /*trans_b*/ true); - - // Hidden projection: Ht-1*(R[iofc]^T) for forward direction or Ht-1*(RB[iofc]^T) for reverse direction - // (batch_size, hidden_size) * trans(4 * hidden_size, hidden_size) => (batch_size, 4 * hidden_size) - tvm::Tensor hidden_proj = MatMul2D(prev_H, R, /*trans_a*/ false, /*trans_b*/ true); - - // (batch_size, 4 * hidden_size) - tvm::Tensor sum_proj = Add(input_proj, hidden_proj); - - // Concatenation of [Wb[iofc], Rb[iofc]] or [WBb[iofc], RBb[iofc]] - if (has_B) { - // (8 * hidden_size) -> (2, 4 * hidden_size) -> (1, 4 * hidden_size), should be done in const folding - tvm::Tensor reduce_B = - ReduceSum(Reshape(B, {2, 4 * static_cast(lstm_attrs.hidden_size)}), {0}, /*keep_dims*/ true); - // (batch_size, 4 * hidden_size) via broadcasting reduce_B - sum_proj = Add(sum_proj, reduce_B); - } - - std::vector iofc_sum_split_sizes(4, lstm_attrs.hidden_size); - // Split sum_proj into iofc, where each gate proj is of (batch_size, hidden_size) - tvm::Array iofc_sum_projs = Split(sum_proj, ToTvmArray(iofc_sum_split_sizes), /*axis*/ 1); - MTI_ASSERT(iofc_sum_projs.size() == 4); - tvm::Tensor i_proj = iofc_sum_projs[0], - o_proj = iofc_sum_projs[1], - f_proj = iofc_sum_projs[2], - c_proj = iofc_sum_projs[3]; - - tvm::Tensor P_i, P_o, P_f; - if (has_P) { - std::vector iof_p_split_sizes(3, lstm_attrs.hidden_size); - // Split P into P_i, P_o, P_f, in const pre-processing (P_i, P_f might be merged?) - // where each P_[iof] has the shape of (hidden_size) - tvm::Array iof_P_projs = Split(P, ToTvmArray(iof_p_split_sizes), /*axis*/ 0); - MTI_ASSERT(iof_P_projs.size() == 3); - P_i = iof_P_projs[0], - P_o = iof_P_projs[1], - P_f = iof_P_projs[2]; - - // (batch_size, hidden_size) via broadcasting P_[if] - i_proj = Add(i_proj, Mul(P_i, prev_C)); - f_proj = Add(f_proj, Mul(P_f, prev_C)); - } - - // TODO: handle more general cases for activations f, h, g and activation_alpha and - // activation_beta. We may consider to move some code such as ActivationInfo from deep_cpu_lstm - // into a common header file, because the code can be used here. - - // Note that by default f = Sigmoid, g = Tanh, h = Tanh - - // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - // shape: (batch_size, hidden_size) - tvm::Tensor i_t = Sigmoid(i_proj); - // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - // shape: (batch_size, hidden_size) - tvm::Tensor f_t = Sigmoid(f_proj); - // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - // shape: (batch_size, hidden_size) - tvm::Tensor c_t = Tanh(c_proj); - - // Ct = ft (.) Ct-1 + it (.) ct - // shape: (batch_size, hidden_size) - Y_c = Add(Mul(f_t, prev_C), Mul(i_t, c_t), Y_c->op->name); - - // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - // shape: (batch_size, hidden_size) - if (has_P) { - o_proj = Add(o_proj, Mul(P_o, Y_c)); - } - // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - // shape: (batch_size, hidden_size) - o_proj = Sigmoid(o_proj); - // Ht = ot (.) h(Ct) - // shape: (batch_size, hidden_size) - Y_h = Mul(o_proj, Tanh(Y_c), Y_h->op->name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/lstm.h b/onnxruntime/core/codegen/mti/nn/lstm.h deleted file mode 100644 index 851fa880c4427..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/lstm.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -// A bubble now. But don't remove it -// TODO: refactor the LSTMcell building to a tvm function -// and move it here - -namespace onnxruntime { -namespace tvm_codegen { - -struct LSTMAttributes { - LSTMAttributes(int64_t hidden_size_p) : hidden_size(hidden_size_p) {} - int64_t hidden_size; -}; - -void LSTM_cell( - const LSTMAttributes& lstm_attrs, - const tvm::Tensor& X, - const tvm::Tensor& W, - const tvm::Tensor& R, - const tvm::Tensor& B, - bool has_B, - const tvm::Tensor& prev_H, - const tvm::Tensor& prev_C, - const tvm::Tensor& P, - bool has_P, - tvm::Tensor& Y_h, - tvm::Tensor& Y_c); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/pool_ops.cc b/onnxruntime/core/codegen/mti/nn/pool_ops.cc deleted file mode 100644 index 868a14748cabc..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/pool_ops.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/nn/pool_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/mlas/inc/mlas.h" -#include "core/providers/cpu/nn/pool_attributes.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// TODO: topi only support 2d-pool, MaxPool1d and MaxPool3d will need to be added if necessary. -// only support version < 8 for topi doesn't come with implementation to output index tensor -tvm::Tensor MaxPool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::pool(input, - ToTvmArray(pool_attrs.kernel_shape), - ToTvmArray(pool_attrs.strides), - ToTvmArray(pool_attrs.pads), - /*pool_type*/ topi::nn::kMaxPool, - /*ceil_mode*/ false, - /*layout*/ pool_attrs.storage_order == 0 ? "NCWH" : "NCHW", - pool_attrs.count_include_pad); -} - -tvm::Tensor AveragePool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::pool(input, - ToTvmArray(pool_attrs.kernel_shape), - ToTvmArray(pool_attrs.strides), - ToTvmArray(pool_attrs.pads), - /*pool_type*/ topi::nn::kAvgPool, - /*ceil_mode*/ false, - /*layout*/ "NCHW", - pool_attrs.count_include_pad); -} - -tvm::Tensor GlobalMaxPool(const tvm::Tensor& input, - const PoolAttributes& /*pool_attrs*/, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::global_pool(input, - /*pool_type*/ topi::nn::kMaxPool, - /*layout*/ "NCHW"); -} - -tvm::Tensor GlobalAveragePool(const tvm::Tensor& input, - const PoolAttributes& /*pool_attrs*/, - const tvm::Array& /*output_shape*/, - const std::string& /*name*/) { - return topi::nn::global_pool(input, - /*pool_type*/ topi::nn::kAvgPool, - /*layout*/ "NCHW"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/nn/pool_ops.h b/onnxruntime/core/codegen/mti/nn/pool_ops.h deleted file mode 100644 index d381f9ddff859..0000000000000 --- a/onnxruntime/core/codegen/mti/nn/pool_ops.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { - -// Forward declaration -struct PoolAttributes; - -namespace tvm_codegen { - -tvm::Tensor MaxPool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "max_pool"); - -tvm::Tensor AveragePool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "average_pool"); - -tvm::Tensor GlobalMaxPool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "global_max_pool"); - -tvm::Tensor GlobalAveragePool(const tvm::Tensor& input, - const PoolAttributes& pool_attrs, - const tvm::Array& output_shape, - const std::string& name = "global_average_pool"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/cast_ops.cc b/onnxruntime/core/codegen/mti/tensor/cast_ops.cc deleted file mode 100644 index a8fc86488d82b..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/cast_ops.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/cast_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Cast(const tvm::Tensor& X, tvm::Type type, const std::string& name) { - return topi::cast(X, type, name); -} - -// handle cases where bool is reprented as uint8 (e.g. in ONNX). -tvm::Tensor CastToUInt8Bool(const tvm::Tensor& X, const std::string& name) { - return tvm::compute( - X->shape, - [&](const tvm::Array& indices) { - auto val = X(indices); - // A special cast from float16 to bool, first cast up to float32, - // to workaround a float16 bug in many TVM backends. - // Intel Skylake is one of them. https://github.com/dmlc/tvm/issues/2959 - // TODO: remove it, after TVM is fixed - if (X->dtype == HalideIR::Float(16)) - val = tvm::cast(HalideIR::Float(32), val); - return tvm::ir::Select::make(topi::equal(val, tvm::make_zero(val.type())), - tvm::make_zero(HalideIR::UInt(8)), - tvm::make_const(HalideIR::UInt(8), 1)); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/cast_ops.h b/onnxruntime/core/codegen/mti/tensor/cast_ops.h deleted file mode 100644 index 02f6f9cb1fde7..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/cast_ops.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Cast(const tvm::Tensor& X, tvm::Type type, const std::string& name = "cast"); -tvm::Tensor CastToUInt8Bool(const tvm::Tensor& X, const std::string& name = "cast_uint8_bool"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc b/onnxruntime/core/codegen/mti/tensor/concat_ops.cc deleted file mode 100644 index 3394d5b7e00a2..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/concat_ops.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/concat_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Concat(const tvm::Array& inputs, - int64_t axis, - const std::string& name) { - return ConcatSafe(inputs, axis, name); -} - -// Note topi's implementation requires control flow within iterations to avoid out-of-bound access. -// Therefore, MTI implements a ConcatSafe that does not have out-of-bound access, -// and does not requires control or predicate. -tvm::Tensor ConcatSafe(const tvm::Array& inputs, - int64_t axis, - const std::string& name) { - axis = HandleNegativeAxis(axis, gsl::narrow(inputs[0]->shape.size())); - MTI_ASSERT(axis < gsl::narrow(inputs[0]->shape.size()) && "axis out of bounds"); - - tvm::Array axis_sizes; - for (auto t : inputs) { - axis_sizes.push_back(t->shape[axis]); - } - - tvm::Expr join_size = axis_sizes[0]; - for (size_t i = 1; i < axis_sizes.size(); ++i) { - join_size += axis_sizes[i]; - } - join_size = tvm::ir::Simplify(join_size); - tvm::Array out_shape; - for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { - out_shape.push_back(i == gsl::narrow(axis) ? join_size : inputs[0]->shape[i]); - } - - return tvm::compute( - out_shape, [&](const tvm::Array& ovars) { - tvm::Array indices; - - // preset - tvm::Expr min = 0; - tvm::Expr extent = axis_sizes[0]; - tvm::Expr offset = 0; - tvm::Expr ret; - - // input i = 0 - for (size_t j = 0; j < ovars.size(); ++j) { - if (j == gsl::narrow(axis)) { - tvm::Expr ivar = ovars[j]; - indices.push_back(tvm::max(tvm::min(ivar, min + extent - 1), min)); - } else { - indices.push_back(ovars[j]); - } - } - ret = inputs[0](indices); - - for (size_t i = 1; i < inputs.size(); ++i) { - offset += extent; - tvm::Expr min = 0; - extent = axis_sizes[i]; - auto j = gsl::narrow(axis); - tvm::Expr ivar = ovars[j] - offset; - indices.Set(j, tvm::max(tvm::min(ivar, min + extent - 1), min)); - - ret = tvm::ir::Select::make(ivar >= 0, - inputs[i](indices), - ret); - } - - return ret; - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/concat_ops.h b/onnxruntime/core/codegen/mti/tensor/concat_ops.h deleted file mode 100644 index 153afebb44615..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/concat_ops.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Concat(const tvm::Array& inputs, int64_t axis, const std::string& name = "concat"); -tvm::Tensor ConcatSafe(const tvm::Array& inputs, int64_t axis, const std::string& name = "concat_safe"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/crop.cc b/onnxruntime/core/codegen/mti/tensor/crop.cc deleted file mode 100644 index 3fe569100df12..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/crop.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/crop.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Crop(const tvm::Tensor& t, - const tvm::Array& border, - const tvm::Array& scale, - const std::string& name) { - MTI_ASSERT(t->shape.size() == 4); - tvm::Expr N = t->shape[0]; - tvm::Expr C = t->shape[1]; - tvm::Expr H = t->shape[2]; - tvm::Expr W = t->shape[3]; - - MTI_ASSERT(border.size() == 4); - tvm::Expr leftBorder = border[0]; - tvm::Expr topBorder = border[1]; - tvm::Expr rightBorder = border[2]; - tvm::Expr bottomBorder = border[3]; - - tvm::Expr bottomLimit = H - bottomBorder; - tvm::Expr rightLimit = W - rightBorder; - - if (!scale.empty()) { - CHECK_EQ(scale.size(), 2); - bottomLimit = topBorder + scale[0]; - rightLimit = leftBorder + scale[1]; - } - - tvm::Array output_shape; - output_shape.push_back(tvm::ir::Simplify(N)); - output_shape.push_back(tvm::ir::Simplify(C)); - output_shape.push_back(tvm::ir::Simplify(bottomLimit - topBorder)); - output_shape.push_back(tvm::ir::Simplify(rightLimit - leftBorder)); - - auto l = [&](const tvm::Array& ovars) { - tvm::Array indices; - - indices.push_back(tvm::min(ovars[0], output_shape[0] - 1)); - indices.push_back(tvm::min(ovars[1], output_shape[1] - 1)); - indices.push_back(tvm::min(topBorder + ovars[2], topBorder + output_shape[2] - 1)); - indices.push_back(tvm::min(leftBorder + ovars[3], leftBorder + output_shape[3] - 1)); - - return t(indices); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/crop.h b/onnxruntime/core/codegen/mti/tensor/crop.h deleted file mode 100644 index ffb6a05c70504..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/crop.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Crop(const tvm::Tensor& t, - const tvm::Array& border, - const tvm::Array& scale = {}, - const std::string& name = "crop"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/expand.cc b/onnxruntime/core/codegen/mti/tensor/expand.cc deleted file mode 100644 index cdac4f56e1f9f..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/expand.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/expand.h" -#include "core/codegen/mti/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Expand(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name) { - MTI_ASSERT(new_shape.size() >= X->shape.size()); - return tvm::compute( - new_shape, - [&](const tvm::Array& out_indices) { - tvm::Array indices; - size_t broadcasted_rank = new_shape.size() - X->shape.size(); - for (size_t d = broadcasted_rank; d < new_shape.size(); ++d) { - if (tvm::is_const_int(X->shape[d - broadcasted_rank], 1)) { - indices.push_back(tvm::make_zero(HalideIR::Int(32))); - } else { - indices.push_back(out_indices[d]); - } - } - return X(indices); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/expand.h b/onnxruntime/core/codegen/mti/tensor/expand.h deleted file mode 100644 index d66d41aeb0194..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/expand.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Expand(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name = "expand"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather.cc b/onnxruntime/core/codegen/mti/tensor/gather.cc deleted file mode 100644 index 152b3981f1623..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/gather.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gather(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name) { - // handle negative axis - axis = HandleNegativeAxis(axis, gsl::narrow(t->shape.size())); - size_t axis_t = gsl::narrow(axis); - - tvm::Array output_shape; - for (size_t i = 0; i < axis_t; ++i) - output_shape.push_back(t->shape[i]); - - for (size_t i = 0; i < indices->shape.size(); ++i) - output_shape.push_back(indices->shape[i]); - - for (size_t i = axis_t + 1; i < t->shape.size(); ++i) - output_shape.push_back(t->shape[i]); - - tvm::Expr idx_upper_bound = t->shape[axis_t]; - auto l = [&](const tvm::Array& ovars) { - tvm::Array ivars; - for (size_t i = 0; i < t->shape.size(); ++i) { - if (i < axis_t) { - ivars.push_back(ovars[i]); - } else if (i == axis_t) { - tvm::Array idx_vars; - for (size_t d = 0; d < indices->shape.size(); ++d) - idx_vars.push_back(ovars[axis_t + d]); - // make sure idx is clamped in the range of [-idx_upper_bound, idx_upper_bound - 1] - tvm::Expr real_idx = tvm_codegen::ClampIndex(indices(idx_vars), idx_upper_bound); - ivars.push_back(tvm::cast(tvm::Int(32), real_idx)); // tvm indices must be Int32 - } else { - ivars.push_back(ovars[i - 1 + indices->shape.size()]); - } - } - return t(ivars); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather.h b/onnxruntime/core/codegen/mti/tensor/gather.h deleted file mode 100644 index a44bf3e4127d5..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Gather(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name = "gather"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather_elements.cc b/onnxruntime/core/codegen/mti/tensor/gather_elements.cc deleted file mode 100644 index 12d2983335890..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather_elements.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/gather_elements.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor GatherElements(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name) { - tvm::Array output_shape; - int64_t indices_rank = static_cast(indices->shape.size()); - // output shape is the same as indices - for (int64_t i = 0; i < indices_rank; ++i) - output_shape.push_back(indices->shape[i]); - - tvm::Expr idx_upper_bound = t->shape[axis]; - auto l = [&](const tvm::Array& ovars) { - tvm::Array ivars; - for (int64_t i = 0; i < indices_rank; i++) { - if (i == axis) { - tvm::Array idx_vars; - for (int64_t j = 0; j < indices_rank; j++) - idx_vars.push_back(ovars[j]); - // make sure idx is clamped in the range of [-idx_upper_bound, idx_upper_bound - 1] - tvm::Expr real_idx = tvm_codegen::ClampIndex(indices(idx_vars), idx_upper_bound); - // tvm idx must be of Int(32) - ivars.push_back(tvm::cast(tvm::Int(32), real_idx)); - } else { - ivars.push_back(ovars[i]); - } - } - return t(ivars); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/gather_elements.h b/onnxruntime/core/codegen/mti/tensor/gather_elements.h deleted file mode 100644 index 650086f0f2e87..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/gather_elements.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor GatherElements(const tvm::Tensor& t, - int64_t axis, - const tvm::Tensor& indices, - const std::string& name = "gather_elements"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/pad_ops.cc b/onnxruntime/core/codegen/mti/tensor/pad_ops.cc deleted file mode 100644 index 2f688290d109e..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/pad_ops.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/pad_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// Note topi::pad does not support modes {edge, reflect} -// Therefore, MTI implements a generic Pad -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& pad_before, - const tvm::Array& pad_after, - float pad_value, - const std::string& mode, - const std::string& name) { - MTI_ASSERT(pad_before.size() >= 1); - MTI_ASSERT(pad_before.size() == pad_after.size()); - MTI_ASSERT(pad_before.size() == t->shape.size()); - - tvm::Array output_shape; - for (size_t i = 0; i < t->shape.size(); ++i) { - output_shape.push_back( - tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i])); - } - - auto l = [&](const tvm::Array& ovars) { - tvm::Array conds; - tvm::Array indices; - tvm::Array coords; - - for (size_t i = 0; i < t->shape.size(); ++i) { - tvm::Expr ivar = ovars[i] - pad_before[i]; - tvm::Expr min = 0; - tvm::Expr extent = t->shape[i]; - - conds.push_back(ivar < min); - conds.push_back(ivar >= min + extent); - indices.push_back(tvm::max(tvm::min(ivar, min + extent - 1), min)); - - if (mode == "reflect") { - // calculate indices for reflect mode - tvm::Expr limit = extent - 1; - tvm::Expr coord = ivar - min; - // Avoid mod zero when tensor shape has 1, - // e.g. input shape is [1, 3, 3] instead of [3, 3] - auto* p_limit = tvm::as_const_int(limit); - if (p_limit != nullptr && *p_limit != 0) - coord = (coord + 2 * limit) % (2 * limit); // avoid negative value - coord = coord - limit; - coord = tvm::abs(coord); - coord = limit - coord; - coord = coord + min; - coords.push_back(coord); - } - } - - if (mode == "reflect") { - return tvm::ir::Select::make(topi::detail::Map(conds, tvm::ir::Or::make), - t(coords), t(indices)); - } else if (mode == "constant") { - return tvm::ir::Select::make(topi::detail::Map(conds, tvm::ir::Or::make), - tvm::make_const(t->dtype, pad_value), t(indices)); - } - - // default mode is edge - return t(indices); - }; - - return tvm::compute(output_shape, l, name); -} - -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& output_shape, - const tvm::Expr& pad_value, - const std::string& name) { - MTI_ASSERT(t->dtype == pad_value.type()); - - auto l = [&](const tvm::Array& ovars) { - tvm::Array conds; - tvm::Array indices; - - for (size_t i = 0; i < t->shape.size(); ++i) { - tvm::Expr ivar = ovars[i]; - tvm::Expr min = 0; - tvm::Expr extent = t->shape[i]; - - conds.push_back(ivar < min); - conds.push_back(ivar >= min + extent); - indices.push_back(tvm::max(tvm::min(ivar, min + extent - 1), min)); - } - - return tvm::ir::Select::make(topi::detail::Map(conds, tvm::ir::Or::make), - pad_value, t(indices)); - }; - - return tvm::compute(output_shape, l, name); -} - -tvm::Tensor PadLastDim(const tvm::Tensor& t, - const int32_t align_size, - const tvm::Expr& pad_value, - const std::string& name) { - auto input_shape = t->shape; - tvm::Array out_shape; - size_t input_shape_rank = input_shape.size(); - for (size_t i = 0; i < input_shape_rank - 1; ++i) { - out_shape.push_back(input_shape[i]); - } - out_shape.push_back( - (input_shape[input_shape_rank - 1] + align_size - 1) / - align_size * align_size); - - return Pad(t, out_shape, pad_value, name + "_pad"); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/pad_ops.h b/onnxruntime/core/codegen/mti/tensor/pad_ops.h deleted file mode 100644 index 6e8e350d71e97..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/pad_ops.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// ONNX Pad semantics -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& pad_before, - const tvm::Array& pad_after, - float pad_value = 0.0f, - const std::string& mode = "constant", - const std::string& name = "pad"); - -// Other common Pad interfaces -// Pad for a given shape -tvm::Tensor Pad(const tvm::Tensor& t, - const tvm::Array& output_shape, - const tvm::Expr& pad_value, - const std::string& name = "pad"); - -// Pad for the last dim only. -// This is widely used for weight layout to guard alignment -tvm::Tensor PadLastDim(const tvm::Tensor& t, - const int32_t align_size, - const tvm::Expr& pad_value, - const std::string& name = "pad_last_dim"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/reshape_ops.cc b/onnxruntime/core/codegen/mti/tensor/reshape_ops.cc deleted file mode 100644 index 817fb32c2837a..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/reshape_ops.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/reshape_ops.h" - -#include "core/codegen/mti/common.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Flatten(const tvm::Tensor& X, int64_t axis, const std::string& name) { - const auto& input_shape = X->shape; - return Reshape(X, {SizeToDimension(input_shape, axis), SizeFromDimension(input_shape, axis)}, name); -} - -tvm::Tensor Identity(const tvm::Tensor& X, const std::string& name) { - return Reshape(X, X->shape, name); -} - -tvm::Tensor Reshape(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name) { - if (new_shape.size() > 0) { - auto X_dim = SizeToDimension(X->shape, X->shape.size()); - auto new_dim = SizeToDimension(new_shape, new_shape.size()); - auto* pX_dim = tvm::as_const_int(X_dim); - auto* pNew_dim = tvm::as_const_int(new_dim); - - if (pX_dim != nullptr && pNew_dim != nullptr) { - MTI_ASSERT(*pX_dim == *pNew_dim); - } - return topi::reshape(X, new_shape, name); - } else { - // generate empty dim tensor with origial input data value - tvm::Array tmp_shape; - tmp_shape.push_back(1); - auto tmp_tensor = topi::reshape(X, tmp_shape); - return tvm::compute( - new_shape, - [&](const tvm::Array&) { - return tmp_tensor[0]; - }, - name); - } -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/reshape_ops.h b/onnxruntime/core/codegen/mti/tensor/reshape_ops.h deleted file mode 100644 index e23d62e4c57b0..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/reshape_ops.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Flatten(const tvm::Tensor& X, int64_t axis, const std::string& name = "flatten"); -tvm::Tensor Identity(const tvm::Tensor& X, const std::string& name = "identity"); -tvm::Tensor Reshape(const tvm::Tensor& X, const tvm::Array& new_shape, const std::string& name = "reshape"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/shape_op.cc b/onnxruntime/core/codegen/mti/tensor/shape_op.cc deleted file mode 100644 index b51bd67a8b2dc..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/shape_op.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/shape_op.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Shape(const tvm::Tensor& X, const std::string& name) { - int ndim = static_cast(X->shape.size()); - tvm::Array out_shape{ndim}; - return tvm::compute( - out_shape, [&](const tvm::Array& indices) { - auto idx = indices[0]; - tvm::Expr ret = 0; - for (int i = 0; i < ndim; ++i) { - ret = tvm::ir::Select::make(idx == i, X->shape[i], ret); - } - return tvm::cast(HalideIR::Int(64), ret); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/slice.cc b/onnxruntime/core/codegen/mti/tensor/slice.cc deleted file mode 100644 index 6cbab43584d4b..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/slice.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/slice.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// local constexpr for INT_MAX -constexpr int64_t max_range = INT_MAX; - -tvm::Expr position(const tvm::Expr& dim, const tvm::Integer& offset, bool allow_out_of_bound = false) { - if (offset->value >= max_range) { - return allow_out_of_bound ? dim : dim - 1; - } else if (offset->value <= -max_range) { - return tvm::make_const(HalideIR::Int(32), allow_out_of_bound ? -1 : 0); - } else { - if (offset->value >= 0) { - return tvm::ir::Simplify(tvm::ir::Min::make(offset, dim + (allow_out_of_bound ? 0 : -1))); - } else { - return tvm::ir::Simplify(dim + tvm::ir::Max::make(offset, -dim + (allow_out_of_bound ? -1 : 0))); - } - } -} - -tvm::Tensor Slice(const tvm::Tensor& X, - const std::vector& starts, - const std::vector& ends, - const std::vector& axes1, - const std::vector& steps, - const std::string& name) { - MTI_ASSERT(starts.size() == ends.size()); - MTI_ASSERT(starts.size() == axes1.size()); - MTI_ASSERT(starts.size() == steps.size()); - - std::vector axes; - for (const auto& i : axes1) { - axes.push_back(HandleNegativeAxis(i, X->shape.size())); - } - - tvm::Array output_shape; - bool empty = false; - for (int64_t i = 0; i < gsl::narrow(X->shape.size()); ++i) { - auto axes_iter = std::find(axes.begin(), axes.end(), i); - if (axes_iter != axes.end()) { - auto axis = axes_iter - axes.begin(); - tvm::Expr start = position(X->shape[i], starts[axis]); - tvm::Expr end = position(X->shape[i], ends[axis], /*allow_out_of_bound*/ true); - auto dim = tvm::ir::Simplify((end - start + tvm::Integer(steps[axis] + (steps[axis] < 0 ? 1 : -1))) / tvm::Integer(steps[axis])); - auto int_dim = tvm::as_const_int(dim); - if (int_dim && *int_dim <= 0) { - output_shape.push_back(0); - empty = true; - } else { - output_shape.push_back(dim); - } - } else { - output_shape.push_back(X->shape[i]); - } - } - - if (empty) { - return MakeZeroTensor(output_shape, X->dtype, name); - } - - return tvm::compute( - output_shape, - [&](const tvm::Array& ovars) { - tvm::Array ivars; - for (size_t i = 0; i < X->shape.size(); ++i) { - auto axes_iter = std::find(axes.begin(), axes.end(), i); - if (axes_iter != axes.end()) { - auto axis = axes_iter - axes.begin(); - ivars.push_back(tvm::ir::Simplify(ovars[i] * tvm::Integer(steps[axis]) + position(X->shape[i], starts[axis]))); - } else { - ivars.push_back(ovars[i]); - } - } - return X(ivars); - }, - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/slice.h b/onnxruntime/core/codegen/mti/tensor/slice.h deleted file mode 100644 index ac5c9437791f6..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/slice.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Slice(const tvm::Tensor& X, - const std::vector& starts, - const std::vector& ends, - const std::vector& axes, - const std::vector& steps, - const std::string& name = "slice"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/split.cc b/onnxruntime/core/codegen/mti/tensor/split.cc deleted file mode 100644 index 6ee366314858f..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/split.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/split.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// Similar to numpy, topi::split takes split indices rather than the -// sizes of the splits. Thus we implement our own. -tvm::Array Split(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name) { - MTI_ASSERT(axis < gsl::narrow(X->shape.size())); - size_t axis_t = gsl::narrow(axis); - - tvm::Array> output_shapes; - int num_splits = gsl::narrow(split_sizes.size()); - for (auto& s : split_sizes) { - tvm::Array shape; - for (size_t i = 0; i < axis_t; i++) { - shape.push_back(X->shape[i]); - } - shape.push_back(s); - for (size_t i = axis_t + 1; i < X->shape.size(); i++) { - shape.push_back(X->shape[i]); - } - output_shapes.push_back(shape); - } - - tvm::Array res; - int idx = 0; - for (int i_split = 0; i_split < num_splits; ++i_split) { - tvm::Expr s = split_sizes[i_split]; - auto l = [&](const tvm::Array& indices) { - tvm::Array new_indices; - for (size_t i = 0; i < axis_t; i++) { - new_indices.push_back(indices[i]); - } - new_indices.push_back(indices[axis_t] + idx); - for (size_t i = axis_t + 1; i < X->shape.size(); i++) { - new_indices.push_back(indices[i]); - } - MTI_ASSERT(topi::detail::IsConstInt(s)); - MTI_ASSERT(new_indices.size() == X->shape.size()); - int size = topi::detail::GetConstInt(s); - idx += size; - return X(new_indices); - }; - res.push_back(tvm::compute(output_shapes[i_split], l, name)); - } - - MTI_ASSERT(topi::detail::IsConstInt(X->shape[axis_t])); - int size_of_splitted_axis = static_cast(topi::detail::GetConstInt(X->shape[axis_t])); - MTI_ASSERT(idx == size_of_splitted_axis); - return res; -} - -tvm::Array SplitWithIndices(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name) { - return topi::split(X, split_sizes, gsl::narrow(axis), name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/split.h b/onnxruntime/core/codegen/mti/tensor/split.h deleted file mode 100644 index bcb9c47d936dd..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/split.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// ONNX Split semantics -tvm::Array Split(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name = "split"); - -// Another common Split interface -// Split with chunck indices -tvm::Array SplitWithIndices(const tvm::Tensor& X, - const tvm::Array& split_sizes, - int64_t axis, - const std::string& name = "split_with_indices"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/tile.cc b/onnxruntime/core/codegen/mti/tensor/tile.cc deleted file mode 100644 index 2fef86adcbaea..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/tile.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/tile.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Tile(const tvm::Tensor& t, - const std::vector& repeats, - const std::string& name) { - MTI_ASSERT(repeats.size() == t->shape.size()); - tvm::Array output_shape; - - bool repeats_zero = false; - for (size_t i = 0; i < t->shape.size(); ++i) { - if (repeats[i] == 0) - repeats_zero = true; - output_shape.push_back(t->shape[i] * gsl::narrow(repeats[i])); - } - - auto l = [&](const tvm::Array& ovars) { - if (repeats_zero) - return tvm::make_zero(t->dtype); - - tvm::Array ivars; - for (size_t i = 0; i < t->shape.size(); ++i) { - tvm::Expr ovar = ovars[i]; - ivars.push_back(ovar % t->shape[i]); - } - return t(ivars); - }; - - return tvm::compute(output_shape, l, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/tile.h b/onnxruntime/core/codegen/mti/tensor/tile.h deleted file mode 100644 index 7ce331fb5ea95..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/tile.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Tile(const tvm::Tensor& t, - const std::vector& repeats, - const std::string& name = "tile"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/transpose.cc b/onnxruntime/core/codegen/mti/tensor/transpose.cc deleted file mode 100644 index 873ff8d7f1708..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/transpose.cc +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/tensor/transpose.h" - -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Transpose(const tvm::Tensor& X, const tvm::Array& axes, const std::string& name) { - return topi::transpose(X, axes, name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/transpose.h b/onnxruntime/core/codegen/mti/tensor/transpose.h deleted file mode 100644 index a2a98fedf1e79..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/transpose.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Transpose(const tvm::Tensor& X, - const tvm::Array& axes, - const std::string& name = "transpose"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/where.cc b/onnxruntime/core/codegen/mti/tensor/where.cc deleted file mode 100644 index 2bdac3cae7ef5..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/where.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/where.h" - -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Where(const tvm::Tensor& B, - const tvm::Tensor& X, - const tvm::Tensor& Y, - const std::string& name) { - size_t rank = std::max(std::max(B->shape.size(), X->shape.size()), Y->shape.size()); - tvm::Array output_shape; - for (size_t i = 0; i < rank; ++i) { - tvm::Expr dim = tvm::make_const(HalideIR::Int(32), 1); - bool broadcasted = - BroadcastDim(B->shape, i, rank, dim) && - BroadcastDim(X->shape, i, rank, dim) && - BroadcastDim(Y->shape, i, rank, dim); - MTI_ASSERT(broadcasted); - output_shape.push_back(dim); - } - - return topi::where(topi::broadcast_to(B, output_shape), - topi::broadcast_to(X, output_shape), - topi::broadcast_to(Y, output_shape), - name); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/where.h b/onnxruntime/core/codegen/mti/tensor/where.h deleted file mode 100644 index 68c5288eb3580..0000000000000 --- a/onnxruntime/core/codegen/mti/tensor/where.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Where(const tvm::Tensor& B, - const tvm::Tensor& X, - const tvm::Tensor& Y, - const std::string& name = "where"); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/all_ops.h b/onnxruntime/core/codegen/passes/op_ir_creator/all_ops.h deleted file mode 100644 index 1463e50bd72fb..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/all_ops.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/codegen/common/op_macro.h" -#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// This macro declares a TVM IR builder -// based on ORT OP type with postfix DefaultTVM -#define DECLARE_GENERIC_OP_IR_CREATOR_CLASS(OP) \ - DECLARE_OP_IR_CREATOR_CLASS(OP, DefaultTVM) - -// This macro returns a TVM IR builder class name -// based ORT OP type with postfix DefaultTVM -#define GENERIC_OP_IR_CREATOR_CLASS(OP) \ - CREATOR_CLASS(OP, DefaultTVM##IRCreator) - -#define GENERIC_OP_IR_CREATOR_STRING(OP) \ - STRINGIZE(GENERIC_OP_IR_CREATOR_CLASS(OP)) - -// define all ops for DefaultTVM -#define ADD_OP_ITEM(OP) DECLARE_GENERIC_OP_IR_CREATOR_CLASS(OP) -#define BINARY_OP(OP) ADD_OP_ITEM(OP) -#define BINARY_CMP_OP(OP) ADD_OP_ITEM(OP) -#define POOL_OP(OP) ADD_OP_ITEM(OP) -#define UNARY_OP(OP) ADD_OP_ITEM(OP) -#define VARIADIC_OP(OP) ADD_OP_ITEM(OP) -#define REDUCE_INDEXED_OP(OP) ADD_OP_ITEM(OP) -#define REDUCE_OP(OP) ADD_OP_ITEM(OP) - -LIST_ALL_GENERIC_OPS() - -#undef ADD_OP_ITEM -#undef BINARY_OP -#undef BINARY_CMP_OP -#undef POOL_OP -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP -#undef UNARY_OP -#undef VARIADIC_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/binary_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/binary_ops.cc deleted file mode 100644 index 9452146621ac7..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/binary_ops.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/tensor/cast_ops.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// helper local macro defines Evaluate of BINARY_OP OpIRCreators -#define BINARY_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = name(inputs[0], inputs[1], node.Name()); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_BINARY_OPS() - -#undef BINARY_OP - -// helper local macro defines Evaluate of BINARY_CMP_OP OpIRCreators -#define BINARY_CMP_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = Cast(name(inputs[0], inputs[1], node.Name()), HalideIR::UInt(8), "cast_bool_" #name); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_BINARY_CMP_OPS() - -#undef BINARY_CMP_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/clip.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/clip.cc deleted file mode 100644 index bb33e6e70accf..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/clip.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/unary_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Clip OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Clip)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int version = ctx_codegen.GetCodeGenHandle()->domain_version_lookup_func(node.Domain()); - tvm::Expr min_value, max_value; - if (version < 11) { - float max_v, min_v; - info.GetAttrOrDefault("min", &min_v, std::numeric_limits::lowest()); - info.GetAttrOrDefault("max", &max_v, std::numeric_limits::max()); - min_value = tvm::make_const(tvm::Float(32), min_v); - max_value = tvm::make_const(tvm::Float(32), max_v); - } else { - // for op_version >= 11, max and min are optional inputs - min_value = tvm::make_const(tvm::Float(32), std::numeric_limits::lowest()); - max_value = tvm::make_const(tvm::Float(32), std::numeric_limits::max()); - auto num_inputs = inputs.size(); - if (num_inputs >= 2 && inputs[1].defined()) { - min_value = inputs[1](); - } - if (num_inputs == 3 && inputs[2].defined()) { - max_value = inputs[2](); - } - } - - tvm::Tensor Y = Clip(inputs[0], min_value, max_value, node.Name() + "_Clip"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/gemm.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/gemm.cc deleted file mode 100644 index 64f995076e1bb..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/gemm.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/gemm.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Gemm OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Gemm)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& /*ctx_codegen*/, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - tvm::Tensor A = inputs[0]; - tvm::Tensor B = inputs[1]; - tvm::Tensor C = inputs[2]; - - int64_t trans_A, trans_B; - ORT_RETURN_IF_ERROR(attrs.GetAttr("transA", &trans_A)); - ORT_RETURN_IF_ERROR(attrs.GetAttr("transB", &trans_B)); - - float alpha, beta; - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha).IsOK()); - ORT_ENFORCE(attrs.GetAttr("beta", &beta).IsOK()); - - tvm::Tensor Y = Gemm(A, B, C, trans_A != 0, trans_B != 0, alpha, beta, node.Name() + "_Gemm"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/logsoftmax.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/logsoftmax.cc deleted file mode 100644 index cb09518bf63d1..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/logsoftmax.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/logsoftmax.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of LogSoftmax OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(LogSoftmax)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis_i64; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis_i64)); - axis_i64 = HandleNegativeAxis(axis_i64, gsl::narrow_cast(inputs[0]->shape.size())); - - tvm::Tensor Y = LogSoftmax(inputs[0], axis_i64, node.Name() + "_LogSoftmax"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/matmul.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/matmul.cc deleted file mode 100644 index ab1ac237bfa5d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/matmul.cc +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/matmul_ops.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of MatMul OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(MatMul)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - tvm::Tensor Y = MatMul(inputs[0], inputs[1], node.Name() + "_MatMul"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc deleted file mode 100644 index 6f66b1f1a2afb..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/quantize/matmul_integer.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/math/matmul_ops.h" -#include "core/codegen/mti/tensor/cast_ops.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of MatMulInteger OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(MatMulInteger)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - const auto& A = inputs[0]; - const auto& B = inputs[1]; - auto& name = node.Name(); - - // A generic path, cast to int32 - // Support skipped trailing inputs - auto A_Int32 = (node.InputDefs().size() >= 3 && node.InputDefs()[2]->Exists()) - ? Sub(Cast(A, HalideIR::Int(32)), Cast(inputs[2], HalideIR::Int(32))) - : Cast(A, HalideIR::Int(32)); - auto B_Int32 = (node.InputDefs().size() >= 4 && node.InputDefs()[3]->Exists()) - ? Sub(Cast(B, HalideIR::Int(32)), Cast(inputs[3], HalideIR::Int(32))) - : Cast(B, HalideIR::Int(32)); - tvm::Tensor Y = MatMul(A_Int32, B_Int32, name + "_MatMulInteger"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/reduce_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/reduce_ops.cc deleted file mode 100644 index f29a3f3e7cdf7..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/reduce_ops.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/mti/math/reduce_ops.h" -#include "core/codegen/mti/tensor/cast_ops.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -using ReduceIndexedFunc = tvm::Tensor (*)(const tvm::Tensor& X, int64_t axis, bool keep_dims, const std::string& name); -using ReduceFunc = tvm::Tensor (*)(const tvm::Tensor& X, const std::vector& axes, bool keep_dims, const std::string& name); - -// helper class for for REDUCE_INDEXED_OP -class FuncReduceIndexed { - public: - FuncReduceIndexed(const Node& node, ReduceIndexedFunc func, const std::string& name) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - axis_ = info.GetAttrOrDefault("axis", 0); - int64_t keepdims_i = 1; - ORT_ENFORCE(info.GetAttr("keepdims", &keepdims_i).IsOK()); - keep_dims_ = (keepdims_i == 1); - func_ = func; - name_ = name; - } - - tvm::Tensor operator()(const tvm::Tensor& X) const { - auto axis = HandleNegativeAxis(axis_, gsl::narrow_cast(X->shape.size())); - tvm::Tensor index32 = func_(X, axis, keep_dims_, name_); - return Cast(index32, tvm::Int(64)); - } - - private: - int64_t axis_; - bool keep_dims_; - ReduceIndexedFunc func_; - std::string name_; -}; - -// helper class for REDUCE_OP -class FuncReduce { - public: - FuncReduce(const Node& node, ReduceFunc func, const std::string& name) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - axes_ = info.GetAttrsOrDefault("axes"); - int64_t keepdims_i = 1; - ORT_ENFORCE(info.GetAttr("keepdims", &keepdims_i).IsOK()); - keep_dims_ = (keepdims_i == 1); - func_ = func; - name_ = name; - } - - tvm::Tensor operator()(const tvm::Tensor& X) const { - std::vector axes; - for (auto i : axes_) - axes.push_back(HandleNegativeAxis(i, gsl::narrow_cast(X->shape.size()))); - - return func_(X, axes, keep_dims_, name_); - } - - private: - std::vector axes_; - bool keep_dims_; - ReduceFunc func_; - std::string name_; -}; - -// helper macro defines Evaluate of REDUCE_OP OpIRCreators -#define REDUCE_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y; \ - if (ShapeRank(node.OutputDefs()[0]) == 0) { \ - tvm::Tensor temp = FuncReduce(node, &name, #name)(inputs[0]); \ - Y = Reshape(temp, {}); \ - } else { \ - Y = FuncReduce(node, &name, #name)(inputs[0]); \ - } \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -// helper macro defines Evaluate of REDUCE_INDEXED_OP OpIRCreators -#define REDUCE_INDEXED_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = FuncReduceIndexed(node, &name, #name)(inputs[0]); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_REDUCE_OPS() - -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/softmax.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/softmax.cc deleted file mode 100644 index 7b13de5a94e48..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/softmax.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/softmax.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Softmax OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Softmax)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis_i64; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis_i64)); - - axis_i64 = HandleNegativeAxis(axis_i64, gsl::narrow_cast(inputs[0]->shape.size())); - tvm::Tensor Y = Softmax(inputs[0], axis_i64, node.Name() + "_Softmax"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_funcs.h b/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_funcs.h deleted file mode 100644 index 29e6519af0ef1..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_funcs.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { -// helper class for unary_ops with alpha -class FuncWithAlpha { - public: - FuncWithAlpha(const Node& node) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); - } - - protected: - float alpha_; -}; - -// helper class for unary_ops with alpha and beta -class FuncWithAlphaBeta { - public: - FuncWithAlphaBeta(const Node& node) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); - ORT_ENFORCE(attrs.GetAttr("beta", &beta_).IsOK()); - } - - protected: - float alpha_; - float beta_; -}; - -// helper class for unary_ops with alpha and gamma -class FuncWithAlphaGamma { - public: - FuncWithAlphaGamma(const Node& node) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - ORT_ENFORCE(attrs.GetAttr("alpha", &alpha_).IsOK()); - ORT_ENFORCE(attrs.GetAttr("gamma", &gamma_).IsOK()); - } - - protected: - float alpha_; - float gamma_; -}; -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_ops.cc deleted file mode 100644 index 0407c0a06abf6..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/unary_ops.cc +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/mti/math/unary_ops.h" -#include "core/codegen/passes/op_ir_creator/math/unary_funcs.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// helper macro declares unary_ops helper class without attribute -#define FuncClass(name) \ - class Func##name { \ - public: \ - Func##name(const Node&) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X); \ - } \ - } - -// helper macro declares unary_ops helper class with alpha -#define FuncClassAlpha(name) \ - class Func##name : public FuncWithAlpha { \ - public: \ - Func##name(const Node& node) : FuncWithAlpha(node) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X, alpha_); \ - } \ - } - -// helper macro declares unary_ops helper class with alpha and beta -#define FuncClassAlphaBeta(name) \ - class Func##name : public FuncWithAlphaBeta { \ - public: \ - Func##name(const Node& node) : FuncWithAlphaBeta(node) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X, alpha_, beta_); \ - } \ - } - -// helper macro declares unary_ops helper class with alpha and gamma -#define FuncClassAlphaGamma(name) \ - class Func##name : public FuncWithAlphaGamma { \ - public: \ - Func##name(const Node& node) : FuncWithAlphaGamma(node) {} \ - tvm::Tensor operator()(const tvm::Tensor& X) const { \ - return name(X, alpha_, gamma_); \ - } \ - } - -FuncClass(Abs); -FuncClassAlphaBeta(Affine); -FuncClass(Ceil); -FuncClassAlpha(Elu); -FuncClass(Exp); -FuncClass(Floor); -FuncClassAlphaBeta(HardSigmoid); -FuncClassAlpha(LeakyRelu); -FuncClass(Log); -FuncClass(Neg); -FuncClassAlphaBeta(ParametricSoftplus); -FuncClass(Reciprocal); -FuncClass(Relu); -FuncClassAlphaBeta(ScaledTanh); -FuncClassAlphaGamma(Selu); -FuncClass(Sigmoid); -FuncClass(Softplus); -FuncClass(Softsign); -FuncClass(Sqrt); -FuncClass(Tanh); -FuncClassAlpha(ThresholdedRelu); - -// helper macro defines Evaluate of UNARY_OP OpIRCreators -#define UNARY_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = Func##name(node)(inputs[0]); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -// helper local macros to replace some calls in LIST_UNARY_OPS -LIST_UNARY_OPS() - -#undef UNARY_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/math/variadic_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/math/variadic_ops.cc deleted file mode 100644 index 9559a713c2876..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/math/variadic_ops.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/math/binary_ops.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -tvm::Tensor Sum(const tvm::Tensor& lhs, const tvm::Tensor& rhs, const std::string& name) { - return Add(lhs, rhs, name); -} - -// helper local macro defines Evaluate of BINARY_OP OpIRCreators -#define VARIADIC_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext&, \ - tvm::Array& outputs) { \ - tvm::Tensor Y = Identity(inputs[0], node.Name() + "0"); \ - for (size_t i = 1; i < inputs.size(); ++i) \ - Y = name(Y, inputs[i], node.Name() + std::to_string(i)); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_VARIADIC_OPS() - -#undef VARIADIC_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc deleted file mode 100644 index 19545d1554405..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/nn/conv_ops.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/concat_ops.h" -#include "core/codegen/mti/tensor/split.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -Status GENERIC_OP_IR_CREATOR_CLASS(Conv)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - // Attributes - int64_t group; - std::string auto_pad; - std::vector kernel_shape, strides, dilations, pads; - - info.GetAttrOrDefault("group", &group, 1); - info.GetAttrOrDefault("auto_pad", &auto_pad, "NOTSET"); - - ORT_THROW_IF_ERROR(info.GetAttrs("kernel_shape", kernel_shape)); - ORT_ENFORCE(kernel_shape.size() <= 2, "Only support 1D/2D convolution currently!"); - ORT_THROW_IF_ERROR(info.GetAttrs("strides", strides)); - - dilations = info.GetAttrs("dilations", dilations).IsOK() ? dilations : std::vector(kernel_shape.size(), 1); - ORT_ENFORCE(dilations == std::vector(kernel_shape.size(), 1), "Only support dilation is 1 currently"); - - pads = info.GetAttrs("pads", pads).IsOK() ? pads : std::vector(kernel_shape.size() * 2, 0); - - // auto_pad - if (auto_pad != "NOTSET") { - auto rank = inputs[0]->shape.size() - 2; - ORT_ENFORCE(rank > 0); - for (uint64_t i = 0; i < rank; i++) { - if (auto_pad == "VALID") { - pads[i] = 0; - pads[i + rank] = 0; - } else if (auto_pad == "SAME_UPPER" || auto_pad == "SAME_LOWER") { - // TODO: handle symbolic dim - ORT_ENFORCE(ShapeHasValue(node.InputDefs()[0], 2 + i)); - - int64_t input_dim_value = ShapeValue(node.InputDefs()[0], 2 + i); - int64_t output_dim_value = (input_dim_value + strides[i] - 1) / strides[i]; - int64_t pad_needed = (output_dim_value - 1) * strides[i] + kernel_shape[i] - input_dim_value; - - pads[i] = auto_pad == "SAME_LOWER" ? (pad_needed + 1) / 2 : pad_needed / 2; - pads[i + rank] = pad_needed - pads[i]; - } else { - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unknown auto_pad value"); - } - } - } - - // Inputs - tvm::Tensor X = inputs[0]; - tvm::Tensor W = inputs[1]; - // Outputs - tvm::Tensor Y; - tvm::Array Y_shape = ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen); - - // 1-D convolution - if (kernel_shape.size() == 1) { - Y = Conv1D(X, W, Y_shape, ToTvmArray(strides), ToTvmArray(pads), node.Name() + "_Conv1D"); - } - // 2-D convolution - else if (kernel_shape.size() == 2) { - if (group == 1) { - Y = Conv2D(X, W, Y_shape, ToTvmArray(strides), ToTvmArray(pads), node.Name() + "_Conv2D"); - } else { - int64_t channel_out = ShapeValue(node.InputDefs()[1], 0); - int64_t channel_in = ShapeValue(node.InputDefs()[1], 1); - ORT_ENFORCE(channel_out % group == 0); - - int64_t cout_group = channel_out / group; - Y_shape.Set(1, Y_shape[1] / gsl::narrow_cast(group)); - - tvm::Array split_index0; - tvm::Array split_index1; - - for (int i = 1; i < group; i++) { - split_index0.push_back(i * channel_in); - split_index1.push_back(i * cout_group); - } - - auto input_groups = SplitWithIndices(X, split_index0, 1); - auto weight_groups = SplitWithIndices(W, split_index1, 0); - - // FIXME: This will trigger a llvm buffer overflow when group is too large - // TODO: fix this change it to batched gemm/conv - tvm::Array output_tensors; - for (int i = 0; i < group; i++) { - auto output_tensor = Conv2D(input_groups[i], - weight_groups[i], - Y_shape, - ToTvmArray(strides), - ToTvmArray(pads), - node.Name() + "_Conv2D"); - output_tensors.push_back(output_tensor); - } - Y = Concat(output_tensors, 1); - } - } - - // Add bias if provided - // Support skipped trailing inputs - if (node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { - tvm::Tensor B = inputs[2]; - Y = tvm::compute( - Y_shape, - [&](const tvm::Array& indices) { - return Y(indices) + B(indices[1]); - }); - } - - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc deleted file mode 100644 index 88170bb56dd2d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/lstm.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/nn/lstm.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// In the cell computation, we don't have the "direction" dimension and sequence dimension, -// which have been processed outside of the cell. -// Here we implement an LTSM cell. -// For those args (inputs/outputs) of hidden states we put AFTER regular args (inputs/outputs) -// with a pre-defined order -// In a LSTM, the order is H and then C. -// Ouputs of LSTM is Y_h and then Y_c -Status GENERIC_OP_IR_CREATOR_CLASS(LSTM)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - std::string direction_attr; - ORT_RETURN_IF_ERROR(attrs.GetAttr("direction", &direction_attr)); - int64_t hidden_size; - ORT_RETURN_IF_ERROR(attrs.GetAttr("hidden_size", &hidden_size)); - - // input tensor with shape [seq_length, batch_size, input_size] - const tvm::Tensor& X = inputs[0]; // input tensor with shape [seq_length, batch_size, input_size] - const tvm::Tensor& W = inputs[1]; // weights tensor with shape [4*hidden_size, input_size] - const tvm::Tensor& R = inputs[2]; // recurrence tensor with shape [4*hidden_size, hidden_size] - const tvm::Tensor& B = inputs[3]; // optional bias tensor with shape [8*hidden_size] - bool has_B = node.InputDefs()[3]->Exists(); - - // Unsupported the 4th inputs - // optional tensor specifying sequence lengths in a batch, shape: [batch_size] - // const tvm::Tensor* seq_len = inputs[4] ? &inputs[4]->tensor : nullptr; - - const tvm::Tensor& prev_H = inputs[5]; // optional initial H, shape: [batch_size, hidden_size] - const tvm::Tensor& prev_C = inputs[6]; // optional initial C, shape: [batch_size, hidden_size] - - const tvm::Tensor& P = inputs[7]; // optional peepholes tensor with shape [3*hidde_size] - bool has_P = node.InputDefs()[7]->Exists(); - - tvm::Tensor Y_h; // shape: [batch_size, hidden_size] - tvm::Tensor Y_c; // shape: [batch_size, hidden_size] - LSTMAttributes lstm_attrs(hidden_size); - LSTM_cell(lstm_attrs, X, W, R, B, has_B, prev_H, prev_C, P, has_P, Y_h, Y_c); - - // Since we only generate lstm cell, lstm's states need to be always outputs, - // regardless whethere they are skipped or not. - // The skipped trailing outputs need to be handled by Execution - outputs.push_back(Y_h); - outputs.push_back(Y_c); - - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/nn/pool_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/nn/pool_ops.cc deleted file mode 100644 index 84d3b7c1e0f79..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/nn/pool_ops.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/nn/pool_ops.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/cpu/nn/pool_attributes.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// A local macro to create Pool Ops - -// helper macro defines Evaluate of of POOL_OP OpIRCreators -#define POOL_OP(name) \ - Status GENERIC_OP_IR_CREATOR_CLASS(name)::Evaluate( \ - const tvm::Array& inputs, \ - const Node& node, \ - CodeGenContext& ctx_codegen, \ - tvm::Array& outputs) { \ - ORT_RETURN_IF_NOT(outputs.size() == 1, "multiple outputs are not supported yet!"); \ - ProtoHelperNodeContext ctx(node); \ - OpNodeProtoHelper info(&ctx); \ - int version = ctx_codegen.GetCodeGenHandle()->domain_version_lookup_func(node.Domain()); \ - PoolAttributes pool_attrs(info, #name, version); \ - for (auto n : pool_attrs.dilations) { \ - ORT_RETURN_IF_NOT(n <= 1, "dilations are not supported yet!"); \ - } \ - if (pool_attrs.global_pooling) { \ - if (inputs[0]->shape.size() != 4) { \ - ORT_NOT_IMPLEMENTED(gsl::narrow_cast(inputs[0]->shape.size()) - 2, "d global pooling is not implementated"); \ - } \ - } else { \ - if (pool_attrs.kernel_shape.size() != 2) { \ - ORT_NOT_IMPLEMENTED(pool_attrs.kernel_shape.size(), "d pooling is not implementated"); \ - } \ - } \ - tvm::Array dummy_output_shape; \ - tvm::Tensor Y = name(inputs[0], pool_attrs, dummy_output_shape); \ - outputs.push_back(Y); \ - return Status::OK(); \ - } - -LIST_POOL_OPS() - -#undef POOL_OP - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/cast.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/cast.cc deleted file mode 100644 index bd324fd359edf..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/cast.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/cast_ops.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Cast OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Cast)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t to; - ORT_RETURN_IF_ERROR(attrs.GetAttr("to", &to)); - auto to_type_proto = gsl::narrow_cast(to); - - tvm::Tensor X = inputs[0]; - tvm::Tensor Y; - if (to_type_proto == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { - // special case for bool as ONNX bool is uint8, while in tvm it's uint1 - Y = CastToUInt8Bool(X, node.Name() + "_Cast"); - } else { - Y = Cast(X, ToTvmType(to_type_proto), node.Name() + "_Cast"); - } - - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/concat.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/concat.cc deleted file mode 100644 index 418296889419e..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/concat.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/concat_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Concat OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Concat)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis)); - - tvm::Tensor Y = Concat(inputs, axis, node.Name() + "_Concat"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc deleted file mode 100644 index 3b6a9a76f0723..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/crop.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/crop.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Crop OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Crop)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - if (inputs[0]->shape.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input is expected to have four dimensions corresponding to [N,C,H,W]"); - } - - std::vector border; - std::vector scale; - - ORT_ENFORCE(attrs.GetAttrs("border", border).IsOK()); - // scale is optional and status is false when omit - bool is_ok = attrs.GetAttrs("scale", scale).IsOK(); - ORT_UNUSED_PARAMETER(is_ok); - - if (border.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Attribute border needs to be specified with four border elements"); - } - - tvm::Tensor Y = Crop(inputs[0], ToTvmArray(border), ToTvmArray(scale), node.Name() + "_Crop"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/expand.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/expand.cc deleted file mode 100644 index 0f0e0cf0987b3..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/expand.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/expand.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Expand OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Expand)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Expand(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Expand"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather.cc deleted file mode 100644 index 3a5d801b6839f..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/gather.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Gather OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Gather)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t axis; - ORT_ENFORCE(attrs.GetAttr("axis", &axis).IsOK()); - - tvm::Tensor Y = Gather(inputs[0], axis, inputs[1], node.Name() + "_Gather"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather_elements.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather_elements.cc deleted file mode 100644 index 0b71506cceed3..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/gather_elements.cc +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/tensor/gather_elements.h" -#include "core/framework/op_kernel_info.h" -#include "core/providers/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of GatherElements OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(GatherElements)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t axis; - ORT_ENFORCE(attrs.GetAttr("axis", &axis).IsOK()); - axis = HandleNegativeAxis(axis, gsl::narrow_cast(inputs[0]->shape.size())); - - tvm::Tensor Y = GatherElements(inputs[0], axis, inputs[1], node.Name() + "_GatherElements"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc deleted file mode 100644 index e9e20e8a43998..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/pad_ops.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Pad OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Pad)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - std::string mode; - std::vector pads; - float value; - - ORT_THROW_IF_ERROR(attrs.GetAttr("mode", &mode)); - ORT_THROW_IF_ERROR(attrs.GetAttrs("pads", pads)); - ORT_THROW_IF_ERROR(attrs.GetAttr("value", &value)); - - if (mode != "constant" && mode != "edge" && mode != "reflect") - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad: Unsupported padding mode!"); - - if (pads.size() != 2 * inputs[0]->shape.size()) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad: pads rank does not match inputs rank!"); - - std::vector pad_before, pad_after; - size_t offset = pads.size() / 2; - for (size_t i = 0; i < offset; i++) { - pad_before.push_back(pads[i]); - pad_after.push_back(pads[i + offset]); - } - - tvm::Tensor Y = Pad(inputs[0], ToTvmArray(pad_before), ToTvmArray(pad_after), value, mode, node.Name() + "_Pad"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/reshape_ops.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/reshape_ops.cc deleted file mode 100644 index a83f598bc8ad1..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/reshape_ops.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/reshape_ops.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Dropout OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Dropout)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Identity(inputs[0]); - outputs.push_back(Y); - - // optional mask - // Support skipped trailing outputs - if (node.OutputDefs().size() > 1 && node.OutputDefs()[1]->Exists()) { - // A fake mask with all ones - auto l = [&](const tvm::Array& /*indices*/) { - return tvm::make_const(tvm::UInt(8), 1); - }; - tvm::Tensor mask = tvm::compute(inputs[0]->shape, l, "mask"); - outputs.push_back(mask); - } - - return Status::OK(); -} - -// Evaluate of Flatten OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Flatten)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - int64_t axis; - ORT_RETURN_IF_ERROR(attrs.GetAttr("axis", &axis)); - - tvm::Tensor Y = Flatten(inputs[0], axis, node.Name() + "_Flatten"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Identity OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Identity)::Evaluate( - const tvm::Array& inputs, - const Node&, - CodeGenContext&, - tvm::Array& outputs) { - tvm::Tensor Y = Identity(inputs[0]); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Reshape OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Reshape)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Reshape(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Reshape"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Squeeze OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Squeeze)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Reshape(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Squeeze"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Unsqueeze OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Unsqueeze)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Reshape(inputs[0], ShapeToTvmArray(node.OutputDefs()[0], ctx_codegen), node.Name() + "_Unsqueeze"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/shape_op.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/shape_op.cc deleted file mode 100644 index 84761ecac1397..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/shape_op.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/shape_op.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Expand OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Shape)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - tvm::Tensor Y = Shape(inputs[0], node.Name() + "_Expand"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/slice.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/slice.cc deleted file mode 100644 index 6a016580c41e4..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/slice.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/slice.h" -#include "core/framework/op_kernel_info.h" -#include "core/framework/tensorprotoutils.h" - -#include - -namespace onnxruntime { -namespace tvm_codegen { - -Status SliceCommon(const tvm::Array& inputs, - const Node& node, - tvm::Array& outputs, - const std::vector& starts, - const std::vector& ends, - const std::vector& axes1, - const std::vector& steps1) { - ORT_RETURN_IF_NOT(nullptr != node.InputDefs()[0], "nullptr == node.InputDefs()[0]"); - - std::vector axes; - if (axes1.size() == 0) { - for (size_t i = 0; i < starts.size(); ++i) { - axes.push_back(gsl::narrow_cast(i)); - } - } else { - axes = axes1; - } - - std::vector steps; - if (steps1.size() == 0) { - steps.resize(starts.size(), 1); - } else { - steps = steps1; - } - - tvm::Tensor Y = Slice(inputs[0], starts, ends, axes, steps, node.Name() + "_Slice"); - outputs.push_back(Y); - return Status::OK(); -} - -// Evaluate of Slice OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Slice)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - // NOTE that in opset 10, Slice has changed starts/ends/axes from attribute to input - // which may lead to dynamic output shape. - int version = ctx_codegen.GetCodeGenHandle()->domain_version_lookup_func(node.Domain()); - ORT_RETURN_IF_NOT(version <= 9, "Dynamic Slice is not supported yet"); - - std::vector starts, ends, steps; - ORT_RETURN_IF_ERROR(info.GetAttrs("starts", starts)); - ORT_RETURN_IF_ERROR(info.GetAttrs("ends", ends)); - ORT_RETURN_IF_NOT(starts.size() == ends.size(), "starts.size() != ends.size()"); - - auto axes = info.GetAttrsOrDefault("axes"); - - return SliceCommon(inputs, node, outputs, starts, ends, axes, steps); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/split.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/split.cc deleted file mode 100644 index ec52d98b5bf96..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/split.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/split.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Split OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Split)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper info(&ctx); - - int64_t axis; - ORT_RETURN_IF_ERROR(info.GetAttr("axis", &axis)); - axis = HandleNegativeAxis(axis, gsl::narrow_cast(inputs[0]->shape.size())); - std::vector split_sizes; - - int64_t split_size_sum = 0; - if (info.GetAttrs("split", split_sizes).IsOK()) { - // optional - split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL); - ORT_RETURN_IF_NOT(std::all_of(split_sizes.cbegin(), split_sizes.cend(), [](int64_t value) { return value > 0; }), - "Invalid value in 'split' attribute. All values must be > 0"); - - // check split sizes - for (size_t i = 0; i < node.OutputDefs().size(); ++i) { - ORT_RETURN_IF_NOT(split_sizes[i] == ShapeValue(node.OutputDefs()[i], gsl::narrow(axis)), - "split_sizes[i] != ShapeValue(node.OutputDefs()[i], axis)"); - } - - } else { - for (size_t i = 0; i < node.OutputDefs().size(); ++i) { - split_sizes.push_back(ShapeValue(node.OutputDefs()[i], gsl::narrow(axis))); - split_size_sum += split_sizes[i]; - } - } - - // check total size - if (ShapeHasValue(node.InputDefs()[0], axis)) { - int64_t input_axis_dim = ShapeValue(node.InputDefs()[0], axis); - if (split_size_sum != input_axis_dim) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Cannot split using values in 'split' attribute. Axis=", axis, - " Dim being splitted=", input_axis_dim, - " Sum of sizes in 'split' (must equal size of selected axis) was ", split_size_sum); - } - } - - tvm::Array output_tensors = Split(inputs[0], ToTvmArray(split_sizes), axis, node.Name() + "_Split"); - for (size_t i = 0; i < node.OutputDefs().size(); ++i) { - outputs.push_back(output_tensors[i]); - } - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc deleted file mode 100644 index 43999ebd1f465..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/transpose.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/transpose.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Transpose OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Transpose)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - size_t input_0_shape_rank = inputs[0]->shape.size(); - std::vector permute; - bool is_ok = attrs.GetAttrs("perm", permute).IsOK(); - if (permute.size() != 0 && permute.size() != input_0_shape_rank) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Transpose: Incorrect permute size"); - - std::vector default_permute; - const std::vector* perm; - // either we don't have perm attribute or the perm attribute is empty - bool use_default_perm = !is_ok || permute.size() == 0; - if (use_default_perm) { - default_permute.resize(input_0_shape_rank); - for (size_t i = 0; i < input_0_shape_rank; ++i) { - default_permute[i] = gsl::narrow(input_0_shape_rank - 1 - i); - } - perm = &default_permute; - } else { - perm = &permute; - } - - tvm::Tensor Y = Transpose(inputs[0], ToTvmArrayInt(*perm), node.Name() + "_Transpose"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/where.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tensor/where.cc deleted file mode 100644 index 9d6df7c1c430d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tensor/where.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/all_ops.h" - -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/mti/tensor/where.h" -#include "core/framework/op_kernel_info.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Evaluate of Transpose OpIRCreator -Status GENERIC_OP_IR_CREATOR_CLASS(Where)::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext&, - tvm::Array& outputs) { - ProtoHelperNodeContext ctx(node); - OpNodeProtoHelper attrs(&ctx); - - tvm::Tensor Y = Where(inputs[0], inputs[1], inputs[2], node.Name() + "_Where"); - outputs.push_back(Y); - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.cc deleted file mode 100644 index 7889e2add755e..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/passes/op_ir_creator/all_ops.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -TVMIRBuilder::TVMIRBuilder(const std::string& name) - : name_(name) {} - -const std::string& TVMIRBuilder::Name() const { - return name_; -} - -void TVMIRBuilder::InsertDispatcher(std::unique_ptr&& ptr) { - dispatchers_.push_back(std::move(ptr)); -} - -void TVMIRBuilder::ClearAllDispatchers() { - dispatchers_.clear(); -} - -void TVMIRBuilder::DumpAllOpIRCreators() const { - int count = 0; - for (auto& d : dispatchers_) { - std::cout << "************ TVM OpIRDispatcher " - << count << " : " - << d->Name() - << " ************" << std::endl; - - d->ForEach([](const std::string& key, OpIRCreator* builder) { - std::cout << "Key " << key - << ", Creator " << builder->Name() << std::endl; - }); - - ++count; - } -} - -// Evaluate finds ONE proper OpIRCreator and build the corresponding OpIR -// If a TVMIRBuilder has more than one OpIRCreator for an ORT Op, -// the first one will be used. -// Please adjust registration order and dispatcher in TVMIRBuilder -// to make sure the proper OpIRCreator is called. -Status TVMIRBuilder::Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx_codegen, - tvm::Array& outputs) { - OpIRCreator* candidate = nullptr; - for (auto& d : dispatchers_) { - candidate = d->Find(node); - if (nullptr != candidate) - break; - } - - if (nullptr == candidate) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not implemented: ", node.OpType()); - } - - ORT_RETURN_IF_ERROR(candidate->Evaluate(inputs, node, ctx_codegen, outputs)); - - return Status::OK(); -} - -// BEGIN: Generic IR creator classes -#define ADD_OP_ITEM(name) \ - op_ir_registry->Register(std::make_unique()); - -#define BINARY_OP(name) ADD_OP_ITEM(name) -#define BINARY_CMP_OP(name) ADD_OP_ITEM(name) -#define POOL_OP(name) ADD_OP_ITEM(name) -#define REDUCE_OP(name) ADD_OP_ITEM(name) -#define REDUCE_INDEXED_OP(name) ADD_OP_ITEM(name) -#define UNARY_OP(name) ADD_OP_ITEM(name) -#define VARIADIC_OP(name) ADD_OP_ITEM(name) - -void RegisterAllGenericOpIRCreators(OpIRRegistry* op_ir_registry) { - LIST_ALL_GENERIC_OPS(); -} - -#undef ADD_OP_ITEM -#undef BINARY_OP -#undef BINARY_CMP_OP -#undef POOL_OP -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP -#undef UNARY_OP -#undef VARIADIC_OP - -// BEGIN: Plugin Generic IR creator classes -#define ADD_OP_ITEM(name) \ - dispatcher->Register(#name, registry->Get(GENERIC_OP_IR_CREATOR_STRING(name))); - -#define BINARY_OP(name) ADD_OP_ITEM(name) -#define BINARY_CMP_OP(name) ADD_OP_ITEM(name) -#define POOL_OP(name) ADD_OP_ITEM(name) -#define REDUCE_OP(name) ADD_OP_ITEM(name) -#define REDUCE_INDEXED_OP(name) ADD_OP_ITEM(name) -#define UNARY_OP(name) ADD_OP_ITEM(name) -#define VARIADIC_OP(name) ADD_OP_ITEM(name) - -void RegisterGenericOrtOpTypeDispatcher(const std::shared_ptr& builder, - const OpIRRegistry* registry) { - auto dispatcher = std::make_unique("GenericOrtOpTypeOpIRCreators"); - LIST_ALL_GENERIC_OPS() - builder->InsertDispatcher(std::move(dispatcher)); -} - -#undef ADD_OP_ITEM -#undef BINARY_OP -#undef BINARY_CMP_OP -#undef POOL_OP -#undef REDUCE_OP -#undef REDUCE_INDEXED_OP -#undef UNARY_OP -// END: Generic IR creators classes - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.h b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.h deleted file mode 100644 index c80056e619d6d..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_ir_builder.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// TVMIRBuilder contains all applicable TVM OpIRCreators -// OpIRCreators are stored in multiple dispatchers -// that check different conditions of an ORT Node. - -// If an ORT Node satisfies more than one OpIRCreators, -// the first dispatched pass will be applied. - -class TVMIRBuilder { - public: - TVMIRBuilder(const std::string& name); - ~TVMIRBuilder() = default; - - // A debug dumps all existing in this TVMIRBuilders - void DumpAllOpIRCreators() const; - - // Evaluates an OpIRCreator that first satisfies condtions of all dispatchers - Status Evaluate( - const tvm::Array& inputs, - const Node& node, - CodeGenContext& ctx, - tvm::Array& outputs); - - // Inserts a dispatcher and move its ownership to this TVMIRBuilder - void InsertDispatcher(std::unique_ptr&& ptr); - - // Clears all dispatchers in this TVMIRBuilder - void ClearAllDispatchers(); - - // Dumps the name of this TVMIRBuilder - const std::string& Name() const; - - private: - std::vector> dispatchers_; - std::string name_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVMIRBuilder); -}; - -// Utility function to register all builtin generic OpIRCreators into an OpIRRegistry. -// It creates instances of all generic OpIRCreators -// and registers them to op_ir_registry -void RegisterAllGenericOpIRCreators(OpIRRegistry* op_ir_registry); - -// Utility function to bind all builtin generic OpIRCreators to a TVMIRBuilder. -// It creates an instance of a Dispatcher that contains all generic OpIRCreators created above -// and uses OrtOpType to dispatch OpIRCreators. -// Then, it registers the created Dispatcher to a TVMIRBuilder, builder. -void RegisterGenericOrtOpTypeDispatcher(const std::shared_ptr& builder, - const OpIRRegistry* registry); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.cc b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.cc deleted file mode 100644 index 992272753f5a4..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/op_ir_creator/tvm_op_creator.h" - -#include "core/codegen/common/common.h" -#include "core/codegen/common/dispatcher.h" -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace codegen { -// Explicit instantiation for OpIRCreator -template class CreatorBase&, - const Node&, - tvm_codegen::CodeGenContext&, - tvm::Array&, - Status>; - -// Explicit instantiation for OpIRCreators' dispatcher -template class DispatcherBase; - -} // namespace codegen - -namespace tvm_codegen { - -// One dispatcher is based on ORT OpType -OpIRCreator* OP_IR_DISPATCHER_CLASS(OpType)::Find(const Node& node) { - return DispatcherBase::Get(node.OpType()); -} - -// Another dispatcher is based ORT NodeArg name (GetKey) -OpIRCreator* OP_IR_DISPATCHER_CLASS(NodeName)::Find(const Node& node) { - return DispatcherBase::Get(GetKey(&node)); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h b/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h deleted file mode 100644 index e29c4a9f20767..0000000000000 --- a/onnxruntime/core/codegen/passes/op_ir_creator/tvm_op_creator.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/creator.h" -#include "core/codegen/common/dispatcher.h" -#include "core/codegen/common/registry.h" -#include "core/graph/graph.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -class CodeGenContext; - -// OpIRCreator lowers an Ort Node to its corresponding TVM IRs -using OpIRCreator = codegen::CreatorBase< - const tvm::Array&, - const Node&, - CodeGenContext&, - tvm::Array&, - Status>; - -// OpIRDispatcher is the base dispatcher for TVM IR Builder -// It checks whether an Ort Node satisfying a criteria (in Find) -// and dispatches a corresponding OpIRCreator. -class OpIRDispatcher : public codegen::DispatcherBase { - public: - OpIRDispatcher(const std::string& name) - : DispatcherBase(name) {} - - virtual ~OpIRDispatcher() = default; - - virtual OpIRCreator* Find(const Node&) = 0; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpIRDispatcher); -}; - -// Macro returns an OpIRCreators' dispatcher's name -#define OP_IR_DISPATCHER_CLASS(OP) \ - TVM##OP##IRCreator - -// Macro declares an OpIRCreators' dispatcher -#define DECLARE_OP_IR_DISPATCHER_CLASS(OP) \ - class OP_IR_DISPATCHER_CLASS(OP) : public OpIRDispatcher { \ - public: \ - TVM##OP##IRCreator(const std::string& name) \ - : OpIRDispatcher(name) {} \ - ~TVM##OP##IRCreator() = default; \ - OpIRCreator* Find(const Node&) override; \ - \ - private: \ - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OP_IR_DISPATCHER_CLASS(OP)); \ - }; - -// Declare two common dispatchers for TVM Op IR builders -// One dispatcher is based on Ort OpType -DECLARE_OP_IR_DISPATCHER_CLASS(OpType) -// Another dispatcher is based Ort NodeArg name -DECLARE_OP_IR_DISPATCHER_CLASS(NodeName) - -// OpIRCreator Registry is a registry holds all OpIRCreators -using OpIRRegistry = codegen::RegistryBase; - -// Macro declares an OpIRCreator -#define DECLARE_OP_IR_CREATOR_CLASS(OP, PREFIX) \ - DECLARE_CREATOR_CLASS(OP, PREFIX##IRCreator, \ - const tvm::Array&, \ - const Node&, \ - tvm_codegen::CodeGenContext&, \ - tvm::Array&, \ - Status) - -// Macro returns an OpIRCreator's name with prefix -#define OP_IR_CREATOR_CLASS_EX(OP, PREFIX, ARCH) \ - CREATOR_CLASS(OP, PREFIX##ARCH##IRCreator) - -// Macro declares an OpIRCreator with prefix and arch -#define DECLARE_OP_IR_CREATOR_CLASS_EX(OP, PREFIX, ARCH) \ - DECLARE_OP_IR_CREATOR_CLASS(OP, PREFIX##ARCH) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/all_schedules.h b/onnxruntime/core/codegen/passes/scheduler/all_schedules.h deleted file mode 100644 index fe4be90f9fc84..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/all_schedules.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/scheduler/tvm_scheduler.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// AlwaysRoot is for debug purpose -DECLARE_TVM_SCHEDULER_CLASS(AlwaysRoot, GenericTVMRule) -// Create schedule for TVM Rule -DECLARE_TVM_SCHEDULER_CLASS(Extern, GenericTVMRule) -DECLARE_TVM_SCHEDULER_CLASS(Reduce, GenericTVMRule) - -// Crete scheduler for ORT OpType, Softmax -DECLARE_TVM_SCHEDULER_CLASS(Softmax, GenericOrtOpType) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/ort_type_schedule.cc b/onnxruntime/core/codegen/passes/scheduler/ort_type_schedule.cc deleted file mode 100644 index 59f492d164b14..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/ort_type_schedule.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/all_schedules.h" - -#include "core/codegen/passes/scheduler/schedule_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -bool TVM_SCHEDULER_CLASS(Softmax, GenericOrtOpType)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - // compute root the exp since it is reused more than once - auto& tensor_exp = tensor->op->InputTensors()[0]; - return InsertRootSchedule(tensor_exp, ctx_sched); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.cc b/onnxruntime/core/codegen/passes/scheduler/schedule_utils.cc deleted file mode 100644 index 76c2ad509c401..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.cc +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/common/utils.h" -#include "core/codegen/passes/scheduler/schedule_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// Check the schedule of tensor -// If it has no compute_root, Insert compute_root to tensor, and record it to ctx.scheduled_tensors -bool InsertRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second == ScheduleType::ScheduleClosure || - it->second == ScheduleType::ScheduleRoot) { - return false; - } - it->second = ScheduleType::ScheduleRoot; - } else { - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleRoot)); - } - ctx.schedule[tensor->op].compute_root(); - return true; -} - -// Check the schedule of tensor -// If it is not labeled as closure, lable it. -bool InsertClosure(const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second == ScheduleType::ScheduleClosure) - return false; - it->second = ScheduleType::ScheduleClosure; - } else { - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleClosure)); - } - return true; -} - -// Combination of InsertRootSchedule and InsertClosure -bool InsertRootScheduleAndClosure( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second == ScheduleType::ScheduleClosure) { - return false; - } - it->second = ScheduleType::ScheduleClosure; - } else { - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleClosure)); - } - ctx.schedule[tensor->op].compute_root(); - return true; -} - -// Check precondition for vectorize schedule -bool ShouldTryVectorization( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if (it->second > ScheduleType::ScheduleInline) { - return false; - } - } - return true; -} - -// Check the schedule of tensor -// If it is not scheduled, try to vectorize it. -// Note TryVectorization has to use with compute_root. -// Therefore, there is a safety check of tensor's schedule -bool TryVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx) { - if (!ShouldTryVectorization(tensor, ctx)) - return false; - - auto shape = tensor->shape; - auto rank = shape.size(); - if (rank < 1) { - return false; - } - const int64_t* tail_dim = as_const_int(shape[rank - 1]); - - if (nullptr != tail_dim) { - auto extern_op = tensor->op.as(); - if (nullptr != extern_op) { - return false; - } - - auto compute_op = tensor->op.as(); - - if (nullptr != compute_op) { - auto axis = compute_op->axis; - tvm::IterVar x = axis[rank - 1]; - if ((*tail_dim) > natural_vector_size) { - if ((*tail_dim) % natural_vector_size != 0) { - natural_vector_size = GCD(natural_vector_size, (*tail_dim)); - } - - if (natural_vector_size > 1) { - tvm::IterVar xi, xo; - ctx.schedule[tensor->op].split(x, static_cast(natural_vector_size), &xo, &xi); - ctx.schedule[tensor->op].vectorize(xi); - return true; - } - } else if (*tail_dim > 0) { - // don't vectorize if dim is 0 - ctx.schedule[tensor->op].vectorize(x); - return true; - } - } - } - return false; -} - -// Check the schedule of tensor -// If it is not scheduled, try to add compute_inline on it. -// Note TryInlineSchedule cannot be used with compute_root. -// Therefore, there is a safety check of tensor's schedule. -bool TryInlineSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - auto it = ctx.scheduled_tensors.find(tensor->op.get()); - if (it != ctx.scheduled_tensors.end()) { - if ((int)it->second < (int)ScheduleType::ScheduleInline) { - ctx.schedule[tensor->op].compute_inline(); - it->second = ScheduleType::ScheduleInline; - return true; - } else { - return false; - } - } - ctx.schedule[tensor->op].compute_inline(); - ctx.scheduled_tensors.insert(std::make_pair(tensor->op.get(), ScheduleType::ScheduleInline)); - return true; -} - -// Check the schedule of tensor's inputs, and call InsertRootSchedule for each of them -bool InputRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx) { - bool status = false; - for (auto& t : tensor->op->InputTensors()) { - if (t->op->InputTensors().size() > 0) { - bool status_root = InsertRootSchedule(t, ctx); - status = status || status_root; - } - } - return status; -} - -// Check the schedule of tensor's inputs, -// and call InsertRootSchedule and TryVectorization for each of them -bool InputRootScheduleWithVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx) { - bool status = false; - for (auto& t : tensor->op->InputTensors()) { - if (t->op->InputTensors().size() > 0) { - bool status_vec = TryVectorization(t, natural_vector_size, ctx); - bool status_root = InsertRootSchedule(t, ctx); - status = status || status_root || status_vec; - } - } - return status; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.h b/onnxruntime/core/codegen/passes/scheduler/schedule_utils.h deleted file mode 100644 index 4a0781f94d385..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/schedule_utils.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// Check the schedule of tensor -// If it has no compute_root, Insert compute_root to tensor, -// and record it to ctx.scheduled_tensors -bool InsertRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor -// If it is not labeled as closure, lable it. -bool InsertClosure( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Combination of InsertRootSchedule and InsertClosure -bool InsertRootScheduleAndClosure( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check precondition for vectorize schedule -bool ShouldTryVectorization( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor -// If it is not scheduled, try to vectorize it. -// Note TryVectorization has to use with compute_root. -// Therefore, there is a safety check of tensor's schedule -bool TryVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx); - -// Check the schedule of tensor -// If it is not scheduled, try to add compute_inline on it. -// Note TryInlineSchedule cannot be used with compute_root. -// Therefore, there is a safety check of tensor's schedule. -bool TryInlineSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor's inputs, -// and call InsertRootSchedule for each of them -bool InputRootSchedule( - const tvm::Tensor& tensor, - ScheduleContext& ctx); - -// Check the schedule of tensor's inputs, -// and call InsertRootSchedule and TryVectorization for each of them -bool InputRootScheduleWithVectorization( - const tvm::Tensor& tensor, - int64_t natural_vector_size, - ScheduleContext& ctx); - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_rule_schedule.cc b/onnxruntime/core/codegen/passes/scheduler/tvm_rule_schedule.cc deleted file mode 100644 index 33162deddc983..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_rule_schedule.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/all_schedules.h" - -#include "core/codegen/passes/scheduler/schedule_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// This is for debug -bool TVM_SCHEDULER_CLASS(AlwaysRoot, GenericTVMRule)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - return InsertRootSchedule(tensor, ctx_sched); -} - -// For External tvm::Tensor -bool TVM_SCHEDULER_CLASS(Extern, GenericTVMRule)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - bool status = InsertRootScheduleAndClosure(tensor, ctx_sched); - bool status_input = InputRootSchedule(tensor, ctx_sched); - return status || status_input; -} - -// For Reduce Compute tvm::Tensor -bool TVM_SCHEDULER_CLASS(Reduce, GenericTVMRule)::Evaluate( - const tvm::Tensor& tensor, - const Node*, - CodeGenContext&, - ScheduleContext& ctx_sched) { - return InsertRootScheduleAndClosure(tensor, ctx_sched); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.cc b/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.cc deleted file mode 100644 index 2c8250198fa5f..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.cc +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/tvm_schedule_builder.h" - -#include "core/codegen/common/op_macro.h" -#include "core/codegen/common/settings.h" -#include "core/common/common.h" -#include "core/common/logging/logging.h" - -namespace onnxruntime { -namespace tvm_codegen { - -TVMScheduleBuilder::TVMScheduleBuilder(const std::string& name) - : name_(name) { -} - -const std::string& TVMScheduleBuilder::Name() const { - return name_; -} - -void TVMScheduleBuilder::InsertDispatcher(std::unique_ptr&& ptr) { - dispatchers_.push_back(std::move(ptr)); -} - -void TVMScheduleBuilder::ClearDispatcher() { - dispatchers_.clear(); -} - -void TVMScheduleBuilder::DumpAllSchedulers() const { - std::ostringstream stream; - int count = 0; - stream << "[CODEGEN_DUMP_SCHEDULE]" << std::endl; - for (auto& d : dispatchers_) { - stream << "************ TVM Scheduler Dispatcher " - << count << " : " - << d->Name() - << " ************" << std::endl; - - d->ForEach([&stream](const std::string& key, Scheduler* op) { - stream << "Key " << key - << ", Creator " << op->Name() << std::endl; - }); - - ++count; - } - - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << stream.str(); -} - -Status TVMScheduleBuilder::Evaluate( - const tvm::Tensor& tensor, - const Node* node, - CodeGenContext& ctx_codegen, - ScheduleContext& sched) { - Scheduler* candidate = nullptr; - - for (auto& d : dispatchers_) { - candidate = d->Find(tensor, node, ctx_codegen); - if (nullptr != candidate) - break; - } - - bool enable_dump_schedule = codegen::CodeGenSettings::Instance().HasOption(codegen::CodeGenSettings::kCodeGenDumpSchedule); - - if (nullptr == candidate) { - if (nullptr != node) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not implemented: ", node->OpType()); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not implemented an internal tvm::Tensor: ", tensor->op->name); - } - - bool status = candidate->Evaluate(tensor, node, ctx_codegen, sched); - - if (enable_dump_schedule) { - std::ostringstream stream; - if (nullptr != node) { - stream << std::endl; - stream << "[CODEGEN_DUMP_SCHEDULE] " - << "Schedule Node: " << node->Name() << std::endl; - } else { - stream << std::endl; - } - - if (status) { - stream << "[CODEGEN_DUMP_SCHEDULE] " - << "Schedule tvm::Tesnor " - << tensor->op->name - << " with " - << candidate->Name() << std::endl; - } else { - stream << "[CODEGEN_DUMP_SCHEDULE] " - << "Schedule tvm::Tesnor " - << tensor->op->name - << " is suppressed " << std::endl; - } - - LOGS_DEFAULT(CODEGEN_SETTINGS_LOG_LEVEL) << stream.str(); - } - - return Status::OK(); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.h b/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.h deleted file mode 100644 index 9f0a1b3ef45c2..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_schedule_builder.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/passes/scheduler/tvm_scheduler.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -// TVMScheduleBuilder contains all applicable TVM scheduler passes. -// Scheduler passes are stored in multiple dispatchers -// that check different conditions of a tvm::Tensor. - -// If a tvm::Tensor satisfies more than one TVM scheduler passes, -// the first dispatched pass will be applied. - -class TVMScheduleBuilder { - public: - // TODO: add more parameter in consructor to support different target - TVMScheduleBuilder(const std::string& name); - ~TVMScheduleBuilder() = default; - - void DumpAllSchedulers() const; - - Status Evaluate( - const tvm::Tensor& tensor, - const Node* node, - CodeGenContext& ctx, - ScheduleContext& sched); - - void InsertDispatcher(std::unique_ptr&& ptr); - void ClearDispatcher(); - - const std::string& Name() const; - - private: - std::vector> dispatchers_; - std::string name_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVMScheduleBuilder); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.cc b/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.cc deleted file mode 100644 index 071200a234e33..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/scheduler/tvm_scheduler.h" - -#include "core/codegen/common/common.h" -#include "core/codegen/common/dispatcher.h" -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace codegen { -// explicit instantiation -template class CreatorBase; - -template class DispatcherBase; - -} // namespace codegen - -namespace tvm_codegen { - -static const std::string TMVOpRuleKey_Extern("TVMOpRule_Extern"); -static const std::string TMVOpRuleKey_ComputeReduce("TVMOpRule_ComputeReduce"); -static const std::string TMVOpRuleKey_ComputeRegular("TVMOpRule_ComputeRegular"); -static const std::string TMVOpRuleKey_AlwaysRoot("TMVOpRuleKey_AlwaysRoot"); -static const std::string TMVOpRuleKey_NoRule("TVMOpRule_NoRule"); - -const std::string& GetTVMOpRule(TVMOpRuleType rule) { - if (rule == TVMOpRuleType::Extern) { - return TMVOpRuleKey_Extern; - } else if (rule == TVMOpRuleType::ComputeReduce) { - return TMVOpRuleKey_ComputeReduce; - } else if (rule == TVMOpRuleType::AlwaysRoot) { - return TMVOpRuleKey_AlwaysRoot; - } - return TMVOpRuleKey_NoRule; -} - -const std::string& GetTVMOpRule(const tvm::Tensor& tensor) { - auto extern_op = tensor->op.as(); - - if (nullptr != extern_op) { - return TMVOpRuleKey_Extern; - } - - auto compute_op = tensor->op.as(); - if (nullptr != compute_op) { - if (compute_op->reduce_axis.size() > 0) { - return TMVOpRuleKey_ComputeReduce; - } - } - - return TMVOpRuleKey_NoRule; -} - -Scheduler* SCHEDULE_DISPATCHER_CLASS(OrtOpType):: - Find(const tvm::Tensor&, const Node* node, tvm_codegen::CodeGenContext&) { - if (nullptr == node) - return nullptr; - return DispatcherBase::Get(node->OpType()); -} - -Scheduler* SCHEDULE_DISPATCHER_CLASS(TVMOpRule):: - Find(const tvm::Tensor& tensor, const Node*, tvm_codegen::CodeGenContext&) { - return DispatcherBase::Get(GetTVMOpRule(tensor)); -} - -Scheduler* SCHEDULE_DISPATCHER_CLASS(OrtOpName):: - Find(const tvm::Tensor&, const Node* node, tvm_codegen::CodeGenContext&) { - if (nullptr == node) - return nullptr; - return DispatcherBase::Get(GetKey(node)); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h b/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h deleted file mode 100644 index d022497c77f7e..0000000000000 --- a/onnxruntime/core/codegen/passes/scheduler/tvm_scheduler.h +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/codegen/common/creator.h" -#include "core/codegen/common/registry.h" -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/graph/graph.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// These are current generic TVMOpRule we used. -enum class TVMOpRuleType : int { - Extern = 0, - ComputeReduce = 1, - ComputeRegular = 2, - AlwaysRoot = 3, // for debug - NoRule, -}; - -const std::string& GetTVMOpRule(const tvm::Tensor& tensor); -const std::string& GetTVMOpRule(TVMOpRuleType rule); - -// These are current generic ScheduleType in tvm_codegen -enum class ScheduleType : int { - ScheduleNone = 0, - ScheduleInline = 1, - ScheduleAt = 2, - ScheduleRoot = 3, - ScheduleClosure = 4, -}; - -// Data struct to bundle tvm::Schedule and scheduled tensor -struct ScheduleContext { - ScheduleContext(const tvm::Array& ops) { - schedule = tvm::create_schedule(ops); - } - tvm::Schedule schedule; - std::map scheduled_tensors; -}; - -// Scheduler inserts a tvm::Schedule content to a tvm::Tensor -using Scheduler = codegen::CreatorBase< - const tvm::Tensor&, - const Node*, - tvm_codegen::CodeGenContext&, - ScheduleContext&, - bool>; - -// TVMScheduleDispatcher is the base dispatcher for TVM Schedule Builder -// It checks whether a pair of {tvm::Tensor, Ort Node} satisfying a criteria (in Find) -// and dispatches a corresponding Scheduler. -class TVMScheduleDispatcher : public codegen::DispatcherBase { - public: - TVMScheduleDispatcher(const std::string& name) - : DispatcherBase(name) {} - - virtual ~TVMScheduleDispatcher() = default; - - virtual Scheduler* Find(const tvm::Tensor&, - const Node*, - tvm_codegen::CodeGenContext&) = 0; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVMScheduleDispatcher); -}; - -// Macro returns an Schedulers' dispatcher's name -#define SCHEDULE_DISPATCHER_CLASS(TYPE) \ - TVM##TYPE##Schedulers - -// Macro declares an Schedulers' dispatcher -#define DECLARE_SCHEDULE_DISPATCHER_CLASS(TYPE) \ - class SCHEDULE_DISPATCHER_CLASS(TYPE) : public tvm_codegen::TVMScheduleDispatcher { \ - public: \ - TVM##TYPE##Schedulers(const std::string& name) \ - : TVMScheduleDispatcher(name) {} \ - ~TVM##TYPE##Schedulers() = default; \ - tvm_codegen::Scheduler* Find(const tvm::Tensor&, \ - const Node*, \ - tvm_codegen::CodeGenContext&) override; \ - \ - private: \ - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TVM##TYPE##Schedulers); \ - }; - -// Common dispatchers are listed here -// For a special pattern, it can be created later. -// One dispatcher is based on Ort OpType -DECLARE_SCHEDULE_DISPATCHER_CLASS(OrtOpType) -// One dispatcher is based on TVMOpRule -DECLARE_SCHEDULE_DISPATCHER_CLASS(TVMOpRule) -// One dispatcher is based Ort NodeArg name -DECLARE_SCHEDULE_DISPATCHER_CLASS(OrtOpName) - -// Scheduler Registry is a registry holds all Schedulers -using TVMScheduleRegistry = codegen::RegistryBase; - -// Macro declares TVM scheduler class -#define DECLARE_TVM_SCHEDULER_CLASS(OP, PRETFIX) \ - DECLARE_CREATOR_CLASS(OP, PRETFIX##Scheduler, \ - const tvm::Tensor&, \ - const Node*, \ - tvm_codegen::CodeGenContext&, \ - tvm_codegen::ScheduleContext&, \ - bool) - -// Macro returns TVM scheduler's name with prefix -#define TVM_SCHEDULER_CLASS(OP, PREFIX) \ - CREATOR_CLASS(OP, PREFIX##Scheduler) - -// Macro returns TVM scheduler's name as string -#define TVM_SCHEDULER_STRING(OP, PREFIX) \ - STRINGIZE(TVM_SCHEDULER_CLASS(OP, PREFIX)) - -// Macro returns TVM scheduler's name with prefix and arch -#define TVM_SCHEDULER_CLASS_EX(OP, PREFIX, ARCH) \ - CREATOR_CLASS(OP, PREFIX##ARCH##Scheduler) - -// Macro declares TVM scheduler class with prefix and arch -#define DECLARE_TVM_SCHEDULER_CLASS_EX(OP, PREFIX, ARCH) \ - DECLARE_TVM_SCHEDULER_CLASS(OP, PREFIX##ARCH) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/codegen_context.cc b/onnxruntime/core/codegen/passes/utils/codegen_context.cc deleted file mode 100644 index 2f1a59b4a92eb..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/codegen_context.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/utils/codegen_context.h" - -#include "core/codegen/common/common.h" - -namespace onnxruntime { -namespace tvm_codegen { - -CodeGenContext::CodeGenContext( - const codegen::CodeGenHandle* handle) - : handle_(handle), unname_symbol_counter_(0) {} - -tvm::Var CodeGenContext::GetOrCreateDynamicDim(const std::string& name) { - if (dynamic_dims_.count(name) == 0) - dynamic_dims_.emplace(name, tvm::Var(name)); - - return dynamic_dims_.at(name); -} - -std::string CodeGenContext::CreateUnnamedSymbol() { - return "unnamed_" + std::to_string(unname_symbol_counter_++); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/codegen_context.h b/onnxruntime/core/codegen/passes/utils/codegen_context.h deleted file mode 100644 index 641552bd3b2e8..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/codegen_context.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/handle.h" -#include "core/codegen/common/common.h" -#include "core/common/common.h" -#include "core/framework/data_types.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// CodeGenContext is a data structure involving across passes -// Compiler developers can use it to store meta data -// to support fine-grained control of code generation -class CodeGenContext { - public: - CodeGenContext(const codegen::CodeGenHandle* handle); - - virtual ~CodeGenContext() = default; - - // returns tvm::Var for the dynamic dim - tvm::Var GetOrCreateDynamicDim(const std::string& name); - - const codegen::CodeGenHandle* GetCodeGenHandle() const { - return handle_; - } - - std::string CreateUnnamedSymbol(); - - protected: - std::unordered_map dynamic_dims_; - - const codegen::CodeGenHandle* handle_; - - int unname_symbol_counter_; -}; - -// Add Promote for CodeGenContext -DYNAMIC_PROMOTE(CodeGenContext) - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc deleted file mode 100644 index 55892974aa33f..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.cc +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/utils/ort_tvm_utils.h" - -#include "core/codegen/common/profile.h" -#include "core/codegen/passes/utils/codegen_context.h" -#include "core/framework/tensorprotoutils.h" -#include "core/providers/common.h" -#include - -#include - -namespace onnxruntime { -namespace tvm_codegen { - -#define RETURN_DLDATATYPE_IF_MATCH(type_enum, type, type_code) \ - case type_enum: \ - return {type_code, sizeof(type) * 8, 1}; \ - break; - -// DLDataType: {DLDataTypeCode, bits, lanes} -DLDataType ToTvmDLDataType(MLDataType ml_type) { - if (ml_type->IsTensorType()) { - ml_type = ml_type->AsTensorType()->GetElementType(); - } - auto prim_type = ml_type->AsPrimitiveDataType(); - if (prim_type == nullptr) { - ORT_NOT_IMPLEMENTED("converting MLDataType ", ml_type, " to tvm DLDataType is not implemented"); - } - - switch (prim_type->GetDataType()) { - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT8, int8_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT8, uint8_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT16, int16_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT16, uint16_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT32, int32_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT32, uint32_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_INT64, int64_t, kDLInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_UINT64, uint64_t, kDLUInt); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_BOOL, bool, kDLUInt); - - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, float, kDLFloat); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, double, kDLFloat); - RETURN_DLDATATYPE_IF_MATCH(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, MLFloat16, kDLFloat); - default: - ORT_NOT_IMPLEMENTED("converting MLDataType ", ml_type, " to tvm DLDataType is not implemented"); - } -} - -tvm::Type ToTvmType(ONNX_NAMESPACE::TensorProto_DataType proto_type) { - switch (proto_type) { - // Note that bool is uint1 in tvm, but uint8 in ONNX, so it always require special handling - // case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - // return tvm::UInt(1); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT16: - return tvm::Int(16); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return tvm::Int(32); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return tvm::Int(64); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - return tvm::UInt(8); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT16: - return tvm::UInt(16); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - return tvm::UInt(32); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - return tvm::UInt(64); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - return tvm::Float(32); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - return tvm::Float(64); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - return tvm::Int(8); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - return tvm::Float(16); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_STRING: - ORT_THROW("Casting to and from strings is not supported yet."); /*break;*/ - case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: - ORT_THROW("Cast op must have 'to' argument of type DataType"); /*break;*/ - default: - ORT_THROW("Unexpected 'to' argument value: ", proto_type); - } -} - -tvm::Array ShapeToTvmArray(const NodeArg* def, CodeGenContext& ctx) { - ORT_ENFORCE(nullptr != def); - const ONNX_NAMESPACE::TensorShapeProto* shape_proto = def->Shape(); - ORT_ENFORCE(nullptr != shape_proto); - - tvm::Array arr; - for (int i = 0; i < shape_proto->dim_size(); ++i) { - arr.push_back(ShapeDimToTvmDim(shape_proto->dim(i), ctx)); - } - return arr; -} - -tvm::Expr ShapeDimToTvmDim(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim, CodeGenContext& ctx) { - if (utils::HasDimParam(dim)) { - return ctx.GetOrCreateDynamicDim(dim.dim_param()); - } else if (utils::HasDimValue(dim)) { - return tvm::Expr(gsl::narrow_cast(dim.dim_value())); - } - return ctx.GetOrCreateDynamicDim(ctx.CreateUnnamedSymbol()); -} - -#ifdef CODEGEN_ENABLE_PROFILER -struct event_in_bracket_and_id { - bool in_bracket; - size_t id; -}; -std::unordered_map g_codegen_profiler_event_ids; -std::vector> g_codegen_profiler_events(1024); - -TVM_REGISTER_GLOBAL("tvm.contrib.onnxruntime.profile_event") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* ret) { - DLTensor* X = args[0]; - DLTensor* Y = args[1]; - size_t event_id = args[2]; - bool is_begin = args[3]; - if (!is_begin) { - DCHECK(event_id < g_codegen_profiler_event_ids.size()); - profiling::Profiler::Instance().EndTimeAndRecordEvent( - profiling::EventCategory::NODE_EVENT, - g_codegen_profiler_events[event_id].first, - g_codegen_profiler_events[event_id].second); - } - - { - CODEGEN_PROFILER_EVENT("profile_stub"); - int64_t elem_count = 1; - for (int i = 0; i < X->ndim; ++i) { - elem_count *= X->shape[i]; - } - // there's overhead in this copy, so put begin after copy and end before copy - memcpy(static_cast(Y->data) + Y->byte_offset, - static_cast(X->data) + X->byte_offset, - elem_count * X->dtype.bits / 8); - } - - if (is_begin) { - DCHECK(g_codegen_profiler_events.size() > event_id); - DCHECK(!g_codegen_profiler_events[event_id].first.empty()); - DCHECK(g_codegen_profiler_event_ids[g_codegen_profiler_events[event_id].first].id == event_id); - g_codegen_profiler_events[event_id].second = - profiling::Profiler::Instance().StartTime(); - } - }); - -tvm::Tensor ProfileBegin(tvm::Tensor X, const std::string& event_name) { - size_t event_id; - if (g_codegen_profiler_event_ids.count(event_name) == 0) { - event_id = g_codegen_profiler_event_ids.size(); - ORT_ENFORCE(event_id < g_codegen_profiler_events.size()); - } else { - ORT_ENFORCE(!g_codegen_profiler_event_ids[event_name].in_bracket); - event_id = g_codegen_profiler_event_ids[event_name].id; - } - g_codegen_profiler_event_ids[event_name] = {true, event_id}; - g_codegen_profiler_events[event_id].first = event_name; - return topi::detail::make_extern( - {X->shape}, {X->dtype}, {X}, - [&](tvm::Array ins, tvm::Array outs) { - return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.profile_event"), - topi::detail::pack_buffer(ins[0]), - topi::detail::pack_buffer(outs[0]), - gsl::narrow(event_id), - true}); - }, - event_name + "_begin", "", {})[0]; -} - -tvm::Tensor ProfileEnd(tvm::Tensor X, const std::string& event_name) { - ORT_ENFORCE(g_codegen_profiler_event_ids.at(event_name).in_bracket); - g_codegen_profiler_event_ids.at(event_name).in_bracket = false; - size_t event_id = g_codegen_profiler_event_ids.at(event_name).id; - ORT_ENFORCE(event_id < g_codegen_profiler_events.size()); - ORT_ENFORCE(g_codegen_profiler_events[event_id].first == event_name); - return topi::detail::make_extern( - {X->shape}, {X->dtype}, {X}, - [&](tvm::Array ins, tvm::Array outs) { - return topi::detail::call_packed({tvm::Expr("tvm.contrib.onnxruntime.profile_event"), - topi::detail::pack_buffer(ins[0]), - topi::detail::pack_buffer(outs[0]), - gsl::narrow(event_id), - false}); - }, - event_name + "_end", "", {})[0]; -} -#endif - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.h b/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.h deleted file mode 100644 index f13e91a2d5cea..0000000000000 --- a/onnxruntime/core/codegen/passes/utils/ort_tvm_utils.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/common.h" -#include "core/framework/data_types.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -class CodeGenContext; - -// Helper function that converts a onnxruntime MLDataType to TVM DLDataType -DLDataType ToTvmDLDataType(MLDataType ml_type); - -tvm::Type ToTvmType(ONNX_NAMESPACE::TensorProto_DataType proto_type); - -tvm::Array ShapeToTvmArray(const NodeArg* def, CodeGenContext& ctx); - -tvm::Expr ShapeDimToTvmDim(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim, CodeGenContext& ctx); - -#ifdef CODEGEN_ENABLE_PROFILER -// Helper functions to inspect into lowered function -tvm::Tensor ProfileBegin(tvm::Tensor X, const std::string& event_name); - -tvm::Tensor ProfileEnd(tvm::Tensor X, const std::string& event_name); -#endif - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.cc b/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.cc deleted file mode 100644 index c65132f6d4bca..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/tiling_2d.h" - -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace tvm_codegen { - -constexpr auto local_name_prefix = "tiling_2d_"; -constexpr int num_bits = 8; - -const std::string WeightLayoutTiling2D::GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width) { - return WeightLayout::GetKey( - local_name_prefix + std::to_string(vector_width), - proto_type, 2, 0.0f); -} - -WeightLayoutTiling2D::WeightLayoutTiling2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width) - : WeightLayout( - local_name_prefix + std::to_string(vector_width), - proto_type, 2, 0.0f), - vector_width_(vector_width) {} - -CoordTransFunc WeightLayoutTiling2D::ToActual(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& nominal_coord) { - ORT_ENFORCE(nominal_coord.size() == 2); - const auto& y = nominal_coord[0]; - const auto& x = nominal_coord[1]; - return tvm::Array{ - x, - y}; - }; -} - -CoordTransFunc WeightLayoutTiling2D::ToNominal(const tvm::Tensor& X) const { - return [&](const tvm::Array& actual_coord) { - ORT_ENFORCE(actual_coord.size() == 2); - ORT_ENFORCE(X->dtype == HalideIR::type_of() || - X->dtype == HalideIR::type_of()); - - int tile_row = (sizeof(int32_t) * num_bits) / X->dtype.bits(); - int tile_col = ((vector_width_ * num_bits) / X->dtype.bits()) / tile_row; - - const auto& x = actual_coord[0]; - const auto& y = actual_coord[1]; - - const int block_dimy = tile_row; - const int block_dimx = tile_col; - - const auto& y0 = y % block_dimy; - const auto& y1 = (y / block_dimy) % block_dimx; - const auto& y2 = y / block_dimy / block_dimx; - - const auto& x0 = x % block_dimx; - const auto& x1 = x / block_dimx; - - return tvm::Array{ - y0 + y2 * block_dimx * block_dimy + x0 * block_dimy, - y1 + x1 * block_dimx}; - }; -} - -tvm::Array WeightLayoutTiling2D::ToActualShape(const tvm::Tensor& X) const { - ORT_ENFORCE(X->dtype == HalideIR::type_of() || - X->dtype == HalideIR::type_of()); - - auto pad_row = tvm::make_const(tvm::Int(32), (vector_width_ * num_bits) / X->dtype.bits()); - auto pad_col = tvm::make_const(tvm::Int(32), vector_width_ / sizeof(int32_t)); - - auto new_shape0 = ((X->shape[1] + pad_col - 1) / pad_col) * pad_col; - auto new_shape1 = ((X->shape[0] + pad_row - 1) / pad_row) * pad_row; - - tvm::Array - new_shape = { - new_shape0, - new_shape1}; - return new_shape; -} - -std::vector WeightLayoutTiling2D::ToActualShape(const Tensor* X) const { - ORT_ENFORCE(X != nullptr); - ORT_ENFORCE(X->Shape().GetDims().size() == 2); - - int pad_row = vector_width_ / X->DataType()->Size(); - int pad_col = vector_width_ / sizeof(int32_t); - - auto old_shape = X->Shape().GetDims(); - auto new_shape0 = (old_shape[1] + pad_col - 1) / pad_col * pad_col; - auto new_shape1 = ((old_shape[0] + pad_row - 1) / pad_row) * pad_row; - - std::vector new_shape = { - new_shape0, - new_shape1}; - - return new_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.h b/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.h deleted file mode 100644 index 64334a069f94f..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/tiling_2d.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/codegen/passes/weight_layout/weight_layout.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -/* - * \class! WeightLayoutTiling2D - * \breif! Transform 2D weight to 4D by tiling both dimension, - * this layout is used for tensorization. - * [M, N] => [M/Tx, N/Ty, Tx, Ty] - */ - -class WeightLayoutTiling2D : public WeightLayout { - public: - static const std::string GetKey(ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width); - - public: - WeightLayoutTiling2D(ONNX_NAMESPACE::TensorProto_DataType proto_type, - int vector_width); - - ~WeightLayoutTiling2D() = default; - - CoordTransFunc ToNominal(const tvm::Tensor& X) const override; - CoordTransFunc ToActual(const tvm::Tensor& X) const override; - tvm::Array ToActualShape(const tvm::Tensor& X) const override; - std::vector ToActualShape(const Tensor* X) const override; - - private: - int vector_width_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayoutTiling2D); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.cc b/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.cc deleted file mode 100644 index ea8597f7dd89d..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/transpose_2d.h" - -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace tvm_codegen { - -constexpr auto local_layout_name = "transpose_2d"; - -const std::string WeightLayoutTranspose2D::GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type) { - return WeightLayout::GetKey(local_layout_name, proto_type, 2, 0.0f); -} - -WeightLayoutTranspose2D::WeightLayoutTranspose2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type) - : WeightLayout(local_layout_name, proto_type, 2, 0.0f) {} - -CoordTransFunc WeightLayoutTranspose2D::ToActual(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& nominal_coord) { - ORT_ENFORCE(nominal_coord.size() == 2); - const auto& y = nominal_coord[0]; - const auto& x = nominal_coord[1]; - return tvm::Array{ - x, - y}; - }; -} - -CoordTransFunc WeightLayoutTranspose2D::ToNominal(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& actual_coord) { - ORT_ENFORCE(actual_coord.size() == 2); - const auto& y = actual_coord[0]; - const auto& x = actual_coord[1]; - return tvm::Array{ - x, - y}; - }; -} - -tvm::Array WeightLayoutTranspose2D::ToActualShape(const tvm::Tensor& X) const { - tvm::Array new_shape = { - X->shape[1], - X->shape[0]}; - return new_shape; -} - -std::vector WeightLayoutTranspose2D::ToActualShape(const Tensor* X) const { - ORT_ENFORCE(X != nullptr); - ORT_ENFORCE(X->Shape().GetDims().size() == 2); - auto old_shape = X->Shape().GetDims(); - - std::vector new_shape = { - old_shape[1], - old_shape[0]}; - - return new_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.h b/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.h deleted file mode 100644 index 65babaaec8dac..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/transpose_2d.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/codegen/passes/weight_layout/weight_layout.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// WeightLayoutTranspose2D for transposing a 2D weight -// [W, H] => [H, W] -class WeightLayoutTranspose2D : public WeightLayout { - public: - static const std::string GetKey(ONNX_NAMESPACE::TensorProto_DataType proto_type); - - public: - WeightLayoutTranspose2D(ONNX_NAMESPACE::TensorProto_DataType proto_type); - - ~WeightLayoutTranspose2D() = default; - - CoordTransFunc ToNominal(const tvm::Tensor& X) const override; - CoordTransFunc ToActual(const tvm::Tensor& X) const override; - tvm::Array ToActualShape(const tvm::Tensor& X) const override; - std::vector ToActualShape(const Tensor* X) const override; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayoutTranspose2D); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.cc b/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.cc deleted file mode 100644 index b1ddb791a3b3d..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/vertical_stripes_2d.h" - -#include "core/codegen/passes/utils/codegen_context.h" - -namespace onnxruntime { -namespace tvm_codegen { - -constexpr auto local_name_prefix = "vertical_stripe_2d_"; - -const std::string WeightLayoutVerticalStripe2D::GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width) { - return WeightLayout::GetKey( - local_name_prefix + std::to_string(stripe_width), - proto_type, 2, 0.0f); -} - -WeightLayoutVerticalStripe2D::WeightLayoutVerticalStripe2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width) - : WeightLayout( - local_name_prefix + std::to_string(stripe_width), - proto_type, 2, 0.0f), - stripe_width_(stripe_width) { -} - -CoordTransFunc WeightLayoutVerticalStripe2D::ToActual(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& nominal_coord) { - ORT_ENFORCE(nominal_coord.size() == 2); - const auto& y = nominal_coord[0]; - const auto& x = nominal_coord[1]; - return tvm::Array{ - x / stripe_width_, - y, - x % stripe_width_}; - }; -} - -CoordTransFunc WeightLayoutVerticalStripe2D::ToNominal(const tvm::Tensor& /*X*/) const { - return [&](const tvm::Array& actual_coord) { - ORT_ENFORCE(actual_coord.size() == 3); - const auto& z = actual_coord[0]; - const auto& y = actual_coord[1]; - const auto& x = actual_coord[2]; - return tvm::Array{ - y, - x + stripe_width_ * z}; - }; -} - -tvm::Array WeightLayoutVerticalStripe2D::ToActualShape(const tvm::Tensor& X) const { - tvm::Array new_shape = { - (X->shape[1] + stripe_width_ - 1) / stripe_width_, - X->shape[0], - stripe_width_}; - return new_shape; -} - -std::vector WeightLayoutVerticalStripe2D::ToActualShape(const Tensor* X) const { - ORT_ENFORCE(X != nullptr); - auto old_shape = X->Shape().GetDims(); - - ORT_ENFORCE(old_shape.size() == 2); - - std::vector new_shape = { - (old_shape[1] + stripe_width_ - 1) / stripe_width_, - old_shape[0], - stripe_width_}; - - return new_shape; -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.h b/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.h deleted file mode 100644 index b9b65025dc014..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/vertical_stripes_2d.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/common.h" -#include "core/codegen/passes/weight_layout/weight_layout.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -// WeightLayoutVerticalStripe2D for making a 2D weight to 3D, by tiling the lowest (verteical) dimension -// [W, H] => [H/stripe, W, stripe] -class WeightLayoutVerticalStripe2D : public WeightLayout { - public: - static const std::string GetKey( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width); - - public: - WeightLayoutVerticalStripe2D( - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int stripe_width); - - ~WeightLayoutVerticalStripe2D() = default; - - virtual CoordTransFunc ToNominal(const tvm::Tensor& X) const override; - virtual CoordTransFunc ToActual(const tvm::Tensor& X) const override; - tvm::Array ToActualShape(const tvm::Tensor& X) const override; - std::vector ToActualShape(const Tensor* X) const override; - - private: - int stripe_width_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayoutVerticalStripe2D); -}; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc deleted file mode 100644 index ab3e647fd284a..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/codegen/passes/weight_layout/weight_layout.h" - -#include "core/codegen/common/common.h" -#include "core/codegen/common/utils.h" -#include "core/codegen/mti/mti_tvm_utils.h" -#include "core/codegen/passes/utils/ort_tvm_utils.h" - -namespace onnxruntime { -namespace tvm_codegen { - -static tvm::Tensor CreateTVMPlaceholder( - const std::string& name, - HalideIR::Type type, - int dim) { - tvm::Array shape; - if (dim > 0) { - for (int i = 0; i < dim; ++i) { - shape.push_back(tvm::Var(name + "_v" + std::to_string(i))); - } - } else { - shape.push_back(1); - } - return tvm::placeholder(shape, type, name + "_placeholder"); -} - -const std::string WeightLayout::GetKey( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero) { - std::ostringstream key; - key << name << "_type_" << static_cast(proto_type); - key << "_dim_" << input_dim; - key << "_pad_zero_" << pad_zero; - return NormalizeCppName(key.str()); -} - -WeightLayout::WeightLayout( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero) - : name_(GetKey(name, proto_type, input_dim, pad_zero)), - proto_type_(proto_type), - input_dim_(input_dim), - pad_zero_(pad_zero) {} - -const std::string& WeightLayout::Name() const { - return name_; -} - -void WeightLayout::CreateLayoutMarshallingTVMOp(tvm::Array& inputs, - tvm::Array& outputs) const { - HalideIR::Type halide_type = ToTvmType(proto_type_); - - tvm::Tensor placeholder = CreateTVMPlaceholder(name_, halide_type, input_dim_); - inputs.push_back(placeholder); - - tvm::Array new_shape = ToActualShape(placeholder); - CoordTransFunc new_coord_to_old_coord_func = ToNominal(placeholder); - tvm::Expr pad_zero_expr = tvm::make_const(halide_type, pad_zero_); - - tvm::Tensor output = tvm::compute( - new_shape, - [&](const tvm::Array& output_coord) { - tvm::Array output_coord1; - for (const auto& coord : output_coord) - output_coord1.push_back(coord); - auto input_coord = new_coord_to_old_coord_func(output_coord1); - ORT_ENFORCE(input_coord.size() == placeholder->shape.size()); - - if (input_coord.size() > 0) { - auto in_range = (input_coord[0] >= 0) && (input_coord[0] < placeholder->shape[0]); - for (size_t dim = 1; dim < input_coord.size(); ++dim) - in_range = in_range && (input_coord[dim] >= 0) && (input_coord[dim] < placeholder->shape[dim]); - - return tvm::if_then_else(in_range, placeholder(input_coord), pad_zero_expr); - } else { - // scalar - return placeholder(input_coord); - } - }); - - outputs.push_back(output); -} - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h b/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h deleted file mode 100644 index 1b45a38e7e24e..0000000000000 --- a/onnxruntime/core/codegen/passes/weight_layout/weight_layout.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/codegen/common/common.h" -#include "core/codegen/common/registry.h" -#include "core/common/common.h" -#include "core/framework/tensor.h" -#include - -namespace onnxruntime { -namespace tvm_codegen { - -using CoordTransFunc = std::function(const tvm::Array&)>; - -// WeightLayout is data layout transformer for weight/initializer -class WeightLayout { - public: - // Static function to return unique string as a key - static const std::string GetKey( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero); - - public: - WeightLayout( - const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - int input_dim, - float pad_zero); - - virtual ~WeightLayout() = default; - - // Return a CoordTransFunc from actual (transformed) coordinate to normial (original) coordinate - virtual CoordTransFunc ToNominal(const tvm::Tensor& X) const = 0; - - // Return a CoordTransFunc from normial (original) coordinate to actual (transformed) coordinate - virtual CoordTransFunc ToActual(const tvm::Tensor& X) const = 0; - - // Return actual (transformed) shape in tvm::Array (tvm_codegen) - virtual tvm::Array ToActualShape(const tvm::Tensor& X) const = 0; - - // Return actual (transformed) shape in vector (ort) - virtual std::vector ToActualShape(const Tensor* X) const = 0; - - // Create Layout Marshalling op in outputs - void CreateLayoutMarshallingTVMOp(tvm::Array& inputs, - tvm::Array& outputs) const; - - // Layout name - const std::string& Name() const; - - protected: - std::string name_; - ONNX_NAMESPACE::TensorProto_DataType proto_type_; - int input_dim_; - float pad_zero_; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightLayout); -}; - -// Weight Layout Registry is a registry holds all WeightLayout -using WeightLayoutRegistry = codegen::RegistryBase; - -} // namespace tvm_codegen -} // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 5dca4cf6c165b..ecd3960107926 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -138,7 +138,8 @@ class PlannerImpl { const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps, const InlinedHashMap& outer_scope_node_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, - const ISequentialPlannerContext& context, SequentialExecutionPlan& plan) + const ISequentialPlannerContext& context, SequentialExecutionPlan& plan, + const logging::Logger& logger) : context_(&context), plan_(plan), parent_node_(parent_node), @@ -148,14 +149,15 @@ class PlannerImpl { kernel_create_info_map_(kernel_create_info_map), subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps), outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map), - ort_value_name_idx_map_(ort_value_name_idx_map) {} + ort_value_name_idx_map_(ort_value_name_idx_map), + logger_(logger) { + } Status CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger); + const PathString& partition_config_file); private: gsl::not_null context_; @@ -183,6 +185,12 @@ class PlannerImpl { InlinedHashMap> dependence_graph_; InlinedHashMap value_node_map_; + // logger_ is not currently used in a minimal build +#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) + [[maybe_unused]] +#endif + const logging::Logger& logger_; + // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: struct OrtValueInfo { const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue @@ -213,6 +221,7 @@ class PlannerImpl { FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point) : ml_value(ort_value), deallocate_point(dealloc_point) {} }; + // freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when // they became free (more recently freed earlier in the list). std::list freelist_; @@ -225,7 +234,8 @@ class PlannerImpl { } int& UseCount(OrtValueIndex n) { - ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size()); + ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), + "invalid value index: ", n, " against size ", ort_value_info_.size()); return ort_value_info_[n].usecount; } int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); } @@ -335,9 +345,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -361,9 +371,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -397,8 +407,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -448,8 +458,8 @@ class PlannerImpl { return true; } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " - << producer_node->Name() << " as it has external outputs."; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " + << producer_node->Name() << " as it has external outputs."; #endif } } @@ -1198,9 +1208,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1241,9 +1251,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1290,8 +1300,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -1869,8 +1879,7 @@ class PlannerImpl { } #ifndef ORT_ENABLE_STREAM - void PartitionIntoStreams(const logging::Logger& /*logger*/, - const ExecutionProviders& /*execution_providers*/, + void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/, const PathString& /*partition_config_file*/) { if (graph_viewer_.NumberOfNodes() > 0) { stream_nodes_.push_back({}); @@ -1915,11 +1924,11 @@ class PlannerImpl { #else - void - PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, - const PathString& partition_config_file) { - auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file); - auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder()); + void PartitionIntoStreams(const ExecutionProviders& execution_providers, + const PathString& partition_config_file) { + auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file); + auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, + context_->GetExecutionOrder()); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); plan_.node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); for (size_t i = 0; i < stream_nodes_.size(); ++i) { @@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger) { + const PathString& partition_config_file) { // 1. partition graph into streams - PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file); + PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file); // 2. initialize the plan based on stream partition result int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; @@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan( PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_create_info_map, subgraphs_kernel_create_info_maps, outer_scope_node_arg_to_location_map, - ort_value_name_idx_map, context, *plan); + ort_value_name_idx_map, context, *plan, logger); return planner.CreatePlan( #ifdef ORT_ENABLE_STREAM stream_handle_registry, #endif - partition_config_file, - logger); + partition_config_file); } #ifdef ORT_ENABLE_STREAM diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index b6dc8ad56f257..26b98b0a04d24 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0 || strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 || - strcmp(name1, onnxruntime::DML) == 0 || strcmp(name1, onnxruntime::HIP) == 0 || strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::DML) == 0) { + *out = new OrtMemoryInfo( + name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, diff --git a/onnxruntime/core/framework/allocator_utils.cc b/onnxruntime/core/framework/allocator_utils.cc index 797b6e1606f97..edf965d3835b5 100644 --- a/onnxruntime/core/framework/allocator_utils.cc +++ b/onnxruntime/core/framework/allocator_utils.cc @@ -77,7 +77,7 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) { } } -bool ShouldCpuAllocatorUseArena([[maybe_unused]] bool is_arena_requested) { +bool DoesCpuAllocatorSupportArenaUsage() { #if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) // We use these allocators instead of the arena. return false; @@ -89,7 +89,7 @@ bool ShouldCpuAllocatorUseArena([[maybe_unused]] bool is_arena_requested) { if constexpr (sizeof(void*) == 4) { return false; } else { - return is_arena_requested; + return true; } #endif } diff --git a/onnxruntime/core/framework/allocator_utils.h b/onnxruntime/core/framework/allocator_utils.h index 4035a0cc349e4..bef0b7057a7f8 100644 --- a/onnxruntime/core/framework/allocator_utils.h +++ b/onnxruntime/core/framework/allocator_utils.h @@ -43,8 +43,8 @@ struct AllocatorCreationInfo { AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info); /** - * Gets whether a CPU allocator should use an arena or not. + * Gets whether a CPU allocator supports arena usage. */ -bool ShouldCpuAllocatorUseArena(bool is_arena_requested); +bool DoesCpuAllocatorSupportArenaUsage(); } // namespace onnxruntime diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index ef68b88187e08..1eb7420b44d2c 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { + gsl::span tentative_nodes, + const logging::Logger& logger) { // automatic conversion from const std::vector& const auto& ordered_nodes = graph.GetNodesInTopologicalOrder(); InlinedVector node_id_to_order_map(graph.MaxNodeIndex()); @@ -83,7 +84,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { candidates.push(consumer_node->Index()); - LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); + LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); } } return Status::OK(); @@ -159,9 +160,9 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe if (place_in_cpu) { cpu_nodes.insert(cur); - LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() - << " because the CPU execution path is deemed faster than overhead involved with execution on other EPs " - << " capable of executing this node"; + LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() + << " because the CPU execution path is deemed faster than overhead involved with execution " + "on other EPs capable of executing this node"; for (auto* output : node->OutputDefs()) { cpu_output_args.insert(output); } diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index c5bcd22888b7c..bca75adbfd5a7 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -9,6 +9,9 @@ #include "core/graph/graph_viewer.h" namespace onnxruntime { +namespace logging { +class Logger; +} /** Returns a list of nodes that are preferred on CPU. @@ -19,6 +22,7 @@ namespace onnxruntime { */ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); } // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 4f745b74abce7..406fc1b15effc 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep, }; } // namespace -static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { +static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) { auto& current_ep = params.current_ep.get(); const auto& ep_type = current_ep.Type(); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) { - LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " + LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " "the layout transformer."; return Status::OK(); } @@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); @@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, const IExecutionProvider& current_ep, + const logging::Logger& logger, std::vector>& capabilities) { const auto& ep_type = current_ep.Type(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; // TODO: Provide EP with a capability to look inside the functions. capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); @@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, GraphPartitioner::Mode mode, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, - const layout_transformation::DebugGraphFn& debug_graph_fn) { + const layout_transformation::DebugGraphFn& debug_graph_fn, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn)); + transform_layout_fn, debug_graph_fn, logger)); } } @@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(transform_layout_fn), std::cref(debug_graph_fn)}; - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach - if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) { + if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { nodes_to_compile.push_back(n); capabilities_to_compile.push_back(std::move(capability)); } else { @@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, Graph& graph, + const logging::Logger& logger, InlinedHashSet& not_inlined, size_t& inlined_count) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, *subgraph, + logger, not_inlined, inlined_count)); } @@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, + capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; if (nodes.size() == 1) { @@ -674,7 +681,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers context_cache_path, "' exist already."); } - Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), {}, logger); auto& ep_graph = ep_context_model.MainGraph(); ep_graph.SetDescription(graph.Description()); @@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { bool modified_graph = false; auto& graph = partition_params.graph.get(); @@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, - partition_params.debug_graph_fn)); + partition_params.debug_graph_fn, + logger)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, - IExecutionProvider& current_ep) { + IExecutionProvider& current_ep, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability auto& graph = partition_params.graph.get(); @@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param auto& subgraph = *entry.second; PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep, + logger)); } } @@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param }; // clang-format on - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param // Simplified partitioning where custom EPs may produce compiled nodes. static Status PartitionOrtFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { // process full graph with each EP for (const auto& ep : execution_providers) { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger)); } return Status::OK(); @@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, graph, + logger, not_inlined, inlined_count)); @@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger)); bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); @@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h index 0dd17d2f4a624..fac43bad0fefb 100644 --- a/onnxruntime/core/framework/kernel_lookup.h +++ b/onnxruntime/core/framework/kernel_lookup.h @@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { public: KernelLookup(ProviderType provider_type, gsl::span> kernel_registries, - const IKernelTypeStrResolver& kernel_type_str_resolver) + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) : provider_type_{provider_type}, kernel_registries_{kernel_registries}, - kernel_type_str_resolver_{kernel_type_str_resolver} { + kernel_type_str_resolver_{kernel_type_str_resolver}, + logger_{logger} { ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified."); } const KernelCreateInfo* LookUpKernel(const Node& node) const override { const KernelCreateInfo* kernel_create_info{}; for (const auto& registry : kernel_registries_) { - const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, + const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return kernel_create_info; @@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { ProviderType provider_type_; const gsl::span> kernel_registries_; const IKernelTypeStrResolver& kernel_type_str_resolver_; + const logging::Logger& logger_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index cbbe0f86b8b7e..8602a3b4004ff 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -183,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { const auto& node_provider = node.GetExecutionProviderType(); const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider); @@ -215,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } @@ -224,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out); + return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out); } Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out); + return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out); } static bool KernelDefCompatible(int version, const KernelDef& kernel_def, @@ -261,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider)); if (out) *out = nullptr; @@ -289,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index f8ccdb8fb0238..721353854a474 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptrTryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, } if (p != nullptr) { - status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); } -bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) { +bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, + const Node& node, + const std::string& provider_type, + const logging::Logger& logger) { const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type); return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) { - return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver()); + return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(), + logger); }); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 1da73208cb536..72f0ed3c6268a 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -67,13 +67,14 @@ class KernelRegistryManager { // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done - Status SearchKernelRegistry(const Node& node, + Status SearchKernelRegistry(const Node& node, const logging::Logger& logger, /*out*/ const KernelCreateInfo** kernel_create_info) const; /** * Whether this node can be run on this provider */ - static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type); + static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type, + const logging::Logger& logger); Status CreateKernel(const Node& node, const IExecutionProvider& execution_provider, diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 6ea12c7f3336b..2185b8332b9cf 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -68,7 +68,7 @@ static void CalculateTotalOutputSizes(OpKernelContextInternal* op_kernel_context int output_count = op_kernel_context->OutputCount(); for (auto i = 0; i < output_count; i++) { const OrtValue* p_output = op_kernel_context->GetOutputMLValue(i); - if (p_output != nullptr && p_output->IsTensor()) { + if (p_output != nullptr && p_output->IsTensor() && p_output->IsAllocated()) { const auto& tensor = p_output->Get(); size_t tensor_size = tensor.SizeInBytes(); #if defined(TRACE_EXECUTION) @@ -104,7 +104,7 @@ static void CalculateTotalInputSizes(const OpKernelContextInternal* op_kernel_co const int input_count = op_kernel_context->InputCount(); for (auto i = 0; i < input_count; i++) { const OrtValue* p_input = op_kernel_context->GetInputMLValue(i); - if (p_input != nullptr && p_input->IsTensor()) { + if (p_input != nullptr && p_input->IsTensor() && p_input->IsAllocated()) { const OpKernelInfo& op_kernel_info = p_op_kernel->Info(); const Tensor* p_tensor = nullptr; bool is_param = op_kernel_info.TryGetConstantInput(i, &p_tensor); diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 18405231750ba..8d4db36106f28 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -83,11 +83,6 @@ struct SessionOptions { // enable profiling for this session. bool enable_profiling = false; - // save pre-packed constant external initializers instead of original initializers to onnxruntime data file. - // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from - // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. - bool save_prepacked_constant_initializers = false; - // Non empty filepath enables serialization of the transformed optimized model to the specified filepath. // // Set session config value for ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT to 'ORT' or 'ONNX' to explicitly @@ -196,7 +191,6 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_ << " execution_mode:" << session_options.execution_mode << " execution_order:" << session_options.execution_order << " enable_profiling:" << session_options.enable_profiling - << " save_prepacked_constant_initializers:" << session_options.save_prepacked_constant_initializers << " optimized_model_filepath:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath) << " enable_mem_pattern:" << session_options.enable_mem_pattern << " enable_mem_reuse:" << session_options.enable_mem_reuse diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 943db091b341f..0ac2271ba09f1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -14,7 +14,6 @@ #include "core/framework/op_kernel.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/session_state_utils.h" -#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -179,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne bool saving_ort_format) { for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; - auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); if (!status.IsOK() && saving_ort_format) { // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. // in that case we assigned the node to that EP but do not compile it into a fused node. @@ -188,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. // if that's not possible for some reason we can fallback to the CPU EP implementation. node.SetExecutionProviderType(kCpuExecutionProvider); - status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); } ORT_RETURN_IF_ERROR(status); @@ -398,18 +397,12 @@ static std::string GenerateKeyForPrepackedWeightsMap(const std::string& op_type, } Status SessionState::PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map, - bool save_prepacked_constant_initializers, - PrePackInitializers& pre_packed_initializers) { - auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map, - save_prepacked_constant_initializers, &pre_packed_initializers]( + const std::unordered_map& initializers_to_share_map) { + auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map]( bool should_cache_prepacked_weights_for_shared_initializers) -> Status { - std::unordered_map pre_packed_kernel_input_map; for (auto& node : GetGraphViewer().Nodes()) { auto kernel = GetMutableKernel(node.Index()); - auto kernel_name = kernel->Info().node().Name(); int input_idx = 0; - bool is_kernel_prepacked = false; for (auto& input_def : node.InputDefs()) { if (input_def->Exists()) { const std::string& input_name = input_def->Name(); @@ -421,27 +414,16 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapGetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) { std::unordered_map& constant_initialized_tensors = st->constant_initialized_tensors_; - if (constant_initialized_tensors.count(ort_value_idx) && !is_kernel_prepacked) { + if (constant_initialized_tensors.count(ort_value_idx)) { bool is_packed = false; const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get(); auto iter = initializers_to_share_map.find(input_name); bool is_shared_initializer = (iter != initializers_to_share_map.end()); - // found pre-packed constant initializers from data file, no need to do pre-packing again - // apply pre-packed tensor to kernel so kernel can use it directly - if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) != 0) { - is_packed = true; - - // kernel like Matmul_nbits will call prepack multiple times with input_B and possibly scales/zero_points. - // If prepacked weights already read from ONNX data file (this happens we ORT reads data file with prepacked - // weights serialized), only need to set prepacked weights once to kernel. - is_kernel_prepacked = true; - ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx, const_initialized_tensor)); - } // Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now - else if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && - node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON + if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && + node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON AllocatorPtr allocator_for_caching = prepacked_weights_container_->GetOrCreateAllocator(CPU); ORT_ENFORCE(allocator_for_caching.get() != nullptr); @@ -453,7 +435,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapPrePack(const_initialized_tensor, input_idx, allocator_for_caching, - save_prepacked_constant_initializers, is_packed, + is_packed, &weights_to_be_filled_in)); if (is_packed) { @@ -500,50 +482,18 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapInfo().GetDevice(OrtMemType::OrtMemTypeDefault)); ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc, // use allocator tied to this session - save_prepacked_constant_initializers, is_packed, nullptr // no caching required )); } if (is_packed) { - // if intended to save prepacked initializers, get prepacked tensors from kernel and save in hashmap, - // will save to data file later - if (save_prepacked_constant_initializers) { - auto tensor = kernel->GetPrePackTensor(input_idx); - - if (tensor != std::nullopt) { - // save prepacked initializers per initializer and kernel since one initializer could - // be used by multiple kernels - pre_packed_initializers.pre_packed_initializers_to_save[input_name][kernel_name] = std::move(tensor.value()); - - pre_packed_kernel_input_map[kernel_name] = input_name; - } - } - ++number_of_prepacks_counter_; - // if constant_initialized_tensor is already pre-packed, don't need to remove it - if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) == 0 && - constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { + if (constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { // release the constant initialized tensor st->initialized_tensors_.erase(ort_value_idx); constant_initialized_tensors.erase(ort_value_idx); } - } else { - // handle prepack for matmul_nbits, it will prepack several times but set is_packed - // to false for scales and zero_points, we keep scales and zero_points as it is only - // update packed_tensor to input_B. - // TODO: this logic works with matmul_nbits kernel but if other kernels also call prepack - // multiple times and use different initializers to store prepacked weights, this piece of logic - // might introduce bug and need a per kernel strategy to update prepacked weights. - if (save_prepacked_constant_initializers && pre_packed_kernel_input_map.count(kernel_name)) { - auto tensor = kernel->GetPrePackTensor(input_idx); - - if (tensor != std::nullopt) { - auto existing_input_name = pre_packed_kernel_input_map[kernel_name]; - pre_packed_initializers.pre_packed_initializers_to_save[existing_input_name][kernel_name] = std::move(tensor.value()); - } - } } } // stop searching in 2 cases: @@ -1226,7 +1176,6 @@ static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging:: Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, const KernelRegistryManager& kernel_registry_manager, - PrePackInitializers& pre_packed_initializers, bool remove_initializers, bool saving_ort_format) { // recursively create the subgraph session state instances and populate the kernel create info in them. @@ -1240,7 +1189,7 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count; ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count); return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, sess_options_, - remove_initializers, constant_initializers_use_count, pre_packed_initializers); + remove_initializers, constant_initializers_use_count); } static Status Index(const OrtValueNameIdxMap& ort_value_name_idx_map, @@ -1374,7 +1323,6 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& constant_initializers_use_count, - PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map, bool graph_info_already_created) { if (!graph_info_already_created) { @@ -1474,8 +1422,6 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string> - typedef std::unordered_map> PrePackedTensorsToSave; - PrePackedTensorsToSave pre_packed_initializers_to_save; - - // This set is used during model load with prepacked initializer serialized in external data file. - // ORT reads prepacked initializers and store their name into this set so we could skip PrePack - // process later to save heap memory. Prepacked tensor itself is saved in session state's constant_initialized_tensors_. - typedef std::unordered_set PrePackedTensorNamesReadFromFile; - PrePackedTensorNamesReadFromFile pre_packed_initializer_names_read_from_file; - }; - Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, - PrePackInitializers& pre_packed_initializers, bool remove_initializers = true, bool saving_ort_format = false); @@ -338,15 +321,6 @@ class SessionState { return parent_; } - Status FinalizeSessionState(const std::basic_string& graph_loc, - const KernelRegistryManager& kernel_registry_manager, - bool remove_initializers = true, - bool saving_ort_format = false) { - PrePackInitializers pre_packed_initializers; - return FinalizeSessionState(graph_loc, kernel_registry_manager, pre_packed_initializers, - remove_initializers, saving_ort_format); - } - // Clear all removable attributes if they exists. // The function logs the list of removable attributes for every node. void PruneRemovableAttributes(); @@ -406,13 +380,9 @@ class SessionState { /** * Prepack the constant initialized tensors for better performance. * The original constant initialized tensors will be removed to save memory. - * For model with prepacked initializer serialized into ONNX data file, - * PrePack will be skipped to save memory. */ Status PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map, - bool save_prepacked_constant_initializers, - PrePackInitializers& pre_packed_initializers); + const std::unordered_map& initializers_to_share_map); SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); @@ -430,7 +400,6 @@ class SessionState { const SessionOptions& session_options, bool remove_initializers, InlinedHashMap& constant_initializers_use_count, - PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map = {}, bool graph_info_already_created = false); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 3424f40e79c01..2c74805c57dce 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -21,6 +21,7 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/sequential_execution_plan.h" +#include "core/framework/session_state.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/framework/bfc_arena.h" @@ -71,7 +72,6 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor, OrtCallback& ext_data_deleter, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); @@ -79,7 +79,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, ext_data_buf, ext_data_len, ext_data_deleter, - &pre_packed_initializers_name_set, buffered_tensor)); + buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -100,7 +100,6 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, bool use_device_allocator_for_initializers = false, Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { @@ -140,7 +139,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, - ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); + ext_data_deleter, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; MLDataType ml_tensor_type = DataTypeImpl::GetType(); @@ -164,7 +163,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st OrtCallback ext_data_deleter; std::optional scoped_ort_callback_invoker; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); + ext_data_deleter, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. @@ -273,8 +272,7 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set) { + std::unordered_map>& buffered_tensors) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -403,7 +401,6 @@ common::Status SaveInitializedTensors( Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, - pre_packed_initializers_name_set, use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index 4de501b6f7429..af27f5caba0f4 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -12,7 +12,6 @@ #include "core/framework/tensor.h" #include "core/framework/tensor_allocator.h" #include "core/framework/session_options.h" -#include "core/framework/session_state.h" #include "core/framework/sequential_execution_plan.h" #include "core/platform/path_lib.h" @@ -51,8 +50,7 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set); + std::unordered_map>& buffered_tensors); common::Status AllocateTensor( const onnxruntime::MemBuffer* m, diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index bcd04effe2bd4..93146e66d9f24 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -40,8 +40,6 @@ Status ExternalDataInfo::Create(const RepeatedPtrField& return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed"); } else if (stringmap.key() == "checksum" && !stringmap.value().empty()) { out->checksum_ = stringmap.value(); - } else if (stringmap.key() == "prepacked" && !stringmap.value().empty()) { - out->prepacked_ = stringmap.value() == "1"; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error!"); } diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index c2490f5cc5bc2..afc8fda6c3037 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -23,8 +23,6 @@ class ExternalDataInfo { const std::string& GetChecksum() const { return checksum_; } - bool GetPrePacked() const noexcept { return prepacked_; } - // If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a // wrong value and return FAIL. static common::Status Create( @@ -38,6 +36,5 @@ class ExternalDataInfo { // 0 means the whole file size_t length_ = 0; std::string checksum_; - bool prepacked_ = false; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 0c69ee11f62bc..2af9f95ad059e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -230,12 +230,11 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { -static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size, - bool& pre_packed) { +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size) { ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), "Tensor does not have external data to read from."); @@ -245,8 +244,6 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - pre_packed = external_data_info->GetPrePacked(); - const auto& location = external_data_info->GetRelPath(); external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) @@ -268,11 +265,6 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } -Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size) { - bool pre_packed = false; - return GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size, pre_packed); -} - void ConvertRawDataInTensorProto(TensorProto* tensor) { size_t element_size = 1; char* bytes = NULL; @@ -996,7 +988,7 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor) { + Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -1005,13 +997,8 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo std::basic_string external_data_file_path; FileOffsetType file_offset; SafeInt raw_data_safe_len = 0; - bool pre_packed = false; ORT_RETURN_IF_ERROR( - GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len, pre_packed)); - - if (pre_packed && pre_packed_initializers_name_set != nullptr) { - (*pre_packed_initializers_name_set).insert(tensor_proto.name()); - } + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len)); if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data @@ -1121,7 +1108,7 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa OrtCallback& d = deleter_for_file_data.d; if (utils::HasExternalData(tensor_proto)) { - ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d, nullptr)); + ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d)); } else if (utils::HasRawData(tensor_proto)) { raw_data = const_cast(tensor_proto.raw_data().data()); // TODO The line above has const-correctness issues. Below is a possible fix which copies the tensor_proto data diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 770132f8e95fc..262f7adaca1cb 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -17,19 +17,26 @@ #include "core/framework/external_data_loader.h" #include "core/framework/ort_value.h" #include "core/framework/mem_buffer.h" -#include "core/framework/session_state.h" #include "core/framework/tensor_external_data_info.h" #include "core/graph/onnx_protobuf.h" #include "core/platform/env.h" namespace onnxruntime { namespace utils { +/** + * This function is used to get the external data info from the given tensor proto. + * @param tensor_proto given initializer tensor + * @param tensor_proto_dir directory of the tensor proto file + * @param external_file_path output external file path + * @param file_offset output tensor offset + * @param tensor_byte_size output tensor byte size + * @returns Status::OK() if the function is executed successfully + */ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size); - /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file @@ -165,7 +172,6 @@ common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem:: const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr); // Given a tensor proto with external data obtain a tensor using the specified custom external data loader. diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 5402345447706..ff664c2c76703 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -57,7 +57,6 @@ void DestroyStrings(void* p_data, int64_t elements) { bool ProviderIsCpuBased(const std::string& provider_type) { return provider_type == onnxruntime::kCpuExecutionProvider || provider_type == onnxruntime::kDnnlExecutionProvider || - provider_type == onnxruntime::kTvmExecutionProvider || provider_type == onnxruntime::kVitisAIExecutionProvider || provider_type == onnxruntime::kOpenVINOExecutionProvider || provider_type == onnxruntime::kNnapiExecutionProvider || @@ -1064,11 +1063,5 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index return false; } -std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name) { - const std::string seperator = ":"; - - return initializer_name + seperator + node_name; -} - } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index db38ef1675595..afdb5a2cb27f5 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -234,8 +234,6 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); -std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name); - #ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0a261d8f731f2..f2a2a52f8334f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -908,7 +908,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(9, "cache_indirection", - // This input is useful for CUDA EP only. "A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies " "which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration", "M", diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 09a4a77780916..c7a0793c4748f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3335,6 +3335,11 @@ void RegisterContribSchemas() { AttributeProto::STRING, OPTIONAL_VALUE) .Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE) + .Attr( + "max_size", + "max size in the context. Usage depend on the EP.", + AttributeProto::INT, + static_cast(0)) .AllowUncheckedAttributes() .Input( 0, diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 6f1f1c831d191..5a3cd86b04492 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -9,7 +9,7 @@ #include "core/graph/constants.h" #include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/contrib_ops/shape_inference_functions.h" -#include "onnx/onnx-ml.pb.h" // ? +#include "core/graph/onnx_protobuf.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build @@ -23,7 +23,7 @@ void convTransposeShapeInference(InferenceContext& ctx); void convPoolShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx); namespace defs::math::utils { - void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); +void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); } } // namespace ONNX_NAMESPACE @@ -822,10 +822,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } } - if (all_lengths_known) { - output_shape->mutable_dim(axis)->set_dim_value(total_length); - } - })); + if (all_lengths_known) { + output_shape->mutable_dim(axis)->set_dim_value(total_length); + } + })); ONNX_MS_OPERATOR_SET_SCHEMA(QLinearWhere, 1, OpSchema() .SetDoc("Return elements, either from X or Y, depending on condition.") @@ -955,7 +955,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::INT, static_cast(0)) .Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is " + .Attr("past_present_share_buffer", + "Corresponding past and present are same tensor, its shape is " "(2, batch_size, num_heads, max_sequence_length, head_size)", AttributeProto::INT, OPTIONAL_VALUE) .Attr("mask_filter_value", diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3f50841f50913..e8a5855b36496 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4084,75 +4084,10 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { return result; } -void Graph::SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, - size_t tensor_bytes_size, - int64_t& external_offset, - std::ofstream& external_stream, - gsl::span raw_data, - ONNX_NAMESPACE::TensorProto& output_proto, - const std::filesystem::path& external_file_path, - const ONNX_NAMESPACE::TensorProto& initializer, - bool is_prepacked) { - // update external_offset for alignment - // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and alloction granularity aligned) like below: - // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX - // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| - if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { - // Align to the larger of the page size or the allocation granularity - int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); - // Align to the next page or alloc granularity boundary - int64_t new_external_offset = static_cast( - std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * - alignment_factor; - - // padding tensor with zeros for alignment - InlinedVector paddings; - size_t padding_size = SafeInt(new_external_offset - external_offset); - paddings.reserve(padding_size); - for (size_t index = 0; index != padding_size; ++index) { - paddings.push_back(0x0); - } - external_stream.write(reinterpret_cast(paddings.data()), padding_size); - - external_offset = new_external_offset; - } - - external_stream.write(reinterpret_cast(raw_data.data()), tensor_bytes_size); - - output_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); - ONNX_NAMESPACE::StringStringEntryProto* location = output_proto.add_external_data(); - location->set_key("location"); - location->set_value(ToUTF8String(external_file_path.native())); - ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto.add_external_data(); - offset->set_key("offset"); - offset->set_value(std::to_string(external_offset)); - ONNX_NAMESPACE::StringStringEntryProto* length = output_proto.add_external_data(); - length->set_key("length"); - length->set_value(std::to_string(tensor_bytes_size)); - - if (is_prepacked) { - ONNX_NAMESPACE::StringStringEntryProto* pre_packed = output_proto.add_external_data(); - pre_packed->set_key("prepacked"); - pre_packed->set_value("1"); - } - - output_proto.set_name(initializer.name()); - output_proto.set_data_type(initializer.data_type()); - for (int i = 0; i != initializer.dims_size(); ++i) { - output_proto.add_dims(initializer.dims(i)); - } - output_proto.set_doc_string(initializer.doc_string()); - - external_offset += tensor_bytes_size; -} - ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - PrePackedTensorProtoToSave& pre_packed_initializers) const { + const OffsetAlignmentInfo& align_info) const { GraphProto result; ToGraphProtoInternal(result); ORT_ENFORCE(external_file_path.is_relative()); @@ -4171,34 +4106,6 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std #endif for (const auto& initializer : graph_proto_->initializer()) { - bool use_pre_packed_initializer = false; - InlinedVector pre_packed_initializers_tensor_proto; - // If this initializer has been prepacked, saved prepacked external initializer instead of original one. - // Since one initializer could be used by multiple kernels and been prepacked differently, - // Save each prepacked initializers seperately, chagne the initializer name to [initializer_name]:[node_name] - // to avoid conflict. Change the node input name accordingly. - // IT could potentially make the ONNX data file larger since we store multiple prepacked initializers into disk - // but this could be rare case. - if (save_prepacked_constant_initializers && pre_packed_initializers.count(initializer.name())) { - for (const auto& item : pre_packed_initializers[initializer.name()]) { - auto& node_name = item.first; - std::string prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer.name(), node_name); - pre_packed_initializers_tensor_proto.push_back(item.second); - use_pre_packed_initializer = true; - - for (auto& node : *result.mutable_node()) { - if (node.name() == node_name) { - int input_index = 0; - for (const auto& input : node.input()) { - if (input == initializer.name()) { - node.set_input(input_index, prepacked_initializer_name); - } - input_index += 1; - } - } - } - } - } #if !defined(DISABLE_SPARSE_TENSORS) if (sparse_end != sparse_tensor_names_.find(initializer.name())) { // Sparse tensors are added to the ONNX file. @@ -4207,39 +4114,61 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); } else { #endif - if (use_pre_packed_initializer) { - for (const auto& pre_packed_initializer : pre_packed_initializers_tensor_proto) { - // Dense tensors larger than the threshold are added to the external file. - TensorProto* output_proto = result.add_initializer(); - std::vector raw_data; - size_t tensor_bytes_size = 0; - - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(pre_packed_initializer, model_path, raw_data)); - tensor_bytes_size = raw_data.size(); - if (tensor_bytes_size < initializer_size_threshold) { - *output_proto = pre_packed_initializer; - continue; - } + // Dense tensors larger than the threshold are added to the external file. + TensorProto* output_proto = result.add_initializer(); + + std::vector raw_data; + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); + size_t tensor_bytes_size = raw_data.size(); + if (tensor_bytes_size < initializer_size_threshold) { + *output_proto = initializer; + continue; + } - SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, - raw_data, *output_proto, external_file_path, pre_packed_initializer, true); - } - } else { - // Dense tensors larger than the threshold are added to the external file. - TensorProto* output_proto = result.add_initializer(); - std::vector raw_data; - size_t tensor_bytes_size = 0; - - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); - tensor_bytes_size = raw_data.size(); - if (tensor_bytes_size < initializer_size_threshold) { - *output_proto = initializer; - continue; + // update external_offset for alignment + // need to do padding before write actual tensor data as we do offset alignment at the begin of + // large tensors (offset need to be page aligned and alloction granularity aligned) like below: + // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX + // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| + if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { + // Align to the larger of the page size or the allocation granularity + int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); + // Align to the next page or alloc granularity boundary + int64_t new_external_offset = static_cast( + std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * + alignment_factor; + + // padding tensor with zeros for alignment + for (int64_t index = external_offset; index != new_external_offset; ++index) { + external_stream << '0'; } - SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, - raw_data, *output_proto, external_file_path, initializer, false); + external_offset = new_external_offset; } + + for (size_t index = 0; index != tensor_bytes_size; ++index) { + external_stream << raw_data[index]; + } + + output_proto->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + ONNX_NAMESPACE::StringStringEntryProto* location = output_proto->add_external_data(); + location->set_key("location"); + location->set_value(ToUTF8String(external_file_path.native())); + ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto->add_external_data(); + offset->set_key("offset"); + offset->set_value(std::to_string(external_offset)); + ONNX_NAMESPACE::StringStringEntryProto* length = output_proto->add_external_data(); + length->set_key("length"); + length->set_value(std::to_string(tensor_bytes_size)); + + output_proto->set_name(initializer.name()); + output_proto->set_data_type(initializer.data_type()); + for (int i = 0; i != initializer.dims_size(); ++i) { + output_proto->add_dims(initializer.dims(i)); + } + output_proto->set_doc_string(initializer.doc_string()); + + external_offset += tensor_bytes_size; #if !defined(DISABLE_SPARSE_TENSORS) } #endif diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index ad1ec9c8dedb3..1bae63b510563 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -384,17 +384,13 @@ ModelProto Model::ToProto() const { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const { + const Graph::OffsetAlignmentInfo& align_info) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info, - save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); return result; } @@ -612,9 +608,7 @@ static Status SaveModelWithExternalInitializers(Model& model, const T& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { + const Graph::OffsetAlignmentInfo& align_info) { int fd = 0; Status status = Env::Default().FileOpenWr(file_path, fd); ORT_RETURN_IF_ERROR(status); @@ -622,8 +616,7 @@ static Status SaveModelWithExternalInitializers(Model& model, ORT_TRY { status = Model::SaveWithExternalInitializers(model, fd, file_path, external_file_name, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -654,12 +647,9 @@ Status Model::Load(const PathString& file_path, std::shared_ptr& p_model, Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { + const Graph::OffsetAlignmentInfo& align_info) { return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); } Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { @@ -776,9 +766,7 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { + const Graph::OffsetAlignmentInfo& align_info) { if (fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); } @@ -787,8 +775,7 @@ Status Model::SaveWithExternalInitializers(Model& model, auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); google::protobuf::io::FileOutputStream output(fd); const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); if (result) { diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 38d9044ff9d31..9bcec6f78ca08 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -191,17 +191,13 @@ class Model { ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const; + const Graph::OffsetAlignmentInfo& align_info) const; ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold) const { Graph::OffsetAlignmentInfo default_align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers; - return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info, - false, pre_packed_initializers); + return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info); } static common::Status Save(Model& model, const PathString& file_path); @@ -214,18 +210,14 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers); + const Graph::OffsetAlignmentInfo& align_info); static common::Status SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers; - return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info, - false, pre_packed_initializers); + return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info); } static common::Status SaveWithExternalInitializers(Model& model, @@ -233,9 +225,7 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers); + const Graph::OffsetAlignmentInfo& align_info); static common::Status SaveWithExternalInitializers(Model& model, int fd, @@ -243,9 +233,7 @@ class Model { const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers; - return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info, - false, pre_packed_initializers); + return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info); } static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 1eb9d87a9546a..ce7c209a795f7 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1450,6 +1450,29 @@ MLAS_FP16* Destination, size_t Count ); +/** + * @brief rotary embedding for one hidden state vector + * + * @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported. + * @param input: input tensor, of shape [dim] + * @param sin: sin tensor, of shape [dim/2] + * @param cos: cos tensor, of shape [dim/2] + * @param dim: dimension of rotary embedding + * @param interleaved: whether the real part and imaginary parts are interleaved + * @param output: output tensor, of shape [dim] + */ +template +void +MLASCALL +MlasRotaryEmbedOneRow( + const T* input, + const T* sin, + const T* cos, + size_t dim, + bool interleaved, + T* output +); + /** * @brief Whether current CPU supports FP16 acceleration. */ diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 232bf2261ef4c..9608644a22523 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -27,51 +27,50 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32, /*!< input fp32, accumulator fp32 */ - CompFp16, /*!< input fp16, accumulator fp16 */ - CompBf16, /*!< input bf16, accumulator fp32 */ - CompInt8, /*!< input int8, accumulator int32 */ - - // special values that should be the first and last actual values - - CompMostAccurate = CompUndef, - CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; + SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ + HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ + BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ + SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ + HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ +} MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. + * + * @tparam T data type of input A */ -struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) +template +struct MLAS_QNBIT_GEMM_DATA_PARAMS { + const T* A = nullptr; ///< address of A (float32/16 matrix) size_t lda = 0; ///< leading dimension of A const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix - MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; }; /** * @brief Batched GEMM: C = A * B + Bias - * A must be a float32 matrix + * A must be a float32/16 matrix * B must be a quantized and packed n-bit int matrix * - * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called. * - * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether - * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with - * MlasSQNBitGemmPackQuantBData(). + * Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasQNBitGemmPackQuantBData(). * - * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should * point to an intermediate workspace buffer. * + * @tparam T data type of input A * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -81,36 +80,37 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] Workspace Address of intermediate workspace buffer. - If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); /** - * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform. * * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -126,22 +126,22 @@ MlasIsSQNBitGemmAvailable( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** * @brief Gets the size in bytes of the packed quantized B data. - * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of - * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). - * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to - * MlasSQNBitGemmBatch(). + * If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch(). + * If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasQNBitGemmBatch(). * * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -150,12 +150,12 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -186,12 +186,12 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp similarity index 99% rename from onnxruntime/core/mlas/lib/fp16_neon_common.cpp rename to onnxruntime/core/mlas/lib/cast_kernel_neon.cpp index 29734c2277667..8a385c9c61751 100644 --- a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp +++ b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - fp16_neon_common.cpp + cast_kernel_neon.cpp Abstract: diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index 30b66cdb2ea78..f4c49905ebbd7 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -64,6 +64,15 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +template +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadLaneFloat16x4(const _mlas_fp16_* Buffer, MLAS_FLOAT16X4 vec) { + return vreinterpret_f16_u16( + vld1_lane_u16(Buffer, vreinterpret_u16_f16(vec), lane) + ); +} + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) @@ -95,6 +104,14 @@ MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) vst1_u16(Buffer, vreinterpret_u16_f16(Vector)); } +template +MLAS_FORCEINLINE +void +MlasStoreLaneFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) +{ + vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), lane); +} + MLAS_FORCEINLINE void MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len) diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..69e37d2b916d1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -0,0 +1,898 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hqnbitgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_QNBIT_GEMM_COMPUTE_TYPE HQNBIT_CompFp16. + +--*/ + +#include + +#include +#include +#include + +#include "fp16_common.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ +MLAS_FORCEINLINE void +Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, + uint8x8_t& v4, uint8x8_t& v5, uint8x8_t& v6, uint8x8_t& v7) +{ + // v0: | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // v1: | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // v2: | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // v3: | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // v4: | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // v5: | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // v6: | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // v7: | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + + uint8x8x2_t a0 = vtrn_u8(v0, v1); + uint8x8x2_t a1 = vtrn_u8(v2, v3); + uint8x8x2_t a2 = vtrn_u8(v4, v5); + uint8x8x2_t a3 = vtrn_u8(v6, v7); + + // a0[0]: | B00 B10 | B01 B11 | B40 B50 | B41 B51 | B80 B90 | B81 B91 | Bc0 Bd0 | Bc1 Bd1 | + // a0[1]: | B20 B30 | B21 B31 | B60 B70 | B61 B71 | Ba0 Bb0 | Ba1 Bb1 | Be0 Bf0 | Be1 Bf1 | + // a1[0]: | B02 B12 | B03 B13 | B42 B52 | B43 B53 | B82 B92 | B83 B93 | Bc2 Bd2 | Bc3 Bd3 | + // a1[1]: | B22 B32 | B23 B33 | B62 B72 | B63 B73 | Ba2 Bb2 | Ba3 Bb3 | Be2 Bf2 | Be3 Bf3 | + // a2[0]: | B04 B14 | B05 B15 | B44 B54 | B45 B55 | B84 B94 | B85 B95 | Bc4 Bd4 | Bc5 Bd5 | + // a2[1]: | B24 B34 | B25 B35 | B64 B74 | B65 B75 | Ba4 Bb4 | Ba5 Bb5 | Be4 Bf4 | Be5 Bf5 | + // a3[0]: | B06 B16 | B07 B17 | B46 B56 | B47 B57 | B86 B96 | B87 B97 | Bc6 Bd6 | Bc7 Bd7 | + // a3[1]: | B26 B36 | B27 B37 | B66 B76 | B67 B77 | Ba6 Bb6 | Ba7 Bb7 | Be6 Bf6 | Be7 Bf7 | + + uint16x4x2_t b0 = vtrn_u16(vreinterpret_u16_u8(a0.val[0]), vreinterpret_u16_u8(a1.val[0])); + uint16x4x2_t b1 = vtrn_u16(vreinterpret_u16_u8(a0.val[1]), vreinterpret_u16_u8(a1.val[1])); + uint16x4x2_t b2 = vtrn_u16(vreinterpret_u16_u8(a2.val[0]), vreinterpret_u16_u8(a3.val[0])); + uint16x4x2_t b3 = vtrn_u16(vreinterpret_u16_u8(a2.val[1]), vreinterpret_u16_u8(a3.val[1])); + + // b0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B80 B90 | B81 B91 | B82 B92 | B83 B93 | + // b0[1]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | + // b1[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | + // b1[1]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | + // b2[0]: | B04 B14 | B05 B15 | B06 B16 | B07 B17 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // b2[1]: | B44 B54 | B45 B55 | B46 B56 | B47 B57 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // b3[0]: | B24 B34 | B25 B35 | B26 B36 | B27 B37 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // b3[1]: | B64 B74 | B65 B75 | B66 B76 | B67 B77 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), vreinterpret_u32_u16(b2.val[0])); + uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), vreinterpret_u32_u16(b2.val[1])); + uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b1.val[0]), vreinterpret_u32_u16(b3.val[0])); + uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b1.val[1]), vreinterpret_u32_u16(b3.val[1])); + + // c0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // c0[1]: | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // c1[0]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // c1[1]: | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // c2[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // c2[1]: | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // c3[0]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // c3[1]: | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + v0 = vreinterpret_u8_u32(c0.val[0]); + v1 = vreinterpret_u8_u32(c2.val[0]); + v2 = vreinterpret_u8_u32(c1.val[0]); + v3 = vreinterpret_u8_u32(c3.val[0]); + v4 = vreinterpret_u8_u32(c0.val[1]); + v5 = vreinterpret_u8_u32(c2.val[1]); + v6 = vreinterpret_u8_u32(c1.val[1]); + v7 = vreinterpret_u8_u32(c3.val[1]); +} + +MLAS_FORCEINLINE void +Transpose4x8(float16x8_t& v0, float16x8_t& v1, float16x8_t& v2, float16x8_t& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE void +Transpose4x4(float16x4_t& v0, float16x4_t& v1, float16x4_t& v2, float16x4_t& v3) +{ + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); + constexpr size_t nbits = 4; + constexpr size_t k_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % k_blk_dim == 0); + + const size_t k_blk_num = MlasDivRoundup(K, k_blk_dim); + const size_t n_blk_num = MlasDivRoundup(N, n_blk_dim); + constexpr size_t k_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, k_blk_dim); + const size_t iterations = k_blk_num * n_blk_num; // one iteration per block + const size_t ld = MlasDivRoundup(K, BlkLen) * MlasQNBitBlkDataSizeInBytes(nbits, BlkLen); + + // + // For blocks 16_K * 8_N, transpose bytes in 8x8 blocks like this: + // src B_k_n: + // | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + // => dst: + // | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + // + + // + // For blocks < 8_N: + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + MlasTrySimpleParallel( + ThreadPool, iterations, + [&](ptrdiff_t tid) { + const size_t n_blk = tid / k_blk_num; + const size_t k_blk = tid % k_blk_num; + size_t n = n_blk * n_blk_dim; + const size_t src_offset = n * ld + k_blk * k_blk_bytes; + + if (n + n_blk_dim <= N) { + const size_t dst_offset = n * ld + k_blk * k_blk_bytes * n_blk_dim; + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + dst_offset; + + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v1 = vld1_u8(src + ld); + uint8x8_t v2 = vld1_u8(src + 2*ld); + uint8x8_t v3 = vld1_u8(src + 3*ld); + uint8x8_t v4 = vld1_u8(src + 4*ld); + uint8x8_t v5 = vld1_u8(src + 5*ld); + uint8x8_t v6 = vld1_u8(src + 6*ld); + uint8x8_t v7 = vld1_u8(src + 7*ld); + + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + vst1_u8(dst, v0); + vst1_u8(dst + 8, v1); + vst1_u8(dst + 16, v2); + vst1_u8(dst + 24, v3); + vst1_u8(dst + 32, v4); + vst1_u8(dst + 40, v5); + vst1_u8(dst + 48, v6); + vst1_u8(dst + 56, v7); + } else { + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + src_offset; + + for (; n < N; ++n, src += ld, dst += ld) { + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v_even = vand_u8(v0, vdup_n_u8(0x0F)); + uint8x8_t v_odd = vshr_n_u8(v0, 4); + uint8x8x2_t v1 = vzip_u8(v_even, v_odd); + uint8x8_t v2 = vorr_u8(v1.val[0], vshl_n_u8(v1.val[1], 4)); + vst1_u8(dst, v2); + } + } + } + ); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t b01 = vld1_u8(src_ptr); + uint8x8_t b23 = vld1_u8(src_ptr + 8); + uint8x8_t b45 = vld1_u8(src_ptr + 16); + uint8x8_t b67 = vld1_u8(src_ptr + 24); + uint8x8_t b89 = vld1_u8(src_ptr + 32); + uint8x8_t bab = vld1_u8(src_ptr + 40); + uint8x8_t bcd = vld1_u8(src_ptr + 48); + uint8x8_t bef = vld1_u8(src_ptr + 56); + + float16x8_t b0 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b01, low_mask), 0)); + float16x8_t b1 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b01, 4), 0)); + float16x8_t b2 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b23, low_mask), 0)); + float16x8_t b3 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b23, 4), 0)); + float16x8_t b4 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b45, low_mask), 0)); + float16x8_t b5 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b45, 4), 0)); + float16x8_t b6 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b67, low_mask), 0)); + float16x8_t b7 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b67, 4), 0)); + float16x8_t b8 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b89, low_mask), 0)); + float16x8_t b9 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b89, 4), 0)); + float16x8_t ba = vcvtq_f16_u16(vshll_n_u8(vand_u8(bab, low_mask), 0)); + float16x8_t bb = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bab, 4), 0)); + float16x8_t bc = vcvtq_f16_u16(vshll_n_u8(vand_u8(bcd, low_mask), 0)); + float16x8_t bd = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bcd, 4), 0)); + float16x8_t be = vcvtq_f16_u16(vshll_n_u8(vand_u8(bef, low_mask), 0)); + float16x8_t bf = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bef, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, b0, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, b1, scale); + float16x8_t c2 = vfmaq_f16(neg_scaled_zp, b2, scale); + float16x8_t c3 = vfmaq_f16(neg_scaled_zp, b3, scale); + float16x8_t c4 = vfmaq_f16(neg_scaled_zp, b4, scale); + float16x8_t c5 = vfmaq_f16(neg_scaled_zp, b5, scale); + float16x8_t c6 = vfmaq_f16(neg_scaled_zp, b6, scale); + float16x8_t c7 = vfmaq_f16(neg_scaled_zp, b7, scale); + float16x8_t c8 = vfmaq_f16(neg_scaled_zp, b8, scale); + float16x8_t c9 = vfmaq_f16(neg_scaled_zp, b9, scale); + float16x8_t ca = vfmaq_f16(neg_scaled_zp, ba, scale); + float16x8_t cb = vfmaq_f16(neg_scaled_zp, bb, scale); + float16x8_t cc = vfmaq_f16(neg_scaled_zp, bc, scale); + float16x8_t cd = vfmaq_f16(neg_scaled_zp, bd, scale); + float16x8_t ce = vfmaq_f16(neg_scaled_zp, be, scale); + float16x8_t cf = vfmaq_f16(neg_scaled_zp, bf, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); + MlasStoreFloat16x8(dst_ptr + 16, c2); + MlasStoreFloat16x8(dst_ptr + 24, c3); + MlasStoreFloat16x8(dst_ptr + 32, c4); + MlasStoreFloat16x8(dst_ptr + 40, c5); + MlasStoreFloat16x8(dst_ptr + 48, c6); + MlasStoreFloat16x8(dst_ptr + 56, c7); + MlasStoreFloat16x8(dst_ptr + 64, c8); + MlasStoreFloat16x8(dst_ptr + 72, c9); + MlasStoreFloat16x8(dst_ptr + 80, ca); + MlasStoreFloat16x8(dst_ptr + 88, cb); + MlasStoreFloat16x8(dst_ptr + 96, cc); + MlasStoreFloat16x8(dst_ptr + 104, cd); + MlasStoreFloat16x8(dst_ptr + 112, ce); + MlasStoreFloat16x8(dst_ptr + 120, cf); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 1 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t v0 = vld1_u8(src_ptr); + + float16x8_t f_low = vcvtq_f16_u16(vshll_n_u8(vand_u8(v0, low_mask), 0)); + float16x8_t f_high = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(v0, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, f_low, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, f_high, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); +} + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +) { + MLAS_UNREFERENCED_PARAMETER(K); + constexpr size_t nbits = 4; + constexpr size_t kk_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % kk_blk_dim == 0); + + const size_t kk_blk_num = BlockCountK * BlkLen / kk_blk_dim; + constexpr size_t kk_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, kk_blk_dim); + const size_t kk_n_src_bytes = kk_blk_bytes * n_blk_dim; + const size_t kk_n_dst_size = kk_blk_dim * n_blk_dim; + const size_t ld_blk_src = kk_blk_num * kk_n_src_bytes; + const size_t ld_blk_dst = BlkLen * BlockCountK * n_blk_dim; + const size_t ld_blk_scale = BlockCountK * n_blk_dim; + const size_t ld_zp = (BlockCountK + 1) / 2; + const size_t ld_blk_zp = ld_zp * n_blk_dim; + const float16x8_t zp_mid_point_vec = MlasBroadcastFloat16x8(MLAS_FP16(8.0f).val); + const bool has_zp = QuantBZeroPoint != nullptr; + + size_t n = 0; + for (; n + n_blk_dim <= CountN; n += n_blk_dim) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + const std::uint8_t* src_ptr = reinterpret_cast(QuantBData); + auto* dst_ptr = reinterpret_cast<_mlas_fp16_*>(FpData); + + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + // prepare scales and zero_points for the block + _mlas_fp16_ scales[n_blk_dim]; + uint16_t zero_points[n_blk_dim]; + float16x8_t scale_vec; + float16x8_t neg_scaled_zp_vec; + + UnrolledLoop([&](int nn){ + scales[nn] = scales_ptr[nn * BlockCountK]; + }); + scale_vec = MlasLoadFloat16x8(scales); + + if (has_zp) { + UnrolledLoop([&](int nn){ + uint8_t zp = zero_points_ptr[nn * ld_zp]; + zp = (k_blk_i & 1) ? (zp >> 4) : (zp & 0x0F); + zero_points[nn] = static_cast(zp); + }); + uint16x8_t zp_u16_vec = vld1q_u16(zero_points); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<8, 16>(src_ptr, scale_vec, neg_scaled_zp_vec, dst_ptr); + + src_ptr += kk_n_src_bytes; + dst_ptr += kk_n_dst_size; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBData += ld_blk_src; + FpData += ld_blk_dst; + QuantBScale += ld_blk_scale; + QuantBZeroPoint = has_zp ? QuantBZeroPoint + ld_blk_zp : nullptr; + } + + // remaining N + for (; n < CountN; ++n) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + const auto scale = scales_ptr[0]; + float16x8_t scale_vec = MlasBroadcastFloat16x8(scale); + float16x8_t neg_scaled_zp_vec; + + if (has_zp) { + uint8_t zero_point = static_cast(zero_points_ptr[0]); + zero_point = (k_blk_i & 1) ? (zero_point >> 4) : (zero_point & 0x0F); + uint16x8_t zp_u16_vec = vdupq_n_u16(static_cast(zero_point)); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<1, 16>( + reinterpret_cast(QuantBData), scale_vec, neg_scaled_zp_vec, + reinterpret_cast<_mlas_fp16_*>(FpData) + ); + + QuantBData += kk_blk_bytes; + FpData += kk_blk_dim; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBScale += BlockCountK; + if (has_zp) { + QuantBZeroPoint += ld_zp; + } + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8), float16x8_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x8(Bias); + } else { + return MlasZeroFloat16x8(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 4), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x4(Bias); + } else { + return MlasZeroFloat16x4(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N == 2 || N == 1)), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + float16x4_t v = MlasZeroFloat16x4(); + + if (Bias) { + v = MlasLoadLaneFloat16x4<0>(Bias, v); + if constexpr (N == 2) { + v = MlasLoadLaneFloat16x4<1>(Bias + 1, v); + } + return v; + } else { + return v; + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 8), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x8_t a0 = MlasLoadFloat16x8(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + float16x8_t b4 = MlasLoadFloat16x8(B + 32); + float16x8_t b5 = MlasLoadFloat16x8(B + 40); + float16x8_t b6 = MlasLoadFloat16x8(B + 48); + float16x8_t b7 = MlasLoadFloat16x8(B + 56); + + // This version uses less instructions, but introduces dependency path between instructions. + // Must pair it with loop unrolling to alleviate dependency path penalty. + float16x8_t c0 = vfmaq_laneq_f16(accumulator, b0, a0, 0); + c0 = vfmaq_laneq_f16(c0, b1, a0, 1); + c0 = vfmaq_laneq_f16(c0, b2, a0, 2); + c0 = vfmaq_laneq_f16(c0, b3, a0, 3); + c0 = vfmaq_laneq_f16(c0, b4, a0, 4); + c0 = vfmaq_laneq_f16(c0, b5, a0, 5); + c0 = vfmaq_laneq_f16(c0, b6, a0, 6); + c0 = vfmaq_laneq_f16(c0, b7, a0, 7); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 4), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0); + c0 = vfmaq_lane_f16(c0, b1, a0, 1); + c0 = vfmaq_lane_f16(c0, b2, a0, 2); + c0 = vfmaq_lane_f16(c0, b3, a0, 3); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && (K == 2 || K == 1)), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasZeroFloat16x4(); + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + if constexpr (K == 2) a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + float16x8_t b0 = MlasLoadFloat16x8(B), b1; + if constexpr (K == 2) b1 = MlasLoadFloat16x8(B + 8); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0), c01; + if constexpr (K == 2) c01 = vfmaq_lane_f16(c0, b1, a0, 1); + + if constexpr (K == 1) + return c0; + else + return c01; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && K == 8), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x8_t a0 = MlasLoadFloat16x8(A); + + float16x8_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x8(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x8(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x8(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x8(B + ldb * 3); + + float16x8_t c00, c01, c02, c03; + c00 = vmulq_f16(b0, a0); + if constexpr (N > 1) + c01 = vmulq_f16(b1, a0); + else + c01 = MlasZeroFloat16x8(); + if constexpr (N > 2) + c02 = vmulq_f16(b2, a0); + else + c02 = MlasZeroFloat16x8(); + if constexpr (N > 3) + c03 = vmulq_f16(b3, a0); + else + c03 = MlasZeroFloat16x8(); + + Transpose4x8(c00, c01, c02, c03); + + float16x8_t c_low_high = vaddq_f16(vaddq_f16(c00, c01), vaddq_f16(c02, c03)); + float16x4_t c_low = vget_low_f16(c_low_high); + float16x4_t c_high = vget_high_f16(c_low_high); + float16x4_t c = vadd_f16(c_low, c_high); + + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K == 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x4_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x4(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x4(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x4(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x4(B + ldb * 3); + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K > 0 && K < 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasZeroFloat16x4(), b1, b2, b3; + if constexpr (N > 1) b1 = MlasZeroFloat16x4(); + if constexpr (N > 2) b2 = MlasZeroFloat16x4(); + if constexpr (N > 3) b3 = MlasZeroFloat16x4(); + + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + b0 = MlasLoadLaneFloat16x4<0>(B, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<0>(B + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<0>(B + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<0>(B + ldb * 3, b3); + + if constexpr (K >= 2) { + a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + b0 = MlasLoadLaneFloat16x4<1>(B + 1, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 3, b3); + } + + if constexpr (K >= 3) { + a0 = MlasLoadLaneFloat16x4<2>(A + 2, a0); + b0 = MlasLoadLaneFloat16x4<2>(B + 2, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 3, b3); + } + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +typename std::enable_if_t<((CountN >= 1 && CountN <= 16 && ((CountN - 1) & CountN) == 0) && (CountM == 1 || CountM == 2)), void> +HQ4BitGemmKernel_CompFp16_Kernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const _mlas_fp16_* Bias, + _mlas_fp16_* C, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + using RegisterType = typename std::conditional_t<(CountN < 8), float16x4_t, float16x8_t>; + + RegisterType accu00, accu01, accu10, accu11; + constexpr size_t b_step = CountN >= 8 ? 8 : 1; + constexpr size_t N = CountN == 16 ? 8 : CountN; + + if constexpr (CountM == 2) { + accu00 = accu10 = PrepareAccumulator(Bias); + } else { + accu00 = PrepareAccumulator(Bias); + } + if constexpr (CountN == 16) { + if constexpr (CountM == 2) { + accu01 = accu11 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } else { + accu01 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } + } + + size_t k = 0; + for (; k + 8 <= K; k += 8, A += 8, B += b_step * 8) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if (K & 4) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 4, A += 4, B += b_step * 4; + } + + if (K & 2) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 2, A += 2, B += b_step * 2; + } + + if (k < K) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C, accu00); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + 8, accu01); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C, accu00); + } else { + MlasStoreLaneFloat16x4<0>(C, accu00); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + 1, accu00); + } + } + + if constexpr (CountM == 2) { + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C + ldc, accu10); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C + ldc, accu10); + } else { + MlasStoreLaneFloat16x4<0>(C + ldc, accu10); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + ldc + 1, accu10); + } + } + } +} + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + assert(CountM <= 2); + + // 2M_16N is the balance between loop unrolling and register spill. + // More unroll will trigger register spill. + // Less unroll will increase micro kernel dependency path penalty. + // TODO: dequant 16N as continuous segments. Current version dequants 8N. + const auto* a = reinterpret_cast(A); + const auto* b = reinterpret_cast(B); + const auto* bias = reinterpret_cast(Bias); + auto* c = reinterpret_cast<_mlas_fp16_*>(C); + + for (; CountN >= 16; CountN -= 16) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<16, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<16, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 16 * ldb, c += 16; + if (bias) bias += 16; + } + + if (CountN & 8) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<8, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<8, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 8 * ldb, c += 8; + if (bias) bias += 8; + } + + if (CountN & 4) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<4, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<4, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 4 * ldb, c += 4; + if (bias) bias += 4; + } + + if (CountN & 2) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<2, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<2, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 2 * ldb, c += 2; + if (bias) bias += 2; + } + + if (CountN & 1) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<1, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<1, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + } +} +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 14b861f01f6f7..785be8ec4e3c8 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -359,6 +359,22 @@ size_t bool ZeroMode ); +#ifdef FORCE_GENERIC_ALGORITHMS +typedef +size_t +(MLASCALL MLAS_GEMM_FLOAT_KERNEL_GENERIC)( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha + ); +#endif + #else #if defined(__aarch64__) && defined(__linux__) @@ -756,6 +772,10 @@ extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelSse; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx; +#ifdef FORCE_GENERIC_ALGORITHMS + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelZero; + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelAdd; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelFma3; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx512F; @@ -1046,17 +1066,24 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; // Float/quantized n-bit integer matrix/matrix multiply dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH; +struct MLAS_QNBIT_GEMM_DISPATCH; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; + +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; + +// +// Rotary embedding dispatch structure. +// +struct MLAS_ROPE_DISPATCH; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; // // Quantized depthwise convolution kernels. @@ -1213,13 +1240,15 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; - const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; - MLAS_QUANTIZE_SOFTMAX_I8_KERNEL *QuantizeSoftmaxI8Kernel; - MLAS_QUANTIZE_SOFTMAX_U8_KERNEL *QuantizeSoftmaxU8Kernel; + const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; + + MLAS_QUANTIZE_SOFTMAX_I8_KERNEL* QuantizeSoftmaxI8Kernel; + MLAS_QUANTIZE_SOFTMAX_U8_KERNEL* QuantizeSoftmaxU8Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index c0ae041701646..374a9da77d2c4 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -287,7 +287,11 @@ Return Value: this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; #ifndef __APPLE__ +#ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; +#else // FORCE_GENERIC_ALGORITHMS + this->CastF16ToF32Kernel = nullptr; +#endif // FORCE_GENERIC_ALGORITHMS #endif // __APPLE__ this->NchwcBlockSize = 8; @@ -309,8 +313,11 @@ Return Value: // // Check if the processor supports SSE 4.1 instructions. // - +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x80000) != 0) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchSse41; } @@ -320,7 +327,11 @@ Return Value: // Check if the processor supports the AVX and OSXSAVE features. // +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x18000000) == 0x18000000) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS // // Check if the operating system supports saving SSE and AVX states. @@ -388,7 +399,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; @@ -420,7 +431,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) @@ -461,7 +472,7 @@ Return Value: this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelAvx512; this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelAvx512; @@ -476,7 +487,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } @@ -536,6 +547,8 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; + this->RopeDispatch = &MlasRopeDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -565,9 +578,6 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; - - // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index 1ef5b5f7411f0..bcd878efa681b 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -867,7 +867,8 @@ MlasGemmQuantGetDispatch( { const MLAS_GEMM_QUANT_DISPATCH* GemmQuantDispatch = &MlasGemmQuantDispatchDefault; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) +#if !defined(FORCE_GENERIC_ALGORITHMS) +#if defined(MLAS_TARGET_AMD64_IX86) if (AIsSigned) { GemmQuantDispatch = BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch; @@ -895,7 +896,13 @@ MlasGemmQuantGetDispatch( if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchPOWER10) { GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; } +#elif defined(MLAS_TARGET_LARCH64) + if (!AIsSigned) { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; + } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) if (nullptr == GemmQuantDispatch) { std::stringstream ss; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp similarity index 62% rename from onnxruntime/core/mlas/lib/sqnbitgemm.cpp rename to onnxruntime/core/mlas/lib/qnbitgemm.cpp index a45494ef2e04f..f064a8e1d6a78 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -6,16 +6,16 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.cpp + qnbitgemm.cpp Abstract: This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + multiplication hardware agnostic entrypoint, MlasQNBitGemmBatch, as well as some SQNBitGemm-related query functions. --*/ -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" #include @@ -23,35 +23,40 @@ Module Name: namespace { -enum SQNBitGemmVariant { +enum QNBitGemmVariant { SQNBitGemmVariantInvalid = -1, // Valid variants SQNBitGemmVariant_BitWidth4_CompFp32 = 0, SQNBitGemmVariant_BitWidth4_CompInt8, + HQNBitGemmVariant_BitWidth4_CompFp16, + HQNBitGemmVariant_BitWidth4_CompInt8, // End of valid variants - // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. // Its value is used as an array size. SQNBitGemmVariantCount, }; -SQNBitGemmVariant -GetSQNBitGemmVariant( +QNBitGemmVariant +GetQNBitGemmVariant( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == CompFp32 || - ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + if (ComputeType == SQNBIT_CompFp32) { return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8) { + } else if (ComputeType == HQNBIT_CompFp16) { + return HQNBitGemmVariant_BitWidth4_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -61,23 +66,28 @@ GetSQNBitGemmVariant( } // namespace bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return false; } - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && - Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; + Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case HQNBitGemmVariant_BitWidth4_CompFp16: { + return Dispatch->HQ4BitGemmPackQuantBData != nullptr && + Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && + Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return @@ -94,80 +104,80 @@ namespace { size_t -SQNBitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); } return 0; } size_t -SQNBitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 1; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } return 1; } size_t -SQNBitGemmPerGemmWorkspaceStride( +QNBitGemmPerGemmWorkspaceStride( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); - const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const auto Size = QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); return MlasDivRoundup(Size, Alignment) * Alignment; } } // namespace size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); if (PerGemmWorkspaceStride == 0) { return 0; } - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; @@ -175,21 +185,21 @@ MlasSQNBitGemmBatchWorkspaceSize( } size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { - return Dispatch->SQ4BitGemmPackQuantBDataSize( + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q4BitGemmPackQuantBDataSize( N, K, BlkLen, ComputeType ); } @@ -213,12 +223,12 @@ struct PerGemmQuantAWorkspace { }; void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSumWorkspace, const void* QuantBScale, @@ -227,15 +237,15 @@ MlasSQNBitGemmPackQuantBData( MLAS_THREADPOOL* ThreadPool ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return; } if (BlkBitWidth == 4) { - if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -248,6 +258,16 @@ MlasSQNBitGemmPackQuantBData( packed_quant_b, ThreadPool ); + } else if (ComputeType == HQNBIT_CompFp16 && Dispatch->HQ4BitGemmPackQuantBData != nullptr) { + Dispatch->HQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. //assert(QuantBScale == nullptr); @@ -295,22 +315,11 @@ AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t } } -typedef void(SQNBitGemmFn)( - size_t BlkLen, - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* PerGemmWorkspace, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - void SQ4BitGemm_CompFp32( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -355,7 +364,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); @@ -393,7 +402,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -425,11 +434,84 @@ SQ4BitGemm_CompFp32( } } +void +HQ4BitGemm_CompFp16( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const size_t k_blk_num = MlasDivRoundup(K, BlkLen); + const size_t qldb = k_blk_num * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ldb = k_blk_num * BlkLen; + const size_t k_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blk_num); + + const MLAS_FP16* A = DataParams->A + RangeStartM * lda; + MLAS_FP16* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * qldb; + const MLAS_FP16* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blk_num; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_zp_bytes; + const MLAS_FP16* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias; + + // 32N is the sweet spot of cache utilization. It is machine dependent though. + constexpr size_t StrideM = 2; + constexpr size_t StrideN = 32; + + // TODO(fajin): move allocation up to the op. + size_t bufsize = ldb * StrideN * sizeof(MLAS_FP16); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant_b, QuantBData, QuantBScale, QuantBZeroPoint, countN, K, k_blk_num + ); + + const MLAS_FP16* a = A; + MLAS_FP16* c = C; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + a, dequant_b, Bias, c, countM, countN, K, lda, ldb, ldc + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + m, RangeStartN + n, countM, countN, ldc + ); + } + + a += countM * lda; + c += countM * ldc; + } + + QuantBData += countN * qldb; + QuantBScale += countN * k_blk_num; + QuantBZeroPoint = QuantBZeroPoint ? QuantBZeroPoint + countN * k_zp_bytes : nullptr; + Bias = Bias ? Bias + countN : nullptr; + C += countN; + } +} + void SQ4BitGemm_CompInt8( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -500,10 +582,10 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + const auto RowsHandled = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias ); @@ -522,10 +604,10 @@ SQ4BitGemm_CompInt8( } } #ifdef MLAS_TARGET_AMD64_IX86 - else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + else if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( BlkLen, QuantA, QuantAScale, @@ -554,26 +636,29 @@ SQ4BitGemm_CompInt8( } } -typedef void(InitializeWorkspaceFn)( +template +void +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ); +template <> void -InitializeWorkspace_CompInt8( +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool @@ -581,8 +666,8 @@ InitializeWorkspace_CompInt8( { MLAS_UNREFERENCED_PARAMETER(N); - const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); @@ -622,61 +707,153 @@ InitializeWorkspace_CompInt8( } } -struct Operations { - InitializeWorkspaceFn* InitializeWorkspace = nullptr; - SQNBitGemmFn* SQNBitGemm = nullptr; -}; +template <> +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(Workspace); + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} + +template +using InitializeWorkspaceFn = std::function* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +)>; -constexpr auto OperationMap = []() { - std::array ops; +template +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant); - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template +using QNBitGemmFn = std::function* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +)>; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; +template +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant); - return ops; -}(); +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: + return SQ4BitGemm_CompFp32; + case SQNBitGemmVariant_BitWidth4_CompInt8: + return SQ4BitGemm_CompInt8; + default: + return nullptr; + } +} + +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompFp16: + return HQ4BitGemm_CompFp16; + default: + return nullptr; + } +} } // namespace +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( const size_t M, const size_t N, const size_t K, const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // // Ensure `Workspace` has correct alignment. // if (Workspace != nullptr) { - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); Workspace = reinterpret_cast( (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) ); } - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); } - const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const auto ComputeOperation = GetQNBitGemm(Variant); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -685,11 +862,11 @@ MlasSQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -756,11 +933,11 @@ MlasSQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); @@ -769,3 +946,33 @@ MlasSQNBitGemmBatch( } }); } + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h similarity index 71% rename from onnxruntime/core/mlas/lib/sqnbitgemm.h rename to onnxruntime/core/mlas/lib/qnbitgemm.h index 2da336ca2f0ec..eb3d0b44ae3de 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.h + qnbitgemm.h Abstract: @@ -46,24 +46,25 @@ MlasAlignAddress(void* addr, const size_t alignment) return addr; } +template struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { - // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize + // TODO: duplicate code from Q4BitGemmPackQuantBDataSize constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); // _mm256_load_si256 requires alignment on a 32-byte boundary PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); - QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); } std::byte* PackedQuantBData; - float* PackedQuantBScale; - float* QuantBBlkSum; + T* PackedQuantBScale; + T* QuantBBlkSum; void* QuantBWorkspace_; size_t N_, BlockCountK_, BlkLen_; @@ -84,44 +85,45 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) // Kernel dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH { +struct MLAS_QNBIT_GEMM_DISPATCH { // // Quantized B data packing function prototypes. // - /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ - typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; - /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ - typedef void(SQ4BitGemmPackQuantBData_Fn)( + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ + typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool ); - SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ); @@ -141,15 +143,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; + Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -157,15 +159,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; + Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; // - // CompFp32 kernel function prototypes. + // SQNBIT_CompFp32 kernel function prototypes. // /** @@ -228,10 +230,41 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t BlockStrideQuantB ); - Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; + Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr; // - // CompInt8 kernel function prototypes. + // SQNBIT_CompInt8 kernel function prototypes. // /** @@ -337,4 +370,35 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply fp16 matrix A rows with fp16 matrix B columns. + * Results are written to fp16 matrix C. + * If bias is provided, the bias are added to the result. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param Bias the bias at the target column. Optional. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param K the number of columns of A matrix and rows of B matrix. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + */ + typedef void(HQ4BitGemmKernel_CompFp16_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc + ); + + HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp similarity index 74% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp rename to onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 3f32cc6c5312d..d05de64e68ec8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.cpp + qnbitgemm_kernel_neon.cpp Abstract: @@ -19,8 +19,8 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon @@ -34,11 +34,11 @@ namespace // size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType @@ -55,7 +55,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -69,7 +69,7 @@ SQ4BitGemmPackQuantBData( const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t Iterations = N * BlockCountK; // one iteration per block - const size_t SubBlkLen = (ComputeType == CompInt8) + const size_t SubBlkLen = (ComputeType == SQNBIT_CompInt8) ? ((BlkLen == 16) ? 16 : 32) : 16; @@ -126,18 +126,18 @@ SQ4BitGemmPackQuantBData( // size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); @@ -150,15 +150,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { @@ -175,20 +175,27 @@ SQ4BitGemmPerGemmWorkspaceAlignment( // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - - d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) { + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + } d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; + d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; + d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h similarity index 69% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h rename to onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index ef9345d7ac484..ccadd24ac1991 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + qnbitgemm_kernel_neon.h Abstract: @@ -30,13 +30,13 @@ namespace sqnbitgemm_neon // // Function declarations for SQNBitGemm ARM NEON kernel entry points. -// Refer to the prototypes in sqnbitgemm.h for documentation. +// Refer to the prototypes in qnbitgemm.h for documentation. // These are declared here so they can be used to initialize the -// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// MLAS_QNBIT_GEMM_DISPATCH structure and also be implemented in separate // files. // -// CompFp32 declarations +// SQNBIT_CompFp32 declarations void SQ4BitGemmM1Kernel_CompFp32( @@ -53,7 +53,7 @@ SQ4BitGemmM1Kernel_CompFp32( ); void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, @@ -64,7 +64,48 @@ Q4BitBlkDequantBForSgemm_CompFp32( size_t BlockCountK ); -// CompInt8 declarations +// HQNBIT_CompFp16 declarations +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +); + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +); + +#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)) + +// SQNBIT_CompInt8 declarations void QuantizeARow_CompInt8( diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp new file mode 100644 index 0000000000000..1f8f7b240694c --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp @@ -0,0 +1,101 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.cpp + +Abstract: + + This module implements rotary embedding kernels for fp32/16. + +--*/ + +#include "rotary_embedding.h" + +namespace { + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +) { + const size_t half_rotary_emb_dim = rotary_emb_dim / 2; + size_t cache_idx = 0; + bool sign = false; + size_t j = 0; + for (size_t i = 0; i < rotary_emb_dim; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_rotary_emb_dim; + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign + } else { + cache_idx = i % half_rotary_emb_dim; + sign = (i >= half_rotary_emb_dim); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; + } + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); + } +} + +} // namespace + + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->SRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + return; + } + + dispatch->SRope(input, sin, cos, dim, interleaved, output); +} + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->HRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output); + return; + } + + dispatch->HRope(input, sin, cos, dim, interleaved, output); +} diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h new file mode 100644 index 0000000000000..352dddccf1025 --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding.h @@ -0,0 +1,46 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing rotary embedding. + +--*/ + +#pragma once + +#include "mlasi.h" + +struct MLAS_ROPE_DISPATCH { + // rotary embedding kernel for fp32 + typedef void(SRope_Fn)( + const float* input, + const float* sin, + const float* cos, + size_t dim, + bool interleaved, + float* output + ); + + SRope_Fn* SRope = nullptr; + + // rotary embedding kernel for fp16 + typedef void(HRope_Fn)( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output + ); + + HRope_Fn* HRope = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp new file mode 100644 index 0000000000000..e59a95cd9ee4e --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.cpp + +Abstract: + + This module implements the rotary embedding kernels for ARM NEON. + +--*/ + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() { + MLAS_ROPE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.HRope = rope_neon::RopeKernel_Fp16; + } +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h new file mode 100644 index 0000000000000..8153f65650f7d --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace rope_neon { + +// Rotary embedding kernel for fp16. Embed one hidden state vector. +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +); + +} // namespace rope_neon diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..3e2eb8fee0e6e --- /dev/null +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -0,0 +1,253 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 rotary embedding kernels for ARM NEON. + +--*/ + +#include +#include + +#include "fp16_common.h" +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +namespace rope_neon { + +namespace { + +template +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +); + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float16x8_t real = MlasLoadFloat16x8(input + i); + float16x8_t imag = MlasLoadFloat16x8(input + j); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x8(output + i, real_out); + MlasStoreFloat16x8(output + j, imag_out); + } + for (; i + 3 < half_dim; i += 4, j += 4) { + float16x4_t real = MlasLoadFloat16x4(input + i); + float16x4_t imag = MlasLoadFloat16x4(input + j); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x4(output + i, real_out); + MlasStoreFloat16x4(output + j, imag_out); + } + if (half_dim - i == 3) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + real = MlasLoadLaneFloat16x4<2>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out); + } else if (half_dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + } else if (half_dim - i == 1) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + } +} + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float16x8_t x0 = MlasLoadFloat16x8(input + i); + float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); + float16x8_t real = vuzp1q_f16(x0, x1); + float16x8_t imag = vuzp2q_f16(x0, x1); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + float16x8_t y0 = vzip1q_f16(real_out, imag_out); + float16x8_t y1 = vzip2q_f16(real_out, imag_out); + MlasStoreFloat16x8(output + i, y0); + MlasStoreFloat16x8(output + i + 8, y1); + } + for (; i + 7 < dim; i += 8) { + float16x4_t x0 = MlasLoadFloat16x4(input + i); + float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); + float16x4_t real = vuzp1_f16(x0, x1); + float16x4_t imag = vuzp2_f16(x0, x1); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + float16x4_t y0 = vzip1_f16(real_out, imag_out); + float16x4_t y1 = vzip2_f16(real_out, imag_out); + MlasStoreFloat16x4(output + i, y0); + MlasStoreFloat16x4(output + i + 4, y1); + } + if (dim - i == 6) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); + imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + MlasStoreLaneFloat16x4<2>(output + i + 4, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out); + } else if (dim - i == 4) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + } else if (dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + } +} + +} // namespace + +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin); + const auto* cos_impl = reinterpret_cast(cos); + auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output); + + if (interleaved) { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +} // namespace rope_neon diff --git a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp b/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp index 62729256dac23..cbec5d89bbac7 100644 --- a/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp +++ b/onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp @@ -83,6 +83,8 @@ Return Value: #endif + int countb = 0; + do { float BElements00; @@ -116,6 +118,7 @@ Return Value: // const float* a = A; + const float* b = B; size_t k = CountK; while (k >= 2) { @@ -128,10 +131,10 @@ Return Value: Row1AElements1 = a[lda + 1]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -144,10 +147,10 @@ Return Value: Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - BElements00 = B[4]; - BElements01 = B[5]; - BElements02 = B[6]; - BElements03 = B[7]; + BElements00 = b[16]; + BElements01 = b[17]; + BElements02 = b[18]; + BElements03 = b[19]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements1; Row0Block01 = Row0Block01 + BElements01 * Row0AElements1; Row0Block02 = Row0Block02 + BElements02 * Row0AElements1; @@ -161,7 +164,7 @@ Return Value: } a += 2; - B += 8; + b += 32; k -= 2; } @@ -173,10 +176,10 @@ Return Value: Row1AElements0 = a[lda]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -188,8 +191,6 @@ Return Value: Row1Block02 = Row1Block02 + BElements02 * Row1AElements0; Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - - B += 4; } // @@ -295,9 +296,14 @@ Return Value: break; } + B += 4; C += 4; CountN -= 4; + countb = (countb + 1) % 4; + if (countb == 0) { + B += CountK * 16 - 16; + } } while (CountN > 0); return ProcessTwoRows ? 2 : 1; diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 4d7a1ceb4eee7..f8b25fb42caf3 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { @@ -1158,6 +1158,7 @@ Return Value: if (M == 1 && TransA == CblasNoTrans && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { +#if !defined(FORCE_GENERIC_ALGORITHMS) #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; @@ -1181,6 +1182,7 @@ Return Value: } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) } @@ -1193,7 +1195,7 @@ Return Value: if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) && !defined(FORCE_GENERIC_ALGORITHMS) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index baaa4ba1a3b1f..81615da46aa2e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" @@ -1306,12 +1306,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -1319,9 +1319,9 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - // TODO: always use SubBlkLen = 64 in CompInt8 + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { SubBlkLen = 64; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -1330,18 +1330,18 @@ SQ4BitGemmPackQuantBDataAndBlkSum( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; @@ -1349,18 +1349,18 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { return d; }(); -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 80d67806ea6e8..445ead329acf8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index af6f52090adcb..5dab8091ce760 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 174ebc580904c..d4b89bd9bad2d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template @@ -117,7 +117,7 @@ accumulate_blklen64_r1c1blk1_avx2( __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 13bd369a065bb..b4e25d4e4040a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx512_int8_blklen16.h" @@ -28,7 +28,7 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // #include "sqnbitgemm_kernel_avx_common_fp32.h" @@ -151,7 +151,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( } // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // MLAS_FORCEINLINE @@ -332,12 +332,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -346,24 +346,24 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h index 7d9dc36854621..8f1ea6676b788 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" @@ -81,7 +81,7 @@ accumulate_blklen32_r2c1blk2_avx2( _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -143,7 +143,7 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); @@ -184,7 +184,7 @@ accumulate_blklen32_r2c1blk1_avx2( const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - + const int8_t zp = get_zp(true, QuantBZeroPointPtr); const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); @@ -435,7 +435,7 @@ Q4Int8Gemm2x4BlkLen32Avx2( } } -template +template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const std::byte* QuantBData, @@ -877,7 +877,7 @@ MLAS_FORCEINLINE QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, C + multipleRows * ldc + multipleCols, remainingRows, - remainingCols, + remainingCols, BlockCountK, Bias ? Bias + multipleCols : nullptr, lda, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index 60a887345d0e0..d79554c34c108 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index bb14babd6c2b1..03064886caf24 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index e9df6b952bd27..3b1096ac05ba7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2a65ac4af0c1d..72ce28d834199 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" static MLAS_FORCEINLINE __m256 diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6a5c01162c51b..a4468bb906bbc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" @@ -314,12 +314,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -328,7 +328,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -337,18 +337,18 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 177f5518bb891..b0367b7fb9a15 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" // @@ -7,16 +7,16 @@ // static size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); @@ -39,7 +39,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -304,7 +304,7 @@ PackQuantBDataAndBlkSum( const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -326,18 +326,18 @@ PackQuantBDataAndBlkSum( // static size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch(ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); // QuantData + Scale + BlkSum @@ -351,15 +351,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } static size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h index 5cd380e591098..d15cfc782e125 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_fp32.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" template MLAS_FORCEINLINE diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 895ce6cd091c2..2e96082968866 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index 12ddc42506e98..31a499b8243af 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompFp32. --*/ @@ -21,8 +21,8 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon { @@ -31,7 +31,7 @@ namespace { // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // MLAS_FORCEINLINE void @@ -608,7 +608,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_Impl( } // namespace void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index 0d62ea37b7e26..73beb06a3cfad 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8. --*/ @@ -21,15 +21,15 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon { // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // namespace diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 45c3963365e6b..941b884d0b9d2 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h index e9c3812bde899..ed78dfa67042d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index 267a82b72670c..935114c40d1a7 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "onnx/defs/shape_inference.h" #include "onnx/defs/tensor_proto_util.h" #include "core/framework/tensorprotoutils.h" @@ -767,7 +768,8 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No } // check where has X=-Infinity - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(where.InputDefs()[1]), -INFINITY, true)) { + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(where.InputDefs()[1]), + -std::numeric_limits::infinity(), true)) { DEBUG_LOG("where const not matched."); return false; } diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 1466de51d0b99..e755b4bfa6364 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, #if !defined(DISABLE_SPARSE_TENSORS) // Create execution frame for executing constant nodes. OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - is_sparse_initializer_check); + is_sparse_initializer_check, logger); #else // Create execution frame for executing constant nodes. - OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + {node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; }, + logger); #endif std::vector fetch_mlvalue_idxs; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f769d31092d19..ba2b87b5aa0ca 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -63,6 +63,7 @@ #ifdef MLAS_TARGET_AMD64_IX86 #include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h" #endif +#include "core/optimizer/qdq_transformer/bias_quantization.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" @@ -189,6 +190,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -243,6 +245,7 @@ InlinedVector> GenerateTransformers( if (!disable_quant_qdq) { transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); // EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input // condition that it ensures is important for the partitioning that happens after Level1 optimizers are run. @@ -402,7 +405,8 @@ InlinedVector> GenerateTransformers( } auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } @@ -435,6 +439,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -488,7 +493,8 @@ InlinedVector> GenerateTransformersForMinimalB #ifndef DISABLE_CONTRIB_OPS AllocatorPtr cpu_allocator = std::make_shared(); auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 67ebc22dab41d..b1665c7172549 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // going to a node that will need a Cast. // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. -static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, + const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), // does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node), @@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const KernelCreateInfo* kernel_create_info{}; const auto lookup_status = cpu_kernel_registry.TryFindKernel( kCpuExecutionProvider, node.OpType(), node.Domain(), - node.SinceVersion(), type_constraint_map, &kernel_create_info); + node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return true; } @@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: return false; } -static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) { // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 node.SetExecutionProviderType(""); } @@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { return dst_bit_length <= src_bit_length; } - if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || + (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { return true; } @@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_)); + ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger)); GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 7953cde6686c0..56f7d28cd5b77 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -177,7 +177,11 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid for (size_t i = 2; i < node->Inputs().size(); i++) { auto constant = api_graph->GetConstant(node->Inputs()[i]); if (constant != nullptr && constant->Data().size() > 0) { - input_perms.push_back(&input_perm); + // Starting from opset version 18, the 'scales' and 'sizes' can be any length up to the input rank. + // However, our current implementation only supports the transposition of 4D tensors. + if (constant->NumElements() == 4) { + input_perms.push_back(&input_perm); + } } else { // TODO: Fix inconsistency. We should Transpose the non-const inputs so that the result of our changes // is consistent - all layout specific inputs are in NHWC format when we're done. diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index e944522c9c338..6b76dc626fba0 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons return false; } + // Checks the first input of MatMul has 2 dimensions. + // The test for the second input is done in method Apply as it accesses the constant. + if (node.InputDefs()[0] == nullptr) { + // This should never happen but just in case. + return false; + } + auto shape_a = node.InputDefs()[0]->Shape(); + if (shape_a == nullptr) { + // We cannot shape the rank. It is better to avoid fusing. + return false; + } + if (shape_a->dim_size() != 2) { + // Gemm only supports 2D tensors. + return false; + } + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. const auto& output_defs = batch_norm_node->OutputDefs(); if (output_defs.size() > 1) { @@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& bias_tensor->dims_size() != 1 || mean_tensor->dims_size() != 1 || var_tensor->dims_size() != 1 || + matmul_b_tensor->dims_size() != 2 || scale_tensor->dims(0) != matmul_b_tensor->dims(1) || bias_tensor->dims(0) != matmul_b_tensor->dims(1) || mean_tensor->dims(0) != matmul_b_tensor->dims(1) || diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 4fee1a6ce224e..b619efb2f751e 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -49,6 +49,49 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { return data_type == actual_data_type; } +// Return total mnumber of Elements. +static uint64_t NumElements(const TensorShapeProto* tensor_shape) { + if (nullptr == tensor_shape || tensor_shape->dim_size() < 1) { + return 0; + } + uint64_t num_elements = 1; + + for (int i = 0; i < tensor_shape->dim_size(); i++) { + num_elements *= tensor_shape->dim(i).dim_value(); + } + return num_elements; +} + +bool CheckMatMulLargeTensors(const Node& matmulinteger_node, const Node& cast_node) { + const auto a_def = matmulinteger_node.InputDefs()[0]; + const auto b_def = matmulinteger_node.InputDefs()[1]; + const int a_dim_size = a_def->Shape()->dim_size(); + const int b_dim_size = b_def->Shape()->dim_size(); + uint64_t a_num_elements = NumElements(a_def->Shape()); + uint64_t b_num_elements = NumElements(b_def->Shape()); + + if (a_dim_size != b_dim_size) { + bool a_is_broadcasted = a_dim_size < b_dim_size; + if (a_is_broadcasted) { + for (int i = 0; i < b_dim_size - a_dim_size; i++) { + a_num_elements *= b_def->Shape()->dim(i).dim_value(); + } + } else { + for (int i = 0; i < a_dim_size - b_dim_size; i++) { + b_num_elements *= a_def->Shape()->dim(i).dim_value(); + } + } + } + + int output_data_type = HasElementDataType(*cast_node.OutputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) ? 2 : 4; + uint64_t total_bytes = (a_num_elements + b_num_elements) * output_data_type; + + if (total_bytes > UINT32_MAX) { + return true; + } + return false; +} + /** MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat: @@ -114,6 +157,17 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g continue; } + const Node* p_dynamicquantize_node = graph_utils::FirstParentByType(*p_matmulinteger_node, "DynamicQuantizeLinear"); + + // Check MatMulInteger Nodes' input is coming from DynamicQuantizeLinear + // For larger tensors DynamicQuantizeLinear -> MatMulInteger is used to be resource efficient + // And we have better MatMulInteger Metacommand coverage in DML + if (is_dml_ep && p_dynamicquantize_node) { + if (CheckMatMulLargeTensors(matmulinteger_node, cast_node)) { + continue; + } + } + // Find bias node Node* p_add_node = nullptr; if (optimizer_utils::CheckOutputEdges(graph, mul_node, 1)) { diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index ee79fa620374e..cd654991c92d5 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -44,7 +44,9 @@ NhwcConvLookup( return &(iter->second); } -NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept +NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, + std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept : GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) { if (!cpu_kernel_registry) { // This is a CPU op nodes optimizer, not useful if cpu EP is not available. @@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_, - qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info); + qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_, - qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info); + qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, - nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info); + nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, - nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info); + nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, - nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, - nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( diff --git a/onnxruntime/core/optimizer/nhwc_transformer.h b/onnxruntime/core/optimizer/nhwc_transformer.h index 000732060b889..c65f851fdab9d 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.h +++ b/onnxruntime/core/optimizer/nhwc_transformer.h @@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed. class NhwcTransformer : public GraphTransformer { private: public: - explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept; + explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept; /** * @brief Usually called right after constructor, it shows whether diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index ed7d5feb2beb3..b2e8e491c361c 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& /* model_path */, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const { std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; - return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out); + return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out); } static Status TryCreateKernel(const Node& node, @@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node, FuncManager& funcs_mgr, const DataTransferManager& data_transfer_mgr, const ConfigOptions& config_options, + const logging::Logger& logger, /*out*/ std::unique_ptr& op_kernel) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; const KernelCreateInfo* kernel_create_info = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, + ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger, &kernel_create_info)); static const AllocatorMap dummy_allocators; @@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); FuncManager func; auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_, - ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, + ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_, op_kernel); // Kernel found in the CPU kernel registry diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index b0f7f461661b5..24a23312feba9 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame { const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); ~Info() = default; @@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame { std::unique_ptr node_index_info_; const IExecutionProvider& execution_provider_; const std::function& is_sparse_initializer_func_; + const logging::Logger& logger_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info); }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 18e462c04dff3..5538aa54801cc 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion( return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end(); } -static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { +static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) { constexpr size_t w_idx = 1; constexpr size_t w_zp_idx = 9; constexpr size_t r_idx = 2; @@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) || !graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) || r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " cannot locate recurrence tensor of const int8 type," << " int8 overflow might impact precision !"; return false; @@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) || !graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) || r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " unable to locate recurrence tensor or its zero point value," << " int8 overflow might impact precision !"; return false; @@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int if (graph_utils::IsSupportedOptypeVersionAndDomain( op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) { // This one has two set of quantized arguments - modified |= TryConvertDynamicQuantizeLSTM(op_node, graph); + modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger); continue; // go on to next operator node } diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc new file mode 100644 index 0000000000000..9e9665e14ede4 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/bias_quantization.h" + +#include "core/common/common.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { + +Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + const auto& input_defs = node.InputDefs(); + + // It's Conv/Gemm node with an initializer bias. + if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() || + !graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) { + continue; + } + + auto bias_shape = input_defs[2]->Shape(); + if (!bias_shape || bias_shape->dim_size() != 1) { + continue; + } + int64_t bias_size = bias_shape->dim(0).dim_value(); + + // input_0 and input_1 are outputs of DequantizeLinear nodes. + const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name()); + const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name()); + if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName || + parent_node_1->OpType() != QDQ::DQOpName) { + continue; + } + + Node& dq_0 = *graph.GetNode(parent_node_0->Index()); + Node& dq_1 = *graph.GetNode(parent_node_1->Index()); + + // Currently we require input_0 is per-tensor scale. + if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) { + continue; + } + + // For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm). + bool is_per_tensor_scale = true; + if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) { + is_per_tensor_scale = false; + auto weight_scale_shape = dq_1.InputDefs()[1]->Shape(); + if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() || + weight_scale_shape->dim(0).dim_value() != bias_size) { + continue; + } + + const auto& dq_attrs = dq_1.GetAttributes(); + if (dq_attrs.find("block_size") != dq_attrs.end()) { + continue; + } + + int64_t axis = 1; + if (dq_attrs.find("axis") != dq_attrs.end()) { + axis = dq_attrs.at("axis").i(); + } + + int64_t expected_axis = 0; + if (node.OpType() == "Gemm") { + int64_t transB = 0; + if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) { + transB = attr->second.i(); + } + expected_axis = transB == 0 ? 1 : 0; + } + + if (axis != expected_axis) { + continue; + } + } + + // Bias is quantized to int32. + ONNX_NAMESPACE::TypeProto int32_type_proto; + int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale. + ONNX_NAMESPACE::TypeProto bias_dq_type; + bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type()); + bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size); + + // scale = input_scale_0 * input_scale_1. + NodeArg& scale_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type); + Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node", + {dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr, + node.Domain()); + + // fp_bias / scale. + NodeArg& bias_div_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type); + Node& div_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node", + {node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain()); + graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1); + + // Round(fp_bias / scale). + NodeArg& bias_div_round_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type); + Node& round_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node", + {&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain()); + graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0); + + // Cast(round(fp_bias / scale)) to int32. + NodeArg& bias_int32_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto); + Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node", + {&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain()); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32)); + graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0); + + // Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0. + NodeArg& bias_dq_node_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type); + Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node", + {&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain()); + if (!is_per_tensor_scale) { + dq_node.AddAttribute("axis", static_cast(0)); + } + + graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0); + graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1); + node.MutableInputDefs()[2] = &bias_dq_node_arg; + graph.AddEdge(dq_node.Index(), node.Index(), 0, 2); + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h new file mode 100644 index 0000000000000..0297def260fd9 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + * @class BiasQuantization + * + * Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias + * with scale = scale_input_0 * scale_input_1 and zero_point = 0. + * + * Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed + * by a DequantizeLinear node. + */ +class BiasQuantization : public GraphTransformer { + public: + BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 2738c3ab02799..2f98711771f1b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -236,7 +236,7 @@ void ConvQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool is_ #if !defined(ORT_MINIMAL_BUILD) // TODO: Enable 16-bit types in selector when QLinearConv supports 16-bit. - std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider}; + std::vector providers = {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider}; std::unique_ptr selector = std::make_unique(is_int8_allowed, false, false, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index d2240b5d50194..81305f7effa16 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -291,7 +291,8 @@ SelectorManager::SelectorManager() { InitializeSelectorsMap(); } -std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const { +std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer, + const logging::Logger& logger) const { std::vector qdq_selections; for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { const auto* node = graph_viewer.GetNode(index); @@ -313,7 +314,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) { - LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType(); + LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType(); continue; } } @@ -329,7 +330,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap } std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer) { +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) { std::vector> node_unit_holder; std::unordered_map node_unit_map; @@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) { // Get QDQ NodeUnits first QDQ::SelectorManager selector_mgr; - const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger); for (const auto& qdq_selection : qdq_selections) { auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index f388206551172..ccc1844e3e985 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -15,7 +15,9 @@ #endif namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; class Node; @@ -65,7 +67,7 @@ class SelectorManager { // Methods that finds and returns a vector of QDQ::NodeGroup in a given graph // Can be used in QDQ support in different EPs - std::vector GetQDQSelections(const GraphViewer& graph_viewer) const; + std::vector GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const; private: Selectors qdq_selectors_; @@ -88,7 +90,7 @@ class SelectorManager { // We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer // library whereas it should be able to be used by an EP with no dependency on optimizers. std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 6a5a85ce0ff31..8c0136c495403 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -17,13 +17,22 @@ class TransformerMemcpyImpl { TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) : graph_(graph), provider_(provider) {} - bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); + bool ModifyGraph(const KernelRegistryManager& schema_registries, + const logging::Logger& logger, + int& copy_node_counter); private: - void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); - void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); + void ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); + void BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger); void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); - bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); + bool ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl); @@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // find defs that require copy for (auto& node : graph_.Nodes()) { // as we process the defs, collect all the initializers consumed at the current graph level - ProcessDefs(node, kernel_registries, initializers_consumed); + ProcessDefs(node, kernel_registries, initializers_consumed, logger); } // for initializers shared by different providers, create dups - if (ProcessInitializers(kernel_registries, initializers_consumed)) + if (ProcessInitializers(kernel_registries, initializers_consumed, logger)) modified = true; for (auto arg : graph_.GetInputs()) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_input_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_output_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : graph_.GetInputs()) // For inputs we need to create a copy node only when the input is connected to both provider @@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } -void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, - InitializedTensorSet& initializers_consumed) { +void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { auto node_provider_type = node.GetExecutionProviderType(); if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || @@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci)); bool is_implicit_input = false; auto process_inputs = @@ -278,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg } // for non_provider defs, collect the nodes that expect it is provider tensor as input/output. -void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) { +void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger) { for (auto& it : graph_.Nodes()) { if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue; auto input_it = @@ -296,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci)); if (arg_input_index != -1) { if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } @@ -351,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co // We duplicate any initializer that is used by both provider nodes and non-provider nodes // to ensure that provider nodes and non-provider nodes don't share initializers, as they // need to stay in different memory locations. -bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) { +bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { std::map replacements; for (const auto& pair : initializers_consumed) { const auto& name = pair.first; @@ -383,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker auto dup_replacements = replacements; const KernelCreateInfo* kci = nullptr; - auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci); + auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); if (kci == nullptr) continue; if (kci->kernel_def == nullptr) continue; diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 470838d36ec1c..10cb6eb97bdd6 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -1653,14 +1654,14 @@ static bool HandleSplit(HandlerArgs& args) { constexpr HandlerInfo split_handler = {&FirstInput, &HandleSplit}; -static bool HandleConcat(HandlerArgs& args) { +bool HandleConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } constexpr HandlerInfo concat_handler = {&AllInputs, &HandleConcat}; // Handles Softmax, Hardmax, and LogSoftmax -static bool HandleSoftHardMax(HandlerArgs& args) { +bool HandleSoftHardMax(HandlerArgs& args) { if (args.ctx.opset >= 13) { return HandleSimpleNodeWithAxis(args, /*default_axis*/ -1); } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 0095ead75f0c8..f65bd6aa82fbb 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -71,6 +71,9 @@ bool HandleSimpleNodeBroadcast(HandlerArgs& args); // Transposes all inputs and all outputs. Updates axis attribute. bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_axis = std::nullopt); +bool HandleConcat(HandlerArgs& args); +bool HandleSoftHardMax(HandlerArgs& args); + // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); bool HandleResize([[maybe_unused]] HandlerArgs& args); diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index 8eaac3d34c3af..824ab20a84668 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -34,10 +34,6 @@ static bool EPAwareHandleResize(HandlerArgs& args) { constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; -static bool HandleQLinearConcat(HandlerArgs& args) { - return HandleSimpleNodeWithAxis(args); -} - std::vector QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) { (void)ctx; std::vector indices; @@ -48,11 +44,7 @@ std::vector QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) { return indices; } -constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleQLinearConcat}; - -static bool HandleQLinearBinaryOp(HandlerArgs& args) { - return HandleSimpleNodeBroadcast(args); -} +constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleConcat}; std::vector QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) { // Inputs are: [A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point], @@ -60,7 +52,7 @@ std::vector QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) { return {0, 3}; } -constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleQLinearBinaryOp}; +constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleSimpleNodeBroadcast}; static bool HandleQLinearPoolOp(HandlerArgs& args) { // Swap between channel first/last variants. Only works for applicable values of perm. @@ -129,6 +121,7 @@ constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; +constexpr HandlerInfo soft_hard_max_handler = {&FirstInput, &HandleSoftHardMax}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, &HandleContribQuantizeDequantizeLinear}; @@ -148,6 +141,7 @@ const HandlerMap& OrtExtendedHandlers() { {"com.microsoft.QLinearMul", q_linear_binary_op_handler}, {"com.microsoft.QLinearReduceMean", reduce_op_handler}, {"com.microsoft.QLinearSigmoid", node_1_inp_handler}, + {"com.microsoft.QLinearSoftmax", soft_hard_max_handler}, }; return map; diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc index bf3b53afbd7d3..7464ab4c57d01 100644 --- a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -1,7 +1,8 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "hardware_core_enumerator.h" +#include "core/platform/windows/env.h" #include #include #include @@ -83,6 +84,38 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetCoreInfo(); +#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) + const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + bool isIntelSpecifiedPlatform = false; + const int kVendorID_IntelSpecifiedPlatformIDs[3] = { + // ExtendedModel, ExtendedFamily, Family Code, and Model Number + 0xa06a, // MTL + 0xc065, // ARL-H + 0xb065 // ARL-U + }; + + int regs_leaf0[4]; + int regs_leaf1[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf1, 0x1); + + auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]); + + for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) { + if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) { + isIntelSpecifiedPlatform = true; + } + } + + if (isIntel) { + if (isIntelSpecifiedPlatform) { + // We want to exclude cores without an LLC + return cores.LLCCores; + } else { + return cores.PhysicalCores; + } + } +#endif return cores.LLCCores; } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index bf73a538ea42f..950ac247a2046 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -158,8 +158,11 @@ void EtwRegistrationManager::LazyInitialize() { initialization_status_ = InitializationStatus::Initializing; etw_status_ = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr); if (FAILED(etw_status_)) { + // Registration can fail when running under Low Integrity process, and should be non-fatal initialization_status_ = InitializationStatus::Failed; - ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status_)); + // Injection of ETW logger can happen very early if ETW provider was already listening. + // Don't use LOGS_DEFAULT here or can get "Attempt to use DefaultLogger but none has been registered" + std::cerr << "Error in ETW registration: " << std::to_string(etw_status_) << std::endl; } initialization_status_ = InitializationStatus::Initialized; } @@ -176,7 +179,9 @@ void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, std::lock_guard lock(callbacks_mutex_); for (const auto& callback : callbacks_) { - (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + if (callback != nullptr) { + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } } diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index 3401507ae911f..cc23d70c0f11f 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -30,7 +30,6 @@ class CaptureStackTrace { // Get the stack trace. Currently only enabled for a DEBUG build as we require the DbgHelp library. std::vector GetStackTrace() { #ifndef NDEBUG -// TVM need to run with shared CRT, so won't work with debug helper now #if (defined __cpp_lib_stacktrace) && !(defined _OPSCHEMA_LIB_) && !(defined _GAMING_XBOX) && !(defined ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) return detail::CaptureStackTrace().Trace(); #else diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index a799ed743ef52..f954baf3eabae 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node); if (cann_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() - << " node name: " << node.Name(); + LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() + << " node name: " << node.Name(); continue; } candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger()); for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) continue; diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index e1f148fa93e23..38ac629331749 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -24,11 +24,12 @@ namespace coreml { OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, int32_t coreml_version, - uint32_t coreml_flags) { + bool only_allow_static_input_shapes, + bool create_mlprogram) { return OpBuilderInputParams{graph_viewer, coreml_version, - (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0, - (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0}; + only_allow_static_input_shapes, + create_mlprogram}; } const IOpBuilder* GetOpBuilder(const Node& node) { @@ -133,13 +134,13 @@ bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& grap return true; } -bool HasNeuralEngine(const logging::Logger& logger) { +bool HasNeuralEngine() { bool has_neural_engine = false; #ifdef __APPLE__ struct utsname system_info; uname(&system_info); - LOGS(logger, VERBOSE) << "Current Apple hardware info: " << system_info.machine; + LOGS_DEFAULT(VERBOSE) << "Current Apple hardware info: " << system_info.machine; #if TARGET_OS_IPHONE // utsname.machine has device identifier. For example, identifier for iPhone Xs is "iPhone11,2". @@ -163,7 +164,7 @@ bool HasNeuralEngine(const logging::Logger& logger) { #else // In this case, we are running the EP on non-apple platform, which means we are running the model // conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine - LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. " + LOGS_DEFAULT(INFO) << "HasNeuralEngine running on non-Apple hardware. " "Returning true to enable model conversion and local testing of CoreML EP implementation. " "No CoreML model will be compiled or run."; has_neural_engine = true; diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index 0acaa0dd8a4a3..ae7f3bdbc31a9 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -25,7 +25,8 @@ namespace coreml { OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, int32_t coreml_version, - uint32_t coreml_flags); + bool only_allow_static_input_shapes, + bool create_mlprogram); const IOpBuilder* GetOpBuilder(const Node& node); @@ -45,7 +46,7 @@ bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& grap // CoreML is more efficient running using Apple Neural Engine // This is to detect if the current system has Apple Neural Engine -bool HasNeuralEngine(const logging::Logger& logger); +bool HasNeuralEngine(); } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 5389eb5ab7e95..4481a5172966b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -40,6 +40,25 @@ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, con } namespace { + +template +void HandlePReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger, + std::vector& alpha_values) { + // add slope initializer as alpha weight + const auto& slope_tensor = *model_builder.GetConstantInitializer(node.InputDefs()[1]->Name()); + Initializer unpacked_tensor(slope_tensor); + const auto alpha_v = unpacked_tensor.DataAsSpan(); + + if (alpha_v.size() == 1) { + // expand to number of channels + std::vector x_shape; + GetShape(*node.InputDefs()[0], x_shape, logger); + alpha_values.resize(x_shape[x_shape.size() - 3], alpha_v[0]); + } else { + alpha_values.assign(alpha_v.begin(), alpha_v.end()); + } +} + Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger, COREML_SPEC::ActivationPReLU& prelu) { @@ -84,6 +103,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation std::string_view coreml_op_type; bool add_alpha = false; + bool add_gelu_mode = false; if (op_type == "Sigmoid") { coreml_op_type = "sigmoid"; } else if (op_type == "Tanh") { @@ -93,6 +113,12 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } else if (op_type == "LeakyRelu") { coreml_op_type = "leaky_relu"; add_alpha = true; + } else if (op_type == "Gelu") { + coreml_op_type = "gelu"; + add_gelu_mode = true; + } else if (op_type == "PRelu") { + coreml_op_type = "prelu"; + add_alpha = true; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -102,16 +128,39 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); if (add_alpha) { - NodeAttrHelper helper(node); - const auto alpha = helper.Get("alpha", 0.01f); - auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + + if ("PRelu" == op_type) { + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector alpha_values; + HandlePReluWeight(model_builder, node, logger, alpha_values); + AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values)); + } else { + std::vector alpha_values; + HandlePReluWeight(model_builder, node, logger, alpha_values); + AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values)); + } } else { - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + NodeAttrHelper helper(node); + const auto alpha = helper.Get("alpha", 0.01f); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } else { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + } } } + if (add_gelu_mode) { + NodeAttrHelper helper(node); + std::string approximate = helper.Get("approximate", std::string("none")); + if (approximate == "tanh") { + approximate = "TANH_APPROXIMATION"; + } else if (approximate == "none") { + approximate = "EXACT"; + } + AddOperationInput(*op, "mode", model_builder.AddScalarConstant(op->type(), "mode", std::string(approximate))); + } AddOperationOutput(*op, *node.OutputDefs()[0]); @@ -213,17 +262,11 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp const logging::Logger& logger) const { const auto& op_type = node.OpType(); -#if defined(COREML_ENABLE_MLPROGRAM) - if (input_params.create_mlprogram) { - if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable - return false; - } - } else -#endif // (COREML_ENABLE_MLPROGRAM) - { - if (op_type == "PRelu") { - return IsPReluOpSupported(node, input_params, logger); - } + if (op_type == "Gelu" && !input_params.create_mlprogram) { + return false; + } + if (op_type == "PRelu") { + return IsPReluOpSupported(node, input_params, logger); } return true; @@ -245,6 +288,7 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration "Relu", "PRelu", "LeakyRelu", + "Gelu", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index bc8b2d1a3505d..6169090a36014 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -15,6 +16,9 @@ class ArgMaxOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + public: + bool SupportsMLProgram() const override { return true; } }; Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -24,41 +28,60 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& graph_viewer = model_builder.GetGraphViewer(); NodeAttrHelper helper(node); - const auto axis = helper.Get("axis", 0); - const auto keepdims = helper.Get("keepdims", 1); + const int64_t axis = helper.Get("axis", 0); + const int64_t keepdims = helper.Get("keepdims", 1); const bool removedim = keepdims != 1; - auto* coreml_argmax = layer->mutable_argmax(); - coreml_argmax->set_axis(axis); - coreml_argmax->set_removedim(removedim); - - // There are two cases here: - // 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input - // (We still have this special case here because CoreML model does not have Cast) - // 2. Otherwise, we add Argmax layer normally - if (node.GetOutputEdgesCount() == 1) { - auto it = node.OutputEdgesBegin(); - const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); - // If Argmax's successive node is a Cast from int64 to int32 output - // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) - // so we omit the check here - if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { - // Skip the cast's input/argmax's output - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); - return Status::OK(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.reduction + + std::unique_ptr op = model_builder.CreateOperation(node, "reduce_argmax"); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims))); + + int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; + // the output of ArgMax must be int32 + AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + auto* coreml_argmax = layer->mutable_argmax(); + coreml_argmax->set_axis(axis); + coreml_argmax->set_removedim(removedim); + + // There are two cases here: + // 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input + // (We still have this special case here because CoreML model does not have Cast) + // 2. Otherwise, we add Argmax layer normally + if (node.GetOutputEdgesCount() == 1) { + auto it = node.OutputEdgesBegin(); + const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); + // If Argmax's successive node is a Cast from int64 to int32 output + // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) + // so we omit the check here + if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { + // Skip the cast's input/argmax's output + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + return Status::OK(); + } } - } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } -bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // Attribute `select_last_index` of ArgMax op is not supported NodeAttrHelper helper(node); @@ -68,6 +91,12 @@ bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + return true; + } +#endif + // If there are multiple downstream nodes and cast (toint32) is one of them // not supported, exit here // Otherwise, for general multiple downstream nodes, supported diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index f185a80de3cbf..2817f34bc64f2 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -13,15 +13,6 @@ using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { -// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to -// filter suppported ones. -static std::set Float16Ops = { - "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", - "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", - "Clip", "DepthToSpace", "Resize", "Slice", "Conv", - "ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", - "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; - namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, @@ -65,20 +56,27 @@ bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& } if (!HasSupportedOpSet(node, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] does not support this opset"; return false; } if (!HasSupportedInputs(node, input_params, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] has unsupported inputs"; return false; } // We do not support external initializers for now const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); if (HasExternalInitializer(initializers, node, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] has external initializers"; return false; } - return IsOpSupportedImpl(node, input_params, logger); + if (!IsOpSupportedImpl(node, input_params, logger)) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] is not supported by the impl"; + return false; + } + return true; } bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, @@ -115,13 +113,10 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, return true; } -// only support MLProgram for FP16 -#if defined(COREML_ENABLE_MLPROGRAM) - if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && - Float16Ops.count(node.OpType())) { + // only MLProgram support FP16 + if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { return true; } -#endif LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 8da58f659acf1..442194cb31cbc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -10,6 +10,10 @@ #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" +#ifdef __APPLE__ +#include +#endif + namespace onnxruntime { namespace coreml { @@ -24,6 +28,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { // BatchNormalization opset 6- has unsupported attributes int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; } + + public: + bool SupportsMLProgram() const override { return true; } }; void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -50,21 +57,46 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu const auto eps = helper.Get("epsilon", 1e-5f); const auto channels = scale_tensor.dims()[0]; - auto* coreml_batch_norm = layer->mutable_batchnorm(); - coreml_batch_norm->set_channels(channels); - coreml_batch_norm->set_epsilon(eps); - coreml_batch_norm->set_computemeanvar(false); - coreml_batch_norm->set_instancenormalization(false); - - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var - - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - - model_builder.AddLayer(std::move(layer)); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm + + std::unique_ptr op = model_builder.CreateOperation(node, "batch_norm"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + AddOperationInput(*op, "mean", model_builder.AddConstant(op->type(), input_defs[3]->Name() + "mean", mean_tensor)); + AddOperationInput(*op, "variance", model_builder.AddConstant(op->type(), input_defs[4]->Name() + "variance", var_tensor)); + AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name(), scale_tensor)); + AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name(), bias_tensor)); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps)); + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + auto* coreml_batch_norm = layer->mutable_batchnorm(); + coreml_batch_norm->set_channels(channels); + coreml_batch_norm->set_epsilon(eps); + coreml_batch_norm->set_computemeanvar(false); + coreml_batch_norm->set_instancenormalization(false); + + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var + + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } @@ -119,6 +151,15 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu return false; } +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 + // To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) { + LOGS(logger, VERBOSE) << "float16 input is not supported on the iOS x86_64 simulator" + << " due to CoreML producing invalid output."; + return false; + } +#endif return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 8aa2dbae2531c..0482620b269a4 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -6,6 +6,7 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/shape_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -55,6 +56,64 @@ bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger } } // namespace +#if defined(COREML_ENABLE_MLPROGRAM) +static std::vector InferOutputShape(const std::vector& a, const std::vector& b) { + std::vector output_shape; + int64_t i_a = 0, j_b = 0; + if (a.size() >= b.size()) { + output_shape = a; + j_b -= a.size() - b.size(); + } else { + output_shape = b; + i_a -= b.size() - a.size(); + } + + for (size_t i = 0; i < output_shape.size(); i++, i_a++, j_b++) { + const int64_t a_dim = (i_a >= 0) ? a[i_a] : 1; + const int64_t b_dim = (j_b >= 0) ? b[j_b] : 1; + if (a_dim == -1 || b_dim == -1) { + output_shape[i] = -1; + } else { + output_shape[i] = std::max(a_dim, b_dim); + } + } + return output_shape; +} + +// Add variadic inputs to the model builder +// in onnx spec, some node allows variadic inputs, such as max(x, y, z, ...) +// while in coreml, maximum op only allows two inputs maximum(x, y) +// the conversion is doing the following: +// max(x, y, z, ...) -> max(max(x, y), z, ...) +static void AddVariadicInputs(std::unique_ptr* op, + ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) { + using namespace CoreML::Specification::MILSpec; + const auto& input_defs(node.InputDefs()); + std::string_view layer_input_name_x = model_builder.GetUniqueName(node, "variadic"); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + const int32_t elem_type = static_cast(input_dtype); + std::vector x0_shape, x1_shape; + GetShape(*input_defs[0], x0_shape, logger); + GetShape(*input_defs[1], x1_shape, logger); + x0_shape = InferOutputShape(x0_shape, x1_shape); + std::unique_ptr op_prev = std::move(*op); + for (size_t i = 2; i < input_defs.size(); i++) { + AddIntermediateOperationOutput(*op_prev, layer_input_name_x, elem_type, x0_shape); + std::unique_ptr op_cur = model_builder.CreateOperation(node, op_prev->type()); + AddOperationInput(*op_cur, "x", layer_input_name_x); + AddOperationInput(*op_cur, "y", input_defs[i]->Name()); + model_builder.AddOperation(std::move(op_prev)); + op_prev = std::move(op_cur); + layer_input_name_x = model_builder.GetUniqueName(node, "variadic"); + GetShape(*input_defs[i], x1_shape, logger); + x0_shape = InferOutputShape(x0_shape, x1_shape); + } + *op = std::move(op_prev); +} +#endif + Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& op_type(node.OpType()); @@ -70,6 +129,8 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const coreml_op_type = "add"; } else if (op_type == "Mul") { coreml_op_type = "mul"; + } else if (op_type == "Max") { + coreml_op_type = "maximum"; } else if (op_type == "Sub") { coreml_op_type = "sub"; } else if (op_type == "Div") { @@ -86,8 +147,11 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", input_defs[0]->Name()); AddOperationInput(*op, "y", input_defs[1]->Name()); + if (input_defs.size() > 2) { + // "max" node may have variadic inputs + AddVariadicInputs(&op, model_builder, node, logger); + } AddOperationOutput(*op, *node.OutputDefs()[0]); - model_builder.AddOperation(std::move(op)); } else #endif // defined (COREML_ENABLE_MLPROGRAM) @@ -157,6 +221,10 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn return false; } + if (node.OpType() == "Max" && !input_params.create_mlprogram) { + return false; + } + return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index fc8879abbefb0..7c7363d4c81ad 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -4,6 +4,7 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -18,14 +19,62 @@ class CastOpBuilder : public BaseOpBuilder { bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + public: + bool SupportsMLProgram() const override { return true; } }; -Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, - const Node& /* node */, - const logging::Logger& /* logger */) const { - // This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. - // The ArgMax is fused with the Cast node and produces an int32 output. - // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here. +Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { +// This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type. +// The ArgMax is fused with the Cast node and produces an int32 output. +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.cast + + NodeAttrHelper helper(node); + auto cast_to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto::UNDEFINED); + std::string to_dtype = ""; + if (cast_to_type == ONNX_NAMESPACE::TensorProto::INT32 || cast_to_type == ONNX_NAMESPACE::TensorProto::INT64) { + to_dtype = "int32"; + // CoreML doesn't support int64, while ONNX uses int64 for indices and as well as data values. + // We convert the data inputs/outputs between int64 and int32 when calling onnxruntime::coreml::Model::Predict, + // and when adding int64 initializers to the CoreML model. + // CoreML operators can only produce int32 and not int64 values. + // Due to that there should be no actual int64 values inside the CoreML model and we can infer any + // ONNX_NAMESPACE::TensorProto::INT64 values to be int32. + cast_to_type = ONNX_NAMESPACE::TensorProto::INT32; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT) { + to_dtype = "fp32"; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT16) { + to_dtype = "fp16"; + } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::BOOL) { + to_dtype = "bool"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported cast type: ", cast_to_type); + } + + std::string_view op_type = "cast"; + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (((input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT32) && + to_dtype == "int32") || + cast_to_type == input_dtype) { + op_type = "identity"; + } + + std::unique_ptr op = model_builder.CreateOperation(node, op_type); + AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + if (op_type == "cast") { + AddOperationInput(*op, "dtype", model_builder.AddScalarConstant(op->type(), "dtype", std::string(to_dtype))); + } + AddOperationOutput(*op, *node.OutputDefs()[0], cast_to_type); + model_builder.AddOperation(std::move(op)); + } +#endif + return Status::OK(); } @@ -36,6 +85,10 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } + if (input_params.create_mlprogram) { + return true; + } + const auto& prec_node = node.InputEdgesBegin()->GetNode(); /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax @@ -67,14 +120,39 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return true; } -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // We only check the type of input 0 const auto& input = *node.InputDefs()[0]; + const auto& output = *node.OutputDefs()[0]; - int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t input_type, output_type; + if (!GetType(input, input_type, logger)) { return false; + } + if (!GetType(output, output_type, logger)) { + return false; + } + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) && + (output_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + output_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || + output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + return true; + } else { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported."; + return false; + } + } +#endif // only support int64 coming from ArgMax (check for ArgMax is done in IsOpSupportedImpl()) if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index bc9e2f10296ed..f7046c213a8cb 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -98,26 +98,24 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const bool min_max_attribs = node.SinceVersion() < 11; std::string_view min_name; if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) - : node.InputDefs()[1]->Name(); + min_name = (min_max_attribs || !has_min) ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); } else { - min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", MLFloat16(min)) - : node.InputDefs()[1]->Name(); + min_name = (min_max_attribs || !has_min) ? model_builder.AddScalarConstant(clip_op.type(), "min", MLFloat16(min)) + : node.InputDefs()[1]->Name(); } AddOperationInput(clip_op, "alpha", min_name); - if (has_max) { - std::string_view max_name; - if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) - : node.InputDefs()[2]->Name(); - } else { - max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", MLFloat16(max)) - : node.InputDefs()[2]->Name(); - } - AddOperationInput(clip_op, "beta", max_name); + std::string_view max_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + max_name = (min_max_attribs || !has_max) ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + } else { + max_name = (min_max_attribs || !has_max) ? model_builder.AddScalarConstant(clip_op.type(), "max", MLFloat16(max)) + : node.InputDefs()[2]->Name(); } + AddOperationInput(clip_op, "beta", max_name); } } @@ -200,7 +198,9 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { float min, max; - return GetClipMinMax(input_params.graph_viewer, node, min, max, logger); + bool ret = GetClipMinMax(input_params.graph_viewer, node, min, max, logger); + // what does it mean if min == max? + return ret && (min != max); } void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc new file mode 100644 index 0000000000000..b4dc8d1647ad0 --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/normalization_op_builder.cc @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/optimizer/initializer.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" +#include + +namespace onnxruntime { +namespace coreml { + +class NormalizationOpBuilder : public BaseOpBuilder { + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; + Status AddGroupNormToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const; + + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + int GetMinSupportedOpSet(const Node& /* node */) const override { return 1; } + + public: + bool SupportsMLProgram() const override { return true; } +}; + +void NormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // skip everything except input0 for Normalization + const auto& input_defs = node.InputDefs(); + model_builder.AddInitializerToSkip(input_defs[1]->Name()); // scale + if (input_defs.size() > 2) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); // B + } +} + +Status NormalizationOpBuilder::AddToModelBuilderImpl( + [[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { + if (node.OpType() == "GroupNormalization") { + return AddGroupNormToModelBuilderImpl(model_builder, node, logger); + } +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + + const auto eps = helper.Get("epsilon", 1e-5f); + + std::vector input_shape; + // GetShape will never fail as we have already verified the input shape in IsOpSupportedImpl + GetShape(*input_defs[0], input_shape, logger); + + const auto rank = input_shape.size(); + auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + + std::vector axes(rank - axis); + std::iota(axes.begin(), axes.end(), axis); + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + std::string_view op_name = (node.OpType() == "InstanceNormalization") ? "instance_norm" : "layer_norm"; + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm + + std::unique_ptr op = model_builder.CreateOperation(node, op_name); + AddOperationInput(*op, "x", layer_input_name_x); + if (op_name == "layer_norm") { + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), input_defs[0]->Name() + "axes", axes)); + } + AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name() + "gamma", scale_tensor)); + if (input_defs.size() > 2) { + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name() + "beta", bias_tensor)); + } + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps)); + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } +#endif // (COREML_ENABLE_MLPROGRAM) + + return Status::OK(); +} + +Status NormalizationOpBuilder::AddGroupNormToModelBuilderImpl( + [[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, + [[maybe_unused]] const logging::Logger& logger) const { +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs = node.InputDefs(); + NodeAttrHelper helper(node); + // Coreml hasn't supported GroupNorm yet. + // we decompose GroupNorm to sub ops and levrage LayerNorm to implement GroupNorm. + // groupnorm --> reshape [b, num_groups, c // (num_groups), h, w] --> layer_norm --> reshape [b, c, h, w]->mul(scale)->add(bias) + + // scale and bias is required for group-norm by the onnx spec + const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + + const auto eps = helper.Get("epsilon", 1e-5f); + int64_t num_groups = helper.Get("num_groups", 1); // GroupNorm + + std::vector input_shape; + GetShape(*input_defs[0], input_shape, logger); + + const auto input_size = input_shape.size(); + int64_t axis = 2; + std::vector axes(input_size + 1 - axis); // Group add one more dim + std::iota(axes.begin(), axes.end(), axis); + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int64_t channel_dims = input_shape[1]; + + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + const int32_t elem_type = static_cast(input_dtype); + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm + // https://github.com/apple/coremltools/blob/9827d424b3c5b5fbb6ddc8891a000d87a188c84f/coremltools/converters/mil/frontend/torch/ops.py#L1354 + // reshape to [b, num_groups, c // (num_groups), h, w] + auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); + std::vector shape1 = input_shape; + shape1.insert(shape1.begin() + 1, num_groups); + shape1[2] = input_shape[1] / num_groups; + std::vector shape_scale_bias(input_shape.size(), 1); + shape_scale_bias[1] = channel_dims; + AddOperationInput(*reshape1, "x", node.InputDefs()[0]->Name()); + AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape1", shape1)); + layer_input_name_x = model_builder.GetUniqueName(node, "ln_reshape1_"); + AddIntermediateOperationOutput(*reshape1, layer_input_name_x, elem_type, shape1); + + std::unique_ptr layer_norm = model_builder.CreateOperation(node, "layer_norm"); + AddOperationInput(*layer_norm, "x", layer_input_name_x); + AddOperationInput(*layer_norm, "axes", model_builder.AddConstant(layer_norm->type(), "axes", axes)); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16 epsilon_fp16(eps); + AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", epsilon_fp16)); + } else { + AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", eps)); + } + + const auto& ln_output_name = model_builder.GetUniqueName(node, "ln_output_"); + AddIntermediateOperationOutput(*layer_norm, ln_output_name, elem_type, shape1); + + auto reshape2 = model_builder.CreateOperation(node, "reshape", "post"); + AddOperationInput(*reshape2, "x", ln_output_name); + AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape2", input_shape)); + + const auto& reshape2_output_name = model_builder.GetUniqueName(node, "gn_reshape_output_"); + AddIntermediateOperationOutput(*reshape2, reshape2_output_name, elem_type, input_shape); + + auto mul = model_builder.CreateOperation(node, "mul", "post_mul"); + AddOperationInput(*mul, "x", reshape2_output_name); + AddOperationInput(*mul, "y", model_builder.AddConstant(mul->type(), "mul1", scale_tensor, shape_scale_bias)); + const auto& mul_output_name = model_builder.GetUniqueName(node, "mul_output_"); + AddIntermediateOperationOutput(*mul, mul_output_name, elem_type, input_shape); + + auto add = model_builder.CreateOperation(node, "add", "post_add"); + AddOperationInput(*add, "x", mul_output_name); + AddOperationInput(*add, "y", model_builder.AddConstant(add->type(), "add1", bias_tensor, shape_scale_bias)); + AddOperationOutput(*add, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(reshape1)); + model_builder.AddOperation(std::move(layer_norm)); + model_builder.AddOperation(std::move(reshape2)); + model_builder.AddOperation(std::move(mul)); + model_builder.AddOperation(std::move(add)); + } +#endif // (COREML_ENABLE_MLPROGRAM) + return Status::OK(); +} + +bool NormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // LayerNormalization may have three output in the training mode, but we only support the inference mode + // for InstanceNormalization and GroupNormalization, they only have one output, so this check will always return true + if (node.OutputDefs().size() != 1) { + LOGS(logger, VERBOSE) << "Your onnx model (with LayerNormalization) may be in training mode," + << " please export it for inferencing."; + return false; + } + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + + // groupnorm and layernorm has attribute "stash_type", while InstanceNormalization doesn't have this attribute + // Type of Mean and InvStdDev. This also specifies stage one’s computation precision. + // if stash_type is 1, this operator casts all input variables to 32-bit float, + // perform the computation, and finally cast Normalized back to the original type of X + // coreml didn't have a similiar attribute to stash_type, for now, we support the default value + if (node.OpType() != "InstanceNormalization") { + NodeAttrHelper helper(node); + const auto stash_type = helper.Get("stash_type", 1); + if (stash_type != 1) { + LOGS(logger, VERBOSE) << "stash_type != 1 is not supported"; + return false; + } + } + + const auto& scale_name = input_defs[1]->Name(); + const auto* scale_tensor = input_params.graph_viewer.GetConstantInitializer(scale_name); + if (!scale_tensor) { + LOGS(logger, VERBOSE) << "Scale must be a constant initializer"; + return false; + } + + if (input_defs.size() > 2) { + const auto& b_name = input_defs[2]->Name(); + const auto& b_tensor = input_params.graph_viewer.GetConstantInitializer(b_name); + if (!b_tensor) { + LOGS(logger, VERBOSE) << "Bias must be a constant initializer"; + return false; + } + } + + return true; +} + +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + if (!input_params.create_mlprogram) { + return false; + } + // We only check the type of input 0,1,2 + const auto& input_0 = *node.InputDefs()[0]; + const auto& input_1 = *node.InputDefs()[1]; + const auto& input_2 = node.InputDefs().size() > 2 ? *node.InputDefs()[2] : input_0; + int32_t input_type_0, input_type_1, input_type_2; + if (!GetType(input_0, input_type_0, logger)) { + return false; + } + if (!GetType(input_1, input_type_1, logger)) { + return false; + } + if (!GetType(input_2, input_type_2, logger)) { + return false; + } + if (input_type_0 != input_type_1 || input_type_0 != input_type_2) { + LOGS(logger, VERBOSE) << "Input types of LayerNorm must be the same"; + return false; + } + + if (input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + LOGS(logger, VERBOSE) << "Input types of LayerNorm must be float or float16"; + return false; + } + return true; +} + +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index 5651b9cc5793e..d533b867bd454 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -5,10 +5,15 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" +#ifdef __APPLE__ +#include +#endif + namespace onnxruntime { namespace coreml { @@ -20,6 +25,7 @@ class ReductionOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; namespace { @@ -48,13 +54,12 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const logging::Logger& /* logger */) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - const auto& initializers(model_builder.GetInitializerTensors()); std::vector axes; NodeAttrHelper helper(node); if (input_defs.size() > 1 && input_defs[1]->Exists()) { - auto& axes_tensor = *initializers.at(input_defs[1]->Name()); + auto& axes_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); Initializer axes_initializer(axes_tensor); int64_t* data = axes_initializer.data(); int64_t size = axes_initializer.size(); @@ -66,28 +71,76 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const bool keepdims = helper.Get("keepdims", 1) != 0; const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::string_view coreml_op_type; + if (noop_with_empty_axes && axes.size() == 0) { + coreml_op_type = "identity"; + } else if (op_type == "ReduceSum") { + coreml_op_type = "reduce_sum"; + } else if (op_type == "ReduceMean") { + coreml_op_type = "reduce_mean"; + } else if (op_type == "ReduceMax") { + coreml_op_type = "reduce_max"; + } else if (op_type == "ReduceMin") { + coreml_op_type = "reduce_min"; + } else if (op_type == "ReduceProd") { + coreml_op_type = "reduce_prod"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ReductionOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); + } + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + if (coreml_op_type != "identity") { + if (axes.size() > 0) { + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", axes)); + } + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", keepdims)); + } + AddOperationOutput(*op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(op)); + } else +#endif // (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + if (op_type == "ReduceSum") { + AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); + } else if (op_type == "ReduceMean") { + AddReductionParams(layer->mutable_reducemean(), axes, keepdims, noop_with_empty_axes); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "ReductionOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } - std::unique_ptr layer = model_builder.CreateNNLayer(node); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - if (op_type == "ReduceSum") { - AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); - } else if (op_type == "ReduceMean") { - AddReductionParams(layer->mutable_reducemean(), axes, keepdims, noop_with_empty_axes); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "ReductionOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + model_builder.AddLayer(std::move(layer)); } - - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - - model_builder.AddLayer(std::move(layer)); return Status::OK(); } bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (!input_params.create_mlprogram && + (node.OpType() == "ReduceMax" || node.OpType() == "ReduceMin" || node.OpType() == "ReduceProd")) { + return false; + } + +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 + // skip ReductionOpTest.ReduceSum_half_bert because reduce_sum will output all zeros + int32_t input_type; + GetType(*input_defs[0], input_type, logger); + if (node.OpType() == "ReduceSum" && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + return false; + } +#endif NodeAttrHelper helper(node); @@ -99,18 +152,16 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInpu if (input_defs.size() > 1 && input_defs[1]->Exists()) { // 'axes' is optional input in new opsets const auto& axes_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, axes_name)) { + const auto* axes = input_params.graph_viewer.GetConstantInitializer(axes_name); + if (!axes) { LOGS(logger, VERBOSE) << "Axes of reduction must be a constant initializer"; return false; } - empty_axes = initializers.at(axes_name)->int64_data_size() == 0; + empty_axes = axes->int64_data_size() == 0; } - - if (empty_axes && noop_with_empty_axes) { - // TODO: When we add ML Program support we should enable this as it makes the node an Identity op - LOGS(logger, VERBOSE) << "CoreML doesn't support noop on empty axes for reduction layers" << std::endl; + if (empty_axes && noop_with_empty_axes && !input_params.create_mlprogram) { + LOGS(logger, VERBOSE) << "NeuralNetwork doesn't support noop on empty axes for reduction layers"; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index a86e3d9538d87..243f949bdd48e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -2,7 +2,9 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/shape_utils.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" // for NodeAttrHelper @@ -14,28 +16,132 @@ class ShapeOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /*logger*/) const { - auto layer = model_builder.CreateNNLayer(node); - layer->mutable_getshape(); - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + const auto& input_defs = node.InputDefs(); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + NodeAttrHelper node_attr_helper{node}; + int64_t size = -1; + int64_t num_dims = 0; + int64_t start = node_attr_helper.Get("start", 0); + // If the input shape is not available, size is -1 and start is 0 + if (input_defs[0]->Shape()) { + num_dims = input_defs[0]->Shape()->dim_size(); + start = HandleNegativeAxis(start, num_dims); + if (node_attr_helper.HasAttr("end")) { + int64_t end = HandleNegativeAxis(node_attr_helper.Get("end", -1), num_dims); + size = end - start; + } + } + + int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; + std::unique_ptr op = model_builder.CreateOperation(node, "shape"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + if (size != -1 || start != 0) { + std::string_view layer_input_name_x = model_builder.GetUniqueName(node, "slice_by_size"); + std::vector x0_shape{num_dims}; + AddIntermediateOperationOutput(*op, layer_input_name_x, output_datatype, x0_shape); + model_builder.AddOperation(std::move(op)); + + auto slice_op = model_builder.CreateOperation(node, "slice_by_size"); + AddOperationInput(*slice_op, "x", layer_input_name_x); + std::vector starts = {start}; + std::vector sizes = {size}; + AddOperationInput(*slice_op, "begin", model_builder.AddConstant(slice_op->type(), "begin", starts)); + AddOperationInput(*slice_op, "size", model_builder.AddConstant(slice_op->type(), "size", sizes)); + AddOperationOutput(*slice_op, *node.OutputDefs()[0], output_datatype); + model_builder.AddOperation(std::move(slice_op)); + } else { + AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); + model_builder.AddOperation(std::move(op)); + } + } else // NOLINT +#endif + { + auto layer = model_builder.CreateNNLayer(node); + layer->mutable_getshape(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } -bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { + const auto* tensor_shape = node.InputDefs()[0]->Shape(); + NodeAttrHelper node_attr_helper{node}; - if (node_attr_helper.Get("start", 0) != 0) { - LOGS(logger, VERBOSE) << "Shape does not support 'start' attribute with value other than 0"; + if (!input_params.create_mlprogram) { + if (node_attr_helper.HasAttr("end")) { + LOGS(logger, VERBOSE) << "Shape does not support 'end' attribute"; + return false; + } + + if (node_attr_helper.Get("start", 0) != 0) { + LOGS(logger, VERBOSE) << "Shape does not support 'start' attribute with value other than 0"; + return false; + } + } else { + int64_t end = node_attr_helper.HasAttr("end") + ? node_attr_helper.Get("end", -1) + : std::numeric_limits::max(); + int64_t start = node_attr_helper.Get("start", 0); + // no need to slice if start is 0 and end is max + if (end == std::numeric_limits::max() && start == 0) { + } else if (tensor_shape == nullptr) { + LOGS(logger, VERBOSE) << "Shape does not support slicing when tensor_shape is not available"; + return false; + } + int64_t dim_size = tensor_shape->dim_size(); + int64_t size = node_attr_helper.HasAttr("end") + ? HandleNegativeAxis(node_attr_helper.Get("end", -1), dim_size) + : dim_size; + start = HandleNegativeAxis(start, dim_size); + size = size - start; + if (size == 0) { + LOGS(logger, VERBOSE) << "Shape does not support slicing when size is 0"; + return false; + } + } + + return true; +} + +bool ShapeOpBuilder::HasSupportedInputsImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // We only check the type of input 0 + const auto& input = *node.InputDefs()[0]; + + int32_t input_type; + if (!GetType(input, input_type, logger)) { return false; } - if (node_attr_helper.HasAttr("end")) { - LOGS(logger, VERBOSE) << "Shape does not support 'end' attribute"; + if (input_params.create_mlprogram) { + if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + return true; + } else { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported."; + return false; + } + } else if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() + << "] Input type: [" << input_type + << "] is not supported."; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index d6584124c6aba..c6e331feed326 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -4,6 +4,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,6 +19,7 @@ class SoftmaxOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -33,55 +35,100 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); int32_t axis_default_value = (node.SinceVersion() < 13) ? 1 : -1; const auto axis = helper.Get("axis", axis_default_value); - const auto axis_nonnegative = HandleNegativeAxis(axis, data_shape.size()); - - if (node.SinceVersion() >= 13 || (data_shape.size() == 2)) { - auto* coreml_softmaxnd = layer->mutable_softmaxnd(); - coreml_softmaxnd->set_axis(axis); - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(layer)); - } else { - // note: if opsets < 13, onnx Softmax coerces the input shape to be 2D based on axis. - // we need to manually reshape to 2D and apply SoftmaxND to axis -1 to achieve equivalent results for CoreML. - TensorShape input_shape(data_shape); - const auto size_to_dimension = input_shape.SizeToDimension(axis_nonnegative); - const auto size_from_dimension = input_shape.SizeFromDimension(axis_nonnegative); - - TensorShapeVector target_shape; - target_shape.push_back(size_to_dimension); - target_shape.push_back(size_from_dimension); - - const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output"); - { // Add reshape layer - auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); - *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; - *reshape_layer->mutable_input()->Add() = input_name; - *reshape_layer->mutable_output()->Add() = reshape1_output_name; - model_builder.AddLayer(std::move(reshape_layer)); + auto axis_nonnegative = HandleNegativeAxis(axis, data_shape.size()); + +#if defined(COREML_ENABLE_MLPROGRAM) + // CoreML's softmax match onnx's softmax behavior since opset 13. + // For opset < 13, we need to reshape to 2D and set axis to -1 to simulate onnx softmax behavior. + // [B,D,...](onnx softmax opset 12, axis=1)->[B,D*...](CoreML softmax, axis=-1)->[B,D,...](reshape back) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + const int32_t elem_type = static_cast(input_dtype); + + std::string_view layer_input_name_x = node.InputDefs()[0]->Name(); + const bool need_reshape = node.SinceVersion() < 13 && axis_nonnegative != static_cast(data_shape.size()) - 1; + std::vector target_shape; + if (need_reshape) { + // reshape to 2D to simulate onnx softmax behavior + auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); + TensorShape input_shape(data_shape); + target_shape.push_back(input_shape.SizeToDimension(axis_nonnegative)); + target_shape.push_back(input_shape.SizeFromDimension(axis_nonnegative)); + axis_nonnegative = 1; + AddOperationInput(*reshape1, "x", layer_input_name_x); + AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape1", target_shape)); + layer_input_name_x = model_builder.GetUniqueName(node, "ln_reshape1_"); + AddIntermediateOperationOutput(*reshape1, layer_input_name_x, elem_type, target_shape); + model_builder.AddOperation(std::move(reshape1)); } - const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output"); - { + std::unique_ptr op = model_builder.CreateOperation(node, "softmax"); + AddOperationInput(*op, "x", layer_input_name_x); + AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis_nonnegative)); + if (!need_reshape) { + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else { + std::string_view ln_output_name = model_builder.GetUniqueName(node, "ln_reshape1_"); + AddIntermediateOperationOutput(*op, ln_output_name, elem_type, target_shape); + model_builder.AddOperation(std::move(op)); + auto reshape2 = model_builder.CreateOperation(node, "reshape", "post"); + AddOperationInput(*reshape2, "x", ln_output_name); + AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape2", data_shape)); + AddOperationOutput(*reshape2, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(reshape2)); + } + } else // NOLINT +#endif + { + if (node.SinceVersion() >= 13 || (data_shape.size() == 2)) { auto* coreml_softmaxnd = layer->mutable_softmaxnd(); - coreml_softmaxnd->set_axis(-1); - *layer->mutable_input()->Add() = reshape1_output_name; - *layer->mutable_output()->Add() = softmax_output_name; + coreml_softmaxnd->set_axis(axis); + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; model_builder.AddLayer(std::move(layer)); - } - { - // Add reshape back layer - auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); - *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; - *reshape_layer->mutable_input()->Add() = softmax_output_name; - *reshape_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(reshape_layer)); + } else { + // note: if opsets < 13, onnx Softmax coerces the input shape to be 2D based on axis. + // we need to manually reshape to 2D and apply SoftmaxND to axis -1 to achieve equivalent results for CoreML. + TensorShape input_shape(data_shape); + const auto size_to_dimension = input_shape.SizeToDimension(axis_nonnegative); + const auto size_from_dimension = input_shape.SizeFromDimension(axis_nonnegative); + + TensorShapeVector target_shape; + target_shape.push_back(size_to_dimension); + target_shape.push_back(size_from_dimension); + + const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output"); + { // Add reshape layer + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; + *reshape_layer->mutable_input()->Add() = input_name; + *reshape_layer->mutable_output()->Add() = reshape1_output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } + const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output"); + { + auto* coreml_softmaxnd = layer->mutable_softmaxnd(); + coreml_softmaxnd->set_axis(-1); + *layer->mutable_input()->Add() = reshape1_output_name; + *layer->mutable_output()->Add() = softmax_output_name; + model_builder.AddLayer(std::move(layer)); + } + { + // Add reshape back layer + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; + *reshape_layer->mutable_input()->Add() = softmax_output_name; + *reshape_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } } } return Status::OK(); } -bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, +bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index dbd0f48576f8b..6372f3136123b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -51,8 +51,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto calculate_remainder_and_chunk_size = [&](int32_t num_outputs) { // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; - uint64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs; - uint64_t remainder = split_dim_size % chunk_size; + int64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs; + int64_t remainder = split_dim_size % chunk_size; return std::make_tuple(remainder, chunk_size); }; @@ -106,20 +106,20 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // if "split" is explicitly provided as an input // const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); - auto split_span = unpacked_tensor.DataAsSpan(); + auto split_span = unpacked_tensor.DataAsSpan(); for (const auto& split_size : split_span) { coreml_splitnd->add_splitsizes(split_size); } } else if (node.SinceVersion() < 18) { - uint64_t num_outputs = narrow(node.OutputDefs().size()); + int64_t num_outputs = narrow(node.OutputDefs().size()); coreml_splitnd->set_numsplits(num_outputs); } else { // note: for opset 18+ 'num_outputs' is a required attribute - uint64_t num_outputs = narrow(helper.GetInt64("num_outputs").value()); + int64_t num_outputs = narrow(helper.GetInt64("num_outputs").value()); auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast(num_outputs)); if (remainder) { // uneven - auto split_sizes = InlinedVector(num_outputs, chunk_size); + auto split_sizes = InlinedVector(num_outputs, chunk_size); split_sizes.back() = remainder; for (size_t i = 0; i < split_sizes.size(); i++) { coreml_splitnd->add_splitsizes(split_sizes[i]); @@ -162,7 +162,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } const auto split_shape = *input_defs[1]->Shape(); - if (split_shape.dim_size() < 2) { + if (split_shape.dim(0).dim_value() < 2) { LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs."; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index e9cc1c2dbf638..a1b3a18265c70 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -5,10 +5,17 @@ #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" +#include "core/providers/cpu/tensor/unsqueeze.h" + +#ifdef __APPLE__ +#include +#endif namespace onnxruntime { namespace coreml { @@ -21,16 +28,16 @@ class SqueezeOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; namespace { -Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { +void GetAxes(ModelBuilder& model_builder, const Node& node, TensorShapeVector& axes) { // Squeeze opset 13 use input as axes if (node.SinceVersion() > 12) { // If axes is not provided, return an empty axes as default to squeeze all if (node.InputDefs().size() > 1) { - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& axes_tensor = *initializers.at(node.InputDefs()[1]->Name()); + const auto& axes_tensor = *model_builder.GetConstantInitializer(node.InputDefs()[1]->Name()); Initializer unpacked_tensor(axes_tensor); auto raw_axes = unpacked_tensor.DataAsSpan(); const auto size = SafeInt(axes_tensor.dims()[0]); @@ -39,10 +46,9 @@ Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector()); + auto axes_attr = helper.Get("axes", std::vector()); + axes.assign(axes_attr.begin(), axes_attr.end()); } - - return Status::OK(); } } // namespace @@ -52,40 +58,103 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } } +#if defined(COREML_ENABLE_MLPROGRAM) +void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder, + const Node& node, const logging::Logger& logger) { + const auto& input_defs(node.InputDefs()); + TensorShapeVector axes; + GetAxes(model_builder, node, axes); + + std::vector input_shape; + GetShape(*input_defs[0], input_shape, logger); + auto op = model_builder.CreateOperation(node, "reshape"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes); + AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape))); + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); +} +#endif + Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + [[maybe_unused]] const logging::Logger& logger) const { std::unique_ptr layer = model_builder.CreateNNLayer(node); - auto* coreml_squeeze = layer->mutable_squeeze(); - std::vector axes; - ORT_RETURN_IF_ERROR(GetAxes(model_builder, node, axes)); - if (axes.empty()) { - coreml_squeeze->set_squeezeall(true); - } else { - *coreml_squeeze->mutable_axes() = {axes.cbegin(), axes.cend()}; - coreml_squeeze->set_squeezeall(false); - } + TensorShapeVector axes; + GetAxes(model_builder, node, axes); +#if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs(node.InputDefs()); + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + +#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64 + // expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input + if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) { + HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger); + return Status::OK(); + } +#endif + std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims"; + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + + if (!axes.empty()) { + // coreml supports negative axes + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes))); + } + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + } else // NOLINT +#endif + { + if (axes.empty()) { + coreml_squeeze->set_squeezeall(true); + } else { + *coreml_squeeze->mutable_axes() = {axes.cbegin(), axes.cend()}; + coreml_squeeze->set_squeezeall(false); + } - *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& /*logger*/) const { + const logging::Logger& logger) const { // Squeeze opset 13 uses input 1 as axes, if we have input 1 then it needs to be an initializer - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - const auto& axes_name = node.InputDefs()[1]->Name(); - if (!Contains(initializers, axes_name)) { - LOGS_DEFAULT(VERBOSE) << "Input axes of Squeeze must be known"; + const auto& input_defs = node.InputDefs(); + if (node.SinceVersion() > 12 && input_defs.size() > 1) { + const auto& axes_name = input_defs[1]->Name(); + if (!input_params.graph_viewer.GetConstantInitializer(axes_name)) { + LOGS(logger, VERBOSE) << "Input axes must be known"; return false; } } + if (node.OpType() == "Unsqueeze") { + if (!input_params.create_mlprogram) { + return false; + } + + int64_t num_of_new_dims = 0; + if (node.SinceVersion() > 12) { + num_of_new_dims = node.InputDefs()[1]->Shape()->dim(0).dim_value(); + } else { + NodeAttrHelper helper(node); + auto axes = helper.Get("axes", std::vector()); + num_of_new_dims = static_cast(axes.size()); + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger) || input_shape.size() + num_of_new_dims > 5) { + LOGS(logger, VERBOSE) << "Unsqueeze to output shape with > 5 dimensions is not supported"; + return false; + } + } return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index a6580920343c4..bc3cad004aec1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -16,6 +16,8 @@ class UnaryOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; bool SupportsMLProgram() const override { return true; } + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; }; Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -32,6 +34,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const coreml_op_type = "sqrt"; } else if (op_type == "Reciprocal") { coreml_op_type = "inverse"; + } else if (op_type == "Erf") { + coreml_op_type = "erf"; + } else if (op_type == "Round") { + coreml_op_type = "round"; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); @@ -74,6 +80,14 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } +bool UnaryOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& /*logger*/) const { + if (!input_params.create_mlprogram && (node.OpType() == "Erf" || node.OpType() == "Round")) { + return false; + } + return true; +} + void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 50faebf06875d..6486942199df7 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -8,12 +8,14 @@ #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/coreml_execution_provider.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/shape_utils.h" +#include "core/optimizer/initializer.h" #if defined(COREML_ENABLE_MLPROGRAM) // includes from coremltools-src in _deps @@ -400,14 +402,14 @@ std::string GetModelOutputPath(bool create_ml_program) { } // namespace ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names) : graph_viewer_(graph_viewer), logger_(logger), coreml_version_(coreml_version), - coreml_flags_(coreml_flags), - create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), + coreml_options_(coreml_options), + create_ml_program_(coreml_options.CreateMLProgram()), model_output_path_(GetModelOutputPath(create_ml_program_)), onnx_input_names_(std::move(onnx_input_names)), onnx_output_names_(std::move(onnx_output_names)), @@ -987,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { get_sanitized_io_info(std::move(input_output_info_)), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_flags_); + logger_, coreml_options_); } else #endif { @@ -997,19 +999,61 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { std::move(input_output_info_), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_flags_); + logger_, coreml_options_); } return model->LoadModel(); // load using CoreML API, including compilation } +#if defined(COREML_ENABLE_MLPROGRAM) +std::string_view ModelBuilder::AddConstant(std::string_view op_type, std::string_view value_type, + const ONNX_NAMESPACE::TensorProto& tensor, + std::optional> shape) { + const auto data_type = tensor.data_type(); + Initializer unpacked_tensor(tensor); + std::string_view ret; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape ? shape : tensor.dims()); + break; + // case ONNX_NAMESPACE::TensorProto_DataType_INT32: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_INT8: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + // case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + // ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan(), shape?shape:tensor.dims()); + // break; + default: + ORT_THROW("AddConstant: Unsupported data type: ", data_type); + } + + return ret; +} +#endif // static Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names, std::unique_ptr& model) { - ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags, + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_options, std::move(onnx_input_names), std::move(onnx_output_names)); ORT_RETURN_IF_ERROR(builder.CreateModel()); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index b3dfec29872a2..e19597cf0dc2e 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -7,6 +7,7 @@ #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/model/model.h" +#include "core/providers/coreml/coreml_options.h" #if defined(COREML_ENABLE_MLPROGRAM) // coremltools classes @@ -29,14 +30,14 @@ class IOpBuilder; class ModelBuilder { private: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names); public: // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags, + int32_t coreml_version, const CoreMLOptions& coreml_options, std::vector&& onnx_input_names, std::vector&& onnx_output_names, std::unique_ptr& model); @@ -129,6 +130,12 @@ class ModelBuilder { return AddConstant(op_type, value_type, gsl::span(value), shape); } + // helper to convert a initializer to a constant + // by default, shape is inferred from the tensor.dims(), but can be provided to override if needed + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, + const ONNX_NAMESPACE::TensorProto& tensor, + std::optional> shape = std::nullopt); + /// /// Add a scalar value as a 'const' operation. See AddConstant for details. /// @@ -210,7 +217,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const int32_t coreml_version_; - const uint32_t coreml_flags_; + CoreMLOptions coreml_options_; const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index b0006b24e7d75..6e7df20a06097 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -21,15 +21,19 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateActivationOpBuilder("Relu", op_registrations); CreateActivationOpBuilder("PRelu", op_registrations); CreateActivationOpBuilder("LeakyRelu", op_registrations); + CreateActivationOpBuilder("Gelu", op_registrations); // Unary ops + CreateUnaryOpBuilder("Erf", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Round", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); // Binary elementwise ops CreateBinaryOpBuilder("Add", op_registrations); CreateBinaryOpBuilder("Div", op_registrations); CreateBinaryOpBuilder("Mul", op_registrations); + CreateBinaryOpBuilder("Max", op_registrations); CreateBinaryOpBuilder("Pow", op_registrations); CreateBinaryOpBuilder("Sub", op_registrations); @@ -41,10 +45,18 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { // Reduction ops CreateReductionOpBuilder("ReduceMean", op_registrations); + CreateReductionOpBuilder("ReduceMin", op_registrations); + CreateReductionOpBuilder("ReduceMax", op_registrations); + CreateReductionOpBuilder("ReduceProd", op_registrations); CreateReductionOpBuilder("ReduceSum", op_registrations); - CreateArgMaxOpBuilder("ArgMax", op_registrations); + // Normalization ops CreateBatchNormalizationOpBuilder("BatchNormalization", op_registrations); + CreateNormalizationOpBuilder("GroupNormalization", op_registrations); + CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); + CreateNormalizationOpBuilder("LayerNormalization", op_registrations); + + CreateArgMaxOpBuilder("ArgMax", op_registrations); CreateCastOpBuilder("Cast", op_registrations); CreateClipOpBuilder("Clip", op_registrations); CreateConcatOpBuilder("Concat", op_registrations); @@ -66,6 +78,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSoftmaxOpBuilder("Softmax", op_registrations); CreateSqueezeOpBuilder("Squeeze", op_registrations); CreateTransposeOpBuilder("Transpose", op_registrations); + CreateSqueezeOpBuilder("Unsqueeze", op_registrations); return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index 1990fb6400ce1..9b51b53d73e9e 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -19,6 +19,7 @@ const std::unordered_map& GetOpBuilders(); void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateBatchNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index f7afbb2f98bd8..5a2867e5524e4 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -23,35 +23,14 @@ namespace onnxruntime { constexpr const char* COREML = "CoreML"; -CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) +CoreMLExecutionProvider::CoreMLExecutionProvider(const CoreMLOptions& options) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, - coreml_flags_(coreml_flags), + coreml_options_(options), coreml_version_(coreml::util::CoreMLVersion()) { LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_; if (coreml_version_ < MINIMUM_COREML_VERSION) { - LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; + ORT_THROW("CoreML EP is not supported on this platform."); } - - // check if only one flag is set - if ((coreml_flags & COREML_FLAG_USE_CPU_ONLY) && (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU)) { - // multiple device options selected - ORT_THROW( - "Multiple device options selected, you should use at most one of the following options:" - "COREML_FLAG_USE_CPU_ONLY or COREML_FLAG_USE_CPU_AND_GPU or not set"); - } - -#if defined(COREML_ENABLE_MLPROGRAM) - if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION && - (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { - LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; - coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; - } -#else - if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { - LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; - coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; - } -#endif } CoreMLExecutionProvider::~CoreMLExecutionProvider() {} @@ -61,26 +40,17 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; - if (coreml_version_ < MINIMUM_COREML_VERSION) { - return result; - } - const auto& logger = *GetLogger(); // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes // TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the // implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate. - if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { - return result; - } - - const bool has_neural_engine = coreml::HasNeuralEngine(logger); - if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { - LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used."; + if (graph_viewer.IsSubgraph() && !coreml_options_.EnableOnSubgraph()) { return result; } - const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_); + const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, + coreml_options_.RequireStaticShape(), coreml_options_.CreateMLProgram()); const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger); const auto gen_metadef_name = @@ -143,7 +113,7 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector onnx_output_names = get_names(fused_node.OutputDefs()); const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_options_, std::move(onnx_input_names), std::move(onnx_output_names), coreml_model)); } diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 24a001280eef5..650d81a4fecf7 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -3,7 +3,7 @@ #pragma once -#include "core/common/inlined_containers.h" +#include "core/providers/coreml/coreml_options.h" #include "core/framework/execution_provider.h" #include "core/framework/model_metadef_id_generator.h" @@ -14,7 +14,7 @@ class Model; class CoreMLExecutionProvider : public IExecutionProvider { public: - CoreMLExecutionProvider(uint32_t coreml_flags); + CoreMLExecutionProvider(const CoreMLOptions& options); virtual ~CoreMLExecutionProvider(); std::vector> @@ -29,7 +29,7 @@ class CoreMLExecutionProvider : public IExecutionProvider { private: // The bit flags which define bool options for COREML EP, bits are defined as // COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h - uint32_t coreml_flags_; + CoreMLOptions coreml_options_; const int32_t coreml_version_; ModelMetadefIdGenerator metadef_id_generator_; diff --git a/onnxruntime/core/providers/coreml/coreml_options.cc b/onnxruntime/core/providers/coreml/coreml_options.cc new file mode 100644 index 0000000000000..4ec780208e528 --- /dev/null +++ b/onnxruntime/core/providers/coreml/coreml_options.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory.h" // defines flags +#include "core/providers/coreml/model/host_utils.h" +#include "core/providers/coreml/builders/helper.h" + +namespace onnxruntime { + +CoreMLOptions::CoreMLOptions(uint32_t coreml_flags) { + // validate the flags and populate the members. should be moving code from ctor to here + require_static_shape_ = (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0; + create_mlprogram_ = (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0; + enable_on_subgraph_ = (coreml_flags & COREML_FLAG_ENABLE_ON_SUBGRAPH) != 0; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (coreml::util::CoreMLVersion() < MINIMUM_COREML_MLPROGRAM_VERSION && create_mlprogram_ != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; + create_mlprogram_ = false; + } +#else + if (create_mlprogram_ != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; + create_mlprogram_ = false; + } +#endif + + compute_units_ = 0; // 0 for all + + if (coreml_flags & COREML_FLAG_USE_CPU_ONLY) { + compute_units_ |= COREML_FLAG_USE_CPU_ONLY; + } + if (coreml_flags & COREML_FLAG_USE_CPU_AND_GPU) { + compute_units_ |= COREML_FLAG_USE_CPU_AND_GPU; + } + if (coreml_flags & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { + compute_units_ |= COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE; + } + + // assure only one device option is selected + if (compute_units_ & (compute_units_ - 1)) { + // multiple device options selected + ORT_THROW( + "Multiple device options selected, you should use at most one of the following options:" + "[COREML_FLAG_USE_CPU_ONLY, COREML_FLAG_USE_CPU_AND_GPU, COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE]"); + } + + const bool has_neural_engine = coreml::HasNeuralEngine(); + if (ComputeUnits(COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { + ORT_THROW("The current system does not have Apple Neural Engine."); + } +} + +void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& options) { + const std::unordered_map available_computeunits_options = { + {"CPUAndNeuralEngine", COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE}, + {"CPUAndGPU", COREML_FLAG_USE_CPU_AND_GPU}, + {"CPUOnly", COREML_FLAG_USE_CPU_ONLY}, + {"ALL", COREML_FLAG_USE_NONE}, + }; + const std::unordered_map available_modelformat_options = { + {"MLProgram", COREML_FLAG_CREATE_MLPROGRAM}, + {"NeuralNetwork", COREML_FLAG_USE_NONE}, + }; + const std::unordered_set valid_options = { + kCoremlProviderOption_MLComputeUnits, + kCoremlProviderOption_ModelFormat, + kCoremlProviderOption_RequireStaticInputShapes, + kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU, + }; + // Validate the options + for (const auto& option : options) { + if (valid_options.find(option.first) == valid_options.end()) { + ORT_THROW("Unknown option: ", option.first); + } + if (kCoremlProviderOption_MLComputeUnits == option.first) { + if (available_computeunits_options.find(option.second) == available_computeunits_options.end()) { + ORT_THROW("Invalid value for option `", option.first, "`: ", option.second); + } else { + compute_units_ = available_computeunits_options.at(option.second); + } + } else if (kCoremlProviderOption_ModelFormat == option.first) { + if (available_modelformat_options.find(option.second) == available_modelformat_options.end()) { + ORT_THROW("Invalid value for option ", option.first, ": ", option.second); + } else { + create_mlprogram_ = available_modelformat_options.at(option.second) & COREML_FLAG_CREATE_MLPROGRAM; + } + } else if (kCoremlProviderOption_RequireStaticInputShapes == option.first) { + require_static_shape_ = option.second == "1"; + } else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) { + enable_on_subgraph_ = option.second == "1"; + } else if (kCoremlProviderOption_SpecializationStrategy == option.first) { + if (option.second != "Default" && option.second != "FastPrediction") { + ORT_THROW("Invalid value for option ", option.first, ": ", option.second, + ". Valid values are Default and FastPrediction."); + } + strategy_ = option.second; + } else if (kCoremlProviderOption_ProfileComputePlan == option.first) { + profile_compute_plan_ = option.second == "1"; + } else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) { + allow_low_precision_accumulation_on_gpu_ = option.second == "1"; + } + } +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_options.h b/onnxruntime/core/providers/coreml/coreml_options.h new file mode 100644 index 0000000000000..fd05c96927bd1 --- /dev/null +++ b/onnxruntime/core/providers/coreml/coreml_options.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +class CoreMLOptions { + private: + bool require_static_shape_{false}; + bool create_mlprogram_{false}; + bool enable_on_subgraph_{false}; + uint32_t compute_units_{0}; + std::string strategy_; + bool profile_compute_plan_{false}; + bool allow_low_precision_accumulation_on_gpu_{false}; + + public: + explicit CoreMLOptions(uint32_t coreml_flags); + + CoreMLOptions(const ProviderOptions& options) { + ValidateAndParseProviderOption(options); + } + bool RequireStaticShape() const { return require_static_shape_; } + bool CreateMLProgram() const { return create_mlprogram_; } + bool EnableOnSubgraph() const { return enable_on_subgraph_; } + uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; } + bool AllowLowPrecisionAccumulationOnGPU() const { return allow_low_precision_accumulation_on_gpu_; } + bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; } + bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; } + + private: + void ValidateAndParseProviderOption(const ProviderOptions& options); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc index fcdf37c446ce7..bc8702d3290f6 100644 --- a/onnxruntime/core/providers/coreml/coreml_provider_factory.cc +++ b/onnxruntime/core/providers/coreml/coreml_provider_factory.cc @@ -9,21 +9,28 @@ using namespace onnxruntime; namespace onnxruntime { + struct CoreMLProviderFactory : IExecutionProviderFactory { - CoreMLProviderFactory(uint32_t coreml_flags) - : coreml_flags_(coreml_flags) {} + CoreMLProviderFactory(const CoreMLOptions& options) + : options_(options) {} ~CoreMLProviderFactory() override {} std::unique_ptr CreateProvider() override; - uint32_t coreml_flags_; + CoreMLOptions options_; }; std::unique_ptr CoreMLProviderFactory::CreateProvider() { - return std::make_unique(coreml_flags_); + return std::make_unique(options_); } std::shared_ptr CoreMLProviderFactoryCreator::Create(uint32_t coreml_flags) { - return std::make_shared(coreml_flags); + CoreMLOptions coreml_options(coreml_flags); + return std::make_shared(coreml_options); +} + +std::shared_ptr CoreMLProviderFactoryCreator::Create(const ProviderOptions& options) { + CoreMLOptions coreml_options(options); + return std::make_shared(coreml_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h index ba701724c4da9..93ec2af50698d 100644 --- a/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h +++ b/onnxruntime/core/providers/coreml/coreml_provider_factory_creator.h @@ -5,10 +5,12 @@ #include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" namespace onnxruntime { struct CoreMLProviderFactoryCreator { static std::shared_ptr Create(uint32_t coreml_flags); + static std::shared_ptr Create(const ProviderOptions& options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index a9991ccb945ce..145c64e5320d3 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -26,6 +26,8 @@ // - iOS 16 ops // 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) // - iOS 17 ops +// 9 : iOS 18, macOS 15, tvOS 18, watchOS 11 (Core ML 8) +// - iOS 18 ops // // **NOTE** We use the Core ML version not the spec version. // @@ -39,6 +41,7 @@ #define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15)) #define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16)) #define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17)) +#define API_AVAILABLE_COREML8 API_AVAILABLE(macos(15), ios(18)) // @available is used in implementation code // Base required OS to run CoreML Specification Version 4 (Core ML 3) @@ -47,6 +50,7 @@ #define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *) #define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *) #define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *) +#define HAS_COREML8_OR_LATER @available(macOS 15, iOS 18, *) #endif diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 70052f50ae1c2..4239121a42c97 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -16,6 +16,8 @@ bool HasRequiredBaseOS() { } int32_t CoreMLVersion() { + if (HAS_COREML8_OR_LATER) + return 8; if (HAS_COREML7_OR_LATER) return 7; if (HAS_COREML6_OR_LATER) diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 7fdd6b25bc7db..84b7d741b4714 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -18,6 +18,7 @@ #endif namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution; @@ -53,7 +54,7 @@ class Model { std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, - const logging::Logger& logger, uint32_t coreml_flags); + const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index ff32c52f942b2..755dbfbd6e68c 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -25,6 +25,7 @@ #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/objc_str_utils.h" #include "core/providers/coreml/shape_utils.h" +#include "core/providers/coreml/coreml_options.h" // force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need // to manually do this @@ -300,6 +301,53 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, return Status::OK(); } +// since __clang_major__ >= 15, MLComputePlan is introduced in +// We are actually ensure the MacOS/IOS version and Xcode version is greater than `macOS 14.4, iOS 17.4`. +// The macro API_AVAILABLE should also be fine. +// Otherwise, the compiler will complain `MLComputePlan` is not defined. +// we define __clang_analyzer__ here is for bypass static analysis +void ProfileComputePlan(NSURL* compileUrl, MLModelConfiguration* config) { +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (@available(macOS 14.4, iOS 17.4, *)) { + [MLComputePlan loadContentsOfURL:compileUrl + configuration:config + completionHandler:^(MLComputePlan* _Nullable computePlan, NSError* _Nullable error) { + if (!computePlan) { + NSLog(@"Error loading compute plan: %@", error); + // Handle error. + return; + } + MLModelStructureProgram* program = computePlan.modelStructure.program; + if (!program) { + NSLog(@"Error loading program from compute plan., this is not a mlprogram model"); + return; + } + + MLModelStructureProgramFunction* mainFunction = program.functions[@"main"]; + if (!mainFunction) { + NSLog(@"Error loading main function from program"); + return; + } + + NSArray* operations = mainFunction.block.operations; + NSLog(@"Number of operations, 'const' node is included. : %lu", operations.count); + for (MLModelStructureProgramOperation* operation in operations) { + // Get the compute device usage for the operation. + MLComputePlanDeviceUsage* computeDeviceUsage = [computePlan computeDeviceUsageForMLProgramOperation:operation]; + id preferredDevice = computeDeviceUsage.preferredComputeDevice; + // Get the estimated cost of executing the operation. + MLComputePlanCost* estimatedCost = [computePlan estimatedCostOfMLProgramOperation:operation]; + if (![operation.operatorName isEqualToString:@"const"]) { + NSLog(@"Operation: %@, Device Usage: %@, Estimated Cost: %f", operation.operatorName, preferredDevice, estimatedCost.weight); + } + } + }]; + } else { + NSLog(@"iOS 17.4+/macOS 14.4+ or later is required to use the compute plan API"); + } +#endif +} + // Internal Execution class // This class is part of the model class and handles the calls into CoreML. Specifically, it performs // 1. Compile the model by given path for execution @@ -307,7 +355,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, // 3. The compiled model will be removed in dealloc or removed using cleanup function class Execution { public: - Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); + Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Execution(); Status LoadModel(); @@ -320,13 +368,13 @@ Status Predict(const std::unordered_map& inputs, NSString* coreml_model_path_{nil}; NSString* compiled_model_path_{nil}; const logging::Logger& logger_; - uint32_t coreml_flags_{0}; + CoreMLOptions coreml_options_; MLModel* model_{nil}; }; -Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) +Execution::Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options) : logger_(logger), - coreml_flags_(coreml_flags) { + coreml_options_(coreml_options) { @autoreleasepool { coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); } @@ -395,15 +443,41 @@ Status Predict(const std::unordered_map& inputs, compiled_model_path_ = [compileUrl path]; MLModelConfiguration* config = [[MLModelConfiguration alloc] init]; - - if (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) { + uint32_t coreml_compute_unit = coreml_options_.ComputeUnits(); + if (coreml_compute_unit & COREML_FLAG_USE_CPU_ONLY) { config.computeUnits = MLComputeUnitsCPUOnly; - } else if (coreml_flags_ & COREML_FLAG_USE_CPU_AND_GPU) { + } else if (coreml_compute_unit & COREML_FLAG_USE_CPU_AND_GPU) { config.computeUnits = MLComputeUnitsCPUAndGPU; + } else if (coreml_compute_unit & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { + config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine } else { config.computeUnits = MLComputeUnitsAll; } + if (coreml_options_.AllowLowPrecisionAccumulationOnGPU()) { + config.allowLowPrecisionAccumulationOnGPU = YES; + } + +// Set the specialization strategy to FastPrediction for macOS 10.15+ +// since __clang_major__ >= 15, optimizationHints is introduced in +// Same as above comments for why we are checking __clang_major__. +// we define __clang_analyzer__ here is for bypass static analysis +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (HAS_COREML8_OR_LATER) { + MLOptimizationHints* optimizationHints = [[MLOptimizationHints alloc] init]; + if (coreml_options_.UseStrategy("FastPrediction")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction; + config.optimizationHints = optimizationHints; + } else if (coreml_options_.UseStrategy("Default")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyDefault; + config.optimizationHints = optimizationHints; + } + } +#endif + if (coreml_options_.ProfileComputePlan()) { + ProfileComputePlan(compileUrl, config); + } + model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; if (error != nil || model_ == nil) { @@ -522,8 +596,8 @@ Status Predict(const std::unordered_map& inputs, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& logger, - uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)), + const CoreMLOptions& coreml_options) + : execution_(std::make_unique(path, logger, coreml_options)), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc index c6f2e7401ea1e..e9036e2fc7e1a 100644 --- a/onnxruntime/core/providers/coreml/model/model_stub.cc +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -4,6 +4,7 @@ #include "core/providers/coreml/model/model.h" namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution {}; @@ -15,7 +16,7 @@ Model::Model(const std::string& /*path*/, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& /*logger*/, - uint32_t /*coreml_flags*/) + const CoreMLOptions& /*coreml_flags*/) : execution_(std::make_unique()), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d57c33ae965b1..0499a15e1df0a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -31,8 +31,7 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} {} std::vector CPUExecutionProvider::CreatePreferredAllocators() { - const bool is_arena_requested = info_.create_arena; - const bool create_arena = ShouldCpuAllocatorUseArena(is_arena_requested); + const bool create_arena = DoesCpuAllocatorSupportArenaUsage() ? info_.create_arena : false; AllocatorCreationInfo device_info{[](int) { return std::make_unique(); }, DEFAULT_CPU_ALLOCATOR_DEVICE_ID, create_arena}; @@ -2926,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder); @@ -3044,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { TreeEnsembleRegressor)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 0a1a3a5995872..37db095e92570 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -51,7 +51,6 @@ class FusedConvFp16 final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -102,7 +101,6 @@ class FusedConvFp16 final : public OpKernel { }; Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index dbc7becdf2397..5406dd1a40446 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -248,7 +248,6 @@ template void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE template Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc_for_caching*/, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weight_for_caching*/) { is_packed = false; @@ -257,7 +256,7 @@ Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, Allocat template <> Status Gemm::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, + AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 92f05a7921f8b..953949732560d 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -21,7 +21,6 @@ class Gemm : protected GemmBase, public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 8f2c2c53b188b..2c6d23e4de908 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -173,7 +173,6 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, #endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index 0bb0e6c2ef596..b9bbe36583879 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -37,7 +37,6 @@ class MatMul final : public OpKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 2f4ebeabe043e..3359b2a69fe83 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -20,44 +20,48 @@ enum class OUTPUT_MODE { ALL_SCORES }; -enum NODE_MODE : uint8_t { - LEAF = 1, - BRANCH_LEQ = 2, - BRANCH_LT = 4, - BRANCH_GTE = 6, - BRANCH_GT = 8, - BRANCH_EQ = 10, - BRANCH_NEQ = 12 +enum NODE_MODE_ONNX : uint8_t { + BRANCH_LEQ = 0, + BRANCH_LT = 1, + BRANCH_GTE = 2, + BRANCH_GT = 3, + BRANCH_EQ = 4, + BRANCH_NEQ = 5, + BRANCH_MEMBER = 6, + LEAF = 7, }; -static inline NODE_MODE MakeTreeNodeMode(const std::string& input) { +static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) { if (input == "BRANCH_LEQ") { - return NODE_MODE::BRANCH_LEQ; + return NODE_MODE_ONNX::BRANCH_LEQ; } if (input == "LEAF") { - return NODE_MODE::LEAF; + return NODE_MODE_ONNX::LEAF; } if (input == "BRANCH_LT") { - return NODE_MODE::BRANCH_LT; + return NODE_MODE_ONNX::BRANCH_LT; } if (input == "BRANCH_GTE") { - return NODE_MODE::BRANCH_GTE; + return NODE_MODE_ONNX::BRANCH_GTE; } if (input == "BRANCH_GT") { - return NODE_MODE::BRANCH_GT; + return NODE_MODE_ONNX::BRANCH_GT; } if (input == "BRANCH_EQ") { - return NODE_MODE::BRANCH_EQ; + return NODE_MODE_ONNX::BRANCH_EQ; } - return NODE_MODE::BRANCH_NEQ; + if (input == "BRANCH_MEMBER") { + return NODE_MODE_ONNX::BRANCH_MEMBER; + } + return NODE_MODE_ONNX::BRANCH_NEQ; } -enum class POST_EVAL_TRANSFORM { - NONE, - LOGISTIC, - SOFTMAX, - SOFTMAX_ZERO, - PROBIT +enum class POST_EVAL_TRANSFORM : int64_t { + NONE = 0, + LOGISTIC = 1, + SOFTMAX = 2, + SOFTMAX_ZERO = 3, + PROBIT = 4 }; static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) { @@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) { return POST_EVAL_TRANSFORM::PROBIT; } -enum class AGGREGATE_FUNCTION { - AVERAGE, - SUM, - MIN, - MAX +enum class AGGREGATE_FUNCTION : int64_t { + AVERAGE = 0, + SUM = 1, + MIN = 2, + MAX = 3 }; static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) { diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc new file mode 100644 index 0000000000000..3ff501d96b72d --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/ml/tree_ensemble.h" +#include "core/providers/cpu/ml/tree_ensemble_helper.h" +#include "core/common/inlined_containers_fwd.h" + +namespace onnxruntime { +namespace ml { + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + TreeEnsemble, + 5, + float, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0), + TreeEnsemble); + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + TreeEnsemble, + 5, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0), + TreeEnsemble); + +template +TreeEnsemble::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) { + if constexpr (std::is_same::value) { + p_tree_ensemble_ = std::make_unique>(); + } else { + p_tree_ensemble_ = std::make_unique>(); + } + ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info)); +} + +template +Status TreeEnsemble::GetRemovableAttributes(InlinedVector& removable_attributes) const { + InlinedVector names{ + "leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs", + "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true", + "nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"}; + removable_attributes.swap(names); + return Status::OK(); +} + +template +common::Status TreeEnsemble::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (X->Shape().NumDimensions() == 0) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input shape needs to be at least a single dimension."); + } + int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0]; + Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()}); + return p_tree_ensemble_->compute(context, X, Y, NULL); +} + +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h new file mode 100644 index 0000000000000..697aae045a7e3 --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "tree_ensemble_common.h" + +namespace onnxruntime { +namespace ml { +template +class TreeEnsemble final : public OpKernel { + typedef T InputType; // input type + typedef float OutputType; // output type + public: + explicit TreeEnsemble(const OpKernelInfo& info); + common::Status Compute(OpKernelContext* context) const override; + Status GetRemovableAttributes(InlinedVector& removable_attributes) const override; + + private: + // Pointer on one instance of + // detail::TreeEnsembleCommonV5 + // where ThresholdType is defined after accessing the attributes. + std::unique_ptr p_tree_ensemble_; +}; +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index b031a6f0cefa3..bf3fd37d10f5c 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -78,6 +78,40 @@ union PtrOrWeight { } weight_data; }; +enum NODE_MODE_ORT : uint8_t { + LEAF = 1, + BRANCH_LEQ = 2, + BRANCH_LT = 4, + BRANCH_GTE = 6, + BRANCH_GT = 8, + BRANCH_EQ = 10, + BRANCH_NEQ = 12, + BRANCH_MEMBER = 14, +}; + +inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) { + switch (node_mode) { + case NODE_MODE_ONNX::LEAF: + return NODE_MODE_ORT::LEAF; + case NODE_MODE_ONNX::BRANCH_LEQ: + return NODE_MODE_ORT::BRANCH_LEQ; + case NODE_MODE_ONNX::BRANCH_LT: + return NODE_MODE_ORT::BRANCH_LT; + case NODE_MODE_ONNX::BRANCH_GTE: + return NODE_MODE_ORT::BRANCH_GTE; + case NODE_MODE_ONNX::BRANCH_GT: + return NODE_MODE_ORT::BRANCH_GT; + case NODE_MODE_ONNX::BRANCH_EQ: + return NODE_MODE_ORT::BRANCH_EQ; + case NODE_MODE_ONNX::BRANCH_NEQ: + return NODE_MODE_ORT::BRANCH_NEQ; + case NODE_MODE_ONNX::BRANCH_MEMBER: + return NODE_MODE_ORT::BRANCH_MEMBER; + default: + ORT_THROW("Unexpected value for node_mode"); + }; +} + template struct TreeNodeElement { int feature_id; @@ -98,10 +132,10 @@ struct TreeNodeElement { // weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also // stored in `value_or_unique_weight`. PtrOrWeight truenode_or_weight; - uint8_t flags; + NODE_MODE_ORT flags; - inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); } - inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); } + inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); } + inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); } inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; } }; diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h new file mode 100644 index 0000000000000..d2d1ba9863ac7 --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "ml_common.h" +#include "tree_ensemble_helper.h" +#include + +namespace onnxruntime { +namespace ml { +namespace detail { + +inline bool _isnan_(float x) { return std::isnan(x); } +inline bool _isnan_(double x) { return std::isnan(x); } +inline bool _isnan_(int64_t) { return false; } +inline bool _isnan_(int32_t) { return false; } + +template +struct TreeEnsembleAttributesV3 { + TreeEnsembleAttributesV3() {} + TreeEnsembleAttributesV3(const OpKernelInfo& info, bool classifier) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); + if (classifier) { + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", target_class_weights_as_tensor)); + } else { + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_class_weights_as_tensor)); + } +#endif + + aggregate_function = info.GetAttrOrDefault("aggregate_function", "SUM"); + base_values = info.GetAttrsOrDefault("base_values"); + nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids"); + nodes_featureids = info.GetAttrsOrDefault("nodes_featureids"); + nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true"); + + std::vector nodes_modes_string = info.GetAttrsOrDefault("nodes_modes"); + nodes_modes.reserve(nodes_modes_string.size()); + for (auto s : nodes_modes_string) { + nodes_modes.emplace_back(MakeTreeNodeMode(s)); + } + + nodes_nodeids = info.GetAttrsOrDefault("nodes_nodeids"); + nodes_treeids = info.GetAttrsOrDefault("nodes_treeids"); + nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids"); + nodes_values = info.GetAttrsOrDefault("nodes_values"); + post_transform = info.GetAttrOrDefault("post_transform", "NONE"); + + if (classifier) { + target_class_ids = info.GetAttrsOrDefault("class_ids"); + target_class_nodeids = info.GetAttrsOrDefault("class_nodeids"); + target_class_treeids = info.GetAttrsOrDefault("class_treeids"); + target_class_weights = info.GetAttrsOrDefault("class_weights"); + classlabels_strings = info.GetAttrsOrDefault("classlabels_strings"); + classlabels_int64s = info.GetAttrsOrDefault("classlabels_int64s"); + n_targets_or_classes = classlabels_strings.empty() ? classlabels_int64s.size() + : classlabels_strings.size(); + } else { + n_targets_or_classes = info.GetAttrOrDefault("n_targets", 0); + target_class_ids = info.GetAttrsOrDefault("target_ids"); + target_class_nodeids = info.GetAttrsOrDefault("target_nodeids"); + target_class_treeids = info.GetAttrsOrDefault("target_treeids"); + target_class_weights = info.GetAttrsOrDefault("target_weights"); + + ORT_ENFORCE(n_targets_or_classes > 0); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes_string.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() || + nodes_falsenodeids.size() == nodes_values_as_tensor.size()); + ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size()); + ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size()); + ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size()); + ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty()); + ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty()); + ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty()); + ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty()); + ORT_ENFORCE(nodes_modes_string.size() < std::numeric_limits::max()); + } + } + + std::string aggregate_function; + std::vector base_values; + std::vector base_values_as_tensor; + int64_t n_targets_or_classes; + std::vector nodes_falsenodeids; + std::vector nodes_featureids; + std::vector nodes_hitrates; + std::vector nodes_hitrates_as_tensor; + std::vector nodes_missing_value_tracks_true; + std::vector nodes_modes; + std::vector nodes_nodeids; + std::vector nodes_treeids; + std::vector nodes_truenodeids; + std::vector nodes_values; + std::vector nodes_values_as_tensor; + std::string post_transform; + std::vector target_class_ids; + std::vector target_class_nodeids; + std::vector target_class_treeids; + std::vector target_class_weights; + std::vector target_class_weights_as_tensor; + std::vector classlabels_strings; + std::vector classlabels_int64s; + std::vector class_labels; +}; + +template +struct TreeEnsembleAttributesV5 { + TreeEnsembleAttributesV5() {} + TreeEnsembleAttributesV5(const OpKernelInfo& info) { +#if !defined(ORT_MINIMAL_BUILD) + std::vector nodes_modes_i; + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "leaf_weights", leaf_weights)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "membership_values", membership_values)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates", nodes_hitrates)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes_i)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_splits", nodes_splits)); + nodes_modes.reserve(nodes_modes.size()); + for (auto i : nodes_modes_i) { + nodes_modes.push_back(static_cast(i)); + } +#else + // GetVectorAttrsOrDefault is not part of the minimal build. + // As a result, TreeEnsemble v5 cannot be available in this build. + ORT_THROW("TreeEnsemble(ai.onnx.ml==5) is not supported with the minimal build."); +#endif + + aggregate_function = info.GetAttrOrDefault("aggregate_function", 1); + leaf_targetids = info.GetAttrsOrDefault("leaf_targetids"); + n_targets = info.GetAttrOrDefault("n_targets", 0); + nodes_falseleafs = info.GetAttrsOrDefault("nodes_falseleafs"); + nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids"); + nodes_featureids = info.GetAttrsOrDefault("nodes_featureids"); + nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true"); + nodes_trueleafs = info.GetAttrsOrDefault("nodes_trueleafs"); + nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids"); + post_transform = info.GetAttrOrDefault("post_transform", 0); + tree_roots = info.GetAttrsOrDefault("tree_roots"); + } + + void convert_to_v3(TreeEnsembleAttributesV3& output) const { + // Doing all transformations to get the old format. + output.n_targets_or_classes = n_targets; + output.aggregate_function = aggregateFunctionToString(); + output.post_transform = postTransformToString(); + std::vector> membership_values_by_id; + getMembershipValuesById(membership_values_by_id); + transformInputAllTrees(output, membership_values_by_id); + } + + int64_t aggregate_function; + std::vector leaf_targetids; + std::vector leaf_weights; + std::vector membership_values; + int64_t n_targets; + std::vector nodes_falseleafs; + std::vector nodes_falsenodeids; + std::vector nodes_featureids; + std::vector nodes_hitrates; + std::vector nodes_missing_value_tracks_true; + std::vector nodes_modes; + std::vector nodes_splits; + std::vector nodes_trueleafs; + std::vector nodes_truenodeids; + int64_t post_transform; + std::vector tree_roots; + + private: + // `membership_values` are seperated by NAN for different nodes + // It is more convenient to preserve the values for each node in a vector + // The vector would be empty for nodes that are not `BRANCH_MEMBER` + void getMembershipValuesById(std::vector>& membership_values_by_id) const { + membership_values_by_id.clear(); + membership_values_by_id.reserve(nodes_modes.size()); + + size_t curr_id = 0; + for (const auto node_mode : nodes_modes) { + membership_values_by_id.emplace_back(); + if (node_mode != NODE_MODE_ONNX::BRANCH_MEMBER) { + continue; + } + + while (curr_id < membership_values.size() && !_isnan_(membership_values[curr_id])) { + membership_values_by_id.back().push_back(membership_values[curr_id++]); + } + curr_id++; + } + } + + std::string aggregateFunctionToString() const { + switch (aggregate_function) { + case static_cast(AGGREGATE_FUNCTION::AVERAGE): + return "AVERAGE"; + case static_cast(AGGREGATE_FUNCTION::SUM): + return "SUM"; + case static_cast(AGGREGATE_FUNCTION::MIN): + return "MIN"; + case static_cast(AGGREGATE_FUNCTION::MAX): + return "MAX"; + default: + ORT_THROW("Unknown value for aggregate_function."); + } + } + + std::string postTransformToString() const { + switch (post_transform) { + case static_cast(POST_EVAL_TRANSFORM::NONE): + return "NONE"; + case static_cast(POST_EVAL_TRANSFORM::SOFTMAX): + return "SOFTMAX"; + case static_cast(POST_EVAL_TRANSFORM::LOGISTIC): + return "LOGISTIC"; + case static_cast(POST_EVAL_TRANSFORM::SOFTMAX_ZERO): + return "SOFTMAX_ZERO"; + case static_cast(POST_EVAL_TRANSFORM::PROBIT): + return "PROBIT"; + default: + ORT_THROW("Unknown value for post_transform."); + } + } + + int64_t transformInputOneTree( + const size_t curr_id, const int64_t curr_treeid, const int64_t curr_nodeid, const size_t curr_membership_value_id, + const bool is_leaf, std::vector>& membership_values_by_id, + TreeEnsembleAttributesV3& output) const { + output.nodes_nodeids.push_back(curr_nodeid); + output.nodes_treeids.push_back(curr_treeid); + + if (is_leaf) { + output.nodes_modes.push_back(NODE_MODE_ONNX::LEAF); + output.target_class_ids.push_back(leaf_targetids[curr_id]); + output.target_class_nodeids.push_back(curr_nodeid); + output.target_class_treeids.push_back(curr_treeid); + output.target_class_weights_as_tensor.push_back(leaf_weights[curr_id]); + + // the below are irrelevant for a `LEAF` + output.nodes_featureids.push_back(0); + output.nodes_truenodeids.push_back(0); + output.nodes_falsenodeids.push_back(0); + output.nodes_values_as_tensor.push_back(0); + if (!nodes_hitrates.empty()) { + output.nodes_hitrates.push_back(0); + } + if (!nodes_missing_value_tracks_true.empty()) { + output.nodes_missing_value_tracks_true.push_back(0); + } + + return curr_nodeid; + } + + output.nodes_featureids.push_back(nodes_featureids[curr_id]); + if (!nodes_hitrates.empty()) { + output.nodes_hitrates_as_tensor.push_back(nodes_hitrates[curr_id]); + } + if (!nodes_missing_value_tracks_true.empty()) { + output.nodes_missing_value_tracks_true.push_back(nodes_missing_value_tracks_true[curr_id]); + } + + // unroll `BRANCH_MEMBER` to a chain of `BRANCH_EQ` + if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER) { + output.nodes_modes.push_back(NODE_MODE_ONNX::BRANCH_EQ); + output.nodes_values_as_tensor.push_back(membership_values_by_id[curr_id][curr_membership_value_id]); + } else { + output.nodes_modes.push_back(nodes_modes[curr_id]); + output.nodes_values_as_tensor.push_back(nodes_splits[curr_id]); + } + + size_t falsenodeid_id = output.nodes_falsenodeids.size(); + output.nodes_falsenodeids.push_back(0); // change after pushing truenode subtree + + int64_t true_nodeid = curr_nodeid + 1; + output.nodes_truenodeids.push_back(true_nodeid); + true_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_truenodeids[curr_id]), + curr_treeid, true_nodeid, 0U, nodes_trueleafs[curr_id] != 0, + membership_values_by_id, output); + + int64_t false_nodeid = true_nodeid + 1; + output.nodes_falsenodeids[falsenodeid_id] = false_nodeid; + + // if node is `BRANCH_MEMBER` we are unrolling the `membership_values` for that node + // therefore if the value is not the last, the `falsenode_id` must be pointing to the "same" node with a different membership value + // so in that case we are only moving the pointer for `membership_values` + // + // otherwise, the `falsenode_id` is pointing to the real falsenode subtree + if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER && + curr_membership_value_id + 1 < membership_values_by_id[curr_id].size()) { + false_nodeid = transformInputOneTree(curr_id, curr_treeid, false_nodeid, curr_membership_value_id + 1, false, + membership_values_by_id, output); + } else { + false_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_falsenodeids[curr_id]), + curr_treeid, false_nodeid, 0U, nodes_falseleafs[curr_id] != 0, + membership_values_by_id, output); + } + return false_nodeid; + } + + void transformInputAllTrees(TreeEnsembleAttributesV3& output, + std::vector>& membership_values_by_id) const { + int64_t curr_treeid = 0; + for (const int64_t& tree_root : tree_roots) { + size_t tree_root_size_t = onnxruntime::narrow(tree_root); + transformInputOneTree(tree_root_size_t, curr_treeid, 0, 0U, + nodes_falsenodeids[tree_root_size_t] == nodes_truenodeids[tree_root_size_t], + membership_values_by_id, output); + curr_treeid++; + } + } +}; + +} // namespace detail +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 94f79518ae8da..10d4db0e0e3b0 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -3,15 +3,21 @@ #pragma once -#include "tree_ensemble_aggregator.h" #include #include "core/platform/threadpool.h" #include "tree_ensemble_helper.h" +#include "tree_ensemble_attribute.h" +#include "tree_ensemble_aggregator.h" namespace onnxruntime { namespace ml { namespace detail { +/** + * These attributes are the kernel attributes. They are different from the onnx operator attributes + * to improve the computation efficiency. The initialization consists in moving the onnx attributes + * into the kernel attributes. + */ class TreeEnsembleCommonAttributes { public: int64_t get_target_or_class_count() const { return this->n_targets_or_classes_; } @@ -57,27 +63,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { Status Init(int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - int64_t n_targets_or_classes, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& target_class_ids, - const std::vector& target_class_nodeids, - const std::vector& target_class_treeids, - const std::vector& target_class_weights, - const std::vector& target_class_weights_as_tensor); + const TreeEnsembleAttributesV3& attributes); protected: TreeNodeElement* ProcessTreeNodeLeave(TreeNodeElement* root, @@ -87,49 +73,52 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const; private: - size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, - const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, - const std::vector& nodes_values_as_tensor, const std::vector& node_values, - const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, - int64_t tree_id, const InlinedVector& node_tree_ids); + bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector& cmodes, + const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, + const InlinedVector& node_tree_ids, InlinedVector> indices); + size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span nodes_missing_value_tracks_true, std::vector& updated_mapping, + int64_t tree_id, const InlinedVector& node_tree_ids, gsl::span target_class_weights, + gsl::span target_class_weights_as_tensor, InlinedVector>& indices); }; +// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17 +// Remove it when that is not the case +template +std::enable_if_t< + sizeof(To) == sizeof(From) && + std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + To> + // constexpr support needs compiler magic + static bit_cast(const From& src) noexcept { + static_assert(std::is_trivially_constructible_v, + "This implementation additionally requires " + "destination type to be trivially constructible"); + + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template +std::conditional_t bit_cast_int(T val) { + if constexpr (sizeof(T) == sizeof(uint32_t)) { + return bit_cast(val); + } else if constexpr (sizeof(T) == sizeof(uint64_t)) { + return bit_cast(val); + } + static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t)); +} + template Status TreeEnsembleCommon::Init(const OpKernelInfo& info) { - std::vector base_values_as_tensor, nodes_hitrates_as_tensor, - nodes_values_as_tensor, target_weights_as_tensor; -#if !defined(ORT_MINIMAL_BUILD) - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_weights_as_tensor)); -#endif - - return Init( - 80, - 128, - 50, - info.GetAttrOrDefault("aggregate_function", "SUM"), - info.GetAttrsOrDefault("base_values"), - base_values_as_tensor, - info.GetAttrOrDefault("n_targets", 0), - info.GetAttrsOrDefault("nodes_falsenodeids"), - info.GetAttrsOrDefault("nodes_featureids"), - info.GetAttrsOrDefault("nodes_hitrates"), - nodes_hitrates_as_tensor, - info.GetAttrsOrDefault("nodes_missing_value_tracks_true"), - info.GetAttrsOrDefault("nodes_modes"), - info.GetAttrsOrDefault("nodes_nodeids"), - info.GetAttrsOrDefault("nodes_treeids"), - info.GetAttrsOrDefault("nodes_truenodeids"), - info.GetAttrsOrDefault("nodes_values"), - nodes_values_as_tensor, - info.GetAttrOrDefault("post_transform", "NONE"), - info.GetAttrsOrDefault("target_ids"), - info.GetAttrsOrDefault("target_nodeids"), - info.GetAttrsOrDefault("target_treeids"), - info.GetAttrsOrDefault("target_weights"), - target_weights_as_tensor); + TreeEnsembleAttributesV3 attributes(info, false); + return Init(80, 128, 50, attributes); } template @@ -137,72 +126,35 @@ Status TreeEnsembleCommon::Init( int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - int64_t n_targets_or_classes, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& target_class_ids, - const std::vector& target_class_nodeids, - const std::vector& target_class_treeids, - const std::vector& target_class_weights, - const std::vector& target_class_weights_as_tensor) { + const TreeEnsembleAttributesV3& attributes) { parallel_tree_ = parallel_tree; parallel_tree_N_ = parallel_tree_N; parallel_N_ = parallel_N; - ORT_ENFORCE(n_targets_or_classes > 0); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() || - nodes_falsenodeids.size() == nodes_values_as_tensor.size()); - ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size()); - ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size()); - ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size()); - ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty()); - ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty()); - ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty()); - ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty()); - - aggregate_function_ = MakeAggregateFunction(aggregate_function); - post_transform_ = MakeTransform(post_transform); - if (!base_values_as_tensor.empty()) { - ORT_ENFORCE(base_values.empty()); - base_values_ = base_values_as_tensor; + aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function); + post_transform_ = MakeTransform(attributes.post_transform); + if (!attributes.base_values_as_tensor.empty()) { + ORT_ENFORCE(attributes.base_values.empty()); + base_values_ = attributes.base_values_as_tensor; } else { - base_values_.reserve(base_values.size()); - for (size_t i = 0, limit = base_values.size(); i < limit; ++i) { - base_values_.push_back(static_cast(base_values[i])); + base_values_.reserve(attributes.base_values.size()); + for (size_t i = 0, limit = attributes.base_values.size(); i < limit; ++i) { + base_values_.push_back(static_cast(attributes.base_values[i])); } } - n_targets_or_classes_ = n_targets_or_classes; + n_targets_or_classes_ = attributes.n_targets_or_classes; max_tree_depth_ = 1000; - ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max()); // Additional members size_t limit; uint32_t i; - InlinedVector cmodes; - cmodes.reserve(nodes_modes.size()); + InlinedVector cmodes; + cmodes.reserve(attributes.nodes_modes.size()); same_mode_ = true; int fpos = -1; - for (i = 0, limit = nodes_modes.size(); i < limit; ++i) { - cmodes.push_back(MakeTreeNodeMode(nodes_modes[i])); - if (cmodes[i] == NODE_MODE::LEAF) continue; + for (i = 0, limit = attributes.nodes_modes.size(); i < limit; ++i) { + cmodes.push_back(attributes.nodes_modes[i]); + if (cmodes[i] == NODE_MODE_ONNX::LEAF) continue; if (fpos == -1) { fpos = static_cast(i); continue; @@ -210,7 +162,7 @@ Status TreeEnsembleCommon::Init( if (cmodes[i] != cmodes[fpos]) same_mode_ = false; } - n_nodes_ = nodes_treeids.size(); + n_nodes_ = attributes.nodes_treeids.size(); limit = static_cast(n_nodes_); InlinedVector node_tree_ids; node_tree_ids.reserve(limit); @@ -227,7 +179,7 @@ Status TreeEnsembleCommon::Init( // Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids for (i = 0; i < limit; ++i) { - TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), static_cast(nodes_nodeids[i])}; + TreeNodeElementId node_tree_id{static_cast(attributes.nodes_treeids[i]), static_cast(attributes.nodes_nodeids[i])}; auto p = node_tree_ids_map.insert(std::pair(node_tree_id, i)); if (!p.second) { ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there."); @@ -237,13 +189,13 @@ Status TreeEnsembleCommon::Init( TreeNodeElementId coor; for (i = 0; i < limit; ++i) { - if (cmodes[i] == NODE_MODE::LEAF) { + if (cmodes[i] == NODE_MODE_ONNX::LEAF) { truenode_ids.push_back(0); falsenode_ids.push_back(0); } else { TreeNodeElementId& node_tree_id = node_tree_ids[i]; coor.tree_id = node_tree_id.tree_id; - coor.node_id = static_cast(nodes_truenodeids[i]); + coor.node_id = static_cast(attributes.nodes_truenodeids[i]); ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); auto found = node_tree_ids_map.find(coor); @@ -255,7 +207,7 @@ Status TreeEnsembleCommon::Init( } truenode_ids.emplace_back(found->second); - coor.node_id = static_cast(nodes_falsenodeids[i]); + coor.node_id = static_cast(attributes.nodes_falsenodeids[i]); ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); found = node_tree_ids_map.find(coor); if (found == node_tree_ids_map.end()) { @@ -270,41 +222,38 @@ Status TreeEnsembleCommon::Init( } } + // Sort targets + InlinedVector> indices; + indices.reserve(attributes.target_class_nodeids.size()); + for (i = 0, limit = attributes.target_class_nodeids.size(); i < limit; i++) { + indices.emplace_back( + TreeNodeElementId{attributes.target_class_treeids[i], attributes.target_class_nodeids[i]}, i); + } + + std::sort(indices.begin(), indices.end()); + // Let's construct nodes_ such that the false branch is always the next element in nodes_. // updated_mapping will translates the old position of each node to the new node position in nodes_. - std::vector updated_mapping(nodes_treeids.size(), 0); + std::vector updated_mapping(attributes.nodes_treeids.size(), 0); int64_t previous_tree_id = -1; for (i = 0; i < n_nodes_; ++i) { if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) { // New tree. int64_t tree_id = node_tree_ids[i].tree_id; size_t root_position = - AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, - nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + AddNodes(i, cmodes, truenode_ids, falsenode_ids, attributes.nodes_featureids, attributes.nodes_values_as_tensor, attributes.nodes_values, + attributes.nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + attributes.target_class_weights, attributes.target_class_weights_as_tensor, indices); roots_.push_back(&nodes_[root_position]); previous_tree_id = tree_id; } } - n_trees_ = roots_.size(); - if (((int64_t)nodes_.size()) != n_nodes_) { - ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ")."); - } - - // Sort targets - InlinedVector> indices; - indices.reserve(target_class_nodeids.size()); - for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { - indices.emplace_back( - std::pair(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i)); - } - - std::sort(indices.begin(), indices.end()); TreeNodeElementId ind; SparseValue w; size_t indi; - for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) { + for (indi = 0, limit = attributes.target_class_nodeids.size(); indi < limit; ++indi) { ind = indices[indi].first; i = indices[indi].second; auto found = node_tree_ids_map.find(ind); @@ -319,9 +268,10 @@ Status TreeEnsembleCommon::Init( // ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf."); continue; } - w.i = target_class_ids[i]; - w.value = target_class_weights_as_tensor.empty() ? static_cast(target_class_weights[i]) - : target_class_weights_as_tensor[i]; + w.i = attributes.target_class_ids[i]; + w.value = attributes.target_class_weights_as_tensor.empty() + ? static_cast(attributes.target_class_weights[i]) + : attributes.target_class_weights_as_tensor[i]; if (leaf.truenode_or_weight.weight_data.n_weights == 0) { leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; @@ -331,7 +281,7 @@ Status TreeEnsembleCommon::Init( } has_missing_tracks_ = false; - for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) { + for (auto itm = attributes.nodes_missing_value_tracks_true.begin(); itm != attributes.nodes_missing_value_tracks_true.end(); ++itm) { if (*itm) { has_missing_tracks_ = true; break; @@ -341,13 +291,58 @@ Status TreeEnsembleCommon::Init( return Status::OK(); } +template +bool TreeEnsembleCommon::CheckIfSubtreesAreEqual( + const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector& cmodes, + const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, + const InlinedVector& node_tree_ids, InlinedVector> indices) { + // Leaves have values set at 0 + if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] || + (!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) || + (nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) { + return false; + } + + if (cmodes[left_id] == NODE_MODE_ONNX::LEAF) { + const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second; + const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second; + + if (target_class_weights_as_tensor.empty()) { + return target_class_weights[left_target_node] == target_class_weights[right_target_node]; + } else { + return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node]; + } + } + + return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, + nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) && + CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, + nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices); +} + +inline void UpdateThreshold(double val, double& mask) { + uint64_t new_mask = bit_cast(mask) | (1ll << (static_cast(val) - 1)); + mask = bit_cast(new_mask); +} + +inline void UpdateThreshold(float val, float& mask) { + uint32_t new_mask = bit_cast(mask) | (1 << (static_cast(val) - 1)); + mask = bit_cast(new_mask); +} + +#define BITCOUNT(T) int64_t(sizeof(T) * 8) +#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T)) && v == std::floor(v) + template size_t TreeEnsembleCommon::AddNodes( - const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, - const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, - const std::vector& nodes_values_as_tensor, const std::vector& node_values, - const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, - const InlinedVector& node_tree_ids) { + const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, + const InlinedVector& node_tree_ids, gsl::span target_class_weights, + gsl::span target_class_weights_as_tensor, InlinedVector>& indices) { // Validate this index maps to the same tree_id as the one we should be building. if (node_tree_ids[i].tree_id != tree_id) { ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i); @@ -364,28 +359,59 @@ size_t TreeEnsembleCommon::AddNodes( updated_mapping[i] = node_pos; TreeNodeElement node; - node.flags = static_cast(cmodes[i]); + node.flags = Convert_NODE_MODE_ONNX_to_ORT(cmodes[i]); node.feature_id = static_cast(nodes_featureids[i]); if (node.feature_id > max_feature_id_) { max_feature_id_ = node.feature_id; } - node.value_or_unique_weight = - nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + + node.value_or_unique_weight = 0; + const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + if (node.flags == NODE_MODE_ORT::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) { + UpdateThreshold(node_threshold, node.value_or_unique_weight); + node.flags = NODE_MODE_ORT::BRANCH_MEMBER; + } else { + node.value_or_unique_weight = node_threshold; + } + if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { - node.flags |= static_cast(MissingTrack::kTrue); + node.flags = static_cast(static_cast(node.flags) | static_cast(MissingTrack::kTrue)); } nodes_.push_back(std::move(node)); if (nodes_[node_pos].is_not_leaf()) { + size_t falsenode_id = falsenode_ids[i]; + + // Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain + // Below we are folding together these nodes into one of mode `BRANCH_MEMBER` + // The threshold of this node should be interpreted as a bitmask showing which categoricals values were found in the chain + // Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node + // and the one of the feature (the mask has only one bit set on the place for its value) + // Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done + if (nodes_[node_pos].flags == NODE_MODE_ORT::BRANCH_MEMBER) { + ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id]; + + while (cmodes[falsenode_id] == NODE_MODE_ONNX::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] && + CANMASK(falsenode_threshold, ThresholdType) && + CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids, + nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) { + UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight); + falsenode_id = falsenode_ids[falsenode_id]; + falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id]; + } + } + size_t false_branch = - AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, - node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + target_class_weights, target_class_weights_as_tensor, indices); if (false_branch != node_pos + 1) { ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", static_cast(nodes_[node_pos].flags)); } size_t true_branch = AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, - node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + target_class_weights, target_class_weights_as_tensor, indices); // We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_. // nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch]; nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch]; @@ -684,10 +710,12 @@ void TreeEnsembleCommon::ComputeAgg(concur } \ } -inline bool _isnan_(float x) { return std::isnan(x); } -inline bool _isnan_(double x) { return std::isnan(x); } -inline bool _isnan_(int64_t) { return false; } -inline bool _isnan_(int32_t) { return false; } +// Check whether the feature value is set true in the mask +template +inline bool SetMembershipCheck(T1 val, T2 mask) { + const int64_t val_as_int = static_cast(val); + return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0); +} template TreeNodeElement* @@ -696,7 +724,7 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( InputType val; if (same_mode_) { switch (root->mode()) { - case NODE_MODE::BRANCH_LEQ: + case NODE_MODE_ORT::BRANCH_LEQ: if (has_missing_tracks_) { while (root->is_not_leaf()) { val = x_data[root->feature_id]; @@ -711,22 +739,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } break; - case NODE_MODE::BRANCH_LT: + case NODE_MODE_ORT::BRANCH_LT: TREE_FIND_VALUE(<) break; - case NODE_MODE::BRANCH_GTE: + case NODE_MODE_ORT::BRANCH_GTE: TREE_FIND_VALUE(>=) break; - case NODE_MODE::BRANCH_GT: + case NODE_MODE_ORT::BRANCH_GT: TREE_FIND_VALUE(>) break; - case NODE_MODE::BRANCH_EQ: + case NODE_MODE_ORT::BRANCH_EQ: TREE_FIND_VALUE(==) break; - case NODE_MODE::BRANCH_NEQ: + case NODE_MODE_ORT::BRANCH_NEQ: TREE_FIND_VALUE(!=) break; - case NODE_MODE::LEAF: + case NODE_MODE_ORT::BRANCH_MEMBER: + if (has_missing_tracks_) { + while (root->is_not_leaf()) { + val = x_data[root->feature_id]; + root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; + } + } else { + while (root->is_not_leaf()) { + val = x_data[root->feature_id]; + root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1; + } + } + case NODE_MODE_ORT::LEAF: break; } } else { // Different rules to compare to node thresholds. @@ -735,31 +777,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( val = x_data[root->feature_id]; threshold = root->value_or_unique_weight; switch (root->mode()) { - case NODE_MODE::BRANCH_LEQ: + case NODE_MODE_ORT::BRANCH_LEQ: root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_LT: + case NODE_MODE_ORT::BRANCH_LT: root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_GTE: + case NODE_MODE_ORT::BRANCH_GTE: root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_GT: + case NODE_MODE_ORT::BRANCH_GT: root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_EQ: + case NODE_MODE_ORT::BRANCH_EQ: root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_NEQ: + case NODE_MODE_ORT::BRANCH_NEQ: root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::LEAF: + case NODE_MODE_ORT::BRANCH_MEMBER: + root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; + break; + case NODE_MODE_ORT::LEAF: return root; } } @@ -786,67 +833,13 @@ class TreeEnsembleCommonClassifier : public TreeEnsembleCommon& base_values, - const std::vector& base_values_as_tensor, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& class_ids, - const std::vector& class_nodeids, - const std::vector& class_treeids, - const std::vector& class_weights, - const std::vector& class_weights_as_tensor, - const std::vector& classlabels_strings, - const std::vector& classlabels_int64s); + const TreeEnsembleAttributesV3& attributes); }; template Status TreeEnsembleCommonClassifier::Init(const OpKernelInfo& info) { - std::vector base_values_as_tensor, nodes_hitrates_as_tensor, - nodes_values_as_tensor, class_weights_as_tensor; -#if !defined(ORT_MINIMAL_BUILD) - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", class_weights_as_tensor)); -#endif - - return Init( - 80, - 128, - 50, - info.GetAttrOrDefault("aggregate_function", "SUM"), - info.GetAttrsOrDefault("base_values"), - base_values_as_tensor, - info.GetAttrsOrDefault("nodes_falsenodeids"), - info.GetAttrsOrDefault("nodes_featureids"), - info.GetAttrsOrDefault("nodes_hitrates"), - nodes_hitrates_as_tensor, - info.GetAttrsOrDefault("nodes_missing_value_tracks_true"), - info.GetAttrsOrDefault("nodes_modes"), - info.GetAttrsOrDefault("nodes_nodeids"), - info.GetAttrsOrDefault("nodes_treeids"), - info.GetAttrsOrDefault("nodes_truenodeids"), - info.GetAttrsOrDefault("nodes_values"), - nodes_values_as_tensor, - info.GetAttrOrDefault("post_transform", "NONE"), - info.GetAttrsOrDefault("class_ids"), - info.GetAttrsOrDefault("class_nodeids"), - info.GetAttrsOrDefault("class_treeids"), - info.GetAttrsOrDefault("class_weights"), - class_weights_as_tensor, - info.GetAttrsOrDefault("classlabels_strings"), - info.GetAttrsOrDefault("classlabels_int64s")); + TreeEnsembleAttributesV3 attributes(info, true); + return Init(80, 128, 50, attributes); } template @@ -854,65 +847,20 @@ Status TreeEnsembleCommonClassifier::Init( int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& class_ids, - const std::vector& class_nodeids, - const std::vector& class_treeids, - const std::vector& class_weights, - const std::vector& class_weights_as_tensor, - const std::vector& classlabels_strings, - const std::vector& classlabels_int64s) { - auto status = TreeEnsembleCommon::Init( - parallel_tree, - parallel_tree_N, - parallel_N, - aggregate_function, - base_values, - base_values_as_tensor, - classlabels_strings.empty() ? classlabels_int64s.size() - : classlabels_strings.size(), - nodes_falsenodeids, - nodes_featureids, - nodes_hitrates, - nodes_hitrates_as_tensor, - nodes_missing_value_tracks_true, - nodes_modes, - nodes_nodeids, - nodes_treeids, - nodes_truenodeids, - nodes_values, - nodes_values_as_tensor, - post_transform, - class_ids, - class_nodeids, - class_treeids, - class_weights, - class_weights_as_tensor); + const TreeEnsembleAttributesV3& attributes) { + auto status = TreeEnsembleCommon::Init(parallel_tree, parallel_tree_N, parallel_N, attributes); ORT_RETURN_IF_ERROR(status); - classlabels_strings_ = classlabels_strings; - classlabels_int64s_ = classlabels_int64s; + classlabels_strings_ = attributes.classlabels_strings; + classlabels_int64s_ = attributes.classlabels_int64s; InlinedHashSet weights_classes; - weights_classes.reserve(class_ids.size()); + weights_classes.reserve(attributes.target_class_ids.size()); weights_are_all_positive_ = true; - for (size_t i = 0, end = class_ids.size(); i < end; ++i) { - weights_classes.insert(class_ids[i]); - if (weights_are_all_positive_ && (!class_weights.empty() ? class_weights[i] : class_weights_as_tensor[i]) < 0) + for (size_t i = 0, end = attributes.target_class_ids.size(); i < end; ++i) { + weights_classes.insert(attributes.target_class_ids[i]); + if (weights_are_all_positive_ && (!attributes.target_class_weights.empty() ? attributes.target_class_weights[i] + : attributes.target_class_weights_as_tensor[i]) < 0) weights_are_all_positive_ = false; } binary_case_ = this->n_targets_or_classes_ == 2 && weights_classes.size() == 1; @@ -957,6 +905,43 @@ Status TreeEnsembleCommonClassifier::compu return Status::OK(); } +template +class TreeEnsembleCommonV5 : public TreeEnsembleCommon { + public: + virtual Status Init(const OpKernelInfo& info); + + Status Init(int parallel_tree, + int parallel_tree_N, + int parallel_N, + const TreeEnsembleAttributesV5& attributes); +}; + +template +Status TreeEnsembleCommonV5::Init(const OpKernelInfo& info) { + TreeEnsembleAttributesV5 attributes(info); + return Init(80, 128, 50, attributes); +} + +template +Status TreeEnsembleCommonV5::Init( + int parallel_tree, + int parallel_tree_N, + int parallel_N, + const TreeEnsembleAttributesV5& attributes) { + TreeEnsembleAttributesV3 attributes_v3; + attributes.convert_to_v3(attributes_v3); + + attributes_v3.base_values.clear(); + attributes_v3.base_values_as_tensor.clear(); + attributes_v3.nodes_hitrates.clear(); + attributes_v3.nodes_values.clear(); + attributes_v3.target_class_weights.clear(); + + auto status = TreeEnsembleCommon::Init(parallel_tree, parallel_tree_N, parallel_N, attributes_v3); + ORT_RETURN_IF_ERROR(status); + return Status::OK(); +} + } // namespace detail } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc index e2981da3a6f25..399dfd56b93c6 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc @@ -5,63 +5,53 @@ #include "core/providers/cpu/ml/tree_ensemble_helper.h" #include "core/common/common.h" +#include "core/common/safeint.h" #include "onnx/defs/tensor_proto_util.h" +#include "core/framework/tensorprotoutils.h" using namespace ::onnxruntime::common; using namespace std; namespace onnxruntime { namespace ml { -Status GetNumberOfElementsAttrsOrDefault(const OpKernelInfo& info, const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - size_t& n_elements, ONNX_NAMESPACE::TensorProto& proto) { - auto status = info.GetAttr(name, &proto); - if (!status.IsOK()) { - // Attribute is missing, n_elements is set to 0. - n_elements = 0; - return Status::OK(); - } - auto n_dims = proto.dims_size(); - if (n_dims == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attribute:'", name, "' is specified but is empty."); - } - ORT_ENFORCE(n_dims == 1, "Attribute '", name, "' must be a vector."); - ORT_ENFORCE(proto.data_type() == proto_type, - "Unexpected type (", proto.data_type(), "(for attribute '", name, "'."); - - n_elements = onnxruntime::narrow(proto.dims()[0]); - ORT_ENFORCE(n_elements > 0, "Attribute '", name, "' has one dimension but is empty."); - return Status::OK(); -} +template +Status GetAnyVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + ONNX_NAMESPACE::TensorProto proto; + auto result = info.GetAttr(name, &proto); -template -Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, std::vector& data) { - if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE) { - ORT_ENFORCE((std::is_same::value)); - } else if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) { - ORT_ENFORCE((std::is_same::value)); - } else { - ORT_NOT_IMPLEMENTED("GetVectorAttrsOrDefault not implemented for type ", proto_type); + SafeInt n_elements(1); + for (auto dim : proto.dims()) { + n_elements *= dim; } - ONNX_NAMESPACE::TensorProto proto; - size_t n_elements; - data.clear(); - ORT_THROW_IF_ERROR(GetNumberOfElementsAttrsOrDefault(info, name, proto_type, n_elements, proto)); - if (n_elements == 0) { + if (proto.dims().empty()) { return Status::OK(); } - data = ONNX_NAMESPACE::ParseData(&proto); + + const SafeInt tensor_size(n_elements); + data.clear(); + data.resize(tensor_size); + + result = utils::UnpackTensor(proto, std::filesystem::path(), data.data(), tensor_size); + ORT_ENFORCE(result.IsOK(), "TreeEnsemble could not unpack tensor attribute ", name); + return Status::OK(); } Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { - return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE, data); + return GetAnyVectorAttrsOrDefault(info, name, data); } Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { - return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, data); + return GetAnyVectorAttrsOrDefault(info, name, data); +} + +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + return GetAnyVectorAttrsOrDefault(info, name, data); +} + +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + return GetAnyVectorAttrsOrDefault(info, name, data); } } // namespace ml diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h index 33172c343a88e..ba23f1ad28ec1 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h @@ -13,6 +13,8 @@ namespace ml { Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index 2c7afddf38070..f0c1b0b409831 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -38,7 +38,6 @@ ONNX_CPU_OPERATOR_KERNEL( template Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/ ) { @@ -48,7 +47,6 @@ Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, Al template <> Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index d03b5566e334f..c82cd5ad49d7e 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -28,7 +28,6 @@ class ConvTranspose : public OpKernel { ConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index fe2bf1035bb65..24a5dcab225c4 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -229,7 +229,6 @@ Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const { } Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index abce87d03c14b..f8b528b398cba 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -15,7 +15,7 @@ class LayerNormImpl : public OpKernel { LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) override; // This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`. diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index 8a8ce27990069..e26eae19b8fd4 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -14,7 +14,6 @@ class MatMulIntegerBase : public OpKernel { MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 736cde24591ff..7797cbe678bd4 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -25,7 +25,6 @@ class QLinearConv : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -361,7 +360,6 @@ REGISTER_QLINEARCONV_INT8_KERNEL(kMSDomain, 1); template Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index 7afd00eacef89..b78c5236e6fab 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -284,7 +284,6 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& } Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 914077b2f2c15..5a6dd97c7c3f2 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -62,7 +62,6 @@ class DeepCpuGruOp final : public OpKernel { private: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -198,4 +197,4 @@ class UniDirectionalGru { }; } // namespace detail -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index e4082e5d7634a..09bbf6c4c79e6 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -225,9 +225,7 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, - /*out*/ bool& is_packed, + AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index ff8ab9abf0eed..9c4c12954022a 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -19,7 +19,6 @@ class DeepCpuLstmOp final : public OpKernel, public LSTMBase { DeepCpuLstmOp(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h index 4b31e3a82f2d0..6960f8838ffde 100644 --- a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h +++ b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h @@ -20,15 +20,6 @@ class UnsqueezeBase { }; Status PrepareCompute(OpKernelContext* context, Prepare& p) const; - - protected: - UnsqueezeBase(const OpKernelInfo& info) { - size_t num_inputs = info.GetInputCount(); - if (num_inputs == 1) { // axes must be a valid attribute - ORT_ENFORCE(info.GetAttrs("axes", axes_).IsOK(), "Missing/Invalid 'axes' attribute value"); - } - } - static TensorShapeVector ComputeOutputShape( const TensorShape& input_shape, const TensorShapeVector& axes) { @@ -59,6 +50,14 @@ class UnsqueezeBase { return output_shape; } + protected: + UnsqueezeBase(const OpKernelInfo& info) { + size_t num_inputs = info.GetInputCount(); + if (num_inputs == 1) { // axes must be a valid attribute + ORT_ENFORCE(info.GetAttrs("axes", axes_).IsOK(), "Missing/Invalid 'axes' attribute value"); + } + } + TensorShapeVector axes_; }; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d3f01c1f7adc1..d4013a7dc3d57 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -50,7 +50,6 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? // The cudaMemCpyAsync will handle the pinned memory and non-pinned memory, // so we don't need the check here. auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); @@ -964,6 +963,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin); + // OpSet 13 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add); @@ -1200,6 +1206,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin); + // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu); @@ -1641,6 +1654,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1823,9 +1839,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { 19, IsInf)>, // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1917,6 +1930,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 13 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2151,6 +2171,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 14 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2567,6 +2594,32 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { return false; } +static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) { + // Opset 12 introduced the attribute "select_last_index" + if (node.SinceVersion() >= 12) { + const auto& node_attributes = node.GetAttributes(); + + for (auto& attr : node_attributes) { + auto& attr_name = attr.first; + auto& attr_value = attr.second; + + // CuDNN doesn't support picking the last index in case of encountering + // duplicate max values. + // CuDNN's API doc doesn't mention what happens in case duplicates are encountered, + // but based on testing, the results seem to indicate a "stable" implementation + // (i.e.) relative ordering is preserved which is the expected behavior when the + // attribute takes on the default value (most common use-case for this operator). + if ("select_last_index" == attr_name) { + if (attr_value.i() != 0) { + return true; + } + } + } + } + + return false; +} + std::unique_ptr CUDAExecutionProvider::GetDataTransfer() const { return std::make_unique(); } @@ -2616,6 +2669,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, } else if ("ConvTranspose" == node.OpType()) { not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred()); force_inside = !not_supported; + } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) { + not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node); + force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside @@ -2637,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/cuda/cudnn_fe_call.cc b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc index 640025c248187..7cd320a26d973 100644 --- a/onnxruntime/core/providers/cuda/cudnn_fe_call.cc +++ b/onnxruntime/core/providers/cuda/cudnn_fe_call.cc @@ -4,7 +4,7 @@ #include "core/providers/cuda/shared_inc/cudnn_fe_call.h" #include "core/providers/shared_library/provider_api.h" #include -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(USE_CUDA_MINIMAL) #include #endif #ifdef _WIN32 @@ -22,7 +22,7 @@ const char* CudaErrString(ERRTYPE) { ORT_NOT_IMPLEMENTED(); } -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(USE_CUDA_MINIMAL) #define CASE_ENUM_TO_STR_CUDNN_FE(x) \ case cudnn_frontend::error_code_t::x: \ return #x diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 71610634577ca..4dafbda409cd3 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -7,10 +7,6 @@ #include "cuda_common.h" namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} - bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; @@ -30,19 +26,25 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +61,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // copy from pinned or non-pinned CPU memory to GPU CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -69,7 +71,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } else if (src_device.Type() == OrtDevice::GPU) { if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking + // copy from GPU to pinned or non-pinned CPU memory. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } } else { @@ -77,6 +79,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, // sync the stream first to make sure the data arrived CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); } + + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.h b/onnxruntime/core/providers/cuda/gpu_data_transfer.h index 68846e68079f3..11e21e91936fc 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.h +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cuh b/onnxruntime/core/providers/cuda/math/topk_impl.cuh index cbde6da457fdb..112566e54bbba 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cuh @@ -412,7 +412,7 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, if (aligned_dimension <= GridDim::maxThreadsPerBlock) { BitonicTopK<<), stream>>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, - aligned_dimension, NumericLimits::Min(), NumericLimits::Max()); + aligned_dimension, NumericLimits::Lowest(), NumericLimits::Max()); } else if (K <= BT * 16 || 0 == sorted) { if (use_deterministic_compute) { static std::once_flag log_warning; @@ -425,19 +425,19 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, if (BT * 2 >= K || 0 == sorted) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else if (BT * 4 >= K) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else if (BT * 8 >= K) { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } else { RadixTopK<<>>( input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, - NumericLimits::Min(), NumericLimits::Max()); + NumericLimits::Lowest(), NumericLimits::Max()); } } else { auto input_key_buffer = kernel->GetScratchBuffer(dimension, ort_stream); diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 45a1d3bbc0414..3129f519da2e5 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -52,7 +52,6 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 6294566af3cb9..e4047a6af272e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -219,7 +219,6 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 9c9a83460daeb..2972ae999adc4 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -45,8 +45,7 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template -Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, +Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index f23c2b94501f2..1a6957164d22f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -22,7 +22,6 @@ class ConvTranspose : public CudaKernel { ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 860bea67dc719..4f8e6605ce151 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -16,17 +16,17 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ +#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, end, \ + begin, end, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ +#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -37,8 +37,13 @@ namespace cuda { name); #define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ - REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ - REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) + +#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -829,14 +834,13 @@ template std::unique_ptr ReduceCompute class ArgMax final : public ReduceKernel { public: - ArgMax(const OpKernelInfo& info) : ReduceKernel(info) {} + ArgMax(const OpKernelInfo& info) : ReduceKernel(info) { + // The following is just a safety check. + // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax + // nodes with select_last_index == 1 to the CUDA EP. + int64_t select_last_index = 0; + if (info.GetAttr("select_last_index", &select_last_index).IsOK()) { + ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA"); + } + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MAX); @@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel { template class ArgMin final : public ReduceKernel { public: - ArgMin(const OpKernelInfo& info) : ReduceKernel(info) {} + ArgMin(const OpKernelInfo& info) : ReduceKernel(info) { + // The following is just a safety check. + // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin + // nodes with select_last_index == 1 to the CUDA EP. + int64_t select_last_index = 0; + if (info.GetAttr("select_last_index", &select_last_index).IsOK()) { + ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA"); + } + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MIN); diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index ed642754af3ba..f9433642f0857 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "core/framework/float16.h" @@ -120,7 +121,7 @@ constexpr int kNumBitsPerBitmaskElement = std::numeric_limits struct NumericLimits { - __inline__ __host__ __device__ static T Min() { + __inline__ __host__ __device__ static T Lowest() { return std::numeric_limits::lowest(); } __inline__ __host__ __device__ static T Max() { @@ -128,43 +129,18 @@ struct NumericLimits { } }; -template <> -struct NumericLimits { - __inline__ __host__ __device__ static half Min() { - return -65504.0; - } - __inline__ __host__ __device__ static half Max() { - return 65504.0; - } -}; - template <> struct NumericLimits { - __inline__ __host__ __device__ static half Min() { - return -65504.0; - } - __inline__ __host__ __device__ static half Max() { - return 65504.0; + __inline__ __host__ __device__ static half Lowest() { + return -65504.0f; } -}; -template <> -struct NumericLimits { - __inline__ __host__ __device__ static float Min() { - return -INFINITY; - } - __inline__ __host__ __device__ static float Max() { - return INFINITY; - } -}; - -template <> -struct NumericLimits { - __inline__ __host__ __device__ static double Min() { - return -HUGE_VAL; - } - __inline__ __host__ __device__ static double Max() { - return HUGE_VAL; + __inline__ __host__ __device__ static half Max() { +#ifdef CUDART_MAX_NORMAL_FP16 // defined in cuda 12.3 or later + return CUDART_MAX_NORMAL_FP16; +#else + return 65504.0f; +#endif } }; diff --git a/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h index a51d84a7efa59..2ce7bc0bf51fd 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cudnn_fe_call.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#if !defined(__CUDACC__) +#if !defined(__CUDACC__) && !defined(USE_CUDA_MINIMAL) #include #endif namespace onnxruntime { @@ -14,10 +14,12 @@ namespace onnxruntime { // Error handling // ----------------------------------------------------------------------- +#ifndef USE_CUDA_MINIMAL #define CUDNN_FE_CALL(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) #define CUDNN_FE_CALL_THROW(expr) (CudaCall((cudnn_frontend::error_t)(expr), #expr, "CUDNN_FE", \ cudnn_frontend::error_code_t::OK, "", __FILE__, __LINE__)) +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/CPPLINT.cfg b/onnxruntime/core/providers/dml/CPPLINT.cfg deleted file mode 100644 index 02d14c65cc861..0000000000000 --- a/onnxruntime/core/providers/dml/CPPLINT.cfg +++ /dev/null @@ -1 +0,0 @@ -filter=-whitespace/braces,-whitespace/parens,-whitespace/line_length,-whitespace/indent,-whitespace/newline diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 801cceb3bd99f..334a40b979bda 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -41,23 +41,19 @@ namespace Dml D3D12_HEAP_FLAGS heapFlags, D3D12_RESOURCE_FLAGS resourceFlags, D3D12_RESOURCE_STATES initialState, - std::unique_ptr&& subAllocator - ) + std::unique_ptr&& subAllocator) : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) - ) - ), - m_device(device), - m_heapProperties(heapProps), - m_heapFlags(heapFlags), - m_resourceFlags(resourceFlags), - m_initialState(initialState), - m_context(context), - m_subAllocator(std::move(subAllocator)) - { + OrtMemoryInfo( + "DML", + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))), + m_device(device), + m_heapProperties(heapProps), + m_heapFlags(heapFlags), + m_resourceFlags(resourceFlags), + m_initialState(initialState), + m_context(context), + m_subAllocator(std::move(subAllocator)) { } /*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size) @@ -186,7 +182,7 @@ namespace Dml } else { - if (!m_closed) + if (!m_context->IsClosed()) { // Free the underlying allocation once queued work has completed. #ifdef _GAMING_XBOX diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h index 65bc9b7f69316..16283d5b19c9c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h @@ -46,11 +46,6 @@ namespace Dml void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode); - void Close() - { - m_closed = true; - } - public: // onnxruntime::IAllocator void* Alloc(size_t size, AllocatorRoundingMode roundingMode); void* Alloc(size_t size) final; @@ -88,7 +83,6 @@ namespace Dml std::vector m_pool; size_t m_currentAllocationId = 0; uint64_t m_currentResourceId = 0; - bool m_closed = false; // Unless specifically requested, allocation sizes are not rounded to enable pooling // until SetDefaultRoundingMode is called. This should be done at completion of session diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp index 67faf333d21e1..988324bab1174 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp @@ -55,7 +55,7 @@ namespace Dml // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference // to its underlying D3D resource when freed. Furthermore, these references are unnecessary // since Close() already blocks for scheduled GPU work before clearing m_queuedReferences. - if (!m_clearingQueue) + if (!m_closing) { QueuedReference queuedReference = {GetLastFenceValue(), object}; @@ -70,15 +70,15 @@ namespace Dml } } - void CommandQueue::WaitForSignalAndClearQueue() + void CommandQueue::Close() { // Wait for flushed work: - assert(!m_clearingQueue); - m_clearingQueue = true; + assert(!m_closing); + m_closing = true; GpuEvent event = GetCurrentCompletionEvent(); event.WaitForSignal(m_cpuSyncSpinningEnabled); m_queuedReferences.clear(); - m_clearingQueue = false; + m_closing = false; } void CommandQueue::ReleaseCompletedReferences() diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h index 9a4728d5845d4..71d5eb173cfec 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h @@ -44,7 +44,7 @@ namespace Dml } #endif - void WaitForSignalAndClearQueue(); + void Close(); void ReleaseCompletedReferences(); private: @@ -61,7 +61,7 @@ namespace Dml ComPtr m_fence; uint64_t m_lastFenceValue = 0; - bool m_clearingQueue = false; + bool m_closing = false; bool m_cpuSyncSpinningEnabled = false; }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h index b22f0b2853e5d..f07b9540ff3fd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h @@ -24,7 +24,7 @@ namespace Dml OrtMemoryInfo( "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0) )) { m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 35a2c451a49a5..9f95818501dac 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -62,7 +62,8 @@ namespace Dml const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; const auto kernel_lookup = onnxruntime::KernelLookup{provider_type, gsl::make_span(®istry, 1), - kernel_type_str_resolver}; + kernel_type_str_resolver, + logger}; std::vector> compiledPartitionInfos; std::vector additionalSplittingNodes; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index 6318b0d5e2865..b9b90d6bc17bd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -54,7 +54,8 @@ namespace Dml const auto kernelLookup = onnxruntime::KernelLookup( providerType, gsl::make_span(®istry, 1), - kernelTypeStrResolver); + kernelTypeStrResolver, + logger); onnxruntime::GraphViewer graphViewer(graph); const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp index ececf13fc8cdf..5dc1213bd76f0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp @@ -11,10 +11,13 @@ namespace Dml ID3D12Device* d3d12Device, IDMLDevice* dmlDevice, ID3D12CommandQueue* queue, - bool cpuSyncSpinningEnabled) + bool cpuSyncSpinningEnabled, + bool keepOpen + ) : m_queue(std::make_shared(queue, cpuSyncSpinningEnabled)) , m_dmlRecorder(d3d12Device, dmlDevice, m_queue) , m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled) + , m_keepOpen(keepOpen) { ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); } @@ -33,6 +36,8 @@ namespace Dml D3D12_RESOURCE_STATES srcState, uint64_t byteCount) { + assert(!m_closed); + SetCommandRecorder(&m_dmlRecorder); std::vector barriers; @@ -79,6 +84,8 @@ namespace Dml _Out_ uint64_t* completionValue ) { + assert(!m_closed); + SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue); } @@ -88,6 +95,7 @@ namespace Dml const DML_BINDING_DESC& persistentResourceBinding, const DML_BINDING_DESC& inputArrayBinding) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding); @@ -99,6 +107,7 @@ namespace Dml gsl::span inputBindings, gsl::span outputBindings) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings); @@ -106,6 +115,7 @@ namespace Dml void ExecutionContext::AddUAVBarrier() { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.AddUAVBarrier(); @@ -113,6 +123,7 @@ namespace Dml void ExecutionContext::ResourceBarrier(gsl::span barriers) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ResourceBarrier(barriers); @@ -120,6 +131,7 @@ namespace Dml void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList) { + assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); // Ensure the descriptor heap is reset to D3D as something external may change it before recording @@ -130,6 +142,8 @@ namespace Dml void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder) { + assert(!m_closed); + // If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct // ordering of operations on the command queue. if (m_currentRecorder != newRecorder) @@ -146,6 +160,8 @@ namespace Dml void ExecutionContext::Flush() { + assert(!m_closed); + if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork()) { // Nothing to flush @@ -164,21 +180,34 @@ namespace Dml void ExecutionContext::QueueReference(IUnknown* object) { + assert(!m_closed); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence // value is the one to signal completion. bool waitForUnsubmittedWork = (m_currentRecorder != nullptr); m_queue->QueueReference(object, waitForUnsubmittedWork); } - void ExecutionContext::WaitForSignalAndClearQueue() + void ExecutionContext::Close() { + assert(!m_closed); + // Discard unflushed work and clear queued references. This prevents the circular reference: // Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel - m_queue->WaitForSignalAndClearQueue(); + m_queue->Close(); + + // Keep the execution context open when requested, e.g. when used through the python API where there's a single context + // and single command queue + if (!m_keepOpen) + { + m_currentRecorder = nullptr; + m_closed = true; + } } GpuEvent ExecutionContext::GetCurrentCompletionEvent() { + assert(!m_closed); + GpuEvent event = m_queue->GetCurrentCompletionEvent(); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence @@ -194,11 +223,13 @@ namespace Dml void ExecutionContext::ReleaseCompletedReferences() { + assert(!m_closed); m_queue->ReleaseCompletedReferences(); } D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const { + assert(!m_closed); return m_queue->GetType(); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index 71aa26f4a0148..e7a6fa3d07296 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -23,13 +23,14 @@ namespace Dml ID3D12Device* d3d12Device, IDMLDevice* dmlDevice, ID3D12CommandQueue* queue, - bool cpuSyncSpinningEnabled); + bool cpuSyncSpinningEnabled, + bool keepOpen); void SetAllocator(std::weak_ptr allocator); // Waits for flushed work, discards unflushed work, and discards associated references to - // prevent circular references. - void WaitForSignalAndClearQueue(); + // prevent circular references. Must be the last call on the object before destruction. + void Close(); // Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition // barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and @@ -86,6 +87,7 @@ namespace Dml D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; } + bool IsClosed() const { return m_closed; } private: Microsoft::WRL::ComPtr m_d3dDevice; @@ -101,6 +103,10 @@ namespace Dml bool m_closed = false; bool m_cpuSyncSpinningEnabled = false; + + // The python API has a global state used for I/O binding where the execution context is shared between session, + // so we don't want to close the context when one of the sessions is destroyed + bool m_keepOpen = false; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 6b0faaad43175..826f48b5f7a68 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -74,7 +74,7 @@ namespace Dml bool enableGraphCapture, bool enableSyncSpinning, bool disableMemoryArena) : - IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) + IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) @@ -95,7 +95,7 @@ namespace Dml const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup); + return m_impl->GetCapability(graph, kernel_lookup, *GetLogger()); #else return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup); #endif @@ -106,26 +106,7 @@ namespace Dml // Release the cached command list references before closing the context m_capturedGraphs.clear(); - // Close the allocator before clearing the command queue to stop it from - // appending resources to it in an attempt to keep them alive. - if (m_allocator) - { - m_allocator->Close(); - } - - // Destroy the allocators. We are closing the execution provider, so from now on the - // only thing it will be used for is doing copies via the DataTransfer, which doesn't - // require allocating any memory. - // TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly - // destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive. - m_allocator = nullptr; - m_cpuInputAllocator = nullptr; - - // Wait for all pending commands to be done executing and empty the command queue. This will - // Force all kernels and resources in flight to get destroyed and, from this point forward, - // ExecutionProviderImpl will only be used to execute transfer between resources that are - // already existing via the DataTransfer; - m_context->WaitForSignalAndClearQueue(); + m_context->Close(); } void ExecutionProviderImpl::WaitForOutstandingWork() @@ -895,7 +876,8 @@ namespace Dml std::vector> ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. @@ -919,7 +901,7 @@ namespace Dml } // Get the list of nodes that should stay on the CPU - auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes); + auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger); for (size_t nodeIndex : toplogicalOrder) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index c20969250fe84..e7d859c5764de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -88,7 +88,8 @@ namespace Dml std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger ) const; uint32_t GetSupportedDeviceDataTypeMask() const; @@ -242,8 +243,8 @@ namespace Dml bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final { - return (srcDevice.Type() == OrtDevice::GPU) || - (dstDevice.Type() == OrtDevice::GPU); + return (srcDevice.Type() == OrtDevice::DML) || + (dstDevice.Type() == OrtDevice::DML); } private: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg index bf14c49304415..7e6be3c6874d5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/CPPLINT.cfg @@ -1 +1 @@ -filter=-whitespace/comments,-readability/todo,-whitespace/end_of_line,-runtime/indentation_namespace +filter=-readability/todo diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 02fb72b5a073a..389bdee3a365b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -28,9 +28,56 @@ class DmlOperatorCast : public DmlOperator castDesc.InputTensor = inputDescs.data(); castDesc.OutputTensor = outputDescs.data(); - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + if (kernelInfo.GetOutputEdgeDescription(0).tensorDataType == static_cast(ONNX_NAMESPACE::TensorProto_DataType_BOOL)) + { + DML_OPERATOR_DESC dmlCastDesc = { DML_OPERATOR_CAST, &castDesc }; - SetDmlOperatorDesc(opDesc, kernelInfo); + DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clipDesc = {}; + clipDesc.InputTensor = outputDescs.data(); + clipDesc.OutputTensor = outputDescs.data(); + clipDesc.Min.UInt8 = 0; + clipDesc.Max.UInt8 = 1; + + DML_OPERATOR_DESC dmlClipDesc = { DML_OPERATOR_ELEMENT_WISE_CLIP1, &clipDesc }; + + std::vector opDescs = { &dmlCastDesc, &dmlClipDesc }; + + DML_INPUT_GRAPH_EDGE_DESC inputToCastEdge = {}; + inputToCastEdge.GraphInputIndex = 0; + inputToCastEdge.ToNodeIndex = 0; + inputToCastEdge.ToNodeInputIndex = 0; + + DML_INTERMEDIATE_GRAPH_EDGE_DESC castToClipEdge = {}; + castToClipEdge.FromNodeIndex = 0; + castToClipEdge.FromNodeOutputIndex = 0; + castToClipEdge.ToNodeIndex = 1; + castToClipEdge.ToNodeInputIndex = 0; + + DML_OUTPUT_GRAPH_EDGE_DESC clipToOutputEdge = {}; + clipToOutputEdge.FromNodeIndex = 1; + clipToOutputEdge.FromNodeOutputIndex = 0; + clipToOutputEdge.GraphOutputIndex = 0; + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + + operatorGraphDesc.inputEdgeCount = 1; + operatorGraphDesc.inputEdges = &inputToCastEdge; + + operatorGraphDesc.intermediateEdgeCount = 1; + operatorGraphDesc.intermediateEdges = &castToClipEdge; + + operatorGraphDesc.outputEdgeCount = 1; + operatorGraphDesc.outputEdges = &clipToOutputEdge; + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } + else + { + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } } void Compute(const MLOperatorKernelContext& kernelContext) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp index 88b827f61f0c9..ad7c77510d988 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp @@ -127,51 +127,51 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper DML_OPERATOR_DESC& desc = descs[i]; ActivationOperatorDescUnion& activationDesc = m_activationDescs[i]; desc.Desc = &activationDesc; - - if (activationName == AttrValue::ActivationRelu) + + if (CompareActivationName(activationName, AttrValue::ActivationRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_RELU; - } - else if (activationName == AttrValue::ActivationLeakyRelu) + } + else if (CompareActivationName(activationName, AttrValue::ActivationLeakyRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU; activationDesc.leakyRelu.Alpha = NextAlpha(desc.Type); } - else if (activationName == AttrValue::ActivationThresholdedRelu) + else if (CompareActivationName(activationName, AttrValue::ActivationThresholdedRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU; activationDesc.thresholdedRelu.Alpha = NextAlpha(desc.Type); - } - else if (activationName == AttrValue::ActivationTanh) + } + else if (CompareActivationName(activationName, AttrValue::ActivationTanh)) { desc.Type = DML_OPERATOR_ACTIVATION_TANH; - } - else if (activationName == AttrValue::ActivationScaledTanh) + } + else if (CompareActivationName(activationName, AttrValue::ActivationScaledTanh)) { desc.Type = DML_OPERATOR_ACTIVATION_SCALED_TANH; activationDesc.scaledTanh.Alpha = NextAlpha(desc.Type); activationDesc.scaledTanh.Beta = NextBeta(desc.Type); - } - else if (activationName == AttrValue::ActivationSigmoid) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSigmoid)) { desc.Type = DML_OPERATOR_ACTIVATION_SIGMOID; - } - else if (activationName == AttrValue::ActivationSigmoidHard) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSigmoidHard)) { desc.Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID; activationDesc.hardSigmoid.Alpha = NextAlpha(desc.Type); activationDesc.hardSigmoid.Beta = NextBeta(desc.Type); - } - else if (activationName == AttrValue::ActivationElu) + } + else if (CompareActivationName(activationName, AttrValue::ActivationElu)) { desc.Type = DML_OPERATOR_ACTIVATION_ELU; activationDesc.elu.Alpha = NextAlpha(desc.Type); - } - else if (activationName == AttrValue::ActivationSoftsign) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSoftsign)) { desc.Type = DML_OPERATOR_ACTIVATION_SOFTSIGN; - } - else if (activationName == AttrValue::ActivationSoftplus) + } + else if (CompareActivationName(activationName, AttrValue::ActivationSoftplus)) { desc.Type = DML_OPERATOR_ACTIVATION_SOFTPLUS; } @@ -182,6 +182,12 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper } } + bool CompareActivationName(std::string_view activationName, std::string_view attrValue) + { + auto comparer = [](char a, char b) {return std::tolower(a) == std::tolower(b);}; + return std::equal(activationName.begin(), activationName.end(), attrValue.begin(), attrValue.end(), comparer); + } + void Compute(const MLOperatorKernelContext& kernelContext) override { // Assume that enough GPU work has been queued up after the RNN operator that it is worth diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index ceed388bb0a6f..b0b37d01370bc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -854,6 +854,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(21, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, @@ -1157,6 +1158,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, + {REG_INFO( 21, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)}, {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, @@ -1170,6 +1172,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)}, + {REG_INFO( 21, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, MatMulNBits, typeNameListTwo, supportedTypeListMatMulNBits, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMatMulNBits)}, // Operators that need to alias an input with an output diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp index 375ee87bd42f1..a4d284df43a72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp @@ -125,8 +125,8 @@ namespace Dml // No chunks were able to accommodate the allocation - create a new chunk and return that instead - // At least double the capacity of the pool - const size_t newChunkSize = std::max({ m_totalCapacity, c_minChunkSize, sizeInBytes }); + // At least double the capacity of the pool, limit to c_maxChunkSize so DX12 does not reject size + const size_t newChunkSize = std::min(std::max({ m_totalCapacity, c_minChunkSize, sizeInBytes }), c_maxChunkSize); m_chunks.push_back(CreateChunk(m_device.Get(), newChunkSize)); m_totalCapacity += newChunkSize; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h index 1202ae9243921..0315b087519ba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h @@ -32,6 +32,7 @@ namespace Dml private: static constexpr size_t c_minChunkSize = 1024 * 1024; // 1MB static constexpr size_t c_allocationAlignment = 512; // In bytes; as per D3D12 requirement for buffers + static constexpr size_t c_maxChunkSize = 0xFFFF0000; // ~4 GiB limitation for DX12 CPU-visible resource // A suballoction from a chunk struct Allocation diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index c1ea69ab35374..c52e26dd321ab 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1648,6 +1648,7 @@ using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper; using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_GroupNorm = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_GroupNorm21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_SkipLayerNormalization = SkipLayerNormHelper; @@ -1839,6 +1840,7 @@ using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Identity21 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index c2a6d57fca0a9..b4d402a1d9e77 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -452,6 +452,9 @@ namespace OperatorHelper static const int sc_sinceVer_Flatten = 21; static const int sc_sinceVer_Pad = 21; static const int sc_sinceVer_Transpose = 21; + static const int sc_sinceVer_Identity = 21; + static const int sc_sinceVer_QLinearMatMul = 21; + static const int sc_sinceVer_GroupNorm = 21; } namespace MsftOperatorSet1 diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 89decfef6fef6..e8fe235fc1d46 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -86,11 +86,11 @@ std::unique_ptr DMLProviderFactory::CreateProvider() { // First, check if an I/O binding API that was used before this session or another session has already created a queue if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) { - execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true); + execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get())); } } else { - execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_); + execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false); } auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_); diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index d2a72c3a38b03..7d8c5525667b9 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -66,14 +66,6 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, -#endif - }, - { - kTvmExecutionProvider, -#ifdef USE_TVM - true, -#else - false, #endif }, { diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 1ff33f6d7b410..c1a8b373bed84 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -121,7 +121,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, Not) class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 8, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 10, Clip); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, Clip); @@ -139,7 +140,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); @@ -150,7 +152,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 19, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 20, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); @@ -233,17 +236,20 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Res class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -273,10 +279,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 18, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 18, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 19, AveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); @@ -333,6 +341,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, GatherND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, GatherND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice); @@ -341,7 +353,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sli class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 8, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 20, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile); @@ -358,12 +371,14 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, 20, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); @@ -389,6 +404,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 15, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 17, ScatterND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ScatterND); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -439,7 +461,8 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - KERNEL_CREATE_INFO(19, Cast), + KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast), + KERNEL_CREATE_INFO(21, Cast), // activations KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip), @@ -501,12 +524,14 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -515,13 +540,15 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -561,7 +588,8 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(16, Where), BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -591,10 +619,12 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -641,6 +671,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -660,7 +694,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -677,12 +712,14 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -706,6 +743,13 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -774,7 +818,7 @@ std::vector> JsExecutionProvider::GetCapabili candidates.push_back(node.Index()); tenative_candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) { diff --git a/onnxruntime/core/providers/js/operators/cast.cc b/onnxruntime/core/providers/js/operators/cast.cc index 9b6ac6d7e253b..f499d0627e032 100644 --- a/onnxruntime/core/providers/js/operators/cast.cc +++ b/onnxruntime/core/providers/js/operators/cast.cc @@ -49,10 +49,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", CastOpTypeConstraints()) .TypeConstraint("T2", CastOpTypeConstraints()), Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 19, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); ONNX_OPERATOR_KERNEL_EX( Cast, kOnnxDomain, - 19, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", CastOpTypeConstraints()) diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 276b600cf40d2..b04df44954295 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -78,7 +78,6 @@ class ConvBase : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { is_packed = false; diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index baa93f825a203..5ff52e8fda4fa 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -126,10 +126,8 @@ class ConvTranspose : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { - ORT_UNUSED_PARAMETER(save_prepacked_initializers); is_packed = false; if (input_idx == 1) { diff --git a/onnxruntime/core/providers/js/operators/flatten.cc b/onnxruntime/core/providers/js/operators/flatten.cc index 1aacae819e304..44a67cb15d958 100644 --- a/onnxruntime/core/providers/js/operators/flatten.cc +++ b/onnxruntime/core/providers/js/operators/flatten.cc @@ -36,10 +36,20 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Flatten); + ONNX_OPERATOR_KERNEL_EX( Flatten, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) diff --git a/onnxruntime/core/providers/js/operators/gather_nd.cc b/onnxruntime/core/providers/js/operators/gather_nd.cc new file mode 100644 index 0000000000000..ee69100cc658e --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_nd.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/js/js_data_types.h" +#include "gather_nd.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + GatherND, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + GatherND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GatherND, + kOnnxDomain, + 12, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + GatherND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GatherND, + kOnnxDomain, + 11, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + GatherND); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gather_nd.h b/onnxruntime/core/providers/js/operators/gather_nd.h new file mode 100644 index 0000000000000..cdf7a52630dad --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_nd.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class GatherND : public JsKernel { + public: + GatherND(const OpKernelInfo& info) : JsKernel(info) { + int64_t batchDims = info.GetAttrOrDefault("batch_dims", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(GatherND, ({ + "batch_dims" : Number($1), + }), + static_cast(batchDims)); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/grid_sample.cc b/onnxruntime/core/providers/js/operators/grid_sample.cc new file mode 100644 index 0000000000000..84eb7df6c5bbe --- /dev/null +++ b/onnxruntime/core/providers/js/operators/grid_sample.cc @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "grid_sample.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GridSample, + kMSInternalNHWCDomain, + 16, 19, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", JsepSupportedDataTypes()) + .TypeConstraint("T2", JsepSupportedFloatTypes()), + GridSample); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GridSample, + kOnnxDomain, + 16, 19, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", JsepSupportedDataTypes()) + .TypeConstraint("T2", JsepSupportedFloatTypes()), + GridSample); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/grid_sample.h b/onnxruntime/core/providers/js/operators/grid_sample.h new file mode 100644 index 0000000000000..352decf33dc20 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/grid_sample.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +template +class GridSample : public JsKernel { + public: + GridSample(const OpKernelInfo& info) : JsKernel(info) { + int64_t align_corners = info.GetAttrOrDefault("align_corners", 0); + std::string mode = info.GetAttrOrDefault("mode", "linear"); + std::string padding_mode = info.GetAttrOrDefault("padding_mode", "zeros"); + int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(GridSample, ({ + "align_corners" : $1, + "mode" : UTF8ToString($2), + "padding_mode" : UTF8ToString($3), + "format" : $4 ? "NHWC" : "NCHW" + }), + static_cast(align_corners), mode.c_str(), + padding_mode.c_str(), static_cast(channels_last)); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/if.cc b/onnxruntime/core/providers/js/operators/if.cc index ef072bb1635dd..368d1b5101bdb 100644 --- a/onnxruntime/core/providers/js/operators/if.cc +++ b/onnxruntime/core/providers/js/operators/if.cc @@ -44,9 +44,21 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, If); // opset-19 supports float8 +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 19, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // Support sequence/optional tensors when all JSEP infra + // (including tests runner) supports it + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + If); + ONNX_OPERATOR_KERNEL_EX(If, kOnnxDomain, - 19, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) // 'cond' needs to be on CPU diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc index 83fee35481aa6..556fdf419212f 100644 --- a/onnxruntime/core/providers/js/operators/pad.cc +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -56,10 +56,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 3), Pad); -ONNX_OPERATOR_KERNEL_EX( +ONNX_OPERATOR_VERSIONED_KERNEL_EX( Pad, kOnnxDomain, 19, + 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), + Pad); + +ONNX_OPERATOR_KERNEL_EX( + Pad, + kOnnxDomain, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedFloatTypes()) diff --git a/onnxruntime/core/providers/js/operators/pool.cc b/onnxruntime/core/providers/js/operators/pool.cc index 7df1e483f52a1..50efafac7d3e6 100644 --- a/onnxruntime/core/providers/js/operators/pool.cc +++ b/onnxruntime/core/providers/js/operators/pool.cc @@ -55,8 +55,10 @@ POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9) POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 7, 9) POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10) POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 10, 10) -POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11) -POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 11, 18) +POOLING_KERNEL_VERSIONED(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11, 18) +POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 19) +POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 19) POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1) POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1) diff --git a/onnxruntime/core/providers/js/operators/reduce.cc b/onnxruntime/core/providers/js/operators/reduce.cc index 2679cfed86124..98c329c1d9377 100644 --- a/onnxruntime/core/providers/js/operators/reduce.cc +++ b/onnxruntime/core/providers/js/operators/reduce.cc @@ -20,6 +20,16 @@ namespace js { // a new opset version update applies to Reduce* operators, we may need to add another macro like // REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT to set input memory type. // i.e. we cannot use REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL to version 18 when the opset version is increased. +#define REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceOp, sinceVersion, endVersion) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + sinceVersion, endVersion, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + ReduceOp); #define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ ONNX_OPERATOR_KERNEL_EX( \ @@ -41,13 +51,15 @@ REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 11, 11); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 12, 12); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 13, 17); -REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMax, 18); +REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMax, 18, 19); +REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMax, 20); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 11, 11); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 12, 12); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 13, 17); -REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMin, 18); +REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMin, 18, 19); +REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceMin, 20); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12); diff --git a/onnxruntime/core/providers/js/operators/scatter_nd.cc b/onnxruntime/core/providers/js/operators/scatter_nd.cc new file mode 100644 index 0000000000000..e9edb7f58fe5e --- /dev/null +++ b/onnxruntime/core/providers/js/operators/scatter_nd.cc @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/js/js_data_types.h" +#include "scatter_nd.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + ScatterND, + kOnnxDomain, + 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 16, + 17, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 13, + 15, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 11, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/scatter_nd.h b/onnxruntime/core/providers/js/operators/scatter_nd.h new file mode 100644 index 0000000000000..8c81c62d71fe7 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/scatter_nd.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace js { + +enum class ScatterNDReduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +class ScatterND : public JsKernel { + public: + ScatterND(const OpKernelInfo& info) : JsKernel(info) { + std::string reduction = info.GetAttrOrDefault("reduction", "none"); + if (reduction == "add") { + reduction_ = ScatterNDReduction::Add; + } else if (reduction == "mul") { + reduction_ = ScatterNDReduction::Mul; + } else if (reduction == "min") { + reduction_ = ScatterNDReduction::Min; + } else if (reduction == "max") { + reduction_ = ScatterNDReduction::Max; + } else if (reduction == "none") { + LOGS_DEFAULT(WARNING) << "ScatterND with reduction=='none' only guarantees " + << "to be correct if indices are not duplicated."; + } else { + ORT_THROW("Reduction '", reduction, "' is not supported on webgpu when opset <= 13."); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(ScatterND, ({ + "reduction" : UTF8ToString($1), + }), + reduction.c_str()); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + Tensor* Y = context->Output(0, X_shape); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + return ComputeInternal(context); + } + + private: + ScatterNDReduction reduction_{ScatterNDReduction::None}; +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/squeeze.cc b/onnxruntime/core/providers/js/operators/squeeze.cc index e858ade348cd4..521d0103d373f 100644 --- a/onnxruntime/core/providers/js/operators/squeeze.cc +++ b/onnxruntime/core/providers/js/operators/squeeze.cc @@ -10,7 +10,7 @@ namespace js { ONNX_OPERATOR_KERNEL_EX( Squeeze, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedDataTypes()) @@ -19,6 +19,17 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 1), Squeeze); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Squeeze, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/transpose.cc b/onnxruntime/core/providers/js/operators/transpose.cc index 332bd35f2434c..136879b93b37f 100644 --- a/onnxruntime/core/providers/js/operators/transpose.cc +++ b/onnxruntime/core/providers/js/operators/transpose.cc @@ -15,10 +15,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T", JsepSupportedDataTypes()), Transpose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + Transpose); + ONNX_OPERATOR_KERNEL_EX( Transpose, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedDataTypes()), diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.cc b/onnxruntime/core/providers/js/operators/unsqueeze.cc index 1485e800e5e76..898deb827cccb 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.cc +++ b/onnxruntime/core/providers/js/operators/unsqueeze.cc @@ -10,7 +10,7 @@ namespace js { ONNX_OPERATOR_KERNEL_EX( Unsqueeze, kOnnxDomain, - 13, + 21, kJsExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", JsepSupportedDataTypes()) @@ -19,6 +19,17 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 1), Unsqueeze); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, 20, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Unsqueeze, kOnnxDomain, diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 51625b83b8f61..77c5e18a5878e 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "gpu_data_transfer.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" + +// If you make change below, please also update onnxruntime/core/providers/rocm/gpu_data_transfer.cc namespace onnxruntime { + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -23,17 +27,24 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const if (src_device.Type() == OrtDevice::GPU) { // Copy only if the two addresses are different. if (dst_data != src_data) { - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); + HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -49,23 +60,28 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { - // copy from pinned memory to GPU, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); + if (src_device.Type() == OrtDevice::CPU) { + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { - // copying between cpu memory + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index f6a95cebf34b5..6d514e01aea96 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -3,6 +3,7 @@ #pragma once #include "migraphx_inc.h" +#include "core/common/common.h" namespace onnxruntime { @@ -16,5 +17,6 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) +#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9d651129fbd4d..9017b36a0f087 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -49,6 +49,8 @@ class Memcpy final : public OpKernel { const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); if (!gpu_data_transfer) return Status(common::ONNXRUNTIME, common::EP_FAIL, "gpu data transfer is missing in Migraphx EP."); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. return gpu_data_transfer->CopyTensorAsync(*X, *Y, *(ctx->GetComputeStream())); } }; @@ -800,6 +802,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "ATen", "AveragePool", "BatchNormalization", + "BiasGelu", "Cast", "Ceil", "Celu", @@ -824,12 +827,14 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Exp", "Expand", "EyeLike", + "FastGelu", "Flatten", "Floor", "GRU", "Gather", "GatherElements", "GatherND", + "Gelu", "Gemm", "GlobalAveragePool", "GlobalMaxPool", @@ -1152,7 +1157,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!no_input_shape) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model"; + LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops @@ -1293,7 +1298,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // re-compile the program if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; #ifndef ENABLE_TRAINING_CORE #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2) cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 03a7c1607e3ad..85b0aff87a436 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -6,8 +6,6 @@ #include "migraphx_inc.h" #include "migraphx_call.h" -#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) - namespace onnxruntime { void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 12416ea0c121b..e4bee6f959a01 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -32,8 +32,16 @@ namespace nnapi { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, gsl::span nnapi_target_devices, - TargetDeviceOption target_device_option) - : nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) { + TargetDeviceOption target_device_option, + const logging::Logger& logger) + : nnapi_(nnapi_handle), + graph_viewer_(graph_viewer), + nnapi_model_{std::make_unique(nnapi_handle)}, + shaper_{graph_viewer}, + nnapi_target_devices_(nnapi_target_devices), + target_device_option_(target_device_option), + nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)), + logger_(logger) { nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_; } @@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { } void ModelBuilder::PreprocessNodeUnits() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); } // Help to get all quantized operators' input and the NodeUnit(s) using the input diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index b2118150dd304..4db335afa98b0 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -14,7 +14,9 @@ struct NnApi; namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; enum class DataLayout; class NodeUnit; @@ -31,7 +33,8 @@ class ModelBuilder { using Shape = Shaper::Shape; ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, - gsl::span nnapi_target_devices, TargetDeviceOption target_device_option); + gsl::span nnapi_target_devices, TargetDeviceOption target_device_option, + const logging::Logger& logger); common::Status Compile(std::unique_ptr& model); @@ -173,6 +176,9 @@ class ModelBuilder { // <1,1> <1,2> <1,3> InlinedVector> operations_recorder_; #endif + + const logging::Logger& logger_; + // Convert the ONNX model to ANeuralNetworksModel common::Status Prepare(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index fca52396a190c..f92c9592742d5 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; // TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators) @@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL; #endif }(); - LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; + LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; const nnapi::OpSupportCheckParams params{ android_feature_level, @@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view }; if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) { - LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" + LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" << params.android_feature_level << "] is lower than minimal supported NNAPI API feature level [" << ORT_NNAPI_MIN_API_LEVEL @@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view node_unit_supported_result[node_unit] = supported; } - LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported + LOGS(logger, VERBOSE) << "Node supported: [" << supported << "] Operator type: [" << node.OpType() << "] index: [" << node.Index() << "] name: [" << node.Name() @@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // If the graph is partitioned in multiple subgraphs, and this may impact performance, // we want to give users a summary message at warning level. if (num_of_partitions > 1) { - LOGS_DEFAULT(WARNING) << summary_msg; + LOGS(logger, WARNING) << summary_msg; } else { - LOGS_DEFAULT(INFO) << summary_msg; + LOGS(logger, INFO) << summary_msg; } return result; @@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context, common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { using namespace android::nn::wrapper; + const auto& logger = *GetLogger(); + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_); + nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger); builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW); builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 4fca4037301fb..a0bcf953938d9 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -25,6 +25,11 @@ GlobalContext& BackendManager::GetGlobalContext() { return global_context_; } +ov::CompiledModel& BackendManager::GetOVCompiledModel() { + ov::CompiledModel& ov_ptr = concrete_backend_->GetOVCompiledModel(); + return (ov_ptr); +} + BackendManager::BackendManager(const GlobalContext& global_context, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, @@ -35,7 +40,7 @@ BackendManager::BackendManager(const GlobalContext& global_context, openvino_sdk_version_ = std::to_string(global_context_.OpenVINO_Version.at(0)) + "." + std::to_string(global_context_.OpenVINO_Version.at(1)); if (ep_ctx_handle_.CheckForOVEPCtxNode(subgraph, openvino_sdk_version_)) { - if (ep_ctx_handle_.ImportBlobFromEPCtxModel(subgraph) != Status::OK()) + if (ep_ctx_handle_.ImportBlobFromEPCtxModel(subgraph, global_context_.ep_context_embed_mode) != Status::OK()) ORT_THROW("Import blob from model failed"); } @@ -65,7 +70,10 @@ BackendManager::BackendManager(const GlobalContext& global_context, i++; } subgraph_context_.subgraph_name = fused_node.Name(); - auto model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); + std::unique_ptr model_proto; + if (!ep_ctx_handle_.IsValidOVEPCtxGraph()) { + model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); + } std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type; if (ModelHasSymbolicInputDims(subgraph)) { diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index b9ff7a72372b3..5ec462afd9d01 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -30,6 +30,7 @@ class BackendManager { GlobalContext& GetGlobalContext(); Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger); + ov::CompiledModel& GetOVCompiledModel(); private: std::unique_ptr GetModelProtoFromFusedNode( diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index f772b9c3b0478..b97736f2e124d 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -39,7 +39,7 @@ struct static_cast_int64 { int64_t operator()(const T1& x) const { return static_cast(x); } }; -std::shared_ptr +std::shared_ptr CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, std::map>& const_outputs_map) { if (IsCILogEnabled()) { @@ -47,13 +47,13 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } const std::string model = model_proto.SerializeAsString(); try { - auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); + auto ov_model = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name); // Check for Constant Folding - if (!global_context.is_wholly_supported_graph) { + if ((global_context.device_type != "NPU") && !global_context.is_wholly_supported_graph) { ov::pass::ConstantFolding pass_const_obj; - pass_const_obj.run_on_model(cnn_network); - auto& results = const_cast(cnn_network.get()->get_results()); + pass_const_obj.run_on_model(ov_model); + auto& results = const_cast(ov_model.get()->get_results()); size_t index = results.size() - 1; for (auto it = results.rbegin(); it != results.rend(); ++it) { @@ -67,12 +67,12 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext } #ifndef NDEBUG if (IsDebugEnabled()) { - std::string name = cnn_network->get_friendly_name(); + std::string name = ov_model->get_friendly_name(); ov::pass::Serialize serializer(name + ".xml", name + ".bin"); - serializer.run_on_model(cnn_network); + serializer.run_on_model(ov_model); } #endif - return cnn_network; + return ov_model; } catch (std::string const& msg) { ORT_THROW(msg); } diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index 9e65770da7d23..9d58e1ca73abb 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -60,7 +60,7 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); -std::shared_ptr +std::shared_ptr CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, std::map>& const_outputs_map); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 8a1844544328c..435ca83ff69d4 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -48,6 +48,16 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // Set the inference_num_threads property of the CPU SetNumThreads(device_config); + auto npuw_status = + std::any_of(device_config.begin(), device_config.end(), [&](const std::pair& pair) { + return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is()) && + (pair.second.as() == "YES"); + }); + + if (npuw_status) { + LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation"; + } + try { std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str; @@ -81,7 +91,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr device_config, global_context_.ep_context_embed_mode, subgraph_context_.subgraph_name); - ie_cnn_network_ = exe_network_.Get().get_runtime_model(); } else if (global_context_.export_ep_ctx_blob && hw_target.find("NPU") != std::string::npos && !global_context_.has_external_weights) { @@ -106,22 +115,22 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr device_config, subgraph_context_.subgraph_name); } else { // For all other types use ov::Model Type - ie_cnn_network_ = CreateOVModel(*model_proto, global_context_, const_outputs_map_); + auto ov_model = CreateOVModel(*model_proto, global_context_, const_outputs_map_); exe_network_ = global_context_.ie_core.CompileModel( - ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + ov_model, hw_target, device_config, subgraph_context_.subgraph_name); } #endif } else { // Full graph is not supported - ie_cnn_network_ = CreateOVModel(*model_proto, global_context_, const_outputs_map_); + auto ov_model = CreateOVModel(*model_proto, global_context_, const_outputs_map_); exe_network_ = global_context_.ie_core.CompileModel( - ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + ov_model, hw_target, device_config, subgraph_context_.subgraph_name); } LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } catch (const char* msg) { ORT_THROW(msg); } - - inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, 1)); + int num_infer_req = (global_context_.num_of_threads > 0) ? global_context_.num_of_threads : 1; + inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req)); } bool BasicBackend::ValidateSubgraph(std::map>& const_outputs_map) { @@ -145,8 +154,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_config.emplace(ov::hint::inference_precision("f32")); } if (global_context_.precision_str.find("ACCURACY") != std::string::npos && - global_context_.device_type == "GPU") { - if (global_context_.OpenVINO_Version.at(0) >= 2024 && global_context_.OpenVINO_Version.at(1) >= 1) { + global_context_.device_type.find("GPU") != std::string::npos) { + if (global_context_.OpenVINO_Version.at(0) >= 2024) { device_config.emplace(ov::hint::inference_precision(ov::element::undefined)); device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); } else { @@ -174,7 +183,7 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type); } device_config.emplace(ov::device::properties("NPU", device_property)); -#if (OPENVINO_VERSION_MAJOR >= 2024) && (OPENVINO_VERSION_MINOR > 3) +#if (((OPENVINO_VERSION_MAJOR == 2024) && (OPENVINO_VERSION_MINOR > 3)) || (OPENVINO_VERSION_MAJOR > 2024)) if (global_context_.export_ep_ctx_blob) { global_context_.ie_core.Get().set_property("NPU", ov::intel_npu::bypass_umd_caching(true)); } @@ -184,6 +193,33 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (!global_context_.load_config.empty()) { const std::map& target_config = global_context_.load_config; + if (global_context_.device_type.find("NPU") != std::string::npos) { + auto npuw_config = target_config.at("NPU"); + + // Check if "NPU_USE_NPUW" exists and is set to "YES" + auto npu_use_npuw_it = npuw_config.find("NPU_USE_NPUW"); + if (npu_use_npuw_it != npuw_config.end() && + npu_use_npuw_it->second.is() && + npu_use_npuw_it->second.as() == "YES") { + // Only add NPUW-related keys if NPU_USE_NPUW is "YES" + for (const auto& [key, value] : npuw_config) { + if (key.find("NPUW") != std::string::npos) { + if (!value.is()) { + LOGS_DEFAULT(ERROR) << "Invalid value type for key: " << key; + continue; + } + device_config[key] = value; + } + } + } else { + // Check if there are any "NPUW" keys and log a warning + if (std::any_of(npuw_config.begin(), npuw_config.end(), + [&](const auto& pair) { return pair.first.find("NPUW") != std::string::npos; })) { + LOGS_DEFAULT(WARNING) << "Skipping NPUW-related configurations as NPU_USE_NPUW is not set to 'YES'."; + } + } + } + // Parse device types like "AUTO:CPU,GPU" and extract individual devices auto parse_individual_devices = [&](const std::string& device_type) -> std::vector { std::vector devices; @@ -213,6 +249,9 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options, const std::vector& supported_properties) { for (const auto& [key, value] : config_options) { + if (key.find("NPUW") != std::string::npos) { + continue; + } if (is_supported_and_mutable(key, supported_properties)) { global_context_.ie_core.Get().set_property(device, ov::AnyMap{{key, value}}); } else { @@ -378,7 +417,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if ((it == ort_ov_tensor_map.end()) || (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { ov_tensor_data_t ov_tensor_data; - auto input = graph_input_info.at(input_idx); + const auto& input = graph_input_info.at(input_idx); ov_tensor_data.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), const_cast(tensor.GetTensorRawData())); @@ -663,7 +702,6 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { // Requesting for an idle infer_request from a pool of infer_requests_ OVInferRequestPtr infer_request; infer_request = inferRequestsQueue_->getIdleRequest(); - #ifdef IO_BUFFER_ENABLED if ((global_context_.device_type.find("GPU") != std::string::npos) && (global_context_.context != nullptr) && global_context_.is_wholly_supported_graph) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 12502a1d83c5d..3fcf6e4384d52 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -58,7 +58,6 @@ class BasicBackend : public IBackend { GlobalContext& global_context_; SubGraphContext subgraph_context_; mutable std::mutex compute_lock_; - std::shared_ptr ie_cnn_network_; OVExeNetwork exe_network_; std::map> const_outputs_map_; std::unique_ptr inferRequestsQueue_; diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index a2f4b236213cc..4f970bc7bc287 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -18,7 +18,7 @@ struct GlobalContext { bool is_wholly_supported_graph = false; bool enable_opencl_throttling = false; bool disable_dynamic_shapes = false; - bool ep_context_embed_mode = true; + bool ep_context_embed_mode = false; bool export_ep_ctx_blob = false; bool enable_qdq_optimizer = false; bool disable_cpu_fallback = false; diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index ee9486a62ea37..6d159db3b390d 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -21,7 +21,8 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, const bool& ep_context_embed_mode, std::string&& model_blob_str, const std::string& openvino_sdk_version) const { - auto model_build = graph_viewer.CreateModel(logger); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); + auto model_build = graph_viewer.CreateModel(logger, metadata); auto& graph_build = model_build->MainGraph(); // Get graph inputs and outputs @@ -94,17 +95,29 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, return Status::OK(); } -Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer) { +Status EPCtxHandler::ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer, bool& ep_context_embed_mode) { auto node = graph_viewer.GetNode(0); auto& attrs = node->GetAttributes(); ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) > 0); - model_stream_ = std::make_shared(attrs.at(EP_CACHE_CONTEXT).s()); + + ep_cache_context_attribute_ = &attrs.at(EP_CACHE_CONTEXT); + + ep_context_embed_mode = static_cast(attrs.at(EMBED_MODE).i()); LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; is_valid_ep_ctx_graph_ = true; return Status::OK(); } +const std::string& EPCtxHandler::GetModelBlobStream() const { + static std::string empty; + if (ep_cache_context_attribute_ != nullptr) { + return ep_cache_context_attribute_->s(); + } else { + return empty; + } +} + bool EPCtxHandler::CheckForOVEPCtxNode(const GraphViewer& graph_viewer, std::string openvino_sdk_version) const { for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) { auto node = graph_viewer.GetNode(i); diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index c631d011d02b1..caab33b7db775 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -23,21 +23,21 @@ static const char SOURCE[] = "source"; class EPCtxHandler { public: EPCtxHandler() = default; - EPCtxHandler(const EPCtxHandler&) = default; + EPCtxHandler(const EPCtxHandler&) = delete; Status ExportEPCtxModel(const GraphViewer& graph_viewer, const std::string& graph_name, const logging::Logger& logger, const bool& ep_context_embed_mode, std::string&& model_blob_str, const std::string& openvino_sdk_version) const; - Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer); + Status ImportBlobFromEPCtxModel(const GraphViewer& graph_viewer, bool& ep_context_embed_mode); bool CheckForOVEPCtxNode(const GraphViewer& graph_viewer, std::string openvino_sdk_version) const; bool IsValidOVEPCtxGraph() const { return is_valid_ep_ctx_graph_; } - [[nodiscard]] const std::shared_ptr GetModelBlobStream() const { return model_stream_; } + const std::string& GetModelBlobStream() const; private: bool is_valid_ep_ctx_graph_{false}; - std::shared_ptr model_stream_; + const onnx::AttributeProto* ep_cache_context_attribute_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 19a634818a442..72a188108adef 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -2,13 +2,16 @@ // Licensed under the MIT License #include #include - +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/openvino_execution_provider.h" #include "core/providers/openvino/contexts.h" #include "core/providers/openvino/backend_manager.h" #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/ov_versions/capability.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "openvino/core/version.hpp" #ifdef USE_OVEP_NPU_MEMORY #include "core/providers/openvino/ov_allocator.h" @@ -150,7 +153,7 @@ common::Status OpenVINOExecutionProvider::Compile( graph_body_viewer, *GetLogger(), ep_ctx_handle_); - + backend_manager_ = backend_manager; compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState(); @@ -186,16 +189,57 @@ common::Status OpenVINOExecutionProvider::Compile( #ifdef USE_OVEP_NPU_MEMORY std::vector OpenVINOExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo npu_allocator_info{ - [this](OrtDevice::DeviceId device_id) { - return std::make_unique(global_context_->ie_core.Get(), OrtDevice::NPU, device_id, OpenVINO_RT_NPU); - }, - 0, - }; - - // fill in allocator - return std::vector{CreateAllocator(npu_allocator_info)}; + if (global_context_->device_type.find("NPU") != std::string::npos) { + AllocatorCreationInfo npu_allocator_info{ + [this](OrtDevice::DeviceId device_id) { + return std::make_unique( + global_context_->ie_core.Get(), + OrtDevice::NPU, + device_id, + OpenVINO_RT_NPU); + }, + 0, + }; + + // fill in allocator + return std::vector{CreateAllocator(npu_allocator_info)}; + } else { + return std::vector{}; + } } #endif +common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + std::string workload_type = ""; + // Ensure the number of keys and values match + if (keys.size() != values.size()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Mismatched keys and values sizes."); + } + + for (size_t i = 0; i < keys.size(); ++i) { + std::string key = keys[i]; + std::string value = values[i]; + + if (key == kOrtEpDynamicOptionsWorkloadType) { + if (value == "Efficient") { + workload_type = "EFFICIENT"; + } else if (value == "Default") { + workload_type = "DEFAULT"; + } else { + LOGS_DEFAULT(WARNING) << "Unknown workload_type - ignoring " << key << "/" << value; + LOGS_DEFAULT(WARNING) << "Supported types are 'Efficient' and 'Default' \n"; + } + if (workload_type != "") { + LOGS_DEFAULT(INFO) << "SetEpDynamicOptions - modifying: " << key << "/" << value; + ov::CompiledModel& ov_compiled_model = backend_manager_->GetOVCompiledModel(); + ov_compiled_model.set_property(ov::workload_type(workload_type)); + } + } else { + // Handle unknown options + LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value; + } + } + return Status::OK(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 7d9da65ea7e07..d5c22a4e2a9e4 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -90,7 +90,7 @@ struct OpenVINOExecutionProviderInfo { bool export_ep_ctx_blob_{false}; bool enable_qdq_optimizer_{false}; bool disable_cpu_fallback_{false}; - bool so_epctx_embed_mode_{true}; + bool so_epctx_embed_mode_{false}; OpenVINOExecutionProviderInfo() = delete; @@ -159,7 +159,7 @@ struct OpenVINOExecutionProviderInfo { device_type_ = std::move(dev_type); } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0 || dev_type.find("AUTO") == 0) { std::vector devices = parseDevices(dev_type, available_devices); - device_type_ = dev_type; + device_type_ = std::move(dev_type); } else { ORT_THROW("Invalid device string: " + dev_type); } @@ -188,6 +188,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider { Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; + Status SetEpDynamicOptions(gsl::span /*keys*/, + gsl::span /*values*/) override; + const void* GetExecutionHandle() const noexcept override { return nullptr; } @@ -196,6 +199,7 @@ class OpenVINOExecutionProvider : public IExecutionProvider { #endif private: std::unique_ptr global_context_; + std::shared_ptr backend_manager_; openvino_ep::EPCtxHandler ep_ctx_handle_{}; }; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index b46106db3c232..5855cb594a08e 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -7,6 +7,7 @@ #include "core/providers/openvino/openvino_provider_factory.h" #include "core/providers/openvino/openvino_execution_provider.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "nlohmann/json.hpp" namespace onnxruntime { @@ -50,13 +51,13 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { - bool so_disable_cpu_fallback = config_options_.GetConfigOrDefault("session.disable_cpu_ep_fallback", "0") == "1"; - bool so_export_ep_ctx_blob = config_options_.GetConfigOrDefault("ep.context_enable", "0") == "1"; - bool so_epctx_embed_mode = config_options_.GetConfigOrDefault("ep.context_embed_mode", "1") == "1"; - std::string so_cache_path = config_options_.GetConfigOrDefault("ep.context_file_path", "").c_str(); + bool so_disable_cpu_fallback = config_options_.GetConfigOrDefault(kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; + bool so_export_ep_ctx_blob = config_options_.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + bool so_epctx_embed_mode = config_options_.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; + std::string so_cache_path = config_options_.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "").c_str(); if (so_export_ep_ctx_blob && !so_cache_path.empty()) { - cache_dir_ = so_cache_path; + cache_dir_ = std::move(so_cache_path); auto file_path = std::filesystem::path(cache_dir_); // ep_context_file_path_ file extension must be .onnx if (file_path.extension().generic_string() == ".onnx") { @@ -247,7 +248,7 @@ struct OpenVINO_Provider : Provider { LOGS_DEFAULT(WARNING) << "Unsupported JSON value type for key: " << inner_key << ". Skipping key."; } } - target_map[key] = inner_map; + target_map[key] = std::move(inner_map); } } catch (const nlohmann::json::parse_error& e) { // Handle syntax errors in JSON diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc index 6700244b754d8..0e5ff8ff98efb 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.cc +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -39,7 +39,6 @@ void* OVRTAllocator::Alloc(size_t size) { } catch (const ov::Exception& e) { ORT_THROW(std::string("Alloc failed: ") + e.what()); } - return nullptr; } void OVRTAllocator::Free(void* p) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 7e8681d304abf..12ab7ecede031 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -109,7 +109,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, } } -OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream, +OVExeNetwork OVCore::ImportModel(const std::string& model_string, std::string hw_target, const ov::AnyMap& device_config, bool embed_mode, @@ -117,10 +117,10 @@ OVExeNetwork OVCore::ImportModel(std::shared_ptr model_strea try { ov::CompiledModel obj; if (embed_mode) { - obj = oe.import_model(*model_stream, hw_target, device_config); + std::istringstream model_stream(model_string); + obj = oe.import_model(model_stream, hw_target, device_config); } else { - std::string blob_file_path = (*model_stream).str(); - std::ifstream modelStream(blob_file_path, std::ios_base::binary | std::ios_base::in); + std::ifstream modelStream(model_string, std::ios_base::binary | std::ios_base::in); obj = oe.import_model(modelStream, hw_target, {}); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index f4da4ea3e3244..c3417003f8e1f 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -54,7 +54,7 @@ class OVCore { ov::AnyMap& device_config, const std::string& name); // OV Interface for Import model Stream - OVExeNetwork ImportModel(std::shared_ptr model_stream, + OVExeNetwork ImportModel(const std::string& model_string, std::string hw_target, const ov::AnyMap& device_config, bool embed_mode, diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 0d7ac64d86e68..3e780f74145ae 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -35,16 +35,14 @@ GetCapability::GetCapability(const GraphViewer& graph_viewer_param, device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; } -#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 1 - data_ops_ = new DataOps(graph_viewer_, V_2024_1, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 2 - data_ops_ = new DataOps(graph_viewer_, V_2024_2, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 3 - data_ops_ = new DataOps(graph_viewer_, V_2024_3, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 4 +#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 4 data_ops_ = new DataOps(graph_viewer_, V_2024_4, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5 + data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 + data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); #else - data_ops_ = new DataOps(graph_viewer_, V_2024_4, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); #endif } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index e8f6ae0a43734..f118f057ac11e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -118,6 +118,7 @@ std::vector supported_op_mode = { {"CumSum", V_2022_1, {"CPU", "GPU"}}, {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, + {"DequantizeLinear", V_2024_4, {"NPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, @@ -254,6 +255,8 @@ void DataOps::populate_types_supported() { std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); supported_types_initializer_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); supported_types_initializer_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)); supported_types_initializer_.insert( @@ -262,6 +265,10 @@ void DataOps::populate_types_supported() { std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); supported_types_initializer_.insert( std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_initializer_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_initializer_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); supported_types_npu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); @@ -285,6 +292,10 @@ void DataOps::populate_types_supported() { std::make_pair(V_2024_3, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN)); supported_types_npu_.insert( std::make_pair(V_2024_3, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ)); + supported_types_npu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_npu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); supported_types_cpu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); @@ -304,6 +315,10 @@ void DataOps::populate_types_supported() { std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); supported_types_cpu_.insert( std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_cpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_cpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); supported_types_gpu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); @@ -319,6 +334,10 @@ void DataOps::populate_types_supported() { std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); supported_types_gpu_.insert( std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_gpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4)); + supported_types_gpu_.insert( + std::make_pair(V_2024_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4)); } void DataOps::populate_op_mode_supported() { @@ -336,6 +355,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Identity", V_2023_0, {"All"}}); + no_dimension_supported_.push_back({"If", V_2022_3, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); @@ -368,7 +388,7 @@ void DataOps::populate_op_mode_supported() { // populate unsupportedmode_t { - UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4}, + UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { // If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { @@ -383,7 +403,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"ReduceMax", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, + V_2024_3, V_2024_4, V_2024_5, V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); @@ -400,7 +421,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Reshape", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, + V_2024_3, V_2024_4, V_2024_5, V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze // If axes is an input, then we cannot produce a static graph. @@ -415,7 +437,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, + V_2025_0}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index 5cd4c8658fb77..07fa36f355d55 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -31,7 +31,9 @@ enum versionNum { V_2024_1, V_2024_2, V_2024_3, - V_2024_4 + V_2024_4, + V_2024_5, + V_2025_0 }; using VersionNum = enum versionNum; diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index f1df1abf4c49a..387aaf9985b4c 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -30,6 +30,10 @@ constexpr std::string_view DuplicateDQ = "/duplicated"; constexpr ONNX_NAMESPACE::TensorProto_DataType DT_UINT16 = ONNX_NAMESPACE::TensorProto_DataType_UINT16; constexpr ONNX_NAMESPACE::TensorProto_DataType DT_INT16 = ONNX_NAMESPACE::TensorProto_DataType_INT16; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_UINT8 = ONNX_NAMESPACE::TensorProto_DataType_UINT8; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_INT8 = ONNX_NAMESPACE::TensorProto_DataType_INT8; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_UINT4 = ONNX_NAMESPACE::TensorProto_DataType_UINT4; +constexpr ONNX_NAMESPACE::TensorProto_DataType DT_INT4 = ONNX_NAMESPACE::TensorProto_DataType_INT4; // Return the data type of the qdq node. // Check output type of Q and input type of DQ to determine it as zero_point is an optional input and may not exist @@ -218,7 +222,7 @@ static bool DQFeedsASupportedOp(const Node* dq_node) { } else { return true; } - } else if (op_type == "Add") { + } else if (op_type == "Add" && !(GetQDQDataType(dq_node) == DT_UINT16 || GetQDQDataType(dq_node) == DT_INT16)) { // Add => keeps all DQs return true; } @@ -687,7 +691,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger); std::unordered_set seen_node_units; const auto& node_indices = src_graph.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 41e418d9eb97f..1c62c1a7a8d0b 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -78,10 +78,6 @@ #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" #endif -#if defined(USE_TVM) -#include "core/providers/tvm/tvm_provider_factory_creator.h" -#endif - #if defined(USE_VITISAI) #include "core/providers/vitisai/vitisai_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 57ae8c354abb7..79674fd706151 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector& names, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size) { ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); NodeAttrHelper node_helper(main_context_node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); @@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), main_context_node.Name(), - qnn_models); + qnn_models, + max_spill_fill_size); } std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); @@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), main_context_node.Name(), - qnn_models); + qnn_models, + max_spill_fill_size); +} + +Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs, + uint32_t total_context_size, + int64_t& max_spill_fill_size, + std::vector& main_context_pos_list) { + max_spill_fill_size = 0; + int max_size_index = 0; + for (uint32_t i = 0; i < total_context_size; ++i) { + auto index = main_context_pos_list[i]; + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph); + ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); + const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin(); + NodeAttrHelper node_helper(*ep_context_node); + int64_t max_size = node_helper.Get(MAX_SIZE, static_cast(0)); + if (max_size > max_spill_fill_size) { + max_spill_fill_size = max_size; + max_size_index = i; + } + } + if (0 != max_size_index) { + int tmp_index = main_context_pos_list[0]; + main_context_pos_list[0] = main_context_pos_list[max_size_index]; + main_context_pos_list[max_size_index] = tmp_index; + } + + return Status::OK(); } Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger) { + const logging::Logger& logger, + int64_t max_spill_fill_size) { ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, - qnn_models); + qnn_models, max_spill_fill_size); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { @@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model, const QnnModelLookupTable& qnn_models, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, + uint64_t max_spill_fill_buffer_size, const logging::Logger& logger) { auto& graph = model->MainGraph(); @@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model, } of_stream.write(reinterpret_cast(buffer), buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name); + ep_node.AddAttribute(MAX_SIZE, static_cast(max_spill_fill_buffer_size)); } } else { ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0)); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index f308a7456d46c..92c5391b40f09 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string EP_SDK_VER = "ep_sdk_version"; static const std::string PARTITION_NAME = "partition_name"; static const std::string SOURCE = "source"; +static const std::string MAX_SIZE = "max_size"; bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); @@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models); + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size); + +Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs, + uint32_t total_context_size, + int64_t& max_spill_fill_size, + std::vector& main_context_pos_list); Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger); + const logging::Logger& logger, + int64_t max_spill_fill_size); Status CreateEPContextNodes(Model* model, unsigned char* buffer, @@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model, const std::unordered_map>& qnn_models, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, + uint64_t max_spill_fill_buffer_size, const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index dd5c6a5a79cdb..6ef17b40d274b 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -83,6 +83,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateReduceOpBuilder("ReduceMin", *this); CreateReduceOpBuilder("ReduceProd", *this); CreateReduceOpBuilder("ReduceSum", *this); + CreateReduceOpBuilder("ReduceL2", *this); } { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index d089235ceaa02..d1a0e88686f39 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -87,10 +87,10 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); } -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 17) +#if QNN_API_VERSION_MAJOR == 2 && QNN_API_VERSION_MINOR >= 17 && QNN_API_VERSION_MINOR <= 20 if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { - // Bias is implicit. QNN SDK 2.24+ (QNN API version 2.17+) has a validation bug for implicit bias inputs, - // so provide an explicit bias of all 0 (quantized int32). + // Bias is implicit. QNN SDK 2.24 to 2.27 (QNN API version 2.17 to 2.20) has a validation bug for + // implicit bias inputs, so provide an explicit bias of all 0 (quantized int32). TensorInfo x_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc index 2aefe5f6b8e71..77bc58bd6f833 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc @@ -6,15 +6,15 @@ #include #include +#include "core/common/safeint.h" +#include "onnx/defs/data_type_utils.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" #include "core/framework/endian_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" -#include "onnx/defs/data_type_utils.h" - -#include "base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { namespace qnn { @@ -25,6 +25,7 @@ enum ReduceOpType { REDUCE_OP_TYPE_MEAN, REDUCE_OP_TYPE_PROD, REDUCE_OP_TYPE_SUM, + REDUCE_OP_TYPE_L2, REDUCE_OP_TYPE_COUNT, REDUCE_OP_TYPE_UNKNOWN, @@ -41,6 +42,8 @@ ReduceOpType GetReduceOpType(const std::string& op_type) { return REDUCE_OP_TYPE_PROD; } else if (op_type == "ReduceSum") { return REDUCE_OP_TYPE_SUM; + } else if (op_type == "ReduceL2") { + return REDUCE_OP_TYPE_L2; } else { return REDUCE_OP_TYPE_UNKNOWN; } @@ -51,21 +54,16 @@ class ReduceOpBuilder : public BaseOpBuilder { ReduceOpBuilder() : BaseOpBuilder("ReduceOpBuilder") {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ReduceOpBuilder); - Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; protected: - Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger, + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, std::vector& input_names, bool do_op_validation = false) const override ORT_MUST_USE_RESULT; - Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const override ORT_MUST_USE_RESULT; private: @@ -84,7 +82,8 @@ const std::array ReduceOpBuilder::opset_with_axes_as_ 18, // ReduceMin 18, // ReduceMean 18, // ReduceProd - 13 // ReduceSum + 13, // ReduceSum + 18, // ReduceL2 }; Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -175,8 +174,7 @@ Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const Nod return Status::OK(); } -Status ReduceOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, +Status ReduceOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { ReduceOpType reduce_op_type = GetReduceOpType(node_unit.OpType()); if (reduce_op_type == ReduceOpType::REDUCE_OP_TYPE_UNKNOWN) { @@ -188,13 +186,17 @@ Status ReduceOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: ReduceProd operator not supported by HTP backend."); } + // ReduceL2 is composed by Mul->ReduceSum->Sqrt, it's not easy to set the quantization parameters for the activation + // tensors between, so we don't support ReduceL2 with quantized input for now. + if (reduce_op_type == ReduceOpType::REDUCE_OP_TYPE_L2 && node_unit.Inputs()[0].quant_param.has_value()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: ReduceL2 operator does not support quantized input for now."); + } + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } -Status ReduceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const logging::Logger& logger, - std::vector& input_names, +Status ReduceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + const logging::Logger& logger, std::vector& input_names, bool do_op_validation) const { ORT_UNUSED_PARAMETER(do_op_validation); @@ -207,11 +209,9 @@ Status ReduceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, +Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { + const logging::Logger& logger, bool do_op_validation) const { NodeAttrHelper node_attr_helper(node_unit); std::vector param_tensor_names; @@ -229,8 +229,8 @@ Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::transform(axes_set.begin(), axes_set.end(), axes_data.begin(), [](AxesOnnxIntType item) { return SafeInt(item); }); - QnnParamWrapper axes_param(node_unit.Index(), node_unit.Name(), QNN_OP_REDUCE_MAX_PARAM_AXES, - std::move(axes_shape), std::move(axes_data)); + QnnParamWrapper axes_param(node_unit.Index(), node_unit.Name(), QNN_OP_REDUCE_MAX_PARAM_AXES, std::move(axes_shape), + std::move(axes_data)); param_tensor_names.push_back(axes_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(axes_param)); @@ -245,10 +245,57 @@ Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w param_tensor_names.push_back(keep_dims_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(keep_dims_param)); - ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, - std::move(input_names), - std::move(param_tensor_names), - logger, do_op_validation, GetQnnOpType(node_unit.OpType()))); + if (node_unit.OpType() == "ReduceL2") { + // If ReduceL2, QNN doesn't have a single Op for it, we need to add a + // ElementWiseMultiply->ReduceSum->ElementWiseSquareRoot node sequence. + const auto& input = node_unit.Inputs()[0]; + const auto& output = node_unit.Outputs()[0]; + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, input_shape), "Cannot get input shape."); + std::vector output_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape), "Cannot get output shape."); + ORT_ENFORCE(!input.quant_param.has_value(), "Input tensor must not be quantized."); + const auto* type_proto = output.node_arg.TypeAsProto(); + Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type)); + const std::string input_name = input_names[0]; + + // Step 1: y_pow2 = x * x, using ElementWiseMultiply instead of ElementWisePower so we don't need to add a new + // initializer tensor for the power value. The performance difference is negligible. + const std::string pow2_name = input_name + "_ort_qnn_ep_pow2"; + QnnTensorWrapper pow2_tensorwrapper(pow2_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), + std::move(input_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(pow2_tensorwrapper)), "AddTensorWrapper failed"); + ORT_RETURN_IF_NOT( + qnn_model_wrapper.CreateQnnNode(pow2_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_MULTIPLY, + {input_name, input_name}, {pow2_name}, {}, do_op_validation), + "CreateQnnNode failed"); + + // Step 2: y_pow2_sum = ReduceSum(y_pow2) + const std::string reduce_name = input_name + "_ort_qnn_ep_pow2_sum"; + QnnTensorWrapper reduce_tensorwrapper(reduce_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), + std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reduce_tensorwrapper)), "AddTensorWrapper failed"); + ORT_RETURN_IF_NOT( + qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_REDUCE_SUM, + {pow2_name}, {reduce_name}, std::move(param_tensor_names), do_op_validation), + "CreateQnnNode failed"); + + // Step 3: y = Sqrt(y_pow2_sum) + Qnn_TensorType_t output_tensor_type = + qnn_model_wrapper.IsGraphOutput(output.node_arg.Name()) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper sqrt_tensorwrapper(output.node_arg.Name(), output_tensor_type, qnn_data_type, + QnnQuantParamsWrapper(), std::move(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(sqrt_tensorwrapper)), "AddTensorWrapper failed"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(input_name + "_ort_qnn_ep_pow2_sum_sqrt", + QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_SQUARE_ROOT, + {reduce_name}, {output.node_arg.Name()}, {}, do_op_validation), + "CreateQnnNode failed"); + } else { + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), + std::move(param_tensor_names), logger, do_op_validation, + GetQnnOpType(node_unit.OpType()))); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 34dcbd1d77fca..3af646c3ce13a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -8,12 +8,14 @@ #include #include "QnnOpDef.h" #include "HTP/QnnHtpPerfInfrastructure.h" +#include "HTP/QnnHtpSystemContext.h" #include "CPU/QnnCpuCommon.h" // TODO: not exist for Windows yet // #include "GPU/QnnGpuCommon.h" #include "DSP/QnnDspCommon.h" #include "HTP/QnnHtpCommon.h" #include "HTP/QnnHtpContext.h" +#include "Saver/QnnSaver.h" #include #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" @@ -531,11 +533,11 @@ Status QnnBackendManager::CreateContext() { } QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT; - QnnHtpContext_CustomConfig_t customConfig; - customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; - customConfig.weightSharingEnabled = enable_htp_weight_sharing_; + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + custom_config.weightSharingEnabled = enable_htp_weight_sharing_; context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; - context_config_weight_sharing.customConfig = &customConfig; + context_config_weight_sharing.customConfig = &custom_config; QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config)); @@ -614,9 +616,78 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return context_buffer; } +Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer, + uint64_t buffer_length, + uint64_t& max_spill_fill_buffer_size) { + max_spill_fill_buffer_size = 0; + // spill fill starts from 2.28 +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) + bool result = nullptr == qnn_sys_interface_.systemContextCreate || + nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || + nullptr == qnn_sys_interface_.systemContextFree; + ORT_RETURN_IF(result, "Failed to get valid function pointer."); + + QnnSystemContext_Handle_t sys_ctx_handle = nullptr; + auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); + + const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; + Qnn_ContextBinarySize_t binary_info_size{0}; + rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, + static_cast(buffer), + buffer_length, + &binary_info, + &binary_info_size); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); + + // binary_info life cycle is here + // Binary info to graph info + // retrieve Qnn graph info from binary info + ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); + uint32_t graph_count = 0; + QnnSystemContext_GraphInfo_t* graphs_info = nullptr; + if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + graph_count = binary_info->contextBinaryInfoV3.numGraphs; + graphs_info = binary_info->contextBinaryInfoV3.graphs; + } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + graph_count = binary_info->contextBinaryInfoV2.numGraphs; + graphs_info = binary_info->contextBinaryInfoV2.graphs; + } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + graph_count = binary_info->contextBinaryInfoV1.numGraphs; + graphs_info = binary_info->contextBinaryInfoV1.graphs; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); + } + + for (uint32_t i = 0; i < graph_count; ++i) { + if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + auto htp_graph_info = reinterpret_cast(graphs_info[i].graphInfoV3.graphBlobInfo); + if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) { + auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize; + max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version."; + } + } else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 || + graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2."; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version."; + } + } +#else + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_length); +#endif + + LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed."; + return Status::OK(); +} + Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + int64_t max_spill_fill_size) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -637,28 +708,60 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // binary_info life cycle is here // Binary info to graph info - // retrieve Qnn graph infor from binary info + // retrieve Qnn graph info from binary info ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); uint32_t graph_count = 0; QnnSystemContext_GraphInfo_t* graphs_info = nullptr; if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { graph_count = binary_info->contextBinaryInfoV1.numGraphs; graphs_info = binary_info->contextBinaryInfoV1.graphs; - } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + } +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 15) // starts from 2.22 + else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { graph_count = binary_info->contextBinaryInfoV2.numGraphs; graphs_info = binary_info->contextBinaryInfoV2.graphs; } +#endif +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) // starts from 2.28 + else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + graph_count = binary_info->contextBinaryInfoV3.numGraphs; + graphs_info = binary_info->contextBinaryInfoV3.graphs; + } +#endif + else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); + } ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count; - ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, - "Invalid function pointer for contextCreateFromBinary."); - QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); - const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + // Register spill fill buffer for multi context + QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT; + + // The spill fill buffer is available since 2.28, API version starts from 2.21 +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; + QnnHtpContext_GroupRegistration_t group_info; + size_t current_contexts_size = GetQnnContextSize(); + // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle + // note that we already move the context with max spill fill size to the beginning of the list + group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0; + group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0 + custom_config.groupRegistration = group_info; + spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + spill_fill_config.customConfig = &custom_config; +#endif + QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr; + LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size; + + const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr}; + + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, + "Invalid function pointer for contextCreateFromBinary."); Qnn_ContextHandle_t context = nullptr; rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, @@ -667,7 +770,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, &context, profile_backend_handle_); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); contexts_.push_back(context); if (1 == graph_count) { // in case the EPContext node is generated from script @@ -693,7 +796,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t return Status::OK(); } -Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { +// need to load system lib if load from Qnn context binary +// or generate Qnn context binary is enabled -- to get the max spill fill buffer size +Status QnnBackendManager::SetupBackend(const logging::Logger& logger, + bool load_from_cached_context, + bool need_load_system_lib) { std::lock_guard lock(logger_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; @@ -708,7 +815,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ LOGS(logger, VERBOSE) << "LoadBackend succeed."; - if (load_from_cached_context) { + if (load_from_cached_context || need_load_system_lib) { ORT_RETURN_IF_ERROR(LoadQnnSystemLib()); } @@ -927,20 +1034,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_ return Status::OK(); } -void QnnBackendManager::Split(std::vector& split_string, - const std::string& tokenized_string, - const char separator) { - split_string.clear(); - std::istringstream tokenized_string_stream(tokenized_string); - while (!tokenized_string_stream.eof()) { - std::string value; - getline(tokenized_string_stream, value, separator); - if (!value.empty()) { - split_string.push_back(value); - } - } -} - Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); @@ -1035,7 +1128,14 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { const QnnProfile_EventId_t* profile_events{nullptr}; uint32_t num_events{0}; Qnn_ErrorHandle_t result = qnn_interface_.profileGetEvents(profile_backend_handle_, &profile_events, &num_events); - ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + if (!qnn_saver_path_.empty()) { // Using QNN Saver backend + // QNN SDK 2.28.2 returns QNN_SAVER_ERROR_DUMMY_RETVALUE, but previous QNN versions return QNN_PROFILE_NO_ERROR. + // We accept both values. + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result && QNN_SAVER_ERROR_DUMMY_RETVALUE != result, + "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + } else { + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + } if (num_events > 0) { LOGS(*logger_, VERBOSE) << "profile_events: " << profile_events << " num_events: " << num_events; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 43007d4a5c244..b145f2a2cd724 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -93,9 +93,10 @@ class QnnBackendManager { Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - std::unordered_map>& qnn_models); + std::unordered_map>& qnn_models, + int64_t max_spill_fill_size); - Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); + Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -112,6 +113,10 @@ class QnnBackendManager { return contexts_[index]; } + size_t GetQnnContextSize() { + return contexts_.size(); + } + const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; } const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; } @@ -145,8 +150,6 @@ class QnnBackendManager { void ReleaseResources(); - void Split(std::vector& split_string, const std::string& tokenized_string, const char separator); - Status ExtractBackendProfilingInfo(); Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled); @@ -163,6 +166,10 @@ class QnnBackendManager { Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + Status GetMaxSpillFillBufferSize(unsigned char* buffer, + uint64_t buffer_length, + uint64_t& max_spill_fill_buffer_size); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index dc797fef2d42a..4f73e4c532ed4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -104,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); @@ -321,29 +321,57 @@ Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_Graph std::vector output_tensor_wrappers; std::string graph_name; + Qnn_Tensor_t* input_tensors = nullptr; + Qnn_Tensor_t* output_tensors = nullptr; + uint32_t graph_input_num = 0; + uint32_t graph_output_num = 0; if (qnn_sys_ctx_graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { graph_name.assign(qnn_sys_ctx_graph_info.graphInfoV1.graphName); - auto graph_input_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphInputs; - auto graph_output_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphOutputs; - ORT_RETURN_IF(nullptr == qnn_sys_ctx_graph_info.graphInfoV1.graphInputs, "Graph from cached context doesn't have any inputs."); - ORT_RETURN_IF(nullptr == qnn_sys_ctx_graph_info.graphInfoV1.graphOutputs, "Graph from cached context doesn't have any outputs."); - - // Copy graph input - Qnn_Tensor_t* input_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphInputs; - for (size_t i = 0; i < graph_input_num; ++i) { - QnnTensorWrapper tensorwrapper; - ORT_RETURN_IF_ERROR(tensorwrapper.Init(input_tensors[i])); - input_tensor_wrappers.push_back(std::move(tensorwrapper)); - } + graph_input_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphInputs; + graph_output_num = qnn_sys_ctx_graph_info.graphInfoV1.numGraphOutputs; - // Copy graph output - Qnn_Tensor_t* output_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphOutputs; - for (size_t i = 0; i < graph_output_num; ++i) { - QnnTensorWrapper tensorwrapper; - ORT_RETURN_IF_ERROR(tensorwrapper.Init(output_tensors[i])); - output_tensor_wrappers.push_back(std::move(tensorwrapper)); - } + input_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphInputs; + output_tensors = qnn_sys_ctx_graph_info.graphInfoV1.graphOutputs; + } +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 18) // start from 2.25 + else if (qnn_sys_ctx_graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) { + graph_name.assign(qnn_sys_ctx_graph_info.graphInfoV2.graphName); + graph_input_num = qnn_sys_ctx_graph_info.graphInfoV2.numGraphInputs; + graph_output_num = qnn_sys_ctx_graph_info.graphInfoV2.numGraphOutputs; + + input_tensors = qnn_sys_ctx_graph_info.graphInfoV2.graphInputs; + output_tensors = qnn_sys_ctx_graph_info.graphInfoV2.graphOutputs; + } +#endif +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) // start from 2.28 + else if (qnn_sys_ctx_graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + graph_name.assign(qnn_sys_ctx_graph_info.graphInfoV3.graphName); + graph_input_num = qnn_sys_ctx_graph_info.graphInfoV3.numGraphInputs; + graph_output_num = qnn_sys_ctx_graph_info.graphInfoV3.numGraphOutputs; + + input_tensors = qnn_sys_ctx_graph_info.graphInfoV3.graphInputs; + output_tensors = qnn_sys_ctx_graph_info.graphInfoV3.graphOutputs; } +#endif + else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context graph info version."); + } + ORT_RETURN_IF(nullptr == input_tensors, "Graph from cached context doesn't have any inputs."); + ORT_RETURN_IF(nullptr == output_tensors, "Graph from cached context doesn't have any outputs."); + + // Copy graph input + for (size_t i = 0; i < graph_input_num; ++i) { + QnnTensorWrapper tensorwrapper; + ORT_RETURN_IF_ERROR(tensorwrapper.Init(input_tensors[i])); + input_tensor_wrappers.push_back(std::move(tensorwrapper)); + } + // Copy graph output + for (size_t i = 0; i < graph_output_num; ++i) { + QnnTensorWrapper tensorwrapper; + ORT_RETURN_IF_ERROR(tensorwrapper.Init(output_tensors[i])); + output_tensor_wrappers.push_back(std::move(tensorwrapper)); + } + Qnn_GraphHandle_t graph; auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); auto rt = qnn_interface.graphRetrieve(context, graph_name.c_str(), &graph); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 6735528bebbf9..27e195dea73d2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -204,7 +204,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "Context cache enable: " << context_cache_enabled_; std::string embed_mode = session_options->config_options.GetConfigOrDefault( - kOrtSessionOptionEpContextEmbedMode, "1"); + kOrtSessionOptionEpContextEmbedMode, "0"); if ("1" == embed_mode) { qnn_context_embed_mode_ = true; } else if ("0" == embed_mode) { @@ -363,20 +363,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_; } + bool enable_htp_weight_sharing = false; static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing"; auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED); if (htp_weight_sharing_enabled_pos != provider_options_map.end()) { if ("1" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = true; + enable_htp_weight_sharing = true; } else if ("0" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = false; + enable_htp_weight_sharing = false; } else { - LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_ + LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing << " only 0 or 1 allowed. Set to 0."; } - LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_; + LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing; } + // Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform + enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map); + model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false, provider_options_map); @@ -396,7 +400,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio device_id_, htp_arch, soc_model, - enable_htp_weight_sharing_); + enable_htp_weight_sharing); #ifdef _WIN32 auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); @@ -686,7 +690,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // It will load the QnnSystem lib if is_qnn_ctx_model=true, and // delay the Qnn context creation to Compile() using the cached context binary - auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model); + // or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size + auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; @@ -713,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // remove is_qnn_ctx_model related code const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, @@ -934,6 +939,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::vector main_context_pos_list; ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list)); + uint32_t total_context_size = SafeInt(main_context_pos_list.size()); + + int64_t max_spill_fill_size = 0; + + // Adjust the main_context_pos_list, move the one with max spill fill buffer to the beginning + // HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28 + if (total_context_size > 1) { + ORT_RETURN_IF_ERROR(qnn::TryGetMaxSpillFillSize(fused_nodes_and_graphs, total_context_size, + max_spill_fill_size, main_context_pos_list)); + } for (auto main_context_pos : main_context_pos_list) { const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); @@ -942,7 +957,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused context_cache_path, qnn_backend_manager_.get(), qnn_models, - logger)); + logger, + max_spill_fill_size)); } for (auto fused_node_and_graph : fused_nodes_and_graphs) { @@ -984,6 +1000,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); + // Get max spill fill buffer size + uint64_t max_spill_fill_buffer_size = 0; + if (enable_spill_fill_buffer_) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->GetMaxSpillFillBufferSize(context_buffer.get(), + buffer_size, + max_spill_fill_buffer_size)); + } qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), context_buffer.get(), @@ -993,6 +1016,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_models_, context_cache_path, qnn_context_embed_mode_, + max_spill_fill_buffer_size, logger)); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 35c061de6132c..a0577e8fd87f2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider { std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; - bool enable_htp_weight_sharing_ = false; int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_; @@ -150,6 +149,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; + bool enable_spill_fill_buffer_ = false; #ifdef _WIN32 onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index 635a25480b646..281a6f35a2808 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -6,10 +6,8 @@ #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_common.h" +// If you make change below, please also update onnxruntime/core/providers/migraphx/gpu_data_transfer.cc namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() {} - -GPUDataTransfer::~GPUDataTransfer() {} bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || @@ -30,19 +28,23 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); + } } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } else { // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } @@ -59,7 +61,8 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to GPU, this is non-blocking + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking @@ -68,15 +71,15 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, } } } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU) { - // copying from GPU to pinned memory, this is non-blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } + // If dest are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); } else { if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { // sync the stream first to make sure the data arrived HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } + ORT_ENFORCE(dst_data != src_data); memcpy(dst_data, src_data, bytes); } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h index 3d297bdce4a93..3d35ed52fff5c 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); - ~GPUDataTransfer(); + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 1340c49c38ded..d8b7e26d17b65 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -16,17 +16,17 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace rocm { -#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ +#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, end, \ + begin, end, \ T, \ kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ +#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -37,8 +37,13 @@ namespace rocm { name); #define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ - REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ - REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) + +#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -830,14 +835,13 @@ template std::unique_ptr ReduceCompute(ncclResult_t retCode, const char* template void RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); #endif -#ifdef USE_HIPBLASLT -template Status RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); -#endif - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 02a21c033e988..0a427b146dcaa 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -41,10 +41,9 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - // do we support async copy? - // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, - // so we don't need the check here. auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + // CopyTensorAsync could handle both pinned memory and non-pinned CPU memory. + // For non-pinned CPU memory, the copy is synchronous. ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); } else { @@ -927,6 +926,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin); // OpSet 13 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow); @@ -1164,6 +1169,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin); + // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Relu); @@ -1604,6 +1616,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1786,9 +1802,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { 19, IsInf)>, // opset 11 - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1880,6 +1893,13 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 13 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2113,6 +2133,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, @@ -2388,6 +2414,26 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { return false; } +static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) { + // Opset 12 introduced the attribute "select_last_index" + if (node.SinceVersion() >= 12) { + const auto& node_attributes = node.GetAttributes(); + + for (auto& attr : node_attributes) { + auto& attr_name = attr.first; + auto& attr_value = attr.second; + + // It is not supported to pick the last index in case of encountering duplicate max values. + if ("select_last_index" == attr_name) { + if (attr_value.i() != 0) { + return true; + } + } + } + } + + return false; +} std::unique_ptr ROCMExecutionProvider::GetDataTransfer() const { return std::make_unique(); } @@ -2426,6 +2472,9 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, "GRU" == node.OpType()) { not_supported = true; force_inside = !not_supported; + } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) { + not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node); + force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside @@ -2444,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For ROCM EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index 6554ed977cef6..486ce5bfb731a 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -37,26 +37,26 @@ enum ActivationType { }; template -constexpr hipblasltDatatype_t HipBlasDataTypeFor(); +constexpr hipDataType HipBlasDataTypeFor(); template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_32F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_32F; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_16F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_16F; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_16B; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_16BF; } template <> -constexpr hipblasltDatatype_t HipBlasDataTypeFor() { - return HIPBLASLT_R_64F; +constexpr hipDataType HipBlasDataTypeFor() { + return HIP_R_64F; } template @@ -108,7 +108,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp hipblasOperation_t trans_a = MapBlasOpToHipBlasLt(); hipblasOperation_t trans_b = MapBlasOpToHipBlasLt(); - hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor(); + hipDataType in_out_datatype = HipBlasDataTypeFor(); std::vector heuristic_result; HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle, @@ -119,7 +119,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp in_out_datatype, in_out_datatype, in_out_datatype, - HIPBLASLT_COMPUTE_F32, + HIPBLAS_COMPUTE_32F, heuristic_result)); HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle)); @@ -161,7 +161,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda)); HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb)); HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); int batch = GetBatchCountFromParams(params); if (batch > 1) { diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b84825236a453..45f81ed22b7f7 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -294,7 +294,8 @@ std::unique_ptr CreateGPUDataTransfer(); std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); @@ -371,8 +372,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { namespace QDQ { inline std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer* graph_viewer) { - return g_host->QDQ__GetAllNodeUnits(graph_viewer); +GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { + return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger); } } // namespace QDQ diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d3b12f9728135..aa8c367d25d51 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -369,8 +369,9 @@ std::string GetEnvironmentVar(const std::string& var_name) { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { - return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) { + return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 3efc715fc3037..d182d0b9173bd 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -202,7 +202,8 @@ struct ProviderHost { virtual std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) = 0; + gsl::span tentative_nodes, + const logging::Logger& logger) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0; @@ -389,6 +390,7 @@ struct ProviderHost { virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; virtual void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) = 0; virtual ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual std::string* AttributeProto__release_s(ONNX_NAMESPACE::AttributeProto* p) = 0; // GraphProto virtual std::unique_ptr GraphProto__construct() = 0; @@ -890,7 +892,7 @@ struct ProviderHost { virtual std::unique_ptr NodeUnit__OutputEdgesEnd(const NodeUnit* p) = 0; virtual std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) = 0; + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0; // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, @@ -960,7 +962,7 @@ struct ProviderHost { // GraphViewer virtual void GraphViewer__operator_delete(GraphViewer* p) = 0; - virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger) = 0; + virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger, const ModelMetaData&) = 0; virtual const std::string& GraphViewer__Name(const GraphViewer* p) noexcept = 0; virtual const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0; @@ -996,6 +998,7 @@ struct ProviderHost { bool include_outer_scope_args, int execution_order) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; + virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0; // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index b9e0951a740a2..54249f0864cd7 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -122,6 +122,7 @@ struct AttributeProto final { void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); } void set_type(AttributeProto_AttributeType value) { return g_host->AttributeProto__set_type(this, value); } TensorProto* add_tensors() { return g_host->AttributeProto__add_tensors(this); } + std::string* release_s() { return g_host->AttributeProto__release_s(this); } typedef AttributeProto_AttributeType AttributeType; static constexpr AttributeType UNDEFINED = AttributeProto_AttributeType_UNDEFINED; @@ -1022,11 +1023,13 @@ struct Graph final { PROVIDER_DISALLOW_ALL(Graph) }; +using ModelMetaData = std::unordered_map; + class GraphViewer final { public: static void operator delete(void* p) { g_host->GraphViewer__operator_delete(reinterpret_cast(p)); } - std::unique_ptr CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); } + std::unique_ptr CreateModel(const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) const { return g_host->GraphViewer__CreateModel(this, logger, metadata); } const std::string& Name() const noexcept { return g_host->GraphViewer__Name(this); } const std::filesystem::path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); } @@ -1068,6 +1071,7 @@ class GraphViewer final { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order); } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } + IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->GraphViewer__GetSchemaRegistry(this); } GraphViewer() = delete; GraphViewer(const GraphViewer&) = delete; diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index ef45d6c85d6a9..fbccd7d4a286b 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -128,7 +128,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, // Serialize modelproto to string auto new_graph_viewer = graph_build.CreateGraphViewer(); - auto model = new_graph_viewer->CreateModel(*logger); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); + auto model = new_graph_viewer->CreateModel(*logger, metadata); auto model_proto = model->ToProto(); new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 4da40823ba4e9..1b432dad44263 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1725,6 +1725,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); } + trt_version_ = getInferLibVersion(); + CUDA_CALL_THROW(cudaRuntimeGetVersion(&cuda_version_)); + + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] CUDA version is " << cuda_version_; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: " << "device_id: " << device_id_ << ", trt_max_partition_iterations: " << max_partition_iterations_ @@ -1948,7 +1954,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph // Find inputs and outputs of the subgraph std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; int input_order = 0; int output_order = 0; @@ -2040,12 +2046,25 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); - // Sort inputs and outputs by the order they were added std::multimap inputs, outputs; + + // Get the input order of the original graph + int order = 0; + for (const auto* input : graph.GetInputs()) { + original_inputs[input] = order++; + } + + // input order needs to be consistent with original graph's input order for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { - inputs.insert(std::pair(it->second, it->first)); + const auto& iter = original_inputs.find(it->first); + if (iter != original_inputs.end()) { + inputs.insert(std::pair(iter->second, iter->first)); + } else { + inputs.insert(std::pair(it->second, it->first)); + } } + // Sort outputs by the order they were added for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { outputs.insert(std::pair(it->second, it->first)); } @@ -2449,23 +2468,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // So, simply return the ComputeCapability here. if (graph.NumberOfNodes() == 1 && GraphHasCtxNode(graph)) { SubGraph_t supported_node_vector = {{0}, true}; - std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0); + std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)), 0); result.push_back(ComputeCapability::Create(std::move(sub_graph))); return result; } // Generate unique kernel name for TRT graph - HashValue model_hash = TRTGenerateId(graph); + HashValue model_hash = TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)); // Get supported node list from TensorRT parser const int number_of_ort_nodes = graph.NumberOfNodes(); std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - std::vector filtered_nodes_vector; + std::set exclude_ops_set; + + /* + * There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. + * TRT EP automatically excludes DDS ops from running on TRT. + */ + if (trt_version_ >= 100000 && trt_version_ < 110000) { + exclude_ops_set.insert("NonMaxSuppression"); + exclude_ops_set.insert("NonZero"); + exclude_ops_set.insert("RoiAlign"); + LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable"; + } + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + bool new_subgraph = true; + + /* Iterate all the nodes and exclude the node if: + * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. + * 2. It's a DDS op. + */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); + bool supported_node = true; /* If current node is control flow op, we take different approach based on following four cases: * @@ -2477,29 +2516,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. */ if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { - auto sub_graphs = node->GetSubgraphs(); - if (sub_graphs.size() != 0) { - bool all_subgraphs_are_supported = true; - for (auto sub_graph : sub_graphs) { - // TRT EP should consider the empty subgraph is fully supported by TRT. - if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { - continue; - } - if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { - all_subgraphs_are_supported = false; - break; + auto supported_control_flow_op = [&](const Node* node) { + auto sub_graphs = node->GetSubgraphs(); + if (sub_graphs.size() != 0) { + for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } + if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { + // if not all its subgraphs are supported, we need to exclude this control flow op + return false; + } } } - if (!all_subgraphs_are_supported) { - // if not all its subgraphs are supported, we need to exclude this control flow op - continue; - } + return true; + }; + supported_node = supported_control_flow_op(node); + } + + // Exclude any ops, if applicable + if (exclude_ops_set.find(node->OpType()) != exclude_ops_set.end()) { + supported_node = false; + } + + if (supported_node) { + if (new_subgraph) { + parser_nodes_vector.emplace_back(); + // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser + parser_nodes_vector.back().second = false; + new_subgraph = false; } + parser_nodes_vector.back().first.emplace_back(index); + } else { + new_subgraph = true; } - filtered_nodes_vector.push_back(index); } - SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}}; bool early_termination = false; supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); if (early_termination) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index c057d48de4070..d3e0b0fba8891 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -329,6 +329,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool cuda_graph_enable_ = false; std::string cache_prefix_; bool engine_hw_compatible_ = false; + std::string op_types_to_exclude_; + + // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH + int32_t trt_version_; + int32_t cuda_version_; // The OrtAllocator object will be get during ep compute time // and should be kept for the lifetime of TRT EP object. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h index 95abcd1bad2b8..5a7b135fd92cd 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h @@ -520,7 +520,7 @@ void RemoveCachesByType(const std::string& root, std::string file_extension) { * compiled kernels, so the name must be unique and deterministic across models and sessions. * */ -HashValue TRTGenerateId(const GraphViewer& graph_viewer) { +HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version, std::string cuda_version) { HashValue model_hash = 0; // find the top level graph @@ -583,12 +583,11 @@ HashValue TRTGenerateId(const GraphViewer& graph_viewer) { #endif #ifdef CUDA_VERSION - hash_str(std::to_string(CUDA_VERSION)); + hash_str(cuda_version); #endif #if defined(NV_TENSORRT_MAJOR) && defined(NV_TENSORRT_MINOR) - std::string TRT_VERSION = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR); - hash_str(TRT_VERSION); + hash_str(trt_version); #endif model_hash = hash[0] | (uint64_t(hash[1]) << 32); diff --git a/onnxruntime/core/providers/tvm/custom_logging.cc b/onnxruntime/core/providers/tvm/custom_logging.cc deleted file mode 100644 index 1cabe81f8e87e..0000000000000 --- a/onnxruntime/core/providers/tvm/custom_logging.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -// -// Enable custom logging - this will cause TVM to use a custom implementation -// of tvm::runtime::detail::LogMessage. We use this to change the absolute -// file path to relative file path. - -#include -#include -#include -#include -#include -#include - -// TODO(agladyshev): Make conditional choice of sep for Windows and UNIX -std::string GetFileName(const std::string& file_path, char sep = '/') { - return {std::next(file_path.begin(), file_path.find_last_of(sep) + 1), - file_path.end()}; -} - -std::string GetTimedLogMessage(const std::string& file, int lineno, const std::string& message) { - std::stringstream sstream; - std::string file_name = GetFileName(file); - std::time_t t = std::time(nullptr); - sstream << "[" -#ifdef _WIN32 -// TODO(vvchernov): use #include instead of and localtime_s() approach for WIN32 -#pragma warning(disable : 4996) // _CRT_SECURE_NO_WARNINGS -#endif - << std::put_time(std::localtime(&t), "%H:%M:%S") -#ifdef _WIN32 -#pragma warning(default : 4996) -#endif - << "][TVM] " - << file_name << ":" << lineno << ": " + message; - return sstream.str(); -} - -namespace tvm { -namespace runtime { -namespace detail { -void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { - throw std::runtime_error(GetTimedLogMessage(file, lineno, message)); -} - -void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { - std::cerr << GetTimedLogMessage(file, lineno, message) << std::endl; -} - -} // namespace detail -} // namespace runtime -} // namespace tvm diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher.cc b/onnxruntime/core/providers/tvm/hash_alg/hasher.cc deleted file mode 100644 index bb62b41c7aa85..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/common.h" - -#include "hasher.h" // NOLINT(build/include_subdir) -#include "hasher_impl.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -namespace tvm { - -Hasher::Hasher(const std::string& hash_type) { - hasher_ = getHasherImpl(hash_type); -} - -std::string Hasher::hash(const char* src, size_t size) const { - return hasher_->hash(src, size); -} - -std::shared_ptr Hasher::getHasherImpl(const std::string& hash_type) { - if (hash_type == "sha256") { - return std::make_shared(); - } else { - ORT_NOT_IMPLEMENTED("Hasher was not implemented for hash type: ", hash_type); - } - return nullptr; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher.h b/onnxruntime/core/providers/tvm/hash_alg/hasher.h deleted file mode 100644 index 7b2f50def2e36..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_H_ - -#include -#include - -namespace onnxruntime { -namespace tvm { -class HasherImpl; - -class Hasher { - public: - Hasher() = delete; - explicit Hasher(const std::string& hash_type); - virtual ~Hasher() = default; - - std::string hash(const char* src, size_t size) const; - - private: - std::shared_ptr getHasherImpl(const std::string& hash_type); - - private: - std::shared_ptr hasher_; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_H_ diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.cc b/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.cc deleted file mode 100644 index 20aef66f3046a..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.cc +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "hasher_impl.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -namespace tvm { - -std::string HasherSHA256Impl::hash(const char* src, size_t size) const { - return hexdigest(src, size); -} - -void HasherSHA256Impl::digest(const Ipp8u* src, int size, Ipp8u* dst) { - IppStatus status = ippStsNoErr; - const IppsHashMethod* hashMethod = ippsHashMethod_SHA256(); - status = ippsHashMessage_rmf(src, size, dst, hashMethod); - if (ippStsNoErr != status) { - ORT_THROW("Can't get SHA-256..."); - } -} - -std::string HasherSHA256Impl::digest(const char* src, size_t size) { - const int digest_size_byte = IPP_SHA256_DIGEST_BITSIZE / 8; - auto dst = std::unique_ptr(new char[digest_size_byte]); - digest(reinterpret_cast(src), static_cast(size), reinterpret_cast(dst.get())); - return std::string(dst.get(), digest_size_byte); -} - -std::string HasherSHA256Impl::hexdigest(const char* src, size_t size) { - std::string byte_digest = digest(src, size); - std::stringstream ss; - for (char c : byte_digest) { - ss << std::hex << std::setw(2) << std::setfill('0') << (0xff & c); - } - return ss.str(); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.h b/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.h deleted file mode 100644 index 6c285dd0c78f3..0000000000000 --- a/onnxruntime/core/providers/tvm/hash_alg/hasher_impl.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_IMPL_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_IMPL_H_ - -#include -#include -#include -#include -#include - -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm { - -class HasherImpl { - public: - HasherImpl() = default; - virtual ~HasherImpl() = default; - - virtual std::string hash(const char* src, size_t size) const = 0; -}; - -class HasherSHA256Impl : public HasherImpl { - public: - HasherSHA256Impl() = default; - virtual ~HasherSHA256Impl() = default; - - std::string hash(const char* src, size_t size) const final; - - private: - static void digest(const Ipp8u* src, int size, Ipp8u* dst); - static std::string digest(const char* src, size_t size); - static std::string hexdigest(const char* src, size_t size); -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_HASH_ALG_HASHER_IMPL_H_ diff --git a/onnxruntime/core/providers/tvm/symbols.txt b/onnxruntime/core/providers/tvm/symbols.txt deleted file mode 100644 index 8d903acd9ea76..0000000000000 --- a/onnxruntime/core/providers/tvm/symbols.txt +++ /dev/null @@ -1 +0,0 @@ -OrtSessionOptionsAppendExecutionProvider_Tvm diff --git a/onnxruntime/core/providers/tvm/tvm_allocator.cc b/onnxruntime/core/providers/tvm/tvm_allocator.cc deleted file mode 100644 index 4b68f6432e8cc..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_allocator.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "tvm_allocator.h" -#include "core/framework/session_state.h" -#include "xpu_data_transfer.h" - -namespace onnxruntime { -namespace tvm { - -void* TVMAllocator::Alloc(size_t size) { - void* p = nullptr; - if (size > 0) { - DLDataType dl_type{kDLInt, 8, 1}; - int err = TVMDeviceAllocDataSpace(ctx, size, ::tvm::runtime::kAllocAlignment, dl_type, reinterpret_cast(&p)); - CHECK_EQ(err, 0); - return p; - } - return p; -} - -void TVMAllocator::Free(void* p) { - TVMDeviceFreeDataSpace(ctx, p); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_allocator.h b/onnxruntime/core/providers/tvm/tvm_allocator.h deleted file mode 100644 index f3ba544b8ac46..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_allocator.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_ALLOCATOR -#define TVM_ALLOCATOR - -#include "core/framework/allocator.h" -#include "tvm_common.h" - -namespace onnxruntime { -namespace tvm { - -#define TVM_ALLOC_ALIGN 128 - -class TVMAllocator : public IAllocator { - public: - TVMAllocator() : TVMAllocator(OrtMemoryInfo("TVM", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), - 0, - OrtMemTypeDefault)) {} - explicit TVMAllocator(const OrtMemoryInfo& info) - : IAllocator(info) { - switch (info.device.Type()) { - case OrtDevice::CPU: - ctx = {kDLCPU, info.device.Id()}; - break; - case OrtDevice::GPU: - ctx = {kDLVulkan, info.device.Id()}; - break; - default: - ORT_NOT_IMPLEMENTED("Unsupported device"); - break; - } - } - - virtual void* Alloc(size_t size) override; - virtual void Free(void* p) override; - DLDevice ctx; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_ALLOCATOR diff --git a/onnxruntime/core/providers/tvm/tvm_api.cc b/onnxruntime/core/providers/tvm/tvm_api.cc deleted file mode 100644 index e9a7d002e77c8..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_api.cc +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef _WIN32 -#include -#else -#include // glob(), globfree() -#endif -#include // memset() -#include -#include -#include - -#include -#include -#include - -#include "core/common/common.h" -#include - -#include "tvm_api.h" - -namespace onnxruntime { -namespace tvm { - -using TvmIntArray = ::tvm::Array<::tvm::Integer>; -using TvmPackedFunc = ::tvm::PackedFunc; -namespace tvm_rt = ::tvm::runtime; -namespace tvm_rt_vm = tvm_rt::vm; - -TvmModule TVMCompile(const TvmEPOptions& options, - const std::string& onnx_txt, - const std::string& model_path, - int opset, - const TVMTensorShapes& input_shapes) { - ::tvm::Array shapes; - for (size_t i = 0; i < input_shapes.size(); ++i) { - TvmIntArray shape; - for (auto& dim : input_shapes[i]) { - shape.push_back(::tvm::Integer(dim)); - } - shapes.push_back(shape); - } - - const TvmPackedFunc* compile = tvm_rt::Registry::Get("tvm_onnx_import_and_compile"); - ORT_ENFORCE(compile != nullptr, "Unable to retrieve 'tvm_onnx_import_and_compile'."); - TvmModule mod = (*compile)(TVMByteArray{onnx_txt.data(), onnx_txt.size()}, - model_path, - options.executor, - options.target, - options.target_host, - options.opt_level, - opset, - options.freeze_weights, - shapes, - options.to_nhwc, - options.tuning_file_path, - options.tuning_type); - ORT_ENFORCE(mod.get() != nullptr, "Compiled TVM Module is nullptr!"); - return mod; -} - -std::vector glob(const std::string& dir, const std::string& extension) { - std::vector filenames; -#ifdef _WIN32 - std::string pattern = dir + "/*." + extension; - WIN32_FIND_DATA fd; - HANDLE hFind = ::FindFirstFile(pattern.c_str(), &fd); - if (hFind != INVALID_HANDLE_VALUE) { - do { - if (!(fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { - filenames.push_back( - dir + - ToUTF8String(PathString{k_preferred_path_separator}) + - fd.cFileName); - } - } while (::FindNextFile(hFind, &fd)); - ::FindClose(hFind); - } -#else - glob_t glob_result; - memset(&glob_result, 0, sizeof(glob_result)); - - std::string pattern = dir + "/*." + extension; - int return_value = glob(pattern.c_str(), GLOB_TILDE, NULL, &glob_result); - ORT_ENFORCE(return_value == 0, "No results of glob for pattern: " + pattern); - - for (size_t i = 0; i < glob_result.gl_pathc; ++i) { - filenames.push_back(std::string(glob_result.gl_pathv[i])); - } - globfree(&glob_result); -#endif - return filenames; -} - -std::string filter_lib_paths(const std::vector& lib_paths, const std::string& lib_ext) { - std::string lib_path; - size_t counter = 0; - for (const auto& path : lib_paths) { - if (path.find("libtvm_runtime." + lib_ext) != std::string::npos || - path.find("liboctomized_model." + lib_ext) != std::string::npos) { - ++counter; - } else { - lib_path = path; - } - } - ORT_ENFORCE((lib_paths.size() - counter) == 1, "It should be only one shared library for model after filtering"); - - return lib_path; -} - -static std::unordered_map str2dev_type = { - {"llvm", 1}, - {"stackvm", 1}, - {"cpu", 1}, - {"c", 1}, - {"hybrid", 1}, - {"composite", 1}, - {"cuda", 2}, - {"nvptx", 2}, - {"cl", 4}, - {"opencl", 4}, - {"sdaccel", 4}, - {"aocl", 5}, - {"aocl_sw_emu", 5}, - {"vulkan", 7}, - {"metal", 8}, - {"vpi", 9}, - {"rocm", 10}, - {"ext_dev", 12}, - {"hexagon", 14}, - {"webgpu", 15}}; - -TvmModule TVMSoCompile(const TvmEPOptions& options) { - const std::string& dir = options.so_folder; -#ifdef _WIN32 - std::string lib_ext = "dll"; -#else - std::string lib_ext = "so"; -#endif - const std::string lib_path = filter_lib_paths(glob(dir, lib_ext), lib_ext); - const std::string consts_path = dir + - ToUTF8String(PathString{k_preferred_path_separator}) + - "consts"; - const auto& ro_paths = glob(dir, "ro"); - ORT_ENFORCE(ro_paths.size() == 1, "It should be only one ro file in folder: " + dir); - const std::string vm_exec_code_path = ro_paths[0]; - - TvmModule lib = TvmModule::LoadFromFile(lib_path); - - std::ifstream code(vm_exec_code_path, std::ios::binary); - std::stringstream ss; - ss << code.rdbuf(); - - auto exec_mod = tvm_rt_vm::Executable::Load(ss.str(), lib); - const tvm_rt_vm::Executable* tmp = exec_mod.as(); - auto exec = tvm_rt::GetObjectPtr(const_cast(tmp)); - exec->LoadLateBoundConstantsFromFile(consts_path); - - auto vm = tvm_rt::make_object(); - vm->LoadExecutable(exec); - - size_t pos = options.target.find(" "); - const std::string dev_type_str = options.target.substr(0, pos); - ORT_ENFORCE(!dev_type_str.empty(), "Device was not found in target string"); - uint64_t dev_type = str2dev_type[dev_type_str]; - const uint64_t cpu_type = str2dev_type["cpu"]; - // Initialize the VM for the specified device. If the device is not a CPU, - // We'll need to add a CPU context to drive it. - int arity; - if (dev_type == cpu_type) { - arity = 3; - } else { - arity = 6; - } - uint64_t alloc_type = uint64_t(tvm_rt_vm::AllocatorType::kPooled); - // TODO(vchernov): multiple devices using and using device with specified id are not supported - // Always use the first device of the specified type. - uint64_t device_id = 0; - std::vector init_vals(arity); - std::vector codes(arity); - tvm_rt::TVMArgsSetter setter(init_vals.data(), codes.data()); - setter(0, dev_type); - setter(1, device_id); - setter(2, alloc_type); - // Also initialize a CPU device context. - if (dev_type != cpu_type) { - setter(3, cpu_type); - setter(4, device_id); - setter(5, alloc_type); - } - tvm_rt::TVMRetValue rv; - // Call the packed func with the init arguments. - vm->GetFunction("init", nullptr).CallPacked(tvm_rt::TVMArgs(init_vals.data(), codes.data(), arity), &rv); - - return TvmModule(vm); -} - -void TVMSetInputs(TvmModule& mod, - std::vector& inds, - std::vector& inputs) { - TvmPackedFunc set_input = mod.GetFunction("set_input", false); - TvmPackedFunc set_input_zero_copy = mod.GetFunction("set_input_zero_copy", false); - for (size_t i = 0; i < inds.size(); ++i) { - if (reinterpret_cast(inputs[i].data) % tvm_rt::kAllocAlignment == 0) { - set_input_zero_copy(inds[i], &inputs[i]); - } else { - set_input(inds[i], &inputs[i]); - } - } -} - -void TVM_VM_SetInputs(TvmModule& mod, - std::vector& inds, - std::vector& inputs) { - size_t num_total_args = inputs.size() + 1; - std::vector tvm_values(num_total_args); - std::vector tvm_type_codes(num_total_args); - ::tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - const std::string func_name = "main"; - setter(0, func_name.c_str()); - for (size_t k = 0; k < num_total_args - 1; ++k) { - setter(inds[k] + 1, &inputs[k]); - } - - TvmPackedFunc set_input = mod.GetFunction("set_input", false); - ::tvm::runtime::TVMRetValue rv; - set_input.CallPacked(::tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), gsl::narrow_cast(num_total_args)), &rv); -} - -void TVMSetOutputsZeroCopy(TvmModule& mod, - std::vector& outputs) { - TvmPackedFunc set_output = mod.GetFunction("set_output_zero_copy", false); - for (size_t i = 0; i < outputs.size(); ++i) { - set_output(i, &outputs[i]); - } -} - -void TVM_VM_SetOutputsZeroCopy(TvmModule& mod, - std::vector& outputs) { - size_t num_total_args = outputs.size() + 1; - std::vector tvm_values(num_total_args); - std::vector tvm_type_codes(num_total_args); - tvm_rt::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - const std::string func_name = "main"; - setter(0, func_name.c_str()); - for (size_t k = 0; k < num_total_args - 1; ++k) { - setter(k + 1, &outputs[k]); - } - - TvmPackedFunc set_output = mod.GetFunction("set_outputs", false); - tvm_rt::TVMRetValue rv; - set_output.CallPacked(tvm_rt::TVMArgs(tvm_values.data(), tvm_type_codes.data(), gsl::narrow_cast(num_total_args)), &rv); -} - -void TVMGetOutputs(TvmModule& mod, - std::vector& outputs) { - TvmPackedFunc get_output = mod.GetFunction("get_output", false); - for (size_t i = 0; i < outputs.size(); ++i) { - get_output(i, &outputs[i]); - } -} - -void TVM_VM_GetOutputs(TvmModule& mod, - std::vector& outputs) { - TvmPackedFunc get_output = mod.GetFunction("get_output", false); - for (size_t i = 0; i < outputs.size(); ++i) { - // TODO(vvchernov): think about improvement of memory management - tvm_rt::NDArray output_array = get_output(i); - output_array.CopyTo(&outputs[i]); - } -} - -void TVMGetOutputShapes(TvmModule& mod, - TVMTensorShapes& output_shapes) { - size_t size = output_shapes.size(); - TvmPackedFunc get_output = mod.GetFunction("get_output", false); - for (size_t i = 0; i < size; ++i) { - tvm_rt::NDArray output_array = get_output(i); - tvm_rt::ShapeTuple shape_tuple = output_array.Shape(); - size_t dims_num = shape_tuple.size(); - TensorShapeVector dims; - for (size_t j = 0; j < dims_num; ++j) { - dims.push_back(int64_t(shape_tuple[j])); - } - output_shapes[i] = dims; - } -} - -void TVMRun(TvmModule& mod) { - TvmPackedFunc run = mod.GetFunction("run", false); - ORT_ENFORCE(run != nullptr, "Unable to retrieve graph executor run."); - run(); -} - -void TVM_VM_Run(TvmModule& mod) { - TvmPackedFunc run = mod.GetFunction("invoke", false); - ORT_ENFORCE(run != nullptr, "Unable to retrieve virtual machine invoke."); - run("main"); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_api.h b/onnxruntime/core/providers/tvm/tvm_api.h deleted file mode 100644 index bbf05f4fc06d9..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_api.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_API_H -#define TVM_API_H - -#include -#include - -#include "tvm_common.h" -#include "tvm_defaults.h" -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -TvmModule TVMCompile(const TvmEPOptions& options, - const std::string& onnx_txt, - const std::string& model_path, - int opset, - const TVMTensorShapes& input_shapes); -TvmModule TVMSoCompile(const TvmEPOptions& options); - -void TVMSetInputs(TvmModule& mod, std::vector& inds, std::vector& inputs); -void TVM_VM_SetInputs(TvmModule& mod, std::vector& inds, std::vector& inputs); -void TVMSetOutputsZeroCopy(TvmModule& mod, std::vector& outputs); -void TVM_VM_SetOutputsZeroCopy(TvmModule& mod, std::vector& outputs); -void TVMGetOutputs(TvmModule& mod, std::vector& outputs); -void TVM_VM_GetOutputs(TvmModule& mod, std::vector& outputs); -void TVMGetOutputShapes(TvmModule& mod, - TVMTensorShapes& output_shapes); -void TVMRun(TvmModule& mod); -void TVM_VM_Run(TvmModule& mod); - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_API_H diff --git a/onnxruntime/core/providers/tvm/tvm_common.h b/onnxruntime/core/providers/tvm/tvm_common.h deleted file mode 100644 index 68e3b6496328a..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_common.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_COMMON_H -#define TVM_COMMON_H - -#include -#include - -#include -#include -#include - -namespace onnxruntime { -namespace tvm { - -using TvmModule = ::tvm::runtime::Module; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_COMMON_H diff --git a/onnxruntime/core/providers/tvm/tvm_compiler.cc b/onnxruntime/core/providers/tvm/tvm_compiler.cc deleted file mode 100644 index 8f4e7e7de9a36..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_compiler.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "tvm_compiler.h" -#include "tvm_api.h" - -namespace onnxruntime { -namespace tvm { - -auto TVMCompilerBase::operator()(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) -> ModulePtr { - if (mod_) { - return mod_; - } - - mod_ = std::make_shared(); - this->compileTVMModule(options, input_shapes); - - return mod_; -} - -TVMCompiler::TVMCompiler(std::string&& onnx_model_str, - const std::string& model_path, - int opset) : onnx_model_str_(std::move(onnx_model_str)), - model_path_(model_path), - opset_(opset) { -} - -void TVMCompiler::compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) { - *mod_ = tvm::TVMCompile(options, - onnx_model_str_, - model_path_, - opset_, - input_shapes); - - onnx_model_str_.clear(); -} - -void TVMSoCompiler::compileTVMModule(const TvmEPOptions& options, - [[maybe_unused]] const TVMTensorShapes& input_shapes) { - *mod_ = tvm::TVMSoCompile(options); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_compiler.h b/onnxruntime/core/providers/tvm/tvm_compiler.h deleted file mode 100644 index bfc73d67aa07f..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_compiler.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_COMPILER_H -#define TVM_COMPILER_H - -#include -#include - -#include "tvm_common.h" -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -class TVMCompilerBase { - public: - using ModulePtr = std::shared_ptr; - - TVMCompilerBase() = default; - virtual ~TVMCompilerBase() = default; - - ModulePtr operator()(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes); - - virtual void compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) = 0; - - protected: - ModulePtr mod_; -}; - -class TVMCompiler : public TVMCompilerBase { - public: - TVMCompiler() = delete; - ~TVMCompiler() = default; - - TVMCompiler(std::string&& onnx_model_str, - const std::string& model_path, - int opset); - - void compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) final; - - private: - std::string onnx_model_str_; - std::string model_path_; - int opset_; -}; - -class TVMSoCompiler : public TVMCompilerBase { - public: - TVMSoCompiler() = default; - ~TVMSoCompiler() = default; - - void compileTVMModule(const TvmEPOptions& options, - const TVMTensorShapes& input_shapes) final; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_COMPILER_H diff --git a/onnxruntime/core/providers/tvm/tvm_defaults.h b/onnxruntime/core/providers/tvm/tvm_defaults.h deleted file mode 100644 index 197d1f363c50d..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_defaults.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_DEFAULTS_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_DEFAULTS_H_ - -#include - -namespace onnxruntime { -namespace tvm { - -namespace env_vars { -static const std::string kDumpSubgraphs = "ORT_TVM_DUMP_SUBGRAPHS"; -} // namespace env_vars - -constexpr const char* default_executor_type = "vm"; -constexpr const char* vm_executor_type = "vm"; -constexpr const char* graph_executor_type = "graph"; - -constexpr const char* default_target_str = "llvm"; -constexpr const char* llvm_target_str = "llvm"; - -constexpr const char* cpu_target_str = "cpu"; -constexpr const char* gpu_target_str = "gpu"; - -constexpr const char* default_tuning_type = "AutoTVM"; -constexpr const char* autotvm_tuning_type = "AutoTVM"; -constexpr const char* ansor_tuning_type = "Ansor"; - -constexpr const unsigned int default_opt_level = 3; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_DEFAULTS_H_ diff --git a/onnxruntime/core/providers/tvm/tvm_ep_options.cc b/onnxruntime/core/providers/tvm/tvm_ep_options.cc deleted file mode 100644 index 70e99833cd78b..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_ep_options.cc +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "core/common/common.h" -#include "core/common/cpuid_info.h" -#include "core/framework/provider_options_utils.h" - -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -namespace provider_option_names { -constexpr const char* kExecutor = "executor"; -constexpr const char* kSoFolder = "so_folder"; -constexpr const char* kCheckHash = "check_hash"; -constexpr const char* kHashFilePath = "hash_file_path"; -constexpr const char* kTarget = "target"; -constexpr const char* kTargetHost = "target_host"; -constexpr const char* kOptLevel = "opt_level"; -constexpr const char* kFreezeWeights = "freeze_weights"; -constexpr const char* kSetOutputZeroCopy = "set_output_zero_copy"; -constexpr const char* kToNHWC = "to_nhwc"; -constexpr const char* kTuningFilePath = "tuning_file_path"; -constexpr const char* kTuningType = "tuning_type"; -constexpr const char* kInputNames = "input_names"; -constexpr const char* kInputShapes = "input_shapes"; - -static const std::unordered_set valid_keys{ - std::string{kExecutor}, - std::string{kSoFolder}, - std::string{kCheckHash}, - std::string{kHashFilePath}, - std::string{kTarget}, - std::string{kTargetHost}, - std::string{kOptLevel}, - std::string{kFreezeWeights}, - std::string{kSetOutputZeroCopy}, - std::string{kToNHWC}, - std::string{kTuningFilePath}, - std::string{kTuningType}, - std::string{kInputNames}, - std::string{kInputShapes}}; - -} // namespace provider_option_names - -size_t split(const std::string& src, std::vector& dst, char ch) { - dst.clear(); - - size_t pos = src.find(ch); - size_t initialPos = 0; - while (pos != std::string::npos) { - dst.push_back(src.substr(initialPos, pos - initialPos)); - initialPos = pos + 1; - - pos = src.find(ch, initialPos); - } - dst.push_back(src.substr(initialPos, std::min(pos, src.size()) - initialPos + 1)); - - return dst.size(); -} - -TvmEPOptions TvmEPOptionsHelper::FromOptionsString(const char* opt_str) { - std::string settings{opt_str}; - ProviderOptions options; - if (!settings.empty()) { - const std::string& str = settings; - - // tokenize settings - std::regex reg("\\s*,\\s*"); - std::sregex_token_iterator iter(str.begin(), str.end(), reg, -1); - std::sregex_token_iterator iter_end; - std::vector pairs(iter, iter_end); - - ORT_ENFORCE(pairs.size() > 0); - - for (const auto& pair : pairs) { - auto pos_colon = pair.find(':'); - ORT_ENFORCE(pos_colon != std::string::npos, "Invalid key value pair."); - std::string key = pair.substr(0, pos_colon); - std::string value = pair.substr(pos_colon + 1); - - // trim leading and trailing spaces from key/value - key = whitespace_trimming(key); - value = whitespace_trimming(value); - - // Check keys of obtained options - if (tvm::provider_option_names::valid_keys.count(key) == 0) { - ORT_NOT_IMPLEMENTED("TvmOptions: unknown option (", key, ")"); - } - - options[key] = value; - } - } - - return TvmEPOptionsHelper::FromProviderOptions(options); -} - -std::string TvmEPOptionsHelper::whitespace_trimming(const std::string& str) { - const std::string WHITESPACE = " \n\r\t\f\v"; - size_t start = str.find_first_not_of(WHITESPACE); - if (start == std::string::npos) { - return ""; - } else { - size_t end = str.find_last_not_of(WHITESPACE); - ORT_ENFORCE(end != std::string::npos); - return str.substr(start, end + 1); - } -} - -TvmEPOptions TvmEPOptionsHelper::FromProviderOptions(const ProviderOptions& pr_options) { - TvmEPOptions options{}; - - ORT_THROW_IF_ERROR( - ProviderOptionsParser{} - .AddAssignmentToReference(tvm::provider_option_names::kExecutor, options.executor) - .AddAssignmentToReference(tvm::provider_option_names::kSoFolder, options.so_folder) - .AddAssignmentToReference(tvm::provider_option_names::kCheckHash, options.check_hash) - .AddAssignmentToReference(tvm::provider_option_names::kHashFilePath, options.hash_file_path) - .AddAssignmentToReference(tvm::provider_option_names::kTarget, options.target) - .AddAssignmentToReference(tvm::provider_option_names::kTargetHost, options.target_host) - .AddAssignmentToReference(tvm::provider_option_names::kOptLevel, options.opt_level) - .AddAssignmentToReference(tvm::provider_option_names::kFreezeWeights, options.freeze_weights) - .AddAssignmentToReference(tvm::provider_option_names::kSetOutputZeroCopy, options.set_output_zero_copy) - .AddAssignmentToReference(tvm::provider_option_names::kToNHWC, options.to_nhwc) - .AddAssignmentToReference(tvm::provider_option_names::kTuningFilePath, options.tuning_file_path) - .AddAssignmentToReference(tvm::provider_option_names::kTuningType, options.tuning_type) - .AddAssignmentToReference(tvm::provider_option_names::kInputNames, options.input_names_str) - .AddAssignmentToReference(tvm::provider_option_names::kInputShapes, options.input_shapes_str) - .Parse(pr_options)); - - optionsPostprocess(options); - - return options; -} - -void TvmEPOptionsHelper::optionsPostprocess(TvmEPOptions& options) { - setInputShapes(options); - targetPostprocess(options.target); - targetHostPostprocess(options.target, options.target_host); - optLevelPostprocess(options.opt_level); -} - -bool TvmEPOptionsHelper::checkCPUTarget(const std::string& target) { - bool check = target.find("llvm") != std::string::npos; - return check; -} - -bool TvmEPOptionsHelper::checkGPUTarget(const std::string& target) { - bool check = (target.find("cuda") != std::string::npos || - target.find("opencl") != std::string::npos || - target.find("metal") != std::string::npos || - target.find("vulkan") != std::string::npos); - return check; -} - -void TvmEPOptionsHelper::setInputShapes(TvmEPOptions& options) { - if (options.input_names_str.empty() && options.input_shapes_str.empty()) - return; - ORT_ENFORCE(!options.input_names_str.empty() && !options.input_shapes_str.empty(), - "Both provider options \"input_names\" and \"input_shapes\" should be empty or full"); - - std::vector name_set; - std::string trimmed_names = whitespace_trimming(options.input_names_str); - size_t inp_tensors_num = split(trimmed_names, name_set, ' '); - ORT_ENFORCE(inp_tensors_num, "There is no any input tensor names!"); - - std::string trimmed_shapes = whitespace_trimming(options.input_shapes_str); - size_t end_pos = trimmed_shapes.find_last_of(']'); - ORT_ENFORCE(end_pos != std::string::npos, "Invalid string for input shapes. Symbol ] is not found"); - ORT_ENFORCE(end_pos == (trimmed_shapes.size() - 1), - "Invalid string for input shapes. Symbol ] should be last after whitespace trimming"); - - std::vector shape_set; - split(trimmed_shapes, shape_set, ']'); - shape_set.pop_back(); - ORT_ENFORCE(shape_set.size() == inp_tensors_num, - "Number of shapes is not the same as number of input tensor names"); - - for (size_t i = 0; i < inp_tensors_num; ++i) { - size_t pos = shape_set[i].find('['); - ORT_ENFORCE(pos != std::string::npos, "There is no symbol [ as pair for ]"); - std::string numbers = shape_set[i].substr(pos + 1); - std::vector number_set; - ORT_ENFORCE(split(numbers, number_set, ' '), "There is no any number between [ and ] symbols"); - - TensorShapeVector dims; - for (const auto& number : number_set) { - dims.push_back(std::stoi(number)); - } - - options.input_shapes[name_set[i]] = dims; - } -} - -void TvmEPOptionsHelper::targetPostprocess(std::string& target) { - if (target == tvm::cpu_target_str || - target == tvm::llvm_target_str) { - ProcessCPUTarget(target); - } else if (target == tvm::gpu_target_str) { - ProcessGPUTarget(); - } else if (target.empty()) { - ORT_NOT_IMPLEMENTED("target option is empty!"); - } else { - // TODO(vvchernov): extend mechanism of auto-definition of target - // target is gotten from option set up by client - } -} - -void TvmEPOptionsHelper::ProcessCPUTarget(std::string& target) { - const auto& cpu_id_info = CPUIDInfo::GetCPUIDInfo(); - // auto detect from CPU ID - if (cpu_id_info.HasAVX512Skylake()) { - target = tvm::cpu_targets::LLVM_TARGET_SKYLAKE_AVX512; - } else if (cpu_id_info.HasAVX512f()) { - target = tvm::cpu_targets::LLVM_TARGET_AVX512; - } else if (cpu_id_info.HasAVX2()) { - target = tvm::cpu_targets::LLVM_TARGET_AVX2; - } else if (cpu_id_info.HasAVX()) { - target = tvm::cpu_targets::LLVM_TARGET_AVX; - } else { - // TODO(vvchernov): extend mechanism of auto-definition of cpu target - target = tvm::llvm_target_str; - } -} - -void TvmEPOptionsHelper::ProcessGPUTarget() { - ORT_NOT_IMPLEMENTED("GPU target auto-defenition is not implemented now!"); -} - -void TvmEPOptionsHelper::targetHostPostprocess(const std::string& target, std::string& target_host) { - if ((target_host == tvm::cpu_target_str || - target_host == tvm::llvm_target_str) && - target_host != target) { - target_host = target; - } else if (target_host.empty()) { - target_host = target; - } else { - // TODO(vvchernov): extend mechanism of auto-definition of target host - // target host is gotten from option set up by client - } -} - -void TvmEPOptionsHelper::optLevelPostprocess(unsigned int& opt_level) { - if (opt_level < 1) { - opt_level = tvm::default_opt_level; - } -} - -std::ostream& operator<<(std::ostream& out, const TvmEPOptions& options) { - out << "TVM EP options:\n" - << "executor type: " << options.executor << "\n" - << "so_folder: " << options.so_folder << "\n" - << "check_hash: " << options.check_hash << "\n" - << "hash_file_path: " << options.hash_file_path << "\n" - << "target: " << options.target << "\n" - << "target_host: " << options.target_host << "\n" - << "opt level: " << options.opt_level << "\n" - << "freeze weights: " << options.freeze_weights << "\n" - << "set_output_zero_copy: " << options.set_output_zero_copy << "\n" - << "tuning file path: " << options.tuning_file_path << "\n" - << "tuning type: " << options.tuning_type << "\n" - << "convert layout to NHWC: " << options.to_nhwc << "\n" - << "input tensor names: " << options.input_names_str << "\n" - << "input tensor shapes: " << options.input_shapes_str; - return out; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_ep_options.h b/onnxruntime/core/providers/tvm/tvm_ep_options.h deleted file mode 100644 index 0f2db30a3b304..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_ep_options.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_EXECUTION_PROVIDER_OPTIONS_H -#define TVM_EXECUTION_PROVIDER_OPTIONS_H - -#include -#include -#include -#include - -#include "core/framework/provider_options.h" -#include "core/framework/tensor_shape.h" - -#include "tvm_defaults.h" - -namespace onnxruntime { - -namespace tvm { -namespace cpu_targets { -// TODO(vvchernov): avx and avx512 need more careful differentiation for target -const std::string LLVM_TARGET_AVX = "llvm -mcpu=corei7-avx"; -const std::string LLVM_TARGET_AVX2 = "llvm -mcpu=core-avx2"; -const std::string LLVM_TARGET_SKYLAKE_AVX512 = "llvm -mcpu=skylake-avx512"; -const std::string LLVM_TARGET_AVX512 = "llvm -mcpu=skylake-avx512"; -} // namespace cpu_targets - -using TVMTensorShapes = std::vector; -using TVMInputShapes = std::unordered_map; -using InputsInfoMap = std::unordered_map; - -// Information needed to construct an TVM execution provider. -struct TvmEPOptions { - std::string executor{tvm::default_executor_type}; - std::string so_folder{""}; - bool check_hash = false; - std::string hash_file_path{""}; - std::string target{tvm::default_target_str}; - std::string target_host{tvm::default_target_str}; - unsigned int opt_level{tvm::default_opt_level}; - bool freeze_weights = true; - bool to_nhwc = false; - bool set_output_zero_copy = true; - std::string tuning_file_path{""}; - std::string tuning_type{tvm::default_tuning_type}; - std::string input_names_str{""}; - std::string input_shapes_str{""}; - TVMInputShapes input_shapes{}; - TVMTensorShapes output_shapes{}; -}; - -std::ostream& operator<<(std::ostream& out, const TvmEPOptions& options); - -class TvmEPOptionsHelper { - public: - static TvmEPOptions FromOptionsString(const char* options); - static TvmEPOptions FromProviderOptions(const ProviderOptions& options); - static std::string whitespace_trimming(const std::string& str); - - static bool checkCPUTarget(const std::string& target); - static bool checkGPUTarget(const std::string& target); - - private: - static void optionsPostprocess(TvmEPOptions& options); - static void setInputShapes(TvmEPOptions& options); - static void targetPostprocess(std::string& target); - static void ProcessCPUTarget(std::string& target); - static void ProcessGPUTarget(); - static void targetHostPostprocess(const std::string& target, std::string& target_host); - static void optLevelPostprocess(unsigned int& opt_level); -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_EXECUTION_PROVIDER_OPTIONS_H diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc b/onnxruntime/core/providers/tvm/tvm_execution_provider.cc deleted file mode 100644 index 61ee8f899dbf1..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.cc +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "core/common/common.h" -#include "core/framework/execution_provider.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/kernel_registry.h" -#include "core/framework/compute_capability.h" -#include "core/graph/graph_proto_serializer.h" -#include "core/platform/env.h" -#include "core/graph/model.h" - -#include "tvm_execution_provider.h" -#include "xpu_data_transfer.h" -#include "tvm_allocator.h" -#include "tvm_utils.h" -#include "tvm_api.h" - -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace tvm { - -// Information to construct kernel function state. -struct TVMFuncState { - AllocateFunc allocate_func = nullptr; - DestroyFunc release_func = nullptr; - AllocatorHandle allocator = nullptr; - std::shared_ptr compiler = nullptr; -}; - -TvmExecutionProvider::TvmExecutionProvider(const TvmEPOptions& options) - : IExecutionProvider{kTvmExecutionProvider}, - options_{options} { - AllocatorCreationInfo default_memory_info = {[](int) { - return std::make_unique(); - }, - 0, false}; - // Get environment variables - const Env& env_instance = Env::Default(); - - const std::string dump_subgraphs_env = env_instance.GetEnvironmentVar(env_vars::kDumpSubgraphs); - if (!dump_subgraphs_env.empty()) { - dump_subgraphs_ = std::stoi(dump_subgraphs_env) != 0; - } -} - -std::vector TvmExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo default_memory_info = {[](int) { - return std::make_unique(); - }, - 0, false}; - return std::vector{CreateAllocator(default_memory_info)}; // TODO(leca): REVIEW: will CPU EP also use this? -} - -TvmExecutionProvider::~TvmExecutionProvider() {} - -std::vector> -TvmExecutionProvider::GetCapability(const GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { - std::vector> result; - if (graph_viewer.IsSubgraph()) { - return result; - } - - const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); - - std::unordered_set required_initializers; - const std::vector& sorted_nodes = graph_viewer.GetNodesInTopologicalOrder(); - std::unique_ptr sub_graph = std::make_unique(); - for (auto& node_idx : sorted_nodes) { - graph_viewer.GetNode(node_idx)->ForEachDef([&required_initializers, &init_tensors](const NodeArg& node_arg, bool is_input) { - if(is_input && init_tensors.count(node_arg.Name())) { - required_initializers.insert(node_arg.Name()); - } }, true); - } - - auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "TVMStandalone"; - meta_def->domain = "StandaloneTest"; - std::vector inputs; - std::vector outputs; - - for (auto& nodeArgPtr : graph_viewer.GetInputs()) { - inputs.push_back(nodeArgPtr->Name()); - } - - for (auto& name : required_initializers) { - inputs.push_back(name); - } - - for (auto& nodeArgPtr : graph_viewer.GetOutputs()) { - outputs.push_back(nodeArgPtr->Name()); - } - meta_def->inputs = inputs; - meta_def->outputs = outputs; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - sub_graph->SetMetaDef(std::move(meta_def)); - sub_graph->nodes = sorted_nodes; - result.push_back( - std::make_unique(std::move(sub_graph))); - return result; -} - -common::Status TvmExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - printOptions(); - for (auto& fused_node_graph : fused_nodes_and_graphs) { - const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; - const Node& fused_node = fused_node_graph.fused_node; - const std::string func_name = fused_node.Name(); - Model model(graph_body_viewer.Name(), true, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), graph_body_viewer.DomainToVersionMap(), - std::vector(), *GetLogger()); - ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); - // TVM EP is using static lib approach, so invoke serializer directly. - GraphViewerToProto(graph_body_viewer, *model_proto.mutable_graph(), true, true); - auto opset = model_proto.add_opset_import(); - opset->set_domain(kOnnxDomain); - opset->set_version(graph_body_viewer.DomainToVersionMap().at(kOnnxDomain)); - - std::string onnx_model_str; - model_proto.SerializeToString(&onnx_model_str); - compilers_[func_name] = std::make_shared(std::move(onnx_model_str), - ToUTF8String(fused_node.ModelPath().ToPathString()), - int(opset->version())); - InputsInfoMap all_input_shapes; - auto mod = compileModel(func_name, graph_body_viewer, all_input_shapes); - - std::vector output_tensors; - prepareOutputTensors(mod, output_tensors, graph_body_viewer.GetOutputs().size()); - - runners_[func_name] = std::make_shared(options_, mod, all_input_shapes, output_tensors); - - if (dump_subgraphs_) { - std::fstream dump("/tmp/" + func_name + ".onnx", - std::ios::out | std::ios::trunc | std::ios::binary); - model_proto.SerializeToOstream(&dump); - } - - // TODO(vvchernov): implement ops checking and mechanism of gracefully passing the responsibility to other EPs - // if the checking fails due to unsupported op(s) - NodeComputeInfo compute_info = prepareComputeInfo(func_name); - - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} - -std::unique_ptr TvmExecutionProvider::GetDataTransfer() const { - // TODO(vvchernov): target or target host? - if (TvmEPOptionsHelper::checkGPUTarget(options_.target)) { - return std::make_unique(); - } else if (TvmEPOptionsHelper::checkCPUTarget(options_.target)) { - return std::make_unique(); - } else { - ORT_NOT_IMPLEMENTED("TVM GetDataTransfer is not implemented for target ", options_.target); - } -} - -void TvmExecutionProvider::printOptions() { - LOGS(*GetLogger(), INFO) << options_; -} - -std::shared_ptr TvmExecutionProvider::compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& all_input_shapes) { - all_input_shapes.clear(); - - TVMTensorShapes input_shapes; - if (options_.freeze_weights) { - setInputShapesForFreezedNN(graph_viewer, input_shapes, all_input_shapes); - } else { - setInputShapesForUnfreezedNN(graph_viewer, input_shapes, all_input_shapes); - } - - std::shared_ptr mod = compilers_[func_name]->operator()(options_, input_shapes); - - return mod; -} - -void TvmExecutionProvider::setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - if (!graph_viewer.IsInitializedTensor(node->Name())) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - input_shapes.emplace_back(shape); - } - } -} - -void TvmExecutionProvider::setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - if (!graph_viewer.IsInitializedTensor(node->Name())) { - input_shapes.emplace_back(shape); - } - } -} - -TensorShapeVector TvmExecutionProvider::getInputShape(const NodeArg* node) { - TensorShapeVector shape; - const auto& node_name = node->Name(); - if (!options_.input_shapes.empty() && - options_.input_shapes.count(node_name)) { - shape = options_.input_shapes[node_name]; - } else { - shape = convertTensorShape(*node->Shape()); - } - - return shape; -} - -TensorShapeVector TvmExecutionProvider::convertTensorShape(const TensorShapeProto& shape_proto) { - TensorShape ort_shape = utils::GetTensorShapeFromTensorShapeProto(shape_proto); - size_t dims = ort_shape.NumDimensions(); - - TensorShapeVector shape(dims); - for (size_t j = 0; j < dims; ++j) { - int64_t dim = int64_t(ort_shape[j]); - ORT_ENFORCE(dim > 0, "Input dimension is not positive value (dim = " + std::to_string(dim) + "). " + - "Please use provider options to setup input_names and input_shapes"); - shape[j] = dim; - } - - return shape; -} - -void TvmExecutionProvider::prepareOutputTensors(const std::shared_ptr& mod, - std::vector& output_tensors, - size_t num) { - ORT_ENFORCE(mod != nullptr, "TVM module is not compiled"); - output_tensors.clear(); - options_.output_shapes.clear(); - options_.output_shapes.resize(num); - - if (options_.executor != "vm") { - TVMGetOutputShapes(*mod, options_.output_shapes); - } - - for (auto& output_shape : options_.output_shapes) { - DLTensor t; - // Draft for tensor, correct data is defined during inference - t.strides = nullptr; - t.byte_offset = 0; - t.data = nullptr; - if (options_.executor == "vm") { - t.ndim = 0; - t.shape = nullptr; - } else { - t.ndim = output_shape.size(); - t.shape = output_shape.data(); - } - - output_tensors.push_back(t); - } -} - -NodeComputeInfo TvmExecutionProvider::prepareComputeInfo(const std::string& func_name) { - NodeComputeInfo compute_info; - compute_info.create_state_func = std::bind(&TvmExecutionProvider::createStateFunc, - this, - std::placeholders::_1, - std::placeholders::_2); - - compute_info.release_state_func = [](FunctionState state) { - if (state) - delete static_cast(state); - }; - - compute_info.compute_func = *runners_[func_name].get(); - - return compute_info; -} - -int TvmExecutionProvider::createStateFunc(ComputeContext* context, FunctionState* state) { - auto* state_ptr = new TVMFuncState(); - *state_ptr = {context->allocate_func, - context->release_func, - context->allocator_handle, - compilers_[context->node_name]}; - // TODO(vvchernov): Who and when release state? - *state = state_ptr; - return 0; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_execution_provider.h b/onnxruntime/core/providers/tvm/tvm_execution_provider.h deleted file mode 100644 index baa46c593fa07..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_execution_provider.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_EXECUTION_PROVIDER_H -#define TVM_EXECUTION_PROVIDER_H - -#include -#include -#include -#include - -#include "core/common/logging/logging.h" -#include "core/framework/execution_provider.h" -#include - -#include "tvm_compiler.h" -#include "tvm_runner.h" - -namespace onnxruntime { -class Graph; -class NodeArg; -namespace tvm { - -class TvmExecutionProvider : public IExecutionProvider { - using Compiler = TVMCompilerBase; - using Compilers = std::unordered_map>; - using Runner = TVMRunner; - using Runners = std::unordered_map>; - - public: - explicit TvmExecutionProvider(const TvmEPOptions& options); - virtual ~TvmExecutionProvider(); - - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; - - common::Status Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; - std::unique_ptr GetDataTransfer() const override; - std::vector CreatePreferredAllocators() override; - - private: - void printOptions(); - std::shared_ptr compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& inputs_info); // NOLINT - void setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - void setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - TensorShapeVector getInputShape(const NodeArg* node); - TensorShapeVector convertTensorShape(const ONNX_NAMESPACE::TensorShapeProto& shape_proto); - void prepareOutputTensors(const std::shared_ptr& mod, - std::vector& output_tensors, size_t num); // NOLINT - NodeComputeInfo prepareComputeInfo(const std::string& func_name); - int createStateFunc(ComputeContext*, FunctionState*); - - private: - TvmEPOptions options_; - Compilers compilers_; - Runners runners_; - bool dump_subgraphs_ = false; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_EXECUTION_PROVIDER_H diff --git a/onnxruntime/core/providers/tvm/tvm_provider_factory.cc b/onnxruntime/core/providers/tvm/tvm_provider_factory.cc deleted file mode 100644 index d83fd8ee4d1cb..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_provider_factory.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "core/providers/tvm/tvm_provider_factory.h" -#include "core/session/abi_session_options_impl.h" - -#include "tvm_execution_provider.h" -#include "tvm_provider_factory_creator.h" -#include "tvm_so_execution_provider.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { - -struct TvmProviderFactory : IExecutionProviderFactory { - TvmProviderFactory(const tvm::TvmEPOptions& options) : options_{options} {} - ~TvmProviderFactory() = default; - - std::unique_ptr CreateProvider() override { - std::unique_ptr provider = nullptr; - if (options_.so_folder != "") { - ORT_ENFORCE(options_.executor == "vm", - "Only virtual machine module is compiled from shared lib and dependences!"); - provider = std::move(std::make_unique(options_)); - } else { - provider = std::move(std::make_unique(options_)); - } - - return provider; - } - - private: - tvm::TvmEPOptions options_; -}; - -std::shared_ptr TVMProviderFactoryCreator::Create(const char* opt_str) { - tvm::TvmEPOptions options = tvm::TvmEPOptionsHelper::FromOptionsString(opt_str); - return std::make_shared(options); -} - -std::shared_ptr TVMProviderFactoryCreator::Create(const tvm::TvmEPOptions& options) { - return std::make_shared(options); -} -} // namespace onnxruntime - -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tvm, - _In_ OrtSessionOptions* options, - _In_ const char* opt_str) { - onnxruntime::tvm::TvmEPOptions tvm_options = onnxruntime::tvm::TvmEPOptionsHelper::FromOptionsString(opt_str); - options->provider_factories.push_back(onnxruntime::TVMProviderFactoryCreator::Create(tvm_options)); - return nullptr; -} diff --git a/onnxruntime/core/providers/tvm/tvm_provider_factory_creator.h b/onnxruntime/core/providers/tvm/tvm_provider_factory_creator.h deleted file mode 100644 index 2d7e06b5b7c59..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_provider_factory_creator.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/providers.h" - -namespace onnxruntime { -namespace tvm { -struct TvmEPOptions; -} - -struct TVMProviderFactoryCreator { - static std::shared_ptr Create(const tvm::TvmEPOptions& options); - static std::shared_ptr Create(const char* params); -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_runner.cc b/onnxruntime/core/providers/tvm/tvm_runner.cc deleted file mode 100644 index 5dda8f5bf9c3e..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/graph/model.h" -#include "core/framework/tensorprotoutils.h" - -#include "tvm_runner.h" - -using namespace ONNX_NAMESPACE; -namespace onnxruntime { -namespace tvm { - -TVMRunner::TVMRunner(const TvmEPOptions& options, - const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const std::vector& output_tensors) { - runner_ = getTVMRunnerImpl(mod, options, inputs_info, output_tensors); -} - -common::Status TVMRunner::operator()(FunctionState state, const OrtApi* /*api*/, OrtKernelContext* context) { - Ort::KernelContext ctx(context); - return runner_->run(ctx); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_runner.h b/onnxruntime/core/providers/tvm/tvm_runner.h deleted file mode 100644 index 4b7349ee3405e..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_RUNNER_H -#define TVM_RUNNER_H - -#include -#include - -#include "tvm_runner_impl.h" - -namespace onnxruntime { -namespace tvm { - -class TVMRunner { - public: - TVMRunner() = delete; - virtual ~TVMRunner() = default; - - TVMRunner(const TvmEPOptions& options, - const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const std::vector& output_tensor); - - common::Status operator()(FunctionState state, const OrtApi* api, OrtKernelContext* context); - - private: - std::shared_ptr runner_; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_TVM_RUNNER_H diff --git a/onnxruntime/core/providers/tvm/tvm_runner_impl.cc b/onnxruntime/core/providers/tvm/tvm_runner_impl.cc deleted file mode 100644 index c88de2652f14b..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner_impl.cc +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/tensorprotoutils.h" - -#include "tvm_runner_impl.h" -#include "tvm_utils.h" -#include "tvm_api.h" - -namespace onnxruntime { -namespace tvm { - -/* ------------------------------------ RunnerImplFactory ----------------------------- */ - -std::shared_ptr getTVMRunnerImpl(const std::shared_ptr& mod, - const TvmEPOptions& options, - const InputsInfoMap& inputs_info, - const std::vector output_tensors) { - const std::string& name = options.executor; - if (name == "graph") { - return std::make_shared(mod, inputs_info, options.output_shapes, - output_tensors, options.set_output_zero_copy); - } else if (name == "vm") { - return std::make_shared(mod, inputs_info, options.output_shapes, - output_tensors, options.set_output_zero_copy); - } - return nullptr; -} - -/* ------------------------------------ RunnerImpl ------------------------------------ */ - -RunnerImpl::RunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector output_tensors, - bool set_output_zero_copy) : mod_(mod), - inputs_info_(inputs_info), - output_shapes_(output_shapes), - output_tensors_(output_tensors), - set_output_zero_copy_(set_output_zero_copy) { -} - -void RunnerImpl::convert_input_tensors2dl_tensors(Ort::KernelContext& context, - std::vector& dst, - std::vector& dst_inds) { - size_t num = inputs_info_.size(); - dst.reserve(num); - dst_inds.reserve(num); - for (auto& info : inputs_info_) { - // TODO(vvchernov): decomposition declaration only available with -std=c++1z or -std=gnu++1z - auto& i = info.first; - auto& shape = info.second; - - auto input_tensor = context.GetInput(i); - ORT_ENFORCE(input_tensor.IsTensor()); - - auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType(); - const auto tensor_type = input_tensor.GetTensorTypeAndShapeInfo().GetElementType(); - - DLTensor t; - t.device = GetDLDevice(ort_device_type); - t.dtype = GetDataType(tensor_type); - t.strides = nullptr; - t.byte_offset = 0; - t.data = const_cast(input_tensor.GetTensorRawData()); - t.ndim = shape.size(); - t.shape = shape.data(); - dst.emplace_back(t); - dst_inds.push_back(i); - } -} - -void RunnerImpl::add_device_type_data2output_tensors(Ort::KernelContext& context) { - size_t num_outputs = output_tensors_.size(); - for (auto i = 0u; i < num_outputs; i++) { - // setup output tensor property - auto output_tensor = context.GetOutput(i, - output_shapes_[i].data(), - output_shapes_[i].size()); - ORT_ENFORCE(output_tensor.IsTensor()); - - output_tensors_[i].device = - GetDLDevice(output_tensor.GetTensorMemoryInfo().GetDeviceType()); - output_tensors_[i].dtype = - GetDataType(output_tensor.GetTensorTypeAndShapeInfo().GetElementType()); - output_tensors_[i].data = output_tensor.GetTensorMutableRawData(); - } -} - -/* ------------------------------------ GERunnerImpl ------------------------------------ */ - -GERunnerImpl::GERunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector output_tensors, - bool set_output_zero_copy) : RunnerImpl(mod, inputs_info, output_shapes, output_tensors, set_output_zero_copy) { -} - -void GERunnerImpl::set_input(Ort::KernelContext& context) { - std::vector inds; - std::vector dl_tensors_inputs; - convert_input_tensors2dl_tensors(context, dl_tensors_inputs, inds); - - tvm::TVMSetInputs(*mod_, inds, dl_tensors_inputs); -} - -void GERunnerImpl::connect_output_tensors2ort(Ort::KernelContext& context) { - add_device_type_data2output_tensors(context); -} - -void GERunnerImpl::set_output_zero_copy() { - tvm::TVMSetOutputsZeroCopy(*mod_, output_tensors_); -} - -void GERunnerImpl::run() { - tvm::TVMRun(*mod_); -} - -void GERunnerImpl::get_outputs() { - tvm::TVMGetOutputs(*mod_, output_tensors_); -} - -/* ------------------------------------ VMRunnerImpl ------------------------------------ */ - -VMRunnerImpl::VMRunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector output_tensors, - bool set_output_zero_copy) : RunnerImpl(mod, inputs_info, output_shapes, output_tensors, set_output_zero_copy) { -} - -void VMRunnerImpl::set_input(Ort::KernelContext& context) { - std::vector inds; - std::vector dl_tensors_inputs; - convert_input_tensors2dl_tensors(context, dl_tensors_inputs, inds); - - tvm::TVM_VM_SetInputs(*mod_, inds, dl_tensors_inputs); -} - -void VMRunnerImpl::connect_output_tensors2ort(Ort::KernelContext& context) { - // TODO(vvchernov): try to find more flexible solution - if (!probe_infer_) { - infer_once_to_get_output_shapes(); - } - - add_device_type_data2output_tensors(context); -} - -void VMRunnerImpl::set_output_zero_copy() { - tvm::TVM_VM_SetOutputsZeroCopy(*mod_, output_tensors_); -} - -void VMRunnerImpl::run() { - tvm::TVM_VM_Run(*mod_); -} - -void VMRunnerImpl::get_outputs() { - tvm::TVM_VM_GetOutputs(*mod_, output_tensors_); -} - -void VMRunnerImpl::infer_once_to_get_output_shapes() { - run(); - size_t num_outputs = output_tensors_.size(); - // TODO(vvchernov): check it - output_shapes_.resize(num_outputs); - tvm::TVMGetOutputShapes(*mod_, output_shapes_); - for (size_t i = 0; i < num_outputs; ++i) { - output_tensors_[i].ndim = output_shapes_[i].size(); - output_tensors_[i].shape = output_shapes_[i].data(); - } - probe_infer_ = true; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_runner_impl.h b/onnxruntime/core/providers/tvm/tvm_runner_impl.h deleted file mode 100644 index 8c325303673b6..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_runner_impl.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_RUNNER_IMPL_H -#define TVM_RUNNER_IMPL_H - -#include -#include -#include - -#include "core/framework/func_api.h" -#include "core/session/onnxruntime_cxx_api.h" - -#include "tvm_common.h" -#include "tvm_ep_options.h" - -namespace onnxruntime { -namespace tvm { - -class RunnerImpl { - public: - RunnerImpl() = delete; - RunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector tensors_outputs, - bool set_output_zero_copy); - virtual ~RunnerImpl() = default; - - virtual common::Status run(Ort::KernelContext& context) { - common::Status res; - if (set_output_zero_copy_) { - res = run_without_output_copying(context); - } else { - res = run_with_output_copying(context); - } - return res; - } - - virtual common::Status run_without_output_copying(Ort::KernelContext& context) { - set_input(context); - connect_output_tensors2ort(context); - set_output_zero_copy(); - run(); - - return Status::OK(); - } - - virtual common::Status run_with_output_copying(Ort::KernelContext& context) { - set_input(context); - connect_output_tensors2ort(context); - run(); - get_outputs(); - - return Status::OK(); - } - - virtual void set_input(Ort::KernelContext& context) = 0; - virtual void connect_output_tensors2ort(Ort::KernelContext& context) = 0; - virtual void set_output_zero_copy() = 0; - virtual void run() = 0; - virtual void get_outputs() = 0; - - protected: - void convert_input_tensors2dl_tensors(Ort::KernelContext& context, - std::vector& dst, - std::vector& dst_inds); - void add_device_type_data2output_tensors(Ort::KernelContext& context); - - protected: - std::shared_ptr mod_; - InputsInfoMap inputs_info_; - TVMTensorShapes output_shapes_; - std::vector output_tensors_; - bool set_output_zero_copy_; -}; - -class GERunnerImpl : public RunnerImpl { - public: - GERunnerImpl() = delete; - GERunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector tensors_outputs, - bool set_output_zero_copy); - virtual ~GERunnerImpl() = default; - - void set_input(Ort::KernelContext& context) final; - void connect_output_tensors2ort(Ort::KernelContext& context) final; - void set_output_zero_copy() final; - void run() final; - void get_outputs() final; -}; - -class VMRunnerImpl : public RunnerImpl { - public: - VMRunnerImpl() = delete; - VMRunnerImpl(const std::shared_ptr& mod, - const InputsInfoMap& inputs_info, - const TVMTensorShapes output_shapes, - const std::vector tensors_outputs, - bool set_output_zero_copy); - virtual ~VMRunnerImpl() = default; - - void set_input(Ort::KernelContext& context) final; - void connect_output_tensors2ort(Ort::KernelContext& context) final; - void set_output_zero_copy() final; - void run() final; - void get_outputs() final; - - private: - void infer_once_to_get_output_shapes(); - - private: - bool probe_infer_ = false; -}; - -std::shared_ptr getTVMRunnerImpl(const std::shared_ptr& mod, - const TvmEPOptions& options, - const InputsInfoMap& inputs_info, - const std::vector output_tensors); - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_TVM_RUNNER_IMPL_H diff --git a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc deleted file mode 100644 index 029f25d6f292a..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.cc +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "core/framework/execution_provider.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/kernel_registry.h" -#include "core/framework/compute_capability.h" -#include "core/platform/env.h" -#include "core/graph/model.h" - -#include "tvm_so_execution_provider.h" // NOLINT(build/include_subdir) -#include "xpu_data_transfer.h" // NOLINT(build/include_subdir) -#include "tvm_allocator.h" // NOLINT(build/include_subdir) -#include "tvm_utils.h" // NOLINT(build/include_subdir) -#include "tvm_api.h" // NOLINT(build/include_subdir) -#ifdef USE_TVM_HASH -#include "hash_alg/hasher.h" // NOLINT(build/include_subdir) -#endif - -using ONNX_NAMESPACE::TensorShapeProto; - -namespace onnxruntime { -namespace tvm { - -// Information to construct kernel function state. -struct TVMFuncState { - AllocateFunc allocate_func = nullptr; - DestroyFunc release_func = nullptr; - AllocatorHandle allocator = nullptr; - std::shared_ptr compiler = nullptr; -}; - -TvmSoExecutionProvider::TvmSoExecutionProvider(const TvmEPOptions& options) - : IExecutionProvider{kTvmExecutionProvider}, - options_{options} { - // Get environment variables - const Env& env_instance = Env::Default(); - - const std::string dump_subgraphs_env = env_instance.GetEnvironmentVar(env_vars::kDumpSubgraphs); - ORT_ENFORCE(dump_subgraphs_env.empty(), "TVM EP processing shared lib does not support subgraphs"); -} - -std::vector TvmSoExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo default_memory_info = {[](int) { - return std::make_unique(); - }, - 0, false}; - return std::vector{CreateAllocator(default_memory_info)}; -} - -TvmSoExecutionProvider::~TvmSoExecutionProvider() {} - -std::vector> -TvmSoExecutionProvider::GetCapability(const GraphViewer& graph_viewer, - const IKernelLookup& /*kernel_lookup*/) const { - std::vector> result; - if (graph_viewer.IsSubgraph()) { - return result; - } - - const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); - - std::unordered_set required_initializers; - const std::vector& sorted_nodes = graph_viewer.GetNodesInTopologicalOrder(); - std::unique_ptr sub_graph = std::make_unique(); - for (auto& node_idx : sorted_nodes) { - graph_viewer.GetNode(node_idx)->ForEachDef([&required_initializers, &init_tensors](const NodeArg& node_arg, bool is_input) { - if (is_input && init_tensors.count(node_arg.Name())) { - required_initializers.insert(node_arg.Name()); - } }, true); - } - - auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); - meta_def->name = "TVMStandalone"; - meta_def->domain = "StandaloneTest"; - std::vector inputs; - std::vector outputs; - - for (auto& nodeArgPtr : graph_viewer.GetInputs()) { - inputs.push_back(nodeArgPtr->Name()); - } - - for (auto& name : required_initializers) { - inputs.push_back(name); - } - - for (auto& nodeArgPtr : graph_viewer.GetOutputs()) { - outputs.push_back(nodeArgPtr->Name()); - } - meta_def->inputs = inputs; - meta_def->outputs = outputs; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - sub_graph->SetMetaDef(std::move(meta_def)); - sub_graph->nodes = sorted_nodes; - result.push_back( - std::make_unique(std::move(sub_graph))); - return result; -} - -common::Status TvmSoExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - printOptions(); - for (auto& fused_node_graph : fused_nodes_and_graphs) { - const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; - const Node& fused_node = fused_node_graph.fused_node; -#ifdef USE_TVM_HASH - if (options_.check_hash) { - ORT_ENFORCE(checkHash(ToUTF8String(fused_node.ModelPath().ToPathString())), - "Hash check shows that used tuning files were not obtained for the given onnx-model"); - } -#endif - const std::string func_name = fused_node.Name(); - - compilers_[func_name] = std::make_shared(); - InputsInfoMap all_input_shapes; - auto mod = compileModel(func_name, graph_body_viewer, all_input_shapes); - - std::vector output_tensors(graph_body_viewer.GetOutputs().size()); - prepareOutputTensors(output_tensors); - - runners_[func_name] = std::make_shared(options_, mod, all_input_shapes, output_tensors); - - // TODO(vvchernov): implement ops checking and mechanism of gracefully passing the responsibility to other EPs - // if the checking fails due to unsupported op(s) - NodeComputeInfo compute_info = prepareComputeInfo(func_name); - - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} - -std::unique_ptr TvmSoExecutionProvider::GetDataTransfer() const { - // TODO(vvchernov): target or target host? - if (TvmEPOptionsHelper::checkGPUTarget(options_.target)) { - return std::make_unique(); - } else if (TvmEPOptionsHelper::checkCPUTarget(options_.target)) { - return std::make_unique(); - } else { - ORT_NOT_IMPLEMENTED("TVM GetDataTransfer is not implemented for target ", options_.target); - } -} - -void TvmSoExecutionProvider::printOptions() { - LOGS(*GetLogger(), INFO) << options_; -} - -#ifdef USE_TVM_HASH -bool TvmSoExecutionProvider::checkHash(const std::string& onnx_path) const { - auto hasher = Hasher("sha256"); - std::string onnx_str = readFromFile(onnx_path); - std::string onnx_hash = hasher.hash(onnx_str.c_str(), onnx_str.size()); - onnx_str.clear(); - std::string hash; - if (options_.hash_file_path.empty()) { - // TODO(vvchernov): align hash file name with OctoML team - hash = readFromFile(options_.so_folder + "/hash.txt"); - } else { - hash = readFromFile(options_.hash_file_path); - } - return onnx_hash == hash; -} -#endif - -std::shared_ptr TvmSoExecutionProvider::compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& all_input_shapes) { - all_input_shapes.clear(); - - TVMTensorShapes input_shapes; - if (options_.freeze_weights) { - setInputShapesForFreezedNN(graph_viewer, input_shapes, all_input_shapes); - } else { - setInputShapesForUnfreezedNN(graph_viewer, input_shapes, all_input_shapes); - } - - std::shared_ptr mod = compilers_[func_name]->operator()(options_, input_shapes); - - return mod; -} - -void TvmSoExecutionProvider::setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - if (!graph_viewer.IsInitializedTensor(node->Name())) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - input_shapes.emplace_back(shape); - } - } -} - -void TvmSoExecutionProvider::setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, - InputsInfoMap& all_input_shapes) { - const std::vector& all_nodes = graph_viewer.GetInputsIncludingInitializers(); - - size_t indx = 0; - for (const auto* node : all_nodes) { - TensorShapeVector shape = getInputShape(node); - all_input_shapes[indx++] = shape; - if (!graph_viewer.IsInitializedTensor(node->Name())) { - input_shapes.emplace_back(shape); - } - } -} - -TensorShapeVector TvmSoExecutionProvider::getInputShape(const NodeArg* node) { - TensorShapeVector shape; - const auto& node_name = node->Name(); - if (!options_.input_shapes.empty() && - options_.input_shapes.count(node_name)) { - shape = options_.input_shapes[node_name]; - } else { - shape = convertTensorShape(*node->Shape()); - } - - return shape; -} - -TensorShapeVector TvmSoExecutionProvider::convertTensorShape(const TensorShapeProto& shape_proto) { - TensorShape ort_shape = utils::GetTensorShapeFromTensorShapeProto(shape_proto); - size_t dims = ort_shape.NumDimensions(); - - TensorShapeVector shape(dims); - for (size_t j = 0; j < dims; ++j) { - int64_t dim = int64_t(ort_shape[j]); - ORT_ENFORCE(dim > 0, "Input dimension is not positive value (dim = " + std::to_string(dim) + "). " + - "Please use provider options to setup input_names and input_shapes"); - shape[j] = dim; - } - - return shape; -} - -void TvmSoExecutionProvider::prepareOutputTensors(std::vector& output_tensors) { - for (DLTensor& t : output_tensors) { - // Draft for tensor, correct data is defined during inference - t.strides = nullptr; - t.byte_offset = 0; - t.data = nullptr; - t.ndim = 0; - t.shape = nullptr; - } -} - -NodeComputeInfo TvmSoExecutionProvider::prepareComputeInfo(const std::string& func_name) { - NodeComputeInfo compute_info; - compute_info.create_state_func = std::bind(&TvmSoExecutionProvider::createStateFunc, - this, - std::placeholders::_1, - std::placeholders::_2); - - compute_info.release_state_func = [](FunctionState state) { - if (state) - delete static_cast(state); - }; - - compute_info.compute_func = *runners_[func_name].get(); - - return compute_info; -} - -int TvmSoExecutionProvider::createStateFunc(ComputeContext* context, FunctionState* state) { - auto* state_ptr = new TVMFuncState(); - *state_ptr = {context->allocate_func, - context->release_func, - context->allocator_handle, - compilers_[context->node_name]}; - // TODO(vvchernov): Who and when release state? - *state = state_ptr; - return 0; -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h b/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h deleted file mode 100644 index d3840f46b5b55..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_so_execution_provider.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_SO_EXECUTION_PROVIDER_H_ -#define ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_SO_EXECUTION_PROVIDER_H_ - -#include -#include -#include -#include - -#include "core/common/logging/logging.h" -#include "core/framework/execution_provider.h" -#include - -#include "tvm_compiler.h" // NOLINT(build/include_subdir) -#include "tvm_runner.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -class Graph; -class NodeArg; -namespace tvm { - -class TvmSoExecutionProvider : public IExecutionProvider { - using Compiler = TVMCompilerBase; - using Compilers = std::unordered_map>; - using Runner = TVMRunner; - using Runners = std::unordered_map>; - - public: - explicit TvmSoExecutionProvider(const TvmEPOptions& options); - virtual ~TvmSoExecutionProvider(); - - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; - - common::Status Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; - std::unique_ptr GetDataTransfer() const override; - std::vector CreatePreferredAllocators() override; - - private: - void printOptions(); -#ifdef USE_TVM_HASH - bool checkHash(const std::string& onnx_path) const; -#endif - std::shared_ptr compileModel(const std::string& func_name, - const GraphViewer& graph_viewer, - InputsInfoMap& inputs_info); // NOLINT - void setInputShapesForFreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - void setInputShapesForUnfreezedNN(const GraphViewer& graph_viewer, - TVMTensorShapes& input_shapes, // NOLINT - InputsInfoMap& all_input_shapes); // NOLINT - TensorShapeVector getInputShape(const NodeArg* node); - TensorShapeVector convertTensorShape(const ONNX_NAMESPACE::TensorShapeProto& shape_proto); - void prepareOutputTensors(std::vector& output_tensors); // NOLINT - NodeComputeInfo prepareComputeInfo(const std::string& func_name); - int createStateFunc(ComputeContext*, FunctionState*); - - private: - TvmEPOptions options_; - Compilers compilers_; - Runners runners_; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // ONNXRUNTIME_CORE_PROVIDERS_TVM_TVM_SO_EXECUTION_PROVIDER_H_ diff --git a/onnxruntime/core/providers/tvm/tvm_utils.cc b/onnxruntime/core/providers/tvm/tvm_utils.cc deleted file mode 100644 index e0a5b566835c8..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_utils.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_UTILS_H -#define TVM_UTILS_H - -#include -#include - -#include "tvm_utils.h" // NOLINT(build/include_subdir) - -namespace onnxruntime { -namespace tvm { - -std::string readFromFile(const std::string& file_path) { - std::string str; - - std::ifstream t(file_path); - t.seekg(0, std::ios::end); - str.reserve(t.tellg()); - t.seekg(0, std::ios::beg); - - str.assign((std::istreambuf_iterator(t)), - std::istreambuf_iterator()); - return str; -} - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_UTILS_H diff --git a/onnxruntime/core/providers/tvm/tvm_utils.h b/onnxruntime/core/providers/tvm/tvm_utils.h deleted file mode 100644 index de77368c715b9..0000000000000 --- a/onnxruntime/core/providers/tvm/tvm_utils.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef TVM_UTILS_H -#define TVM_UTILS_H - -#include - -#include "tvm_common.h" - -#include "core/session/onnxruntime_cxx_api.h" -#include "core/framework/ortdevice.h" -#include "core/common/common.h" - -namespace onnxruntime { -namespace tvm { - -inline DLDataType GetDataType(ONNXTensorElementDataType type) { - switch (type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return {kDLUInt, 8, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - return {kDLInt, 8, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - return {kDLUInt, 16, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - return {kDLInt, 16, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - return {kDLUInt, 32, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return {kDLInt, 32, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - return {kDLUInt, 64, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - return {kDLInt, 64, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - return {kDLFloat, 16, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return {kDLFloat, 32, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - return {kDLFloat, 64, 1}; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - return {kDLUInt, 1, 1}; - default: - ORT_NOT_IMPLEMENTED("Unsupported data type"); - } -} - -inline DLDevice GetDLDevice(OrtMemoryInfoDeviceType device_type) { - DLDevice context; - switch (device_type) { - case OrtDevice::CPU: - context = {kDLCPU, 0}; - break; - case OrtDevice::GPU: - context = {kDLVulkan, 0}; - break; - default: - ORT_NOT_IMPLEMENTED("Unsupported device"); - break; - } - return context; -} - -std::string readFromFile(const std::string& file_path); - -} // namespace tvm -} // namespace onnxruntime - -#endif // TVM_UTILS_H diff --git a/onnxruntime/core/providers/tvm/xpu_data_transfer.cc b/onnxruntime/core/providers/tvm/xpu_data_transfer.cc deleted file mode 100644 index 4011dee7b7b7f..0000000000000 --- a/onnxruntime/core/providers/tvm/xpu_data_transfer.cc +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/tensor.h" - -#include "xpu_data_transfer.h" -#include "tvm_utils.h" - -namespace onnxruntime { -namespace tvm { - -XPUDataTransfer::XPUDataTransfer() { -} - -XPUDataTransfer::~XPUDataTransfer() { -} - -bool XPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return (src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU) || - (src_device.Type() == OrtDevice::GPU || dst_device.Type() == OrtDevice::GPU); -} - -common::Status XPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - const auto src_device_type = src.Location().device.Type(); - const auto dst_device_type = dst.Location().device.Type(); - - if ((src_device_type == OrtDevice::CPU) && (dst_device_type == OrtDevice::CPU)) { - if (src_data == dst_data) { - // no need copying as both pointers are referring to same piece of memory. - return Status::OK(); - } - memcpy(dst_data, src_data, bytes); - } else { - DLTensor tvm_src, tvm_dst; - DLDataType dl_type{kDLInt, 8, 1}; - std::vector shape{int64_t(bytes)}; - // Construct source DLTensor - tvm_src.device = GetDLDevice(static_cast(src_device_type)); - tvm_src.dtype = dl_type; - tvm_src.strides = nullptr; - tvm_src.byte_offset = 0; - tvm_src.data = const_cast(src_data); - tvm_src.ndim = 1; - tvm_src.shape = shape.data(); - // Construct destination DLTensor - tvm_dst.device = GetDLDevice(static_cast(dst_device_type)); - tvm_dst.dtype = dl_type; - tvm_dst.strides = nullptr; - tvm_dst.byte_offset = 0; - tvm_dst.data = dst_data; - tvm_dst.ndim = 1; - tvm_dst.shape = shape.data(); - // Copying from src to dst - TVMDeviceCopyDataFromTo(&tvm_src, &tvm_dst, nullptr); - } - return Status::OK(); -} - -DLDevice XPUDataTransfer::get_context(const OrtDevice& device) const { - return GetDLDevice(static_cast(device.Type())); -} - -bool TvmCPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU; -} - -common::Status TvmCPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - if (src_data == dst_data) { - // no need copying as both pointers are referring to same piece of memory. - return Status::OK(); - } - // Copying only happens between two same size tensors. - ORT_ENFORCE(src.SizeInBytes() == dst.SizeInBytes()); - memcpy(dst_data, src_data, src.SizeInBytes()); - return Status::OK(); -} - -} // namespace tvm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/tvm/xpu_data_transfer.h b/onnxruntime/core/providers/tvm/xpu_data_transfer.h deleted file mode 100644 index a2cf55b241bb1..0000000000000 --- a/onnxruntime/core/providers/tvm/xpu_data_transfer.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifndef XPU_DATA_TRANSFER -#define XPU_DATA_TRANSFER - -#include "core/framework/data_transfer.h" -#include "tvm_common.h" - -namespace onnxruntime { -namespace tvm { - -class XPUDataTransfer : public IDataTransfer { - public: - XPUDataTransfer(); - ~XPUDataTransfer(); - - bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; - - // Dumpen MSVC warning about not fully overriding - using IDataTransfer::CopyTensor; - common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; - DLDevice get_context(const OrtDevice& device) const; -}; - -class TvmCPUDataTransfer : public IDataTransfer { - public: - TvmCPUDataTransfer() = default; - // Dampen MSVC warning about not fully overriding CopyTensor - using IDataTransfer::CopyTensor; - bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; - common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; -}; - -} // namespace tvm -} // namespace onnxruntime - -#endif // XPU_DATA_TRANSFER diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc index a9275b24ce91f..2b9ddf8ad147f 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc @@ -104,4 +104,8 @@ std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeP } return ret; } +std::string* attr_proto_release_string(ONNX_NAMESPACE::AttributeProto* attr) { + vai_assert(attr->type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRING, attr->name()); + return attr->release_s(); +} } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h index bb2883512037b..08d980ec94c14 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h @@ -23,5 +23,6 @@ const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::A gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr); gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr); std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr); +std::string* attr_proto_release_string(ONNX_NAMESPACE::AttributeProto* attr); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 772e778dd5ed4..51dc79c569589 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -444,6 +444,18 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } }; the_global_api.node_arg_external_location = vaip::node_arg_external_location; + the_global_api.model_to_proto = [](onnxruntime::Model& model) { return model.ToProto().release(); }; + the_global_api.model_proto_serialize_as_string = [](ONNX_NAMESPACE::ModelProto& model_proto) { + return vaip_core::DllSafe(model_proto.SerializeAsString()); + }; + the_global_api.model_proto_delete = [](ONNX_NAMESPACE::ModelProto* p) { delete p; }; + the_global_api.attr_proto_release_string = [](ONNX_NAMESPACE::AttributeProto* attr) -> vaip_core::DllSafe { + auto pstr = vaip::attr_proto_release_string(attr); + std::string local_str = std::move(*pstr); + pstr = nullptr; + return vaip_core::DllSafe(std::move(local_str)); + }; + if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h index 5d020e00ff5b7..64cf52ec0a404 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h @@ -25,18 +25,18 @@ class ExecutionProvider { virtual DllSafe> get_meta_def_nodes() const = 0; virtual DllSafe> get_meta_def_constant_initializer() const = 0; + virtual bool get_meta_def_fallback_CPU() const { return false; }; virtual std::unique_ptr compile() const = 0; public: - inline void set_fused_node(const onnxruntime::Node* fused_node) { - fused_node_ = fused_node; - } - inline const onnxruntime::Node* get_fused_node() const { - return fused_node_; - } + inline void set_fused_node(const onnxruntime::Node* fused_node) { fused_node_ = fused_node; } + inline const onnxruntime::Node* get_fused_node() const { return fused_node_; } + inline void set_model(onnxruntime::Model* model) { model_ = model; } + inline onnxruntime::Model* get_model() const { return model_; } private: const onnxruntime::Node* fused_node_ = nullptr; + onnxruntime::Model* model_ = nullptr; }; class CustomOp { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 74482d8e9ee0e..7628e45d2b933 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -20,6 +20,7 @@ struct NodeAttributes; namespace ONNX_NAMESPACE { struct AttributeProto; struct TensorProto; +struct ModelProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { TensorProto_DataType_UNDEFINED = 0, @@ -70,6 +71,7 @@ enum AttributeProto_AttributeType : int { namespace vaip_core { class GraphHolder; using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::ModelProto; using ONNX_NAMESPACE::TensorProto; using onnxruntime::Graph; using onnxruntime::GraphViewer; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index bbe8b6e6e4934..9425c08dceebc 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (10u) +#define VAIP_ORT_API_MAJOR (12u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -231,6 +231,10 @@ struct OrtApiForVaip { gsl::span inputs); // [92] int (*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); // [93] void (*session_option_configuration)(void* mmap, void* session_option, void (*push)(void* mmap, const char* name, const char* value)); // [94] + ModelProto* (*model_to_proto)(Model& model); // [95] + DllSafe (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96] + void (*model_proto_delete)(ModelProto* p); // [97] + DllSafe (*attr_proto_release_string)(AttributeProto* attr); // [98] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 633847e6f163b..023a954c83d70 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -76,7 +76,17 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectorexecution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get()); + auto& ep = (**this->execution_providers_)[index]; + ep->set_fused_node(&fused_node_graph.fused_node.get()); + if (ep->get_meta_def_fallback_CPU()) { + auto& subgraph = fused_node_graph.filtered_graph.get(); + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model_proto = subgraph.CreateModel(logger)->ToProto(); + subgraph.ToProto(*model_proto->mutable_graph(), true, true); + auto local_registries = IOnnxRuntimeOpSchemaRegistryList{subgraph.GetSchemaRegistry()}; + auto model = Model::Create(std::move(*model_proto), subgraph.ModelPath(), &local_registries, logger); + ep->set_model(model.release()); + } compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 07085cd248d06..77dede6035b4c 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -50,10 +50,9 @@ class VitisAIExecutionProvider : public IExecutionProvider { ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; - std::set vitisai_optypes_; // EP context related. bool ep_ctx_enabled_ = false; - bool ep_ctx_embed_mode_ = true; + bool ep_ctx_embed_mode_ = false; std::string ep_ctx_model_path_cfg_{""}; mutable PathString ep_ctx_model_file_loc_{}; // It might need to be called before loading diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h index 3ed432c2efa1c..5278efdb4a400 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h @@ -112,7 +112,7 @@ class ConvOpBuilder : public BaseOpBuilder { } } } else { - auto pads = helper.Get("pads", std::vector{0U, 0U}); + auto pads = helper.Get("pads", std::vector{0U, 0U, 0U, 0U}); if (group != 1 && group != weight_tensor->GetShape()[OChannel_idx]) { if (is_1d_conv) { op = graph_ep->GetGraph() diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h index 4c10ba01b1c2e..7da1e6e674601 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h @@ -65,6 +65,12 @@ ELEMENTWISE_OP_BUILDER(Floor, Floor); ELEMENTWISE_OP_BUILDER(Log, Log); ELEMENTWISE_OP_BUILDER(Sin, Sin); ELEMENTWISE_OP_BUILDER(HardSwish, HardSwish); +ELEMENTWISE_OP_BUILDER(Neg, Neg); +ELEMENTWISE_OP_BUILDER(Not, LogicalNot); +ELEMENTWISE_OP_BUILDER(Ceil, Ceil); +ELEMENTWISE_OP_BUILDER(Round, Round); +ELEMENTWISE_OP_BUILDER(Min, Minimum); +ELEMENTWISE_OP_BUILDER(Max, Maximum); class PowOpBuilder : public BaseOpBuilder { bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h new file mode 100644 index 0000000000000..19cbe4e7f3e48 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h @@ -0,0 +1,191 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/optimizer/initializer.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +typedef tim::vx::ops::PadV2::pad_mode_type PadMode; + +class PadOpBuilder : public BaseOpBuilder { + public: + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 11; } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + const auto mode = helper.Get("mode", "constant"); + auto input_defs = node->InputDefs(); + size_t num_inputs = input_defs.size(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + const auto& initializers = graph_viewer.GetAllInitializedTensors(); + + if (mode == "wrap") { + LOGS_DEFAULT(WARNING) << "`wrap` mode Pad is not currently supported for now."; + return false; + } + if (mode == "constant") { + if (num_inputs > 2 && input_defs[2]->Exists()) { + // only support if `constant_value` input is a constant initializer + if (!Contains(initializers, input_defs[2]->Name())) { + LOGS_DEFAULT(WARNING) << "constant_value must be a constant initializer."; + return false; + } + } + } + // only support if `pads` input is known and does not contain negative values + { + const auto* pads_initializer = graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!pads_initializer) { + LOGS_DEFAULT(WARNING) << "pads must be a constant initializer"; + return false; + } + + Initializer unpacked_tensor(*pads_initializer); + auto tensor_data = unpacked_tensor.DataAsSpan(); + for (size_t i = 0; i < unpacked_tensor.size(); i++) { + if (tensor_data[i] < 0) { + LOGS_DEFAULT(WARNING) << "Negative pad value is not supported: pads[" + << i << "] = " << tensor_data[i]; + return false; + } + } + } + return true; + } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (0 == i) { + if (!util::IsTypeSupported(&iodef.node_arg) || + (*iodef.node_arg.Type() == "tensor(int64)") || + (*iodef.node_arg.Type() == "tensor(bool)")) { + LOGS_DEFAULT(WARNING) << "Unspport tensor data type:" << *iodef.node_arg.Type(); + return false; + } + } else if (1 == i) { + if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "pads must be a constant initializer."; + return false; + } + } else if (2 == i) { + if (iodef.node_arg.Exists() && !Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "constant_value must be a constant initializer."; + return false; + } + } else if (i == 3) { + if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "axes must be a constant initializer.."; + return false; + } + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Pad Op."; + NodeAttrHelper helper(node_unit); + const auto mode = helper.Get("mode", "constant"); + auto input_defs = node_unit.Inputs(); + PadMode pad_mode = PadMode::PAD_MODE_CONSTANT; + float const_val = 0.0f; + std::vector axes_tensor_data; + int32_t input_rank = inputs[0]->GetShape().size(); + + if (mode == "constant") { + pad_mode = PadMode::PAD_MODE_CONSTANT; + } else if (mode == "reflect") { + pad_mode = PadMode::PAD_MODE_REFLECT; + } else if (mode == "edge") { + pad_mode = PadMode::PAD_MODE_EDGE; + } else { + LOGS_DEFAULT(WARNING) << "`wrap` mode Pad is not currently supported for now."; + return false; + } + + // `pads` input + std::vector onnx_pads(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(onnx_pads.data()); + + // `constant_value` input + if (inputs.size() > 2 && pad_mode == PadMode::PAD_MODE_CONSTANT) { + if (input_defs[2].node_arg.Exists()) { + inputs[2]->CopyDataFromTensor(&const_val); + } + } + // `axes` input + if (inputs.size() > 3) { + // optional input axes is provided, use axes initializer data + std::vector axes_tensor(inputs[3]->GetSpec().GetElementNum()); + inputs[3]->CopyDataFromTensor(axes_tensor.data()); + std::transform( + axes_tensor.begin(), axes_tensor.end(), std::back_inserter(axes_tensor_data), + [input_rank](int64_t axis) { return HandleNegativeAxis(axis, input_rank); }); + } else { + // if not provided, make a default axes as [0, 1, ..., input_rank - 1] + std::vector default_axes(input_rank); + std::iota(std::begin(default_axes), std::end(default_axes), 0); + axes_tensor_data = std::move(default_axes); + } + + int64_t num_axes = axes_tensor_data.size(); + std::vector front_size(input_rank, 0); + std::vector back_size(input_rank, 0); + + int64_t axes_index = 0; + for (int64_t axes : axes_tensor_data) { + front_size[axes] = onnx_pads[axes_index]; + back_size[axes] = onnx_pads[axes_index + num_axes]; + axes_index++; + } + + std::reverse(front_size.begin(), front_size.end()); + std::reverse(back_size.begin(), back_size.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + front_size, back_size, const_val, pad_mode); + op->BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h new file mode 100644 index 0000000000000..e08416bda70d4 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h @@ -0,0 +1,190 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/optimizer/initializer.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class SplitOpBuilder : public BaseOpBuilder { + public: + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + auto axis = helper.Get("axis", 0); + auto input_defs = node->InputDefs(); + size_t num_inputs = input_defs.size(); + size_t num_outputs = node->OutputDefs().size(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + std::vector splits_list; + bool split_provided = false; + if (axis >= rank || axis < -rank) { + LOGS_DEFAULT(WARNING) << "Axis is invalid in Split. Axis(" << axis + << ") is out of rank[" << -rank << "," << rank - 1 << "]"; + return false; + } + axis = HandleNegativeAxis(axis, rank); + const auto split_dims_at_axis = input_shape.GetDims()[axis]; + if (num_inputs > 1 && input_defs[1]->Exists()) { + // if optional input `split` is provided + const auto* splits = graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!splits) { + LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; + return false; + } + Initializer unpacked_tensor(*splits); + auto split_sizes_ = unpacked_tensor.DataAsSpan(); + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + split_provided = true; + } + if (num_inputs == 1) { + // opset1,2,11 split as attribute + if (helper.HasAttr("split")) { + auto split_sizes_ = *helper.GetInt64s("split"); + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + split_provided = true; + } else if (node->SinceVersion() >= 18) { + const auto outputs_count = helper.GetInt64("num_outputs"); + if (!outputs_count.has_value()) { + LOGS_DEFAULT(WARNING) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; + return false; + } + if (outputs_count.value() != static_cast(num_outputs) || + outputs_count.value() > split_dims_at_axis) { + LOGS_DEFAULT(WARNING) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size " + "of dimension being split. num_outputs: " + << outputs_count.value(); + return false; + } + } + } + if (!split_provided) { + // populate split sizes based on num_outputs so existing code can be utilized + int32_t size = narrow(std::ceil(float(split_dims_at_axis) / num_outputs)); + int32_t remainder = split_dims_at_axis % size; + std::vector split_sizes_(num_outputs, size); + if (remainder) { + split_sizes_.back() = remainder; + } + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + } + + uint32_t sum_of_splits = std::accumulate(splits_list.begin(), splits_list.end(), SafeInt(0)); + if (sum_of_splits != split_dims_at_axis) { + LOGS_DEFAULT(WARNING) << "Sum of the 'split' input values must equal to the dim value at 'axis' specified. " + << "dim value at 'axis' specified: " + << split_dims_at_axis + << ", sum of 'split' input values: " + << sum_of_splits; + return false; + } + if (!std::all_of(splits_list.begin(), splits_list.end(), [](int64_t value) { return value >= 0; })) { + LOGS_DEFAULT(WARNING) << "Invalid value in 'split' attribute. All values must be > 0"; + return false; + } + auto average_split = sum_of_splits / num_outputs; + if (!std::all_of(splits_list.begin(), splits_list.end(), [average_split](int64_t value) { return value == average_split; })) { + // TO DO, remove this check after driver supports it. + LOGS_DEFAULT(WARNING) << "Uneven splits are not currently supported for now."; + return false; + } + + return true; + } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (0 == i) { + if (!util::IsTypeSupported(&iodef.node_arg) || + (*iodef.node_arg.Type() == "tensor(int64)") || + (*iodef.node_arg.Type() == "tensor(bool)")) { + LOGS_DEFAULT(WARNING) << "Unsupport tensor data type:" << *iodef.node_arg.Type(); + return false; + } + } else if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; + return false; + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Split Op."; + NodeAttrHelper helper(node_unit); + auto axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + const auto split_dims_at_axis = inputs[0]->GetShape()[axis]; + auto num_outputs = outputs.size(); + // transform splite vector to timvx slice + std::vector onnx_split; + if (inputs.size() > 1) { + std::vector split_sizes_(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(split_sizes_.data()); + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + if (inputs.size() == 1) { + if (helper.HasAttr("split")) { + auto split_sizes_ = *helper.GetInt64s("split"); + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + if (node_unit.SinceVersion() >= 18 || !helper.HasAttr("split")) { + // populate split sizes based on num_outputs so existing code can be utilized + int32_t size = narrow(std::ceil(float(split_dims_at_axis) / num_outputs)); + int32_t remainder = split_dims_at_axis % size; + std::vector split_sizes_(num_outputs, size); + if (remainder) { + split_sizes_.back() = remainder; + } + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + } + std::vector slices(onnx_split.begin(), onnx_split.end()); + std::reverse(slices.begin(), slices.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + axis, slices); + op->BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h index dc0969429b8ff..fcf9479a6058b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h @@ -53,6 +53,8 @@ #include "impl/cast_op_builder.h" #include "impl/dropout_op_builder.h" #include "impl/slice_op_builder.h" +#include "impl/split_op_builder.h" +#include "impl/pad_op_builder.h" namespace onnxruntime { namespace vsi { namespace npu { @@ -110,7 +112,15 @@ static const std::map reg = { REGISTER_OP_BUILDER("Resize", ResizeOpBuilder), REGISTER_OP_BUILDER("Cast", CastOpBuilder), REGISTER_OP_BUILDER("Dropout", DropoutOpBuilder), - REGISTER_OP_BUILDER("Slice", SliceOpBuilder) + REGISTER_OP_BUILDER("Slice", SliceOpBuilder), + REGISTER_OP_BUILDER("Split", SplitOpBuilder), + REGISTER_OP_BUILDER("Neg", NegOpBuilder), + REGISTER_OP_BUILDER("Not", NotOpBuilder), + REGISTER_OP_BUILDER("Ceil", CeilOpBuilder), + REGISTER_OP_BUILDER("Round", RoundOpBuilder), + REGISTER_OP_BUILDER("Min", MinOpBuilder), + REGISTER_OP_BUILDER("Max", MaxOpBuilder), + REGISTER_OP_BUILDER("Pad", PadOpBuilder) #undef REGISTER_OP_BUILDER }; diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch index 45de47f3e5128..95a4e4650e9fe 100644 --- a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -1,8 +1,8 @@ diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake -index c02ac2096d..2bc51298f0 100644 +index 10c307b3b9..a52bf71c4d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake -@@ -361,7 +361,7 @@ else() +@@ -370,7 +370,7 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -12,11 +12,11 @@ index c02ac2096d..2bc51298f0 100644 ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h -index e46105324a..414c46a1ce 100644 +index 28ae64c4d5..0c77e0ca78 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h -@@ -82,6 +82,9 @@ Abstract: - +@@ -83,6 +83,9 @@ Abstract: + #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) +#if !defined(USE_VSINPU) @@ -25,51 +25,51 @@ index e46105324a..414c46a1ce 100644 #if !defined(__APPLE__) // Had to temporary disable fp16 under APPLE ARM64, as compiling // the source files require a hardware specific compilation flag. -@@ -90,6 +93,7 @@ Abstract: - +@@ -91,6 +94,7 @@ Abstract: + #define MLAS_F16VEC_INTRINSICS_SUPPORTED - + +#endif // #endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic -@@ -1635,6 +1639,7 @@ MlasHalfGemmConvertPackB( +@@ -1644,6 +1648,7 @@ MlasHalfGemmConvertPackB( ); - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) /** * @brief Whether current CPU supports Bfloat16(bf16) acceleration. */ -@@ -1746,6 +1751,7 @@ MlasSBGemmPackBSize(size_t N, size_t K); +@@ -1755,6 +1760,7 @@ MlasSBGemmPackBSize(size_t N, size_t K); void MLASCALL MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); #endif +#endif - + /** * @brief Indirect Depthwise convolution for fp16 diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h -index 4239e2ecae..3df7e5573d 100644 +index 0533a5e49b..c18bf7f90d 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h -@@ -361,6 +361,7 @@ size_t +@@ -377,6 +377,7 @@ size_t #else - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( const float* A, const bfloat16_t* B, -@@ -373,6 +374,7 @@ typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( +@@ -389,6 +390,7 @@ typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( const float* Bias ); #endif +#endif - + typedef size_t -@@ -763,8 +765,10 @@ extern "C" { +@@ -796,8 +798,10 @@ extern "C" { MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; #if defined(__aarch64__) && defined(__linux__) @@ -80,39 +80,25 @@ index 4239e2ecae..3df7e5573d 100644 #endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; -@@ -899,8 +903,10 @@ extern "C" { +@@ -946,8 +950,10 @@ extern "C" { #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) #define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #endif +#endif - + // // Single-threaded single precision matrix/matrix multiply operation. -@@ -2570,4 +2576,3 @@ MlasPackInt4Elements(uint8_t* Output, UnpackedType ValueLow, UnpackedType ValueH - static_assert(std::is_same_v || std::is_same_v); - *Output = static_cast(((ValueHigh & 0xF) << 4) | (ValueLow & 0xF)); - } -- diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp -index ed437f20f7..8c9d0a75fd 100644 +index b3c9461293..424c3b0441 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp -@@ -20,7 +20,7 @@ Abstract: - #include - #include - --#if defined(MLAS_TARGET_POWER) -+#if defined(MLAS_TARGET_POWER) - #if defined(__linux__) - #include - #elif defined(_AIX) -@@ -536,7 +536,7 @@ Return Value: - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; +@@ -574,7 +574,7 @@ Return Value: + this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } - + -#if defined(__linux__) +#if defined(__linux__) && !defined(USE_VSINPU) // @@ -124,12 +110,12 @@ index de7fd72fad..4f75dbd6fa 100644 +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -31,6 +31,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #pragma once - + @@ -396,4 +397,5 @@ MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t Bat } ); @@ -137,11 +123,11 @@ index de7fd72fad..4f75dbd6fa 100644 +#endif #endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc -index 6a71283f9d..d8bd348854 100644 +index 2c6d23e4de..61aaacdfd6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc -@@ -132,7 +132,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { - +@@ -133,7 +133,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { + return Status::OK(); } -#if defined(__aarch64__) && defined(__linux__) @@ -149,7 +135,7 @@ index 6a71283f9d..d8bd348854 100644 bool GemmPackBBfloat16(AllocatorPtr& alloc, const Tensor& tensor_b, bool trans_b, -@@ -180,6 +180,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc +@@ -181,6 +181,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc if (input_idx == 1) { size_t packed_b_size; #if defined(__aarch64__) && defined(__linux__) @@ -157,7 +143,7 @@ index 6a71283f9d..d8bd348854 100644 size_t dim1 = 0; size_t dim2 = 0; TensorShape b_shape = tensor.Shape(); -@@ -192,6 +193,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc +@@ -193,6 +194,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); } else @@ -165,7 +151,7 @@ index 6a71283f9d..d8bd348854 100644 #endif { is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); -@@ -257,6 +259,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { +@@ -259,6 +261,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); #if defined(__aarch64__) && defined(__linux__) @@ -173,7 +159,7 @@ index 6a71283f9d..d8bd348854 100644 if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { -@@ -273,6 +276,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { +@@ -275,6 +278,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { } MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); } else @@ -187,7 +173,7 @@ index b9bbe36583..2f570502d2 100644 +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -31,8 +31,10 @@ class MatMul final : public OpKernel { trans_batch_b_ = trans_batch_b_attr != 0; - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); @@ -195,10 +181,10 @@ index b9bbe36583..2f570502d2 100644 +#endif #endif } - + @@ -57,12 +59,14 @@ class MatMul final : public OpKernel { bool trans_batch_b_; - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) // fastmath mode state @@ -209,7 +195,7 @@ index b9bbe36583..2f570502d2 100644 #endif +#endif }; - + } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp index f85fe97776..6039b7fa9e 100644 @@ -217,12 +203,12 @@ index f85fe97776..6039b7fa9e 100644 +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -16,6 +16,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #include "test_sbgemm.h" - + @@ -138,4 +139,5 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe } return SBGemmRegistLongExecute() > 0; @@ -235,15 +221,15 @@ index 13701e2e3d..7e432f53c2 100644 +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -16,6 +16,7 @@ Abstract: --*/ - + #if defined(__aarch64__) && defined(__linux__) +#if !defined(USE_VSINPU) - + #pragma once - + @@ -278,4 +279,5 @@ class MlasSBGemmTest : public MlasTestBase { } }; - + +#endif #endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc index bbf8255ac2940..db8a87d9eaf24 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -34,7 +34,8 @@ namespace onnxruntime { namespace vsi { namespace npu { -GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) { +GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) + : graph_viewer_(graph_viewer), logger_(logger) { Prepare(); context_ = tim::vx::Context::Create(); graph_ = context_->CreateGraph(); @@ -42,7 +43,7 @@ GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(g } bool GraphEP::Prepare() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); for (const auto& node_unit : node_unit_holder_) { auto quant_op_type = util::GetQuantizedOpType(*node_unit); diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h index 49344770d060e..5bb332fad0177 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -51,7 +51,7 @@ struct NodeIOInfo { class GraphEP { public: - explicit GraphEP(const GraphViewer& graph_viewer); + explicit GraphEP(const GraphViewer& graph_viewer, const logging::Logger& logger); ~GraphEP() {} bool Prepare(); @@ -104,6 +104,7 @@ class GraphEP { // In the form of {input_name, [NodeUnit(s) using the input]} std::unordered_map> all_quantized_op_inputs_; const GraphViewer& graph_viewer_; + const logging::Logger& logger_; // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is // valid throughout the lifetime of the ModelBuilder diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 669c702544de8..7da7cc6cb63ba 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; if (graph_viewer.IsSubgraph()) { @@ -82,7 +83,7 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -174,7 +175,8 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, - OrtKernelContext* context) { + OrtKernelContext* context, + const logging::Logger& logger) { Ort::KernelContext ctx(context); size_t num_in = ctx.GetInputCount(); const size_t num_inputs = graph_ep->GetGraphInputs().size(); @@ -192,7 +194,7 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, } if (!graph_ep->GetGraph()->Run()) { - LOGS_DEFAULT(ERROR) << "Failed to run graph."; + LOGS(logger, ERROR) << "Failed to run graph."; } for (size_t i = 0; i < ctx.GetOutputCount(); i++) { auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor; @@ -207,12 +209,14 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, Status VSINPUExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { + const auto& logger = *GetLogger(); + for (const auto& fused_node_graph : fused_nodes_and_graphs) { const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; - std::shared_ptr graph_ep = std::make_shared(graph_viewer); + std::shared_ptr graph_ep = std::make_shared(graph_viewer, logger); for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) { - LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" + LOGS(logger, VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" << graph_viewer.IsInitializedTensor(tensor->Name()); auto input = std::make_shared(); input->name = tensor->Name(); @@ -220,7 +224,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu graph_ep->GetGraphInputs().push_back(input); } for (auto tensor : graph_viewer.GetOutputs()) { - LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); + LOGS(logger, VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); auto output = std::make_shared(); output->name = tensor->Name(); output->is_initializer = false; @@ -236,16 +240,16 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu if (node != &node_unit.GetNode()) { continue; } - LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]"; + LOGS(logger, VERBOSE) << "Adding node: [" << node->OpType() << "]"; vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit); } - LOGS_DEFAULT(INFO) << "Verifying graph"; + LOGS(logger, INFO) << "Verifying graph"; graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile(); if (!graph_ep->GetCompiled()) { - LOGS_DEFAULT(ERROR) << "Failed to verify graph."; + LOGS(logger, ERROR) << "Failed to verify graph."; } else { - LOGS_DEFAULT(INFO) << "Graph has been verified successfully."; + LOGS(logger, INFO) << "Graph has been verified successfully."; } NodeComputeInfo compute_info; @@ -259,7 +263,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu [graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */, OrtKernelContext* context) { std::lock_guard lock(this->GetMutex()); - Status res = ComputeStateFunc(graph_ep.get(), context); + Status res = ComputeStateFunc(graph_ep.get(), context, *GetLogger()); return res; }; diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc new file mode 100644 index 0000000000000..8e27acdc285d4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/framework/session_state.h" +#include "core/providers/webgpu/allocator.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +void* GpuBufferAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + + auto buffer = context_.BufferManager().Create(size); + + stats_.num_allocs++; + return buffer; +} + +void GpuBufferAllocator::Free(void* p) { + if (p != nullptr) { + context_.BufferManager().Release(static_cast(p)); + stats_.num_allocs--; + } +} + +void GpuBufferAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h new file mode 100644 index 0000000000000..51ca65a8b4822 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class GpuBufferAllocator : public IAllocator { + public: + GpuBufferAllocator(const WebGpuContext& context) + : IAllocator( + OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), + 0, OrtMemTypeDefault)), + context_{context} { + } + + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc new file mode 100644 index 0000000000000..45eb123943de9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -0,0 +1,361 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +constexpr size_t NormalizeBufferSize(size_t size) { + return (size + 15) / 16 * 16; +} + +class DisabledCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + // always return empty buffer + return nullptr; + } + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + void ReleaseBuffer(WGPUBuffer buffer) override { + wgpuBufferRelease(buffer); + } + + void OnRefresh() override { + // no-op + } +}; + +class LazyReleaseCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + pending_buffers_.clear(); + } + + std::vector pending_buffers_; +}; + +class SimpleCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buffers_.find(buffer_size); + if (it != buffers_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + buffers_[wgpuBufferGetSize(buffer)].push_back(buffer); + } + pending_buffers_.clear(); + } + + std::map> buffers_; + std::vector pending_buffers_; +}; + +// TODO: maybe use different bucket size for storage and uniform buffers? +constexpr std::initializer_list> BUCKET_DEFAULT_LIMIT_TABLE = { + {64, 250}, + {128, 200}, + {256, 200}, + {512, 200}, + {2048, 230}, + {4096, 200}, + {8192, 50}, + {16384, 50}, + {32768, 50}, + {65536, 50}, + {131072, 50}, + {262144, 50}, + {524288, 50}, + {1048576, 50}, + {2097152, 30}, + {4194304, 20}, + {8388608, 10}, + {12582912, 10}, + {16777216, 10}, + {26214400, 15}, + {33554432, 22}, + {44236800, 2}, + {58982400, 6}, + // we don't want to cache the bucket sizes below but not caching them + // results in some major performance hits for models like sd-turbo. + {67108864, 6}, + {134217728, 6}, + {167772160, 6}, +}; + +class BucketCacheManager : public IBufferCacheManager { + public: + BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + BucketCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + // TODO: consider graph capture. currently not supported + + for (auto& buffer : pending_buffers_) { + auto buffer_size = wgpuBufferGetSize(buffer); + + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.push_back(buffer); + } else { + wgpuBufferRelease(buffer); + } + } + + pending_buffers_.clear(); + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + buckets_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + buckets_.emplace(pair.first, std::vector()); + } + std::sort(buckets_keys_.begin(), buckets_keys_.end()); + +#ifndef NDEBUG // if debug build + ORT_ENFORCE(std::all_of(buckets_keys_.begin(), buckets_keys_.end(), [](size_t size) { return size % 16 == 0; }), + "Bucket sizes must be multiples of 16."); + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_; + std::vector pending_buffers_; + std::vector buckets_keys_; +}; + +std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { + switch (cache_mode) { + case BufferCacheMode::Disabled: + return std::make_unique(); + case BufferCacheMode::LazyRelease: + return std::make_unique(); + case BufferCacheMode::Simple: + return std::make_unique(); + case BufferCacheMode::Bucket: + return std::make_unique(); + default: + ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode"); + } +} + +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { + switch (mode) { + case BufferCacheMode::Disabled: + os << "Disabled"; + break; + case BufferCacheMode::LazyRelease: + os << "LazyRelease"; + break; + case BufferCacheMode::Simple: + os << "Simple"; + break; + case BufferCacheMode::Bucket: + os << "Bucket"; + break; + default: + os << "Unknown(" << static_cast(mode) << ")"; + } + return os; +} + +BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) + : context_{context}, + storage_cache_{CreateBufferCacheManager(storage_buffer_cache_mode)}, + uniform_cache_{CreateBufferCacheManager(uniform_buffer_cache_mode)}, + query_resolve_cache_{CreateBufferCacheManager(query_resolve_buffer_cache_mode)}, + default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} { +} + +void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite; + desc.mappedAtCreation = true; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto mapped_data = staging_buffer.GetMappedRange(); + memcpy(mapped_data, src, size); + staging_buffer.Unmap(); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); + pending_staging_buffers_.push_back(staging_buffer); +} + +void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { + ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); + + auto buffer_size = NormalizeBufferSize(size); + ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + "Source and destination buffers must have enough space for the copy operation. src_size=", + wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); +} + +WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { + auto& cache = GetCacheManager(static_cast(usage)); + auto buffer_size = cache.CalculateBufferSize(size); + + auto buffer = cache.TryAcquireCachedBuffer(buffer_size); + if (buffer) { + return buffer; + } + + // cache miss, create a new buffer + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = usage; + // desc.label = std::to_string(xx++).c_str(); + buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); + + ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); + + cache.RegisterBuffer(buffer, size); + return buffer; +} + +void BufferManager::Release(WGPUBuffer buffer) { + GetCacheManager(buffer).ReleaseBuffer(buffer); +} + +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); + context_.Flush(); + + // TODO: revise wait in whole project + + ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, wgpu::CallbackMode::WaitAnyOnly, [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + + auto mapped_data = staging_buffer.GetConstMappedRange(); + memcpy(dst, mapped_data, size); +} + +void BufferManager::RefreshPendingBuffers() { + pending_staging_buffers_.clear(); + storage_cache_->OnRefresh(); + uniform_cache_->OnRefresh(); + query_resolve_cache_->OnRefresh(); + default_cache_->OnRefresh(); +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const { + if (usage & WGPUBufferUsage_Storage) { + return *storage_cache_; + } else if (usage & WGPUBufferUsage_Uniform) { + return *uniform_cache_; + } else if (usage & WGPUBufferUsage_QueryResolve) { + return *query_resolve_cache_; + } else { + return *default_cache_; + } +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const { + return GetCacheManager(wgpuBufferGetUsage(buffer)); +} + +std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) { + return std::make_unique(context, storage_buffer_cache_mode, uniform_buffer_cache_mode, query_resolve_buffer_cache_mode); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h new file mode 100644 index 0000000000000..00febfbc29f1b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +enum class BufferCacheMode { + Disabled, + LazyRelease, + Simple, + Bucket +}; +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); + +// +// IBufferCacheManager is an interface for buffer cache management. +// +// By implementing this interface, we can have different buffer cache management strategies. +// Currently, we have 3 strategies: +// - Disabled: no cache. always allocate a new buffer and release it immediately after use. +// - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh. +// - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache. +// - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket. +// +class IBufferCacheManager { + public: + virtual ~IBufferCacheManager() = default; + + // calculate actual buffer size to allocate based on the requested size. + virtual size_t CalculateBufferSize(size_t request_size) = 0; + + // return a buffer if available in cache. otherwise empty. + virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) = 0; + + // register a newly created buffer + virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size) = 0; + + // release a buffer + virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; + + // when a stream refresh is requested + virtual void OnRefresh() = 0; +}; + +// +// BufferManager manages operations on buffers. +// +class BufferManager { + public: + BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + void Upload(void* src, WGPUBuffer dst, size_t size); + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size); + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + void Release(WGPUBuffer buffer); + void Download(WGPUBuffer src, void* dst, size_t size); + void RefreshPendingBuffers(); + + private: + IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const; + IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; + + WebGpuContext& context_; + std::unique_ptr storage_cache_; + std::unique_ptr uniform_cache_; + std::unique_ptr query_resolve_cache_; + std::unique_ptr default_cache_; + + std::vector pending_staging_buffers_; +}; + +class BufferManagerFactory { + public: + static std::unique_ptr Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + private: + BufferManagerFactory() {} +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc new file mode 100644 index 0000000000000..ce4f3e49611e2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { +ComputeContext::ComputeContext(OpKernelContext& kernel_context) + : webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())}, + kernel_context_{kernel_context} { +} + +void ComputeContext::PushErrorScope() { + if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) { + webgpu_context_.Device().PushErrorScope(wgpu::ErrorFilter::Validation); + } +} + +Status ComputeContext::PopErrorScope() { + Status status{}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_ERROR(webgpu_context_.Wait( + webgpu_context_.Device().PopErrorScope( + wgpu::CallbackMode::WaitAnyOnly, [](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) { + ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped."); + if (error_type == wgpu::ErrorType::NoError) { + return; + } + *status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message); + }, + &status))); + } + return status; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h new file mode 100644 index 0000000000000..b7ea8a58e232b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include + +#include "core/framework/execution_provider.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class Tensor; + +namespace webgpu { + +class WebGpuContext; + +class ComputeContext { + public: + ComputeContext(OpKernelContext& kernel_context); + + virtual ~ComputeContext() = default; + + // + // Get various information from the context. + // + + inline const wgpu::AdapterInfo& AdapterInfo() const { + return webgpu_context_.AdapterInfo(); + } + inline const wgpu::Limits& DeviceLimits() const { + return webgpu_context_.DeviceLimits(); + } + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; + } + + // + // Get the logger. + // + inline const logging::Logger& Logger() const { + return kernel_context_.Logger(); + } + + // + // Get input tensor. + // + template + inline const T* Input(int index) const { + return kernel_context_.Input(index); + } + + // + // Get input count. + // + inline int InputCount() const { + return kernel_context_.InputCount(); + } + + // + // Set output tensor. + // + template + inline Tensor* Output(int index, TensorShapeType&& shape) { + return kernel_context_.Output(index, std::forward(shape)); + } + + // + // Get output count. + // + inline int OutputCount() const { + return kernel_context_.OutputCount(); + } + + // + // Create CPU tensor. + // + template + Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceCPUAllocator(&allocator)); + return {data_type, std::forward(shape), allocator}; + } + + // + // Create GPU tensor. + // + template + Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); + return {data_type, std::forward(shape), allocator}; + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + // + // Push error scope. + // + // This is useful only when "skip_validation" is not set. + // + void PushErrorScope(); + + // + // Pop error scope. + // + // This is useful only when "skip_validation" is not set. + // + Status PopErrorScope(); + + protected: + WebGpuContext& webgpu_context_; + OpKernelContext& kernel_context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc new file mode 100644 index 0000000000000..615ae11175782 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + if (bytes > 0) { + void const* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + context_.BufferManager().MemCpy(static_cast(const_cast(src_data)), + static_cast(dst_data), bytes); + } else { + // copy from CPU to GPU + context_.BufferManager().Upload(const_cast(src_data), static_cast(dst_data), bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); + } + } + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h new file mode 100644 index 0000000000000..f9949576aa60b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class DataTransfer : public IDataTransfer { + public: + DataTransfer(const WebGpuContext& context) : context_{context} {}; + ~DataTransfer() {}; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + + private: + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc new file mode 100644 index 0000000000000..ee7c67ec24185 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/generator/range.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +template +Status Range::ComputeInternal(ComputeContext& context) const { + T start = context.Input(0)->Data()[0]; + T limit = context.Input(1)->Data()[0]; + T delta = context.Input(2)->Data()[0]; + + GSL_SUPPRESS(io.2) // Ignore warning about potential overflow in (limit - start) + int64_t n = static_cast(ceil((1.0 * (limit - start)) / delta)); + if (n <= 0) { + n = 0; + } + auto* output_tensor = context.Output(0, TensorShape{n}); + if (n == 0) { + return Status::OK(); + } + + uint32_t output_size = gsl::narrow(n); + RangeProgram program{}; + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + output_size, + *reinterpret_cast(&start), + *reinterpret_cast(&delta), + }); + + return context.RunProgram(program); +} + +Status RangeProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let value = bitcast(uniforms.start) + output_value_t(global_idx) * bitcast(uniforms.delta);\n" + << output.SetByOffset("global_idx", "value"); + + return Status(); +} + +#define WEBGPU_RANGE_KERNEL(TYPE) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Range, \ + kOnnxDomain, \ + 11, \ + TYPE, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 0) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Range); + +WEBGPU_RANGE_KERNEL(float) +WEBGPU_RANGE_KERNEL(int32_t) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/generator/range.h b/onnxruntime/core/providers/webgpu/generator/range.h new file mode 100644 index 0000000000000..2f5812bb460ad --- /dev/null +++ b/onnxruntime/core/providers/webgpu/generator/range.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +template +class Range : public WebGpuKernel { + public: + explicit Range(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +class RangeProgram : public Program { + public: + RangeProgram() : Program{"Range"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"start", ProgramUniformVariableDataType::Uint32}, + {"delta", ProgramUniformVariableDataType::Uint32}); +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc new file mode 100644 index 0000000000000..7f7a5707afa0a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -0,0 +1,310 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/webgpu/math/binary_elementwise_ops.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + // check whether can use element-wise mode. + // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. + // In element-wise mode, no indices calculation is needed. + if (is_lhs_scalar_ || is_rhs_scalar_ || !is_broadcast_) { + // get A data + if (is_lhs_scalar_) { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("global_idx") << ";\n"; + } + + // get B data + if (is_rhs_scalar_) { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("global_idx") << ";\n"; + } + } else { + const auto& c_indices = shader.AddIndices("bcast_indices"); + // check whether can use vectorize mode. + // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode + // can be enabled. + // In vectorize mode, the source data of A and B will be loaded only once to calculate 4 output values. + // Use indices helpers to calculate the offset of A and B. + if (vectorize_) { + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + + shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + // get A data + if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("offset_a") << ");\n"; + } + + // get B data + if (b.NumComponents() == 4) { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("offset_b / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("offset_b") << ");\n"; + } + } else { + // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. + shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a0 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b0 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n" + << "let offset_a1 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b1 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n" + << "let offset_a2 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b2 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n" + << "let offset_a3 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b3 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + + // get A data + shader.MainFunctionBody() << "let a = vec4(" + << a.GetByOffset("offset_a0") << ", " + << a.GetByOffset("offset_a1") << ", " + << a.GetByOffset("offset_a2") << ", " + << a.GetByOffset("offset_a3") << ");\n"; + // get B data + shader.MainFunctionBody() << "let b = vec4(" + << b.GetByOffset("offset_b0") << ", " + << b.GetByOffset("offset_b1") << ", " + << b.GetByOffset("offset_b2") << ", " + << b.GetByOffset("offset_b3") << ");\n"; + } + } + + shader.MainFunctionBody() << c.SetByOffset("global_idx", expression_); + return Status::OK(); +} + +Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { + auto lhs_tensor = context.Input(0); + auto rhs_tensor = context.Input(1); + const auto& lhs_shape = lhs_tensor->Shape(); + const auto& rhs_shape = rhs_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); + auto output_tensor = context.Output(0, output_shape); + int64_t size = output_shape.Size(); + if (size == 0) { + return Status::OK(); + } + + bool is_broadcast = lhs_shape != rhs_shape; + bool is_lhs_scalar = lhs_shape.IsScalar(); + bool is_rhs_scalar = rhs_shape.IsScalar(); + + bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast; + bool a_last_dim_divisible_by_4 = false; + bool b_last_dim_divisible_by_4 = false; + bool shared_dimension_divisible_by_4 = false; + size_t num_shared_dimension = 0; + if (!vectorize) { + // check whether vectorize can be enabled + a_last_dim_divisible_by_4 = lhs_shape.NumDimensions() > 0 && lhs_shape[lhs_shape.NumDimensions() - 1] % 4 == 0; + b_last_dim_divisible_by_4 = rhs_shape.NumDimensions() > 0 && rhs_shape[rhs_shape.NumDimensions() - 1] % 4 == 0; + if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) { + vectorize = true; + } else { + size_t shared_dimension = 1; + for (size_t i = 1; i < output_shape.NumDimensions(); i++) { + size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; + size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; + if (dimA == dimB) { + shared_dimension *= dimA; + num_shared_dimension++; + } else { + break; + } + } + if (shared_dimension % 4 == 0) { + shared_dimension_divisible_by_4 = true; + vectorize = true; + } + } + } + + uint32_t vec_size = gsl::narrow((size + 3) / 4); + BinaryElementwiseProgram program{kernel_name_, + expression_, + is_broadcast, + is_lhs_scalar, + is_rhs_scalar, + vectorize}; + program + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}); + + if (is_lhs_scalar || is_rhs_scalar || !is_broadcast) { + // Mode Element-wise + // cache hint: "E{is_a_scalar}{is_b_scalar}" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Type, {is_lhs_scalar ? 1 : vec_size}, 4}, + {rhs_tensor, ProgramTensorMetadataDependency::Type, {is_rhs_scalar ? 1 : vec_size}, 4}}) + .CacheHint("E" + std::to_string(is_lhs_scalar) + std::to_string(is_rhs_scalar)); + } else if (vectorize) { + // reshape the dims to merge the shared dimension if available + bool need_reshape = shared_dimension_divisible_by_4 && num_shared_dimension > 1; + TensorShape reshaped_lhs_shape = need_reshape ? lhs_shape.Slice(0, lhs_shape.NumDimensions() - num_shared_dimension + 1) + : lhs_shape; + TensorShape reshaped_rhs_shape = need_reshape ? rhs_shape.Slice(0, rhs_shape.NumDimensions() - num_shared_dimension + 1) + : rhs_shape; + TensorShape reshaped_output_shape = need_reshape ? output_shape.Slice(0, output_shape.NumDimensions() - num_shared_dimension + 1) + : output_shape; + if (need_reshape) { + reshaped_lhs_shape[reshaped_lhs_shape.NumDimensions() - 1] = lhs_shape.SizeFromDimension(lhs_shape.NumDimensions() - num_shared_dimension); + reshaped_rhs_shape[reshaped_rhs_shape.NumDimensions() - 1] = rhs_shape.SizeFromDimension(rhs_shape.NumDimensions() - num_shared_dimension); + reshaped_output_shape[reshaped_output_shape.NumDimensions() - 1] = output_shape.SizeFromDimension(output_shape.NumDimensions() - num_shared_dimension); + } + + if (shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4) { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type, {(lhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({lhs_tensor, ProgramTensorMetadataDependency::Type}); + } + if (shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4) { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type, {(rhs_shape.Size() + 3) / 4}, 4}); + } else { + program.AddInput({rhs_tensor, ProgramTensorMetadataDependency::Type}); + } + // Mode Vectorize broadcast + // cache hint: "V{a_rank};{b_rank};{output_rank}" + program + .AddIndices(reshaped_output_shape) + .AddIndices(reshaped_lhs_shape) + .AddIndices(reshaped_rhs_shape) + .CacheHint("V" + absl::StrJoin({reshaped_lhs_shape.NumDimensions(), + reshaped_rhs_shape.NumDimensions(), + reshaped_output_shape.NumDimensions()}, + ";")); + } else { + // Mode Broadcast + // cache hint: "B" + program + .AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {rhs_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddIndices(output_tensor->Shape()) + .CacheHint("B"); + } + + return context.RunProgram(program); +} + +#define WEBGPU_BINARY_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public BinaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : BinaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_BINARY_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_KERNEL_2(OP_TYPE, VERSION, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +#define WEBGPU_BINARY_VERSIONED_KERNEL_2(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE, TYPE1) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", TYPE) \ + .TypeConstraint("T1", TYPE1), \ + KERNEL_CLASS); + +WEBGPU_BINARY_IMPL(Add, "a + b") +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 7, 12, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Add, 13, 13, Add, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Add, 14, Add, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Div, "a / b") +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 7, 12, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Div, 13, 13, Div, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Div, 14, Div, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Mul, "a * b") +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 7, 12, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 13, 13, Mul, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Mul, 14, Mul, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Sub, "a - b") +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4(a), vec4(b)))") +WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL_2(Pow, 15, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Equal, "vec4(a == b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 7, 10, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 11, 12, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Equal, 13, 18, Equal, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Equal, 19, Equal, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Greater, "vec4(a > b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 7, 8, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Greater, 9, 12, Greater, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Greater, 13, Greater, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(Less, "vec4(a < b)") +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 7, 8, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_VERSIONED_KERNEL(Less, 9, 12, Less, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(Less, 13, Less, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(GreaterOrEqual, "vec4(a >= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(GreaterOrEqual, 12, 15, GreaterOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(GreaterOrEqual, 16, GreaterOrEqual, WebGpuSupportedNumberTypes()) + +WEBGPU_BINARY_IMPL(LessOrEqual, "vec4(a <= b)") +WEBGPU_BINARY_VERSIONED_KERNEL(LessOrEqual, 12, 15, LessOrEqual, WebGpuSupportedNumberTypes()) +WEBGPU_BINARY_KERNEL(LessOrEqual, 16, LessOrEqual, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h new file mode 100644 index 0000000000000..84cbcdf3244d8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class BinaryElementwiseProgram final : public Program { + public: + BinaryElementwiseProgram(const std::string& kernel_name, + const std::string& expression, + const bool is_broadcast, + const bool is_lhs_scalar, + const bool is_rhs_scalar, + const bool vectorize) : Program{kernel_name}, + expression_{expression}, + is_broadcast_{is_broadcast}, + is_lhs_scalar_{is_lhs_scalar}, + is_rhs_scalar_{is_rhs_scalar}, + vectorize_{vectorize} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + std::string expression_; + bool is_broadcast_; + bool is_lhs_scalar_; + bool is_rhs_scalar_; + bool vectorize_; +}; + +class BinaryElementwise : public WebGpuKernel { + public: + BinaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression} {} + + protected: + Status ComputeInternal(ComputeContext& context) const final; + + private: + std::string kernel_name_; + std::string expression_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc new file mode 100644 index 0000000000000..8dcf63671092b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << additional_impl_; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression_); + + return Status::OK(); +} + +Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + uint32_t vec_size = gsl::narrow((size + 3) / 4); + UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }); + if (!cache_hint.empty()) { + program.CacheHint(cache_hint); + } + ORT_RETURN_IF_ERROR(ConfigureProgram(context, program)); + return context.RunProgram(program); +} + +#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public UnaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : UnaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION_FROM, VERSION_TO, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + OP_TYPE_AND_CLASS_NAME); + +// +// math +// + +WEBGPU_ELEMENTWISE_IMPL(Abs, "abs(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Neg, "-a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Neg, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Neg, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Floor, "floor(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Floor, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Ceil, "ceil(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Ceil, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Reciprocal, "1.0/a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Reciprocal, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sqrt, "sqrt(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sqrt, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Log, "log(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Log, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) + +class HardSigmoid final : public UnaryElementwise { + public: + HardSigmoid(const OpKernelInfo& info) + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderUsage::UseElementTypeAlias} { + // attr[0] is alpha, attr[1] is beta + info.GetAttrOrDefault("alpha", attr, 0.2f); + info.GetAttrOrDefault("beta", attr + 1, 0.5f); + } + + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.AddUniformVariables({gsl::make_span(attr, 2)}); + return Status::OK(); + } + + protected: + float attr[2]; +}; + +WEBGPU_ELEMENTWISE_KERNEL(HardSigmoid, 6, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sin, "sin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cos, "cos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tan, "tan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Tan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asin, "asin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acos, "acos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Atan, "atan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Atan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sinh, "sinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asinh, "asinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acosh, "acosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) + +#if __APPLE__ +// Metal returns 0 for values >= 1.0. +// Need custom impl to return +inf for 1.0 (by dividing 1 by 0), and NaN for > 1.0 (by dividing 0 by 0) +WEBGPU_ELEMENTWISE_IMPL(Atanh, + "select(" + " select(x_value_t(1.0), x_value_t(0.0), a > x_value_t(1.0)) / x_value_t(0.0)," + " atanh(a)," + " a < x_value_t(1.0))", + "", + ShaderUsage::UseValueTypeAlias) +#else +WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") +#endif +WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Not, "!a") +WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(Not, 1) + +// No longer support Clip < opset 11 (where min and max are attributes) +// +// Use template class for "Clip" because the implementation is significantly different between float16 and float32 +template +class Clip final : public UnaryElementwise { + public: + Clip(const OpKernelInfo& info) + : UnaryElementwise{info, + "Clip", + std::is_same_v ? ClipF16Impl : ClipImpl, + "", ShaderUsage::UseElementTypeAlias} {} + + Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { + const auto* clip_min_tensor = context.Input(1); + const auto* clip_max_tensor = context.Input(2); + + const T attr[] = {clip_min_tensor ? clip_min_tensor->Data()[0] + : std::numeric_limits::lowest(), + clip_max_tensor ? clip_max_tensor->Data()[0] + : std::numeric_limits::max()}; + if constexpr (std::is_same_v) { + // F16: stores span as a single float + float encoded_value = *reinterpret_cast(attr); + program.AddUniformVariable({encoded_value}); + } else { + static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32"); + // stores span as-is + program.AddUniformVariable({gsl::make_span(attr, 2)}); + } + return Status::OK(); + } + + // uniforms.attr is a f32 value. It is encoded as a float for 2 f16 values. + // bitcast>(uniforms.attr)[0] is clip_min, bitcast>(uniforms.attr)[1] is clip_max + constexpr static const char ClipF16Impl[] = "clamp(a, vec4(bitcast>(uniforms.attr)[0]), vec4(bitcast>(uniforms.attr)[1]))"; + + // the size of element of uniforms.attr should be the same as x_element_t. use bitcast to convert between them + // uniforms.attr[0] is clip_min, uniforms.attr[1] is clip_max + constexpr static const char ClipImpl[] = "clamp(a, vec4(bitcast(uniforms.attr[0])), vec4(bitcast(uniforms.attr[1])))"; +}; +#define WEBGPU_CLIP_KERNEL(TYPE) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(Clip, kOnnxDomain, 13, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip); +WEBGPU_CLIP_KERNEL(float) +WEBGPU_CLIP_KERNEL(MLFloat16) + +// +// activation +// + +class LinearUnit : public UnaryElementwise { + public: + LinearUnit(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl, + float default_alpha) + : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} { + info.GetAttrOrDefault("alpha", &alpha_, default_alpha); + } + + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.AddUniformVariables({alpha_}); + return Status::OK(); + } + + protected: + float alpha_; +}; + +#define WEBGPU_LU_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public LinearUnit { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) +WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) + +class Gelu : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) + : UnaryElementwise{info, + "Gelu", + info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, + ShaderUsage::UseValueTypeAlias} { + cache_hint = info.GetAttrOrDefault("approximate", "none"); + } +}; + +WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.attr) * a, a, a >= vec4(0))", "", 0.01f) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.attr))", "", 1.0f) +WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h new file mode 100644 index 0000000000000..70fa81d21f95d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class UnaryElementwiseProgram final : public Program { + public: + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderUsage usage) + : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size + {"attr", ProgramUniformVariableDataType::Float32}); // float type attribute(s) + // TODO: add u32/i32 attribute(s) if needed + + private: + std::string_view expression_; + std::string_view additional_impl_; + ShaderUsage additional_usage_; +}; + +// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch +// the std::string to std::string_view. This will avoid the cost of copying the string. + +class UnaryElementwise : public WebGpuKernel { + public: + UnaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl = "", + ShaderUsage usage = ShaderUsage::None) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl}, + additional_usage_{usage} {} + + protected: + std::string cache_hint; + + Status ComputeInternal(ComputeContext& context) const final; + virtual Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const { + program.AddUniformVariables({{}}); // empty for attribute(s) + return Status::OK(); + } + + private: + std::string kernel_name_; + std::string expression_; + std::string additional_impl_; + ShaderUsage additional_usage_; +}; + +constexpr const char ErfImpl[] = R"( +const r0 = 0.3275911; +const r1 = 0.254829592; +const r2 = -0.284496736; +const r3 = 1.421413741; +const r4 = -1.453152027; +const r5 = 1.061405429; + +fn erf_v(v: x_value_t) -> x_value_t { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +} +)"; + +constexpr const char HardSigmoidImpl[] = R"( +fn hard_sigmoid_v(v: vec4) -> vec4 { + let alpha = x_element_t(uniforms.attr[0]); + let beta_v = vec4(uniforms.attr[1]); + return max(vec4(0.0), + min(vec4(1.0), alpha * v + beta_v)); +} +)"; + +// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) +// https://github.com/gpuweb/gpuweb/issues/4458 +constexpr const char TanhImpl[] = R"( +fn tanh_v(a: x_value_t) -> x_value_t { + let expr = exp(-2 * abs(a)); + return sign(a) * (1 - expr) / (1 + expr); +} +)"; + +constexpr const char EluImpl[] = R"( +fn elu(a: x_element_t) -> x_element_t { + let alpha = x_element_t(uniforms.attr); + return select((exp(a) - 1.0) * alpha, a, a >= 0.0); +} + +fn elu_v(v: vec4) -> vec4 { + return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); +} +)"; + +// default GELU expression, depending on ErfImpl +constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; + +// fast GELU expression, depending on TanhImpl +constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))"; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc new file mode 100644 index 0000000000000..1ee771e945820 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -0,0 +1,155 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" + +namespace onnxruntime { +namespace webgpu { + +static int GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { + int64_t rank = static_cast(tensor_rank); + if (axis < -rank && axis >= rank) { + ORT_THROW("invalid axis: ", axis); + } + return gsl::narrow(axis < 0 ? axis + rank : axis); +} + +static std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scale", ShaderUsage::UseUniform); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + int components = x.NumComponents(); + std::string bias = (has_bias_) ? " + bias[j]" : ""; + std::string simpl1 = (simplified_) ? "" : " - mean * mean"; + std::string simpl2 = (simplified_) ? "" : " - mean"; + + shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") + << "let offset = global_idx * uniforms.norm_size_vectorized;\n" + << "var mean_vector = f32_val_t(0);\n" + << "var mean_square_vector = f32_val_t(0);\n" + << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" + << " let value = f32_val_t(x[h + offset]);\n" + << " mean_vector += value;\n" + << " mean_square_vector += value * value;\n" + << "}\n" + << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n" + << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" + << " let f32input = f32_val_t(x[j + offset]);\n" + << " let f32scale = f32_val_t(scale[j]);\n" + << " output[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" + << "}\n"; + + return Status::OK(); +} + +template +Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* scale = context.Input(1); + const auto* bias = context.Input(2); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + + const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); + const uint32_t norm_count = gsl::narrow(x_shape.SizeToDimension(axis)); + const int64_t norm_size = x_shape.SizeFromDimension(axis); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = gsl::narrow((norm_size + components - 1) / components); + + const auto scale_size = scale->Shape().Size(); + const auto bias_size = (bias) ? bias->Shape().Size() : 0; + if (scale_size != norm_size || (bias && bias_size != norm_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale and bias (if provided) must match this. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } + + LayerNormProgram program{bias != nullptr, is_fp16, simplified}; + + program + .CacheHint(simplified) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(norm_count)}, + }) + .AddUniformVariables({ + {static_cast(norm_size)}, + }) + .AddUniformVariables({ + {static_cast(norm_size_vectorized)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h new file mode 100644 index 0000000000000..17a9edbf4dd01 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class LayerNormProgram final : public Program { + public: + LayerNormProgram(bool has_bias, + bool is_fp16, + bool simplified) : Program{"LayerNorm"}, + has_bias_{has_bias}, + is_fp16_{is_fp16}, + simplified_{simplified} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"norm_count", ProgramUniformVariableDataType::Uint32}, + {"norm_size", ProgramUniformVariableDataType::Uint32}, + {"norm_size_vectorized", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool has_bias_; + bool is_fp16_; + bool simplified_; +}; + +template +class LayerNorm final : public WebGpuKernel { + public: + LayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("axis", &axis_, -1); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + info.GetAttrOrDefault("stash_type", &stash_type_, 1); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + int64_t axis_; + float epsilon_; + int64_t stash_type_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc new file mode 100644 index 0000000000000..d1d4c242c4697 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramUniformVariableValue::ProgramUniformVariableValue() + : length{0}, data_type{} {} // representing an empty uniform variable + +ProgramUniformVariableValue::ProgramUniformVariableValue(float value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, &value, sizeof(float)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(uint32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, &value, sizeof(uint32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(int32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, &value, sizeof(int32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(MLFloat16 value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, &value, sizeof(MLFloat16)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, values.data(), sizeof(float), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, values.data(), sizeof(uint32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, values.data(), sizeof(int32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, values.data(), sizeof(MLFloat16), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, + const void* ptr, + size_t element_byte_size, + size_t length /* = 1 */) + : length{length}, data_type{data_type} { + ORT_ENFORCE(length > 0, "number of element of uniform variable must be greater than 0"); + + data.resize(length * element_byte_size); + memcpy(data.data(), ptr, length * element_byte_size); +} + +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) { + os << ProgramUniformVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) { + os << ProgramConstantDataTypeName[std::underlying_type::type(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) { + bool first = true; + if ((dep & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { + os << "Type"; + first = false; + } + if ((dep & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + if (!first) os << "|"; + os << "Rank"; + first = false; + } + if ((dep & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + if (!first) os << "|"; + os << "Shape"; + first = false; + } + if (first) { + os << "None"; + } + + return os; +} + +#ifndef NDEBUG +constexpr std::string_view ProgramVariableDataTypeName[] = { + "f32", // Float32 + "f32x2", // Float32x2 + "f32x4", // Float32x4 + "f16", // Float16 + "f16x2", // Float16x2 + "f16x4", // Float16x4 + "i32", // Int32 + "i32x2", // Int32x2 + "i32x4", // Int32x4 + "u32", // Uint32 + "u32x2", // Uint32x2 + "u32x4", // Uint32x4 + "i64", // Int64 + "u64", // Uint64 + "boolx4", // Boolx4 + "u8x4", // Uint8x4 + "u8x8", // Uint8x8 + "u8x16", // Uint8x16 +}; +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { + os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} +#endif + +int NumberOfComponents(ProgramVariableDataType type) { + switch (type) { + case ProgramVariableDataType::Float32: + case ProgramVariableDataType::Int32: + case ProgramVariableDataType::Uint32: + case ProgramVariableDataType::Int64: + case ProgramVariableDataType::Uint64: + case ProgramVariableDataType::Float16: + return 1; + case ProgramVariableDataType::Float32x2: + case ProgramVariableDataType::Int32x2: + case ProgramVariableDataType::Uint32x2: + case ProgramVariableDataType::Float16x2: + return 2; + case ProgramVariableDataType::Float32x4: + case ProgramVariableDataType::Int32x4: + case ProgramVariableDataType::Uint32x4: + case ProgramVariableDataType::Float16x4: + case ProgramVariableDataType::Boolx4: + case ProgramVariableDataType::Uint8x4: + return 4; + case ProgramVariableDataType::Uint8x8: + return 8; + case ProgramVariableDataType::Uint8x16: + return 16; + default: + return -1; + } +} + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { + if (component == 1) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return ProgramVariableDataType::Int64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return ProgramVariableDataType::Uint64; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 2) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32x2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32x2; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 4) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return ProgramVariableDataType::Boolx4; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 8) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x8; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 16) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x16; + default: + return ProgramVariableDataType::InvalidType; + } + } else { + return ProgramVariableDataType::InvalidType; + } +} + +namespace { +TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { + ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, + "Cannot reduce shape ", shape.ToString(), " by component=", component); + TensorShape reduced_shape = shape; + reduced_shape[reduced_shape.NumDimensions() - 1] /= component; + return reduced_shape; +} +} // namespace + +ProgramInput::ProgramInput(const Tensor* tensor) : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramOutput::ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) + : name_{name}, + metadata_{metadata}, + dispatch_group_size_x_{0}, + dispatch_group_size_y_{0}, + dispatch_group_size_z_{0}, + workgroup_size_x_{0}, + workgroup_size_y_{0}, + workgroup_size_z_{0} { +} + +ProgramBase& ProgramBase::AddInput(ProgramInput&& input) { + inputs_.emplace_back(input); + return *this; +} + +ProgramBase& ProgramBase::AddInputs(std::initializer_list inputs) { + inputs_.insert(inputs_.end(), inputs.begin(), inputs.end()); + return *this; +} + +ProgramBase& ProgramBase::AddOutput(ProgramOutput&& output) { + outputs_.emplace_back(output); + return *this; +} + +ProgramBase& ProgramBase::AddOutputs(std::initializer_list outputs) { + outputs_.insert(outputs_.end(), outputs.begin(), outputs.end()); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) { + indices_.emplace_back(shape); + return *this; +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { + return SetDispatchGroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y) { + return SetDispatchGroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { + dispatch_group_size_x_ = x; + dispatch_group_size_y_ = y; + dispatch_group_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x) { + return SetWorkgroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y) { + return SetWorkgroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { + workgroup_size_x_ = x; + workgroup_size_y_ = y; + workgroup_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::AddUniformVariable(ProgramUniformVariableValue&& variable) { + variables_.emplace_back(variable); + return *this; +} + +ProgramBase& ProgramBase::AddUniformVariables(std::initializer_list variables) { + variables_.insert(variables_.end(), variables.begin(), variables.end()); + return *this; +} + +ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list overridable_constants) { + overridable_constants_.insert(overridable_constants_.end(), overridable_constants.begin(), overridable_constants.end()); + return *this; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h new file mode 100644 index 0000000000000..1562ec158b40a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.h @@ -0,0 +1,605 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include + +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webgpu { +class ShaderHelper; +class ComputeContext; +class WebGpuContext; + +// data type of uniform variable +enum class ProgramUniformVariableDataType { + Float32, + Float16, + Uint32, + Int32, +}; +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType); + +constexpr size_t ProgramUniformVariableDataTypeSize[] = {sizeof(float), sizeof(uint16_t), sizeof(uint32_t), sizeof(int32_t)}; + +constexpr std::string_view ProgramUniformVariableDataTypeName[] = {"f32", "f16", "u32", "i32"}; + +// represents a runtime value of a uniform variable +struct ProgramUniformVariableValue { + ProgramUniformVariableValue(); // representing an empty uniform variable + ProgramUniformVariableValue(float value); + ProgramUniformVariableValue(uint32_t value); + ProgramUniformVariableValue(int32_t value); + ProgramUniformVariableValue(MLFloat16 value); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + + size_t length; + ProgramUniformVariableDataType data_type; + std::vector data; + + private: + ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, const void* ptr, size_t element_byte_size, size_t length = 1); +}; + +// represents a uniform variable definition +struct ProgramUniformVariableDefinition { + constexpr ProgramUniformVariableDefinition(std::string_view name, ProgramUniformVariableDataType data_type) + : name{name}, data_type{data_type} {} + + std::string_view name; + ProgramUniformVariableDataType data_type; +}; + +// data type of constant +enum class ProgramConstantDataType { + Float32, + Float16, + Uint32, + Int32, + Bool +}; +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType); + +constexpr std::string_view ProgramConstantDataTypeName[] = {"f32", "f16", "u32", "i32", "bool"}; + +// represents a constant in a program +struct ProgramConstant { + constexpr ProgramConstant(std::string_view name, float value) : name{name}, type{ProgramConstantDataType::Float32}, f32{value} {} + constexpr ProgramConstant(std::string_view name, uint32_t value) : name{name}, type{ProgramConstantDataType::Uint32}, u32{value} {} + constexpr ProgramConstant(std::string_view name, int32_t value) : name{name}, type{ProgramConstantDataType::Int32}, i32{value} {} + constexpr ProgramConstant(std::string_view name, MLFloat16 value) : name{name}, type{ProgramConstantDataType::Float16}, f16{value} {} + constexpr ProgramConstant(std::string_view name, bool value) : name{name}, type{ProgramConstantDataType::Bool}, boolean{value} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; +}; + +// represents a runtime value of an overridable constant +struct ProgramOverridableConstantValue { + constexpr ProgramOverridableConstantValue() : type{}, u32{}, has_value{false} {} // representing not overriding + constexpr ProgramOverridableConstantValue(float value) : type{ProgramConstantDataType::Float32}, f32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(uint32_t value) : type{ProgramConstantDataType::Uint32}, u32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(int32_t value) : type{ProgramConstantDataType::Int32}, i32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(MLFloat16 value) : type{ProgramConstantDataType::Float16}, f16{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(bool value) : type{ProgramConstantDataType::Bool}, boolean{value}, has_value{true} {} + + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_value; +}; + +// represents an overridable constant definition. may or may not have a default value. +struct ProgramOverridableConstantDefinition { + constexpr ProgramOverridableConstantDefinition(std::string_view name, ProgramConstantDataType type) + : name{name}, type{type}, u32{}, has_default_value{false} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, float value) + : name{name}, type{ProgramConstantDataType::Float32}, f32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, uint32_t value) + : name{name}, type{ProgramConstantDataType::Uint32}, u32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, int32_t value) + : name{name}, type{ProgramConstantDataType::Int32}, i32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, MLFloat16 value) + : name{name}, type{ProgramConstantDataType::Float16}, f16{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, bool value) + : name{name}, type{ProgramConstantDataType::Bool}, boolean{value}, has_default_value{true} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_default_value; +}; + +// represents whether the program shader depends on the type, rank, or shape of an input/output tensor +enum class ProgramTensorMetadataDependency : int { + None = 0, + Type = 1, + Rank = 2, + Shape = 4, + TypeAndRank = Type | Rank, + TypeAndShape = Type | Shape, +}; +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency); + +inline ProgramTensorMetadataDependency operator|(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a | (int&)b); +} +inline ProgramTensorMetadataDependency operator&(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a & (int&)b); +} +inline ProgramTensorMetadataDependency& operator|=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a |= (int&)b); +} +inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b); +} + +constexpr SafeInt WORKGROUP_SIZE = 64; + +// data type of variable +// +// this is not a full list of all possible data types in shader programs. +// it only includes what are used in WebGPU EP. +enum class ProgramVariableDataType { + InvalidType = -1, + Float32, + Float32x2, + Float32x4, + Float16, + Float16x2, + Float16x4, + Int32, + Int32x2, + Int32x4, + Uint32, + Uint32x2, + Uint32x4, + Int64, + Uint64, + Boolx4, + Uint8x4, + Uint8x8, + Uint8x16 +}; +#ifndef NDEBUG +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); +#endif + +int NumberOfComponents(ProgramVariableDataType type); + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); + +struct ProgramInput { + ProgramInput(const Tensor* tensor); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); + + const Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +struct ProgramOutput { + ProgramOutput(Tensor* tensor); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); + + Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +enum class ValidationMode { + Disabled = 0, + WGPUOnly, + Basic, + Full +}; + +namespace details { +class ProgramWrapper; +} + +struct ProgramMetadata { + gsl::span constants; + gsl::span overridable_constants; + gsl::span uniform_variables; +}; + +class ProgramBase { + public: + // + // chain-style methods for setting properties + // + + // set the cache hint for the program + template + ProgramBase& CacheHint(T&&... hints) { + cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(hints)...), "|"); + return *this; + } + + // add a program input + ProgramBase& AddInput(ProgramInput&& input); + // add multiple program inputs + ProgramBase& AddInputs(std::initializer_list inputs); + // add a program output + ProgramBase& AddOutput(ProgramOutput&& output); + // add multiple program outputs + ProgramBase& AddOutputs(std::initializer_list outputs); + // add a program variable for indices + ProgramBase& AddIndices(const TensorShape& shape); + // add a program variable for indices + ProgramBase& AddIndices(TensorShape&& shape); + + // set the size of dispatch groups. Y and Z are 1 if not specified. + ProgramBase& SetDispatchGroupSize(uint32_t x); + // set the size of dispatch groups. Z is 1 if not specified. + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y); + // set the size of dispatch groups. + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); + + // set the size of a workgroup grid. Y and Z are 1 if not specified. + ProgramBase& SetWorkgroupSize(uint32_t x); + // set the size of a workgroup grid. Z is 1 if not specified. + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y); + // set the size of a workgroup grid. + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z); + + // add a uniform variable. + // + // the specified uniform variable should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& AddUniformVariable(ProgramUniformVariableValue&& variable); + // add multiple uniform variables. + // + // the specified uniform variables should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& AddUniformVariables(std::initializer_list variables); + + // set the overridable constants + // + // the specified overridable constants should match the overridable constant definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS. + ProgramBase& SetOverridableConstants(std::initializer_list overridable_constants); + + // + // shader code generation + // + + virtual Status GenerateShaderCode(ShaderHelper& shader) const = 0; + + // + // Properties Getters + // + + inline const std::string& Name() const { return name_; } + inline const ProgramMetadata& Metadata() const { return metadata_; } + inline const std::string& CacheHint() const { return cache_hint_; } + inline const std::vector& Inputs() const { return inputs_; } + inline const std::vector& Outputs() const { return outputs_; } + inline const std::vector& Indices() const { return indices_; } + inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } + inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } + inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } + inline uint32_t WorkgroupSizeX() const { return workgroup_size_x_; } + inline uint32_t WorkgroupSizeY() const { return workgroup_size_y_; } + inline uint32_t WorkgroupSizeZ() const { return workgroup_size_z_; } + inline const std::vector& UniformVariables() const { return variables_; } + inline const std::vector& OverridableConstants() const { return overridable_constants_; } + + protected: + virtual ~ProgramBase() = default; + + private: + // Make the constructor private to prevent direct instantiation or inheritance from this class + // Use the Program template class as base class to create a new program class + explicit ProgramBase(std::string_view name, ProgramMetadata&& metadata); + + std::string name_; + ProgramMetadata metadata_; + + std::string cache_hint_; + std::vector inputs_; + std::vector outputs_; + std::vector indices_; + + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + uint32_t workgroup_size_x_; + uint32_t workgroup_size_y_; + uint32_t workgroup_size_z_; + + std::vector variables_; + std::vector overridable_constants_; + + friend class details::ProgramWrapper; +}; + +namespace details { +// class ProgramWrapper is for accessing private constructor of ProgramBase. +// only ProgramWrapper can access the constructor of ProgramBase because ProgramWrapper is the only friend class of +// ProgramBase. This design is used to prevent direct instantiation or inheritance from ProgramBase. +class ProgramWrapper : public ProgramBase { + protected: + template + ProgramWrapper(Args&&... args) : ProgramBase{std::forward(args)...} {} +}; + +#if defined(ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK) +#error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined" +#endif + +#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ + private: \ + template \ + static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ + template \ + static auto test_has_##identifier(...)->std::false_type; \ + \ + template ::value && /* - is a const std::array */ \ + std::is_const_v && /* - has "const" modifier */ \ + !std::is_member_pointer_v>> /* - is static */ \ + static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ + template \ + static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ + \ + public: \ + static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ + static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type(0))::value + +// the following template class checks whether the type is a const std::array +template +struct is_const_std_array : std::false_type {}; +template +struct is_const_std_array> : std::true_type {}; + +// the following template class checks whether certain static members exist in the derived class (SFINAE) +template +class DerivedProgramClassTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(constants, ProgramConstant); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(overridable_constants, ProgramOverridableConstantDefinition); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(uniform_variables, ProgramUniformVariableDefinition); +}; + +// compile-time tests for the type check +// +// TODO: move this to test folder +namespace test { + +template +class TestTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +}; + +struct TestClass_Empty {}; +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_0 { + int b; +}; +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_1 { + int a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_2 { + const int a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_0 { + const int a[2]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_1 { + static constexpr int a[] = {0}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_2 { + static int a[]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_3 { + static const int a[]; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_0 { + std::array a = {1}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_1 { + static constexpr std::array a = {1, 2}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_2 { + static const std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_3 { + static constexpr const std::array a = {1, 2, 3, 4}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_4 { + static std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +} // namespace test + +#undef ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK + +} // namespace details + +template +class Program : public details::ProgramWrapper { + public: + template + Program(Args&&... args) : details::ProgramWrapper{std::forward(args)..., GetMetadata()} {} + + static ProgramMetadata GetMetadata() { + ProgramMetadata metadata; + if constexpr (details::DerivedProgramClassTypeCheck::has_constants) { + constexpr const ProgramConstant* ptr = T::constants.data(); + constexpr size_t len = T::constants.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_constants_with_correct_type, + "Derived class of \"Program\" has member \"constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_CONSTANTS() to declare constants."); + + metadata.constants = {ptr, len}; + } else { + metadata.constants = {}; + } + + if constexpr (details::DerivedProgramClassTypeCheck::has_overridable_constants) { + constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data(); + constexpr size_t len = T::overridable_constants.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type, + "Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS() to declare overridable constants."); + + metadata.overridable_constants = {ptr, len}; + } else { + metadata.overridable_constants = {}; + } + + if constexpr (details::DerivedProgramClassTypeCheck::has_uniform_variables) { + constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data(); + constexpr size_t len = T::uniform_variables.size(); + + static_assert(details::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type, + "Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() or WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES() to declare uniform variables."); + + metadata.uniform_variables = {ptr, len}; + } else { + metadata.uniform_variables = {}; + } + + return metadata; + } +}; + +namespace details { +// helper function to convert a C-style array to std::array +// +// This is basically the same as std::to_array in C++20. +// +template +constexpr auto _to_std_array_impl(T (&arr)[N], std::index_sequence) -> std::array, N> { + return {{arr[Idx]...}}; +} + +template +constexpr auto _to_std_array(T (&arr)[N]) -> std::array, N> { + return _to_std_array_impl(arr, std::make_index_sequence{}); +} + +// helper function to concatenate a std::array and a C-style array to a std::array +// +template +constexpr std::array, L + R> _concat2_impl(const std::array& lhs, + T (&rhs)[R], + std::index_sequence, + std::index_sequence) { + return {{lhs[IdxL]..., rhs[IdxR]...}}; +} + +template +constexpr std::array, L + R> _concat2(const std::array& lhs, T (&rhs)[R]) { + return _concat2_impl(lhs, rhs, std::make_index_sequence{}, std::make_index_sequence{}); +} + +} // namespace details +#define WEBGPU_PROGRAM_DEFINE_(identifier, T, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::details::_to_std_array(identifier##_own) + +#define WEBGPU_PROGRAM_EXTEND_(identifier, T, BASE, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::details::_concat2(BASE::identifier, identifier##_own) + +#define WEBGPU_PROGRAM_DEFINE_CONSTANTS(...) \ + WEBGPU_PROGRAM_DEFINE_(constants, onnxruntime::webgpu::ProgramConstant, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(constants, onnxruntime::webgpu::ProgramConstant, BASE, __VA_ARGS__) + +#define WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS(...) \ + WEBGPU_PROGRAM_DEFINE_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, BASE, __VA_ARGS__) + +#define WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(...) \ + WEBGPU_PROGRAM_DEFINE_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, BASE, __VA_ARGS__) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc new file mode 100644 index 0000000000000..a5c21563dbfcd --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/program_cache_key.h" + +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +// macro "D" - append to the ostream only in debug build +#ifndef NDEBUG // if debug build +#define D(str) << str +#else +#define D(str) +#endif + +namespace { +// append the info of an input or output to the cachekey +void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, + bool& first) { + if (first) { + first = false; + } else { + ss << '|'; + } + + if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { +#ifndef NDEBUG // if debug build + ss << var_type; +#else + ss << static_cast(var_type); +#endif + ss << ';'; + } + + if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + ss D("Dims=") << tensor.Shape().ToString(); + } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + ss D("Rank=") << tensor.Shape().NumDimensions(); + } +} +} // namespace + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { + SS(ss, kStringInitialSizeCacheKey); + + // final key format: + // =[]:::: + // + // = ||... + // = ,, + // = + // = ||... + // = + // = ||... + // = ; + ss << program.Name(); + + // append custom cache hint if any + if (auto& hint = program.CacheHint(); !hint.empty()) { + ss << '[' D("CacheHint=") << hint << ']'; + } + + // append workgroup size if overridden + if (auto x = program.WorkgroupSizeX(), y = program.WorkgroupSizeY(), z = program.WorkgroupSizeZ(); + x != 0 || y != 0 || z != 0) { + ss << ":" D("WorkgroupSize="); + // only append non-zero values. zero values are considered as use default + if (x > 0) { + ss << x; + } + ss << ","; + if (y > 0) { + ss << y; + } + ss << ","; + if (z > 0) { + ss << z; + } + } + + ss << ":" D("DispatchDim=") << (is_1d_dispatch ? "1" : "3"); + ss << ":" D("UniformSizes="); + bool first = true; + for (const auto& uniform : program.UniformVariables()) { + if (first) { + first = false; + } else { + ss << "|"; + } + if (uniform.length > 0) { + ss << uniform.length; + } + } + + ss << ":" D("Inputs="); + first = true; + for (const auto& input : program.Inputs()) { + AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first); + } + + ss << ":" D("Outputs="); + first = true; + for (const auto& output : program.Outputs()) { + AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); + } + + return SS_GET(ss); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/codegen/mti/tensor/shape_op.h b/onnxruntime/core/providers/webgpu/program_cache_key.h similarity index 51% rename from onnxruntime/core/codegen/mti/tensor/shape_op.h rename to onnxruntime/core/providers/webgpu/program_cache_key.h index 67ee2de50eca9..22ba19ebd0f25 100644 --- a/onnxruntime/core/codegen/mti/tensor/shape_op.h +++ b/onnxruntime/core/providers/webgpu/program_cache_key.h @@ -2,13 +2,15 @@ // Licensed under the MIT License. #pragma once + #include -#include + +#include "core/providers/webgpu/program.h" namespace onnxruntime { -namespace tvm_codegen { +namespace webgpu { -tvm::Tensor Shape(const tvm::Tensor& X, const std::string& name = "shape"); +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch); -} // namespace tvm_codegen +} // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc new file mode 100644 index 0000000000000..109bac34d6503 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/common.h" + +#include "core/common/common.h" +#include "core/common/logging/logging.h" + +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks) + : name{program.Name()}, + compute_pipeline{compute_pipeline}, + shape_uniform_ranks{shape_uniform_ranks} {} + +Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { + ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); + + auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; + if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { + auto size = static_cast(x) * static_cast(y) * static_cast(z); + uint32_t dispatch_avg = gsl::narrow(std::ceil(std::sqrt(size))); + if (dispatch_avg > limit_per_dimension) { + dispatch_avg = gsl::narrow(std::ceil(std::cbrt(size))); + ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); + x = y = z = dispatch_avg; + } else { + x = y = dispatch_avg; + z = 1; + } + } + return Status::OK(); +} + +Status ProgramManager::Build(const ProgramBase& program, + const ProgramMetadata& program_metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniform_ranks) const { + ShaderHelper shader_helper{program, + program_metadata, + device_, + limits_, + normalized_dispatch_x, + normalized_dispatch_y, + normalized_dispatch_z}; + ORT_RETURN_IF_ERROR(shader_helper.Init()); + + ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs()); + ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices()); + + // code is a large std::string that contains the final shader code + std::string code; + ORT_RETURN_IF_ERROR(shader_helper.GenerateSourceCode(code, shape_uniform_ranks)); + + LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] Start ===\n\n" + << code + << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] End ===\n"; + + wgpu::ShaderModuleWGSLDescriptor wgsl_descriptor{}; + wgsl_descriptor.code = code.c_str(); + + wgpu::ShaderModuleDescriptor descriptor{}; + descriptor.nextInChain = &wgsl_descriptor; + + auto shader_module = device_.CreateShaderModule(&descriptor); + + // TODO: a new cache hierarchy for constants. + // + // Explaination: + // Currently, we use Uniforms for dynamic data. This helps to reduce the number of program artifacts. + // + // "dynamic data" here means the data the determined at runtime, such as the shape of the input tensor. + // + // However, some programs may not necessarily depend on dynamic data. For example, "Clip" may depend on the value of "min" and "max". + // We are using uniforms for the value of "min" and "max" in the current implementation, but usually "min" and "max" are determined + // earlier because they are either from Attributes or from the initializers of the model. + // + // Questions: + // - can we use one instance of ShaderModule to create multiple ComputePipeline? + // - is there any benefit to do so compared to the current implementation? + // + + // process overridable constants if available + size_t constant_count = program.OverridableConstants().size(); + + // making a copy of the constant names is required because they are stored as std::string_view in the program + // metadata. A value of std::string_view is not guaranteed to be a C-stlye string (null-terminated) and hence + // cannot be used directly in the WebGPU API (which expects a const char*). + std::vector constant_names; + constant_names.reserve(constant_count); + std::vector constant_entries; + constant_entries.reserve(constant_count); + for (size_t i = 0; i < constant_count; ++i) { + const auto& constant_override = program.OverridableConstants()[i]; + const auto& constant_def = program_metadata.overridable_constants[i]; + + if (constant_override.has_value) { + double value = 0; + switch (constant_override.type) { + case ProgramConstantDataType::Bool: + value = constant_override.boolean ? 1 : 0; + break; + case ProgramConstantDataType::Float16: + // convert f16(MLFloat16) -> f32(float) -> f64(double) + // because the value of a constant must be a double in WebGPU API, it is expensive to use f16 overridable constants. + value = constant_override.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + value = constant_override.f32; + break; + case ProgramConstantDataType::Int32: + value = constant_override.i32; + break; + case ProgramConstantDataType::Uint32: + value = constant_override.u32; + break; + } + + const auto& name_string = constant_names.emplace_back(constant_def.name); + wgpu::ConstantEntry entry{}; + entry.key = name_string.c_str(); + entry.value = value; + constant_entries.push_back(std::move(entry)); + } + } + + wgpu::ProgrammableStageDescriptor compute_stage{}; + compute_stage.module = shader_module; + compute_stage.entryPoint = "main"; + if (!constant_entries.empty()) { + compute_stage.constants = constant_entries.data(); + compute_stage.constantCount = constant_entries.size(); + } + + wgpu::ComputePipelineDescriptor pipeline_descriptor{}; + pipeline_descriptor.compute = compute_stage; +#ifndef NDEBUG // if debug build + pipeline_descriptor.label = program.Name().c_str(); +#endif + + compute_pipeline = device_.CreateComputePipeline(&pipeline_descriptor); + + return Status(); +} + +const ProgramArtifact* ProgramManager::Get(const std::string& key) const { + auto result = programs_.find(key); + if (result != programs_.end()) { + return &result->second; + } + + return nullptr; +} + +const ProgramArtifact* ProgramManager::Set(const std::string& key, ProgramArtifact&& program) { + return &(programs_.emplace(key, std::move(program)).first->second); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h new file mode 100644 index 0000000000000..eded1cfa17970 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { + +class ProgramArtifact { + public: + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks); + + const std::string name; + const wgpu::ComputePipeline compute_pipeline; + const std::vector shape_uniform_ranks; + + ProgramArtifact(ProgramArtifact&&) = default; + ProgramArtifact& operator=(ProgramArtifact&&) = delete; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact); +}; + +class ProgramManager { + public: + ProgramManager(const wgpu::Device& device, const wgpu::Limits& limits) : device_(device), limits_(limits) {} + + Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const; + + Status Build(const ProgramBase& program, + const ProgramMetadata& metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniform_ranks) const; + const ProgramArtifact* Get(const std::string& key) const; + const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); + + private: + std::unordered_map programs_; + const wgpu::Device& device_; + const wgpu::Limits& limits_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc new file mode 100644 index 0000000000000..5685494556248 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -0,0 +1,530 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/string_utils.h" +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +ShaderHelper::ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z) + : device_{device}, + limits_{limits}, + dispatch_group_size_x_{dispatch_group_size_x}, + dispatch_group_size_y_{dispatch_group_size_y}, + dispatch_group_size_z_{dispatch_group_size_z}, + program_{program}, + program_metadata_{program_metadata}, + additional_implementation_ss_{&additional_implementation_}, + body_ss_{&body_} {} + +Status ShaderHelper::Init() { + // dispatch group size is normalized so no need to validate it here + + // validate workgroup size + auto workgroup_size_x = program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ(); + + ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && + workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && + workgroup_size_z <= limits_.maxComputeWorkgroupSizeZ, + "Workgroup size exceeds the maximum allowed size [", + limits_.maxComputeWorkgroupSizeX, ", ", + limits_.maxComputeWorkgroupSizeY, ", ", + limits_.maxComputeWorkgroupSizeZ, "]"); + + ORT_RETURN_IF_NOT(workgroup_size_x * workgroup_size_y * workgroup_size_z <= limits_.maxComputeInvocationsPerWorkgroup, + "Workgroup size exceeds the maximum allowed invocations ", limits_.maxComputeInvocationsPerWorkgroup); + + // init body string stream + bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; + body_.reserve(4096); + additional_implementation_.reserve(1024); + + // append header for main function so it is ready for user to append main function body + body_ss_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" + "fn main(@builtin(global_invocation_id) global_id : vec3,\n" + " @builtin(workgroup_id) workgroup_id : vec3,\n" + " @builtin(local_invocation_index) local_idx : u32,\n" + " @builtin(local_invocation_id) local_id : vec3"; + if (!is_1d_dispatch) { + body_ss_ << ",\n" + " @builtin(num_workgroups) num_workgroups : vec3"; + } + body_ss_ << ") {\n"; + if (is_1d_dispatch) { + body_ss_ << " let global_idx = global_id.x;\n" + " let workgroup_idx = workgroup_id.x;\n"; + } else { + body_ss_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + } + + return Status::OK(); +} + +const ShaderVariableHelper& ShaderHelper::AddInput(const std::string& name, ShaderUsage usage) { + const size_t input_index = input_vars_.size(); + ORT_ENFORCE(input_index < program_.Inputs().size(), + "Too many inputs in the program (", program_.Inputs().size(), ")"); + + const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape + : program_.Inputs()[input_index].tensor->Shape(); + return AddVariableImpl(true, name, usage, dims); +} + +const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, ShaderUsage usage) { + const size_t output_index = output_vars_.size(); + ORT_ENFORCE(output_index < program_.Outputs().size(), + "Too many outputs in the program (", program_.Outputs().size(), ")"); + + const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape + : program_.Outputs()[output_index].tensor->Shape(); + return AddVariableImpl(false, name, usage, dims); +} + +const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, bool use_uniform) { + const size_t indices_index = indices_vars_.size(); + return *indices_vars_.emplace_back( + std::make_unique(name, + ProgramVariableDataType::InvalidType, + use_uniform ? ShaderUsage::UseUniform : ShaderUsage::None, + program_.Indices()[indices_index])); +} + +#ifndef NDEBUG // if debug build +namespace { +// Validate if the tensor element type matches the program variable data type +Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || + var_type == ProgramVariableDataType::Float32x2 || + var_type == ProgramVariableDataType::Float32x4, + "Unexpected program variable type ", int(var_type), " for float32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 || + var_type == ProgramVariableDataType::Float16x2 || + var_type == ProgramVariableDataType::Float16x4, + "Unexpected program variable type ", int(var_type), " for float16 tensor"); + + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || + var_type == ProgramVariableDataType::Int32x2 || + var_type == ProgramVariableDataType::Int32x4, + "Unexpected program variable type ", int(var_type), " for int32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 || + var_type == ProgramVariableDataType::Uint32x2 || + var_type == ProgramVariableDataType::Uint32x4, + "Unexpected program variable type ", int(var_type), " for uint32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int64, + "Unexpected program variable type ", int(var_type), " for int64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint64, + "Unexpected program variable type ", int(var_type), " for uint64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Boolx4, + "Unexpected program variable type ", int(var_type), " for bool tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint8x4 || + var_type == ProgramVariableDataType::Uint8x8 || + var_type == ProgramVariableDataType::Uint8x16, + "Unexpected program variable type ", int(var_type), " for uint8 tensor"); + break; + default: + ORT_RETURN_IF(true, "Unsupported data type: ", element_type); + // todo: add int4/uint4 + } + return Status::OK(); +} + +// Validate if the number of components and override shape match the original shape +Status ValidateVariableShape(const TensorShape& origin_shape, + bool use_override_shape, + const TensorShape& override_shape, + int num_components) { + if (use_override_shape) { + // if override shape specified, assert override_size == ceil( origin_size / 4 ) + ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(), + "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); + } + + return Status::OK(); +} + +// Validate if the dependency and variable usage match +Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderUsage usage, bool is_input) { + bool dependency_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool dependency_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + bool dependency_type = (dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type; + + // if dependency is already set for shape, it is no need to set for rank. + ORT_RETURN_IF(dependency_rank && dependency_shape, + "Dependency cannot set for both \"Rank\" and \"Shape\"."); + + // if dependency is set for shape, it's already part of the shader cache. no need to use uniform. + ORT_RETURN_IF(dependency_shape && (usage & ShaderUsage::UseUniform), + "Dependency is set for \"Shape\", using uniform for shape is not allowed."); + + // for input variable, check is more strict. + // this is because usually output shape is determined by the existing information, which is already part of the shader cache. + if (is_input) { + // if dependency is not set for type, should not use type alias for element and value. + // storage type is always used. so setting not depending on type is at user's own risk. + ORT_RETURN_IF(!dependency_type && (usage & (ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias)), + "Input dependency is not set for \"Type\", but type alias for element type or value type is used."); + + // if dependency is not set for rank and shape, the shader should not use shape and stride. + ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderUsage::UseShapeAndStride), + "Input dependency is set for neither \"Rank\" nor \"Shape\", but variable shape and stride is used."); + } + + return Status::OK(); +} +} // namespace + +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), + input.use_override_shape, + input.use_override_shape ? input.override_shape : input.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); + + return Status::OK(); +} +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), + output.use_override_shape, + output.use_override_shape ? output.override_shape : output.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); + + return Status::OK(); +} + +#endif // NDEBUG + +const ShaderVariableHelper& ShaderHelper::AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims) { + ORT_ENFORCE(input_vars_.size() + output_vars_.size() < limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); + + ProgramVariableDataType type = ProgramVariableDataType::InvalidType; + auto& vars = is_input ? input_vars_ : output_vars_; + + if (is_input) { + const auto& input = program_.Inputs()[vars.size()]; + type = input.var_type; + } else { + const auto& output = program_.Outputs()[vars.size()]; + type = output.var_type; + } + + const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); + return *var; +} + +Status ShaderHelper::ValidateShapeForInputs() const { + // Validate input as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(input_vars_.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", input_vars_.size(), ", Program: ", program_.Inputs().size()); + for (size_t i = 0; i < input_vars_.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate input shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars_[i])); +#endif + + // check input dependencies with actual usages. + auto usage = input_vars_[i]->usage_; + auto dependency = program_.Inputs()[i].dependency; + bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + ORT_RETURN_IF_NOT((use_rank || input_vars_[i]->rank_ < 2) && !use_shape, + "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); + // If you want neither hard-coded shape nor shape uniform, use a flattened shape (rank=1). + // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + } + } + } + return Status::OK(); +} + +Status ShaderHelper::ValidateShapeForOutputs() const { + // Validate output as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(output_vars_.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", output_vars_.size(), ", Program: ", program_.Outputs().size()); + + for (size_t i = 0; i < output_vars_.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate output shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars_[i])); +#endif + + // check output dependencies with actual usages. + auto usage = output_vars_[i]->usage_; + auto dependency = program_.Outputs()[i].dependency; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (usage & ShaderUsage::UseShapeAndStride) { + if (usage & ShaderUsage::UseUniform) { + // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not + // necessarily a part of the cache key. + ORT_RETURN_IF_NOT(!use_shape, + "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + } + } + } + return Status::OK(); +} + +Status ShaderHelper::ValidateIndices() const { + ORT_RETURN_IF_NOT(indices_vars_.size() == program_.Indices().size(), + "Mismatched indices variable count. Shader: ", indices_vars_.size(), ", Program: ", program_.Indices().size()); + + return Status::OK(); +} + +Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { + SS(ss, kStringInitialSizeShaderSourceCode); + + // + // Section feature enabling + // + if (std::any_of(program_.Inputs().begin(), + program_.Inputs().end(), + [](const ProgramInput& input) { + return input.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + }) || + std::any_of(program_.Outputs().begin(), + program_.Outputs().end(), + [](const ProgramOutput& output) { + return output.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + })) { + ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); + ss << "enable f16;\n"; + if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { + ss << "enable subgroups_f16;\n"; + } + } + if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { + ss << "enable subgroups;\n"; + } + + // + // Section constants + // + ss << "const workgroup_size_x: u32 = " << (program_.WorkgroupSizeX() == 0 ? uint32_t(WORKGROUP_SIZE) : program_.WorkgroupSizeX()) + << ";\nconst workgroup_size_y: u32 = " << (program_.WorkgroupSizeY() == 0 ? uint32_t(1) : program_.WorkgroupSizeY()) + << ";\nconst workgroup_size_z: u32 = " << (program_.WorkgroupSizeZ() == 0 ? uint32_t(1) : program_.WorkgroupSizeZ()) + << ";\n"; + + for (const auto& constant : program_metadata_.constants) { + ss << "const " << constant.name << ": " << constant.type << " = "; + WriteConstantValue(ss, constant); + ss << ";\n"; + } + + size_t override_constant_count = program_metadata_.overridable_constants.size(); + for (size_t i = 0; i < override_constant_count; ++i) { + // size and type are previously checked to match + const auto& constant_def = program_metadata_.overridable_constants[i]; + const auto& constant_override = program_.OverridableConstants()[i]; + + ss << "override " << constant_def.name << ": " << constant_def.type << " = "; + if (constant_override.has_value) { + WriteConstantValue(ss, constant_override); + } else { + WriteConstantValue(ss, constant_def); + } + ss << ";\n"; + } + + // + // Input/output variables + // + size_t variable_count = 0; + for (const auto& input : input_vars_) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; + } + for (const auto& output : output_vars_) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; + } + + // + // uniform variables + // + + // store shape uniform ranks in shape_uniform_ranks + bool use_any_shape_uniform = false; + ORT_ENFORCE(shape_uniform_ranks.size() == 0); + shape_uniform_ranks.reserve(input_vars_.size() + output_vars_.size() + indices_vars_.size()); + + for (const auto& input : input_vars_) { + bool use_uniform = (input->usage_ & ShaderUsage::UseUniform) && + (input->usage_ & ShaderUsage::UseShapeAndStride) && + input->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? input->rank_ : 0); + } + for (const auto& output : output_vars_) { + bool use_uniform = (output->usage_ & ShaderUsage::UseUniform) && + (output->usage_ & ShaderUsage::UseShapeAndStride) && + output->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? output->rank_ : 0); + } + for (const auto& indices : indices_vars_) { + bool use_uniform = (indices->usage_ & ShaderUsage::UseUniform) && + (indices->usage_ & ShaderUsage::UseShapeAndStride) && + indices->rank_ > 0; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? indices->rank_ : 0); + } + + if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), + program_.UniformVariables().cend(), + [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { + bool first = true; + ss << "struct Uniforms {"; + + // lambda append_uniform is used to append one uniform variable to the uniform struct + auto append_uniform = [&ss, &first](std::string_view name, ProgramUniformVariableDataType data_type, size_t length) { + if (length == 0) { + return; + } + + if (first) { + first = false; + } else { + ss << ","; + } + + auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; + ss << "\n " << alignment << name << ": "; + if (length > 4) { + if (data_type == ProgramUniformVariableDataType::Float16) { + size_t array_size = (length + 7) / 8; + ss << "array, " << array_size << ">"; + } else { + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; + } + } else if (length > 1) { + ss << "vec" << length << "<" << data_type << ">"; + } else { + ss << data_type; + } + }; + + for (const auto& input : input_vars_) { + const size_t rank = input->rank_; + if (rank > 0 && (input->usage_ & ShaderUsage::UseUniform) && (input->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = input->name_ + "_shape"; + std::string stride = input->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (const auto& output : output_vars_) { + const size_t rank = output->rank_; + if (rank > 0 && (output->usage_ & ShaderUsage::UseUniform) && (output->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = output->name_ + "_shape"; + std::string stride = output->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (const auto& indices : indices_vars_) { + const size_t rank = indices->rank_; + if (rank > 0 && (indices->usage_ & ShaderUsage::UseUniform) && (indices->usage_ & ShaderUsage::UseShapeAndStride)) { + std::string shape = indices->name_ + "_shape"; + std::string stride = indices->name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); + } + } + + for (size_t i = 0; i < program_.UniformVariables().size(); i++) { + const auto& uniform_def = program_metadata_.uniform_variables[i]; + const auto& uniform_value = program_.UniformVariables()[i]; + append_uniform(uniform_def.name, uniform_def.data_type, uniform_value.length); + } + + ss << "\n};\n" + "@group(0) @binding(" + << variable_count << ") var uniforms: Uniforms;\n"; + } + + // + // Indices helper + // + ss << "\n"; + for (const auto& var : input_vars_) { + var->Impl(ss); + } + for (const auto& var : output_vars_) { + var->Impl(ss); + } + for (const auto& var : indices_vars_) { + var->Impl(ss); + } + ss << "\n"; + + // + // Additional Implementation + // + ss << additional_implementation_; + + // + // Main Function Body + // + ss << body_; + ss << "\n" + "}\n"; + + code = SS_GET(ss); + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h new file mode 100644 index 0000000000000..a4b96edc63c74 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/string_utils.h" + +namespace onnxruntime { +namespace webgpu { + +class ShaderHelper final { + // The content of a shader code is composed of the following parts: + // + // ** + // ** section: feature sets definition + // ** + // // this sections enable features like "enable f16;". need to be defined at the beginning of the shader. + // + // ** + // ** section: constants and overridable constants + // ** + // // this section defines constants and overridable constants. + // - constants are defined as "const a:f32 = 1.0;". It's hard coded in the shader. + // - overridable constants are defined as "override a:f32 = 1.0;" (may override or not) + // or "override b:u32;" (must override) + // the value can be overriden by pipeline creation config. + // + // ** + // ** section: inputs and outputs + // ** + // // this section defines input and output variables. + // user can call shader_helper.AddVariable() to add input and output variables. + // + // ** + // ** section: uniforms + // ** + // // this section defines uniform type and variables. + // + // ** + // ** section: indices helper generated utility functions + // ** + // // this section defines utility functions to calculate indices. + // + // ** + // ** section: additional implementation + // ** + // // this section contains additional implementation provided by the user. + // user can call shader_helper.AppendImplementation() to append additional implementation. + // + // ** + // ** section: main function + // ** + // // this section contains the main function of the shader. + // user can call shader_helper.MainFunctionBody() to set the main function body. + // + + public: + ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z); + + Status Init(); + + // Add an input variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. + const ShaderVariableHelper& AddInput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an output variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. + const ShaderVariableHelper& AddOutput(const std::string& name, + ShaderUsage usage = ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseUniform); + + // Add an indices variable to the shader. + const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform = true); + + // Get the string stream for additional implementation code to the shader. + inline OStringStream& AdditionalImplementation() { + return additional_implementation_ss_; + } + + // Get the string stream for the main function body of the shader. + inline OStringStream& MainFunctionBody() { + return body_ss_; + } + + std::string GuardAgainstOutOfBoundsWorkgroupSizes(std::string_view size) const { + return MakeStringWithClassicLocale(" if (global_idx >= ", size, ") { return; }\n"); + } + + private: + template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} + void WriteConstantValue(std::ostream& ss, const ConstantType& constant) const { + switch (constant.type) { + case ProgramConstantDataType::Float16: + ss << constant.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + ss << constant.f32; + break; + case ProgramConstantDataType::Int32: + ss << constant.i32; + break; + case ProgramConstantDataType::Uint32: + ss << constant.u32; + break; + case ProgramConstantDataType::Bool: + ss << (constant.boolean ? "true" : "false"); + break; + default: + ORT_THROW("Invalid constant type", constant.type); + } + } + + const ShaderVariableHelper& AddVariableImpl(bool is_input, + const std::string& name, + ShaderUsage usage, + const TensorShape& dims); + +#ifndef NDEBUG // if debug build + Status ValidateVariable(const ProgramInput& input, const ShaderVariableHelper& var) const; + Status ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const; +#endif + + Status ValidateShapeForInputs() const; + Status ValidateShapeForOutputs() const; + Status ValidateIndices() const; + + // Generate source code. + // + // This function: + // - performs validation if neccessary, + // - appends the ranks for variables to the shape_uniform_ranks. + // (The rank value is zero if no uniform is needed for the variable.) + // - generates the final source code. + // + // \param code The generated full WGSL source code. + // \param shape_uniform_ranks The ranks for variables that need a uniform for the shape. + // + Status GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const; + friend class ProgramManager; + + const wgpu::Device& device_; + const wgpu::Limits& limits_; + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + const ProgramBase& program_; + const ProgramMetadata& program_metadata_; + + std::vector> input_vars_; + std::vector> output_vars_; + std::vector> indices_vars_; + std::string additional_implementation_; + OStringStream additional_implementation_ss_; + std::string body_; + OStringStream body_ss_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc new file mode 100644 index 0000000000000..15020b801c97d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/providers/webgpu/shader_variable.h" + +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +constexpr static const std::string_view STORAGE_TYPE_ARRAY[] = { + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "vec2", // Int64 + "vec2", // Uint64 + "u32", // Boolx4 + "u32", // Uint8x4 + "vec2", // Uint8x8 + "vec4", // Uint8x16 +}; +constexpr static const auto STORAGE_TYPE = details::_to_std_array(STORAGE_TYPE_ARRAY); + +constexpr static const std::string_view VALUE_TYPE_ARRAY[] = { + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "i32", // Int64 (trancated to i32) + "u32", // Uint64 (trancated to u32) + "vec4", // Boolx4 + "u32", // Uint8x4 (u32 as 4 elements of uint8) + "vec2", // Uint8x8 (vec2 as 2x4 elements of uint8) + "vec4", // Uint8x16 (vec4 as 4x4 elements of uint8) +}; +constexpr static const auto VALUE_TYPE = details::_to_std_array(VALUE_TYPE_ARRAY); + +constexpr static const std::string_view ELEMENT_TYPE_ARRAY[] = { + "f32", // Float32 + "f32", // Float32x2 + "f32", // Float32x4 + "f16", // Float16 + "f16", // Float16x2 + "f16", // Float16x4 + "i32", // Int32 + "i32", // Int32x2 + "i32", // Int32x4 + "u32", // Uint32 + "u32", // Uint32x2 + "u32", // Uint32x4 + "i32", // Int64 + "u32", // Uint64 + "bool", // Boolx4 + "u32", // Uint8x4 + "u32", // Uint8x8 + "u32", // Uint8x16 +}; +constexpr static const auto ELEMENT_TYPE = details::_to_std_array(ELEMENT_TYPE_ARRAY); + +inline std::string GetIndicesType(int rank) { + return rank < 2 ? "u32" + : (rank <= 4 ? MakeStringWithClassicLocale("vec", rank, "") + : MakeStringWithClassicLocale("array")); +} + +} // namespace + +ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : name_(name), + type_(type), + num_components_{NumberOfComponents(type)}, + rank_{gsl::narrow(dims.NumDimensions())}, + dims_{dims}, + usage_(usage), + indices_type_{GetIndicesType(rank_)}, + value_type_alias_{name_ + "_value_t"}, + element_type_alias_{name_ + "_element_t"}, + indices_type_alias_{name_ + "_indices_t"} {} + +ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims) + : ShaderIndicesHelper{name, type, usage, dims} { + ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); + ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); +} + +void ShaderIndicesHelper::Impl(std::ostream& ss) const { + // Start generating code + + const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; + const std::string stride = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; + + // Types + if (usage_ & ShaderUsage::UseValueTypeAlias) { + SS_APPEND(ss, "alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); + } + if (usage_ & ShaderUsage::UseIndicesTypeAlias) { + SS_APPEND(ss, "alias ", indices_type_alias_, " = ", indices_type_, ";\n"); + } + if (usage_ & ShaderUsage::UseElementTypeAlias) { + SS_APPEND(ss, "alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); + } + + // Need shape and strides when (not use uniform) and (use shape and stride is enabled) + if (!(usage_ & ShaderUsage::UseUniform) && (usage_ & ShaderUsage::UseShapeAndStride) && rank_ > 0) { + SS_APPEND(ss, "const ", shape, " = ", IndicesType(), "("); + + bool first = true; + for (auto dim : dims_.GetDims()) { + if (!first) { + ss << ","; + } + + ss << dim; + first = false; + } + ss << ");\n"; + + if (rank_ > 1) { + SS_APPEND(ss, "const ", stride, " = ", GetIndicesType(rank_ - 1), "("); + first = true; + for (int i = 1; i < rank_; i++) { + if (!first) { + ss << ","; + } + ss << dims_.SizeFromDimension(i); + first = false; + } + ss << ");\n"; + } + } + + // Implementation of "fn o2i_{name}" + if (usage_ & ShaderUsage::UseOffsetToIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); + SS_APPEND(ss, " var indices: ", IndicesType(), ";\n"); + SS_APPEND(ss, " var current = offset;\n"); + for (int i = 0; i < rank_ - 1; i++) { + auto current_stride = GetElementAt(stride, i, rank_ - 1); + SS_APPEND(ss, " let dim", i, " = current / ", current_stride, ";\n"); + SS_APPEND(ss, " let rest", i, " = current % ", current_stride, ";\n"); + SS_APPEND(ss, " indices[", i, "] = dim", i, ";\n"); + SS_APPEND(ss, " current = rest", i, ";\n"); + } + SS_APPEND(ss, " indices[", rank_ - 1, "] = current;\n"); + SS_APPEND(ss, " return indices;\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn i2o_{name}" + if (usage_ & ShaderUsage::UseIndicesToOffset) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); + SS_APPEND(ss, " return "); + for (int i = 0; i < rank_ - 1; i++) { + SS_APPEND(ss, "indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); + } + SS_APPEND(ss, "indices[", rank_ - 1, "];\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn {res_name}_bi2o_{name}" + if (usage_ & ShaderUsage::UseBroadcastedIndicesToOffset) { + if (rank_ > 0) { + for (const auto& broadcasted_result_ptr : broadcasted_to_) { + const auto& broadcasted_result = *broadcasted_result_ptr; + SS_APPEND(ss, "fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); + if (rank_ == 1) { + SS_APPEND(ss, " return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); + } else { + SS_APPEND(ss, " return "); + for (int i = 0; i < rank_ - 1; i++) { + auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); + std::string current_stride = rank_ == 2 ? stride : GetElementAt(stride, i, rank_ - 1); + SS_APPEND(ss, current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); + } + SS_APPEND(ss, broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + } + SS_APPEND(ss, "}\n"); + } + } + } +} + +void ShaderVariableHelper::Impl(std::ostream& ss) const { + ShaderIndicesHelper::Impl(ss); + + // Implementation of "fn set_{name}" + if (usage_ & ShaderUsage::UseSet) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn set_", name_, "(d0: u32"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i, ": u32"); + } + SS_APPEND(ss, ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " set_", name_, "_by_indices(d0"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i); + } + SS_APPEND(ss, ", value);\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn set_{name}_by_indices" + if (usage_ & ShaderUsage::UseSetByIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn get_{name}" + if (usage_ & ShaderUsage::UseGet) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn get_", name_, "(d0: u32"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i, ": u32"); + } + SS_APPEND(ss, ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return get_", name_, "_by_indices(d0"); + for (int i = 1; i < rank_; i++) { + SS_APPEND(ss, ", d", i); + } + SS_APPEND(ss, ");\n"); + SS_APPEND(ss, "}\n"); + } + } + + // Implementation of "fn get_{name}_by_indices" + if (usage_ & ShaderUsage::UseGetByIndices) { + if (rank_ >= 2) { + SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS_APPEND(ss, "}\n"); + } + } +} + +std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const { + SS(ss, kStringInitialSizeGetByOffsetImpl); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << ElementType() << "(" << name_ << "[" << offset << "].x)"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: + ss << "vec4(bool(" + << name_ << "[" << offset << "] & 0xFFu), bool(" + << name_ << "[" << offset << "] & 0xFF00u), bool(" + << name_ << "[" << offset << "] & 0xFF0000u), bool(" + << name_ << "[" << offset << "] & 0xFF000000u))"; + break; + default: + ss << name_ << "[" << offset << "]"; + } + + return SS_GET(ss); +} + +std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::string_view value) const { + SS(ss, kStringInitialSizeSetByOffsetImpl); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), select(0u, 0xFFFFFFFFu, " << value << " < 0));"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), 0u);"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: + ss << name_ << "[" << offset << "]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(" << value << "));"; + break; + default: + ss << name_ << "[" << offset << "]=" << value << ";"; + } + + return SS_GET(ss); +} + +std::string_view ShaderVariableHelper::StorageType() const { + return STORAGE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariableHelper::ValueType() const { + return (usage_ & ShaderUsage::UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariableHelper::ElementType() const { + return (usage_ & ShaderUsage::UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; +} + +std::string_view ShaderIndicesHelper::IndicesType() const { + return (usage_ & ShaderUsage::UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; +} +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h new file mode 100644 index 0000000000000..4c87bc9158890 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +template || std::is_same_v>> +std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { + // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. + if (var.rfind("uniforms.", 0) == 0) { + if (rank > 4) { + if constexpr (std::is_integral_v) { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + } else { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + } + } else { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + } else { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); + } + } + } + } + + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : std::string{var}; +} + +struct ShaderUsage { + enum : uint32_t { + None = 0, // no usage. this means no additional implementation code will be generated. + UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) + UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) + UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) + UseShapeAndStride = 16, // use shape and stride for the variable + UseOffsetToIndices = 32, // use implementation of fn o2i_{name} + UseIndicesToOffset = 64, // use implementation of fn i2o_{name} + UseBroadcastedIndicesToOffset = 128, // use implementation of fn {broadcasted_result_name}_bi2o_{name} + UseSet = 256, // use implementation of fn set_{name} + UseSetByIndices = 512, // use implementation of fn set_{name}_by_indices + UseGet = 1024, // use implementation of fn get_{name} + UseGetByIndices = 2048, // use implementation of fn get_{name}_by_indices + UseUniform = 32768, // use uniform for shape and stride + } usage; + + ShaderUsage(decltype(usage) usage) : usage{usage} {} + ShaderUsage(uint32_t usage) : usage{usage} {} + + explicit operator bool() { + return usage != None; + } +}; + +// A helper class to make it easier to generate shader code related to indices calculation. +class ShaderIndicesHelper { + public: + ShaderIndicesHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderIndicesHelper(ShaderIndicesHelper&&) = default; + ShaderIndicesHelper& operator=(ShaderIndicesHelper&&) = default; + + // get the number of components of the variable. + inline int NumComponents() const { return num_components_; } + + // get the rank of the indices. + inline int Rank() const; + + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. + // \param offset: a WGSL expression (u32) representing the offset. + inline std::string OffsetToIndices(std::string_view offset_expr) const; + + // create a WGSL expression (u32) for getting offset from indices. + // \param indices: a WGSL expression ({varname}_indices_t) representing the indices. + inline std::string IndicesToOffset(std::string_view indices_expr) const; + + // create a WGSL expression (u32) for getting original offset from broadcasted indices. + // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. + // \param broadcasted_result: the broadcasted result variable. + inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const; + + // create a WGSL expression ({varname}_indices_t) as an indices literal + // \param init: a list of indices values. + template + inline std::string Indices(TIndices&&... indices_args) const; + + // create a WGSL statement for setting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to set. + // \param value: the value (u32) to set. + template + inline std::string IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const; + + // create a WGSL expression (u32) for getting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to get. + template + inline std::string IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const; + + protected: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper); + + void Impl(std::ostream& ss) const; + + std::string_view IndicesType() const; + + std::string name_; + ProgramVariableDataType type_; // for variable + int num_components_; // for variable + int rank_; + TensorShape dims_; + + mutable ShaderUsage usage_; + // the pointers stored here are owned by the ShaderHelper instance that also owns this ShaderIndicesHelper instance. + // these instances are kept valid during the lifetime of the ShaderHelper instance. + mutable std::set broadcasted_to_; + + // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. + std::string indices_type_; + + // the alias for the types + std::string value_type_alias_; + std::string element_type_alias_; + std::string indices_type_alias_; + + friend class ShaderHelper; +}; + +// A helper class to make it easier to generate shader code related to a variable setting/getting and its indices calculation. +class ShaderVariableHelper : public ShaderIndicesHelper { + public: + ShaderVariableHelper(std::string_view name, ProgramVariableDataType type, ShaderUsage usage, const TensorShape& dims); + + ShaderVariableHelper(ShaderVariableHelper&&) = default; + ShaderVariableHelper& operator=(ShaderVariableHelper&&) = default; + + // create a WGSL statement for setting data at the given indices. + // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). + template + inline std::string Set(TIndicesAndValue&&... args) const; + + // create a WGSL statement for setting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param value: the value ({varname}_value_t) to set. + inline std::string SetByIndices(std::string_view indices_var, std::string_view value) const; + + // create a WGSL statement for setting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + // \param value: the value ({varname}_value_t) to set. + template + inline std::string SetByOffset(TOffset&& offset, TValue&& value) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices: a list of indices values (u32). + template + inline std::string Get(TIndices&&... indices) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + inline std::string GetByIndices(std::string_view indices_var) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + template + inline std::string GetByOffset(TOffset&& offset) const; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); + + void Impl(std::ostream& ss) const; + + std::string GetByOffsetImpl(std::string_view offset) const; + std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; + std::string_view StorageType() const; + std::string_view ValueType() const; + std::string_view ElementType() const; + + friend class ShaderHelper; +}; + +inline ShaderUsage operator|(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage | (uint32_t)b.usage; +} +inline ShaderUsage operator&(ShaderUsage a, ShaderUsage b) { + return (uint32_t)a.usage & (uint32_t)b.usage; +} +inline ShaderUsage& operator|=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage |= (uint32_t)b.usage; + return a; +} +inline ShaderUsage& operator&=(ShaderUsage& a, ShaderUsage b) { + (uint32_t&)a.usage &= (uint32_t)b.usage; + return a; +} + +namespace detail { +template >> +std::string pass_as_string(T&& v) { + return std::to_string(std::forward(v)); +} +template +std::string_view pass_as_string(std::string_view sv) { + return sv; +} +template +std::string pass_as_string(T&& v) { + return std::forward(v); +} +} // namespace detail + +inline int ShaderIndicesHelper::Rank() const { + // getting the rank means the information is exposed to the shader. So we consider it as a usage of shape and stride. + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_; +} + +inline std::string ShaderIndicesHelper::OffsetToIndices(std::string_view offset_expr) const { + usage_ |= ShaderUsage::UseOffsetToIndices | ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{offset_expr} + : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); +} + +inline std::string ShaderIndicesHelper::IndicesToOffset(std::string_view indices_expr) const { + usage_ |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{indices_expr} + : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); +} + +inline std::string ShaderIndicesHelper::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderIndicesHelper& broadcasted_result) const { + ORT_ENFORCE(broadcasted_result.num_components_ == -1 || + num_components_ == -1 || + broadcasted_result.num_components_ == num_components_, + "number of components should be the same for 2 variables to calculate"); + usage_ |= ShaderUsage::UseBroadcastedIndicesToOffset | ShaderUsage::UseShapeAndStride; + broadcasted_to_.insert(&broadcasted_result); + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); +} + +template +inline std::string ShaderIndicesHelper::Indices(TIndices&&... indices_args) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(IndicesType(), "(", + absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), + ')'); +} + +template +inline std::string ShaderIndicesHelper::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') + : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); +} + +template +inline std::string ShaderIndicesHelper::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + usage_ |= ShaderUsage::UseShapeAndStride; + return rank_ < 2 ? std::string{indices_var} + : GetElementAt(indices_var, idx_expr, rank_); +} + +template +inline std::string ShaderVariableHelper::SetByOffset(TOffset&& offset, TValue&& value) const { + return SetByOffsetImpl(detail::pass_as_string(offset), detail::pass_as_string(value)); +} + +template +inline std::string ShaderVariableHelper::Set(TIndicesAndValue&&... args) const { + usage_ |= ShaderUsage::UseShapeAndStride; + ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); + if constexpr (sizeof...(TIndicesAndValue) == 1) { + return SetByOffset("0", std::forward(args)...); + } else if constexpr (sizeof...(TIndicesAndValue) == 2) { + return SetByOffset(std::forward(args)...); + } else { + usage_ |= ShaderUsage::UseSet | ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(args)...), ", "), + ");"); + } +} + +inline std::string ShaderVariableHelper::SetByIndices(std::string_view indices_var, std::string_view value) const { + usage_ |= ShaderUsage::UseShapeAndStride; + if (rank_ < 2) { + return SetByOffset(indices_var, value); + } else { + usage_ |= ShaderUsage::UseSetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, "_by_indices(", indices_var, ", ", value, ");"); + } +} + +template +inline std::string ShaderVariableHelper::GetByOffset(TOffset&& offset) const { + return GetByOffsetImpl(detail::pass_as_string(offset)); +} + +template +inline std::string ShaderVariableHelper::Get(TIndices&&... indices) const { + usage_ |= ShaderUsage::UseShapeAndStride; + ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); + if constexpr (sizeof...(TIndices) == 0) { + return GetByOffset("0"); + } else if constexpr (sizeof...(TIndices) == 1) { + return GetByOffset(std::forward(indices)...); + } else { + usage_ |= ShaderUsage::UseGet | ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(indices)...), ", "), + ')'); + } +} + +inline std::string ShaderVariableHelper::GetByIndices(std::string_view indices_var) const { + usage_ |= ShaderUsage::UseShapeAndStride; + if (rank_ < 2) { + return GetByOffset(indices_var); + } else { + usage_ |= ShaderUsage::UseGetByIndices | ShaderUsage::UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, "_by_indices(", indices_var, ")"); + } +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/string_macros.h b/onnxruntime/core/providers/webgpu/string_macros.h new file mode 100644 index 0000000000000..7821d9c49a171 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_macros.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/string_utils.h" + +// macro "SS" - declare an ostream variable and its string buffer +#define SS(ss, reserve_size) \ + std::string ss##_str; \ + ss##_str.reserve(reserve_size); \ + ::onnxruntime::webgpu::OStringStream ss(&ss##_str) + +// macro "SS_GET" - get the string from the ostream +#define SS_GET(ss) ss##_str + +// macro "SS_APPEND" - use function call style to append to the ostream +#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppend(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/string_utils.h b/onnxruntime/core/providers/webgpu/string_utils.h new file mode 100644 index 0000000000000..e6d7097ad6182 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_utils.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/make_string.h" +#include + +namespace onnxruntime { +namespace webgpu { + +constexpr const size_t kStringInitialSizeSetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeGetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeShaderSourceCode = 2048; +#ifndef NDEBUG +constexpr const size_t kStringInitialSizeCacheKey = 512; +#else +constexpr const size_t kStringInitialSizeCacheKey = 256; +#endif + +using OStringStream = absl::strings_internal::OStringStream; + +namespace detail { +inline void OStringStreamAppendImpl(std::ostream& /*ss*/) noexcept { +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... args) noexcept { + OStringStreamAppendImpl(ss, t); + OStringStreamAppendImpl(ss, args...); +} + +template +inline void OStringStreamAppend(std::ostream& ss, const Args&... args) { + return OStringStreamAppendImpl(ss, ::onnxruntime::detail::if_char_array_make_ptr_t(args)...); +} + +} // namespace detail + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc new file mode 100644 index 0000000000000..8b5bede34e6d0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/tensor/cast.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +const std::vector& CastOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 6, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 9, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 13, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_KERNEL_EX( + Cast, + kOnnxDomain, + 19, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); + +Status Cast::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + uint32_t vec_size = gsl::narrow((size + 3) / 4); + + CastProgram program{to_}; + program + .AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .CacheHint(std::to_string(to_)); + return context.RunProgram(program); +} + +Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& input = sh.AddInput("x", ShaderUsage::UseUniform); + const auto& output = sh.AddOutput("y", ShaderUsage::UseUniform); + std::string expression; + switch (to_) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + expression = "vec4(a)"; + break; + default: + ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported."); + } + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h new file mode 100644 index 0000000000000..ef5c4d5d0dabe --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class CastProgram final : public Program { + public: + CastProgram(int32_t to) : Program{"Cast"}, to_{to} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int32_t to_; +}; + +class Cast final : public WebGpuKernel { + public: + Cast(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t to; + Status status = info.GetAttr("to", &to); + ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); + to_ = gsl::narrow(to); + + // ignore attribute 'saturate' as float8 is not supported in WebGPU + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int32_t to_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc new file mode 100644 index 0000000000000..c708f24dcc330 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/tensor/concat.h" + +#include "core/common/inlined_containers.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +#define WEBGPU_CONCAT_KERNEL(version) \ + ONNX_OPERATOR_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +WEBGPU_CONCAT_VERSIONED_KERNEL(1, 3) +WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) +WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) +WEBGPU_CONCAT_KERNEL(13) + +void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(index: u32) -> u32 {\n" + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" + << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" + << " return i;\n" + << " }\n" + << " }\n" + << " return " << input_count << ";\n" + << "}\n"; +} + +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { + os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i == 0) { + os << " if (input_index == 0u) {\n"; + } else if (i == inputs.size() - 1) { + os << " } else {\n"; + } else { + os << " } else if (input_index == " << i << "u) {\n"; + } + os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; + } + os << " }\n" + "}\n"; +} + +Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { + size_t input_count = Inputs().size(); + std::vector inputs; + inputs.reserve(input_count); + for (size_t i = 0; i < input_count; ++i) { + inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + // add implementation of fn calculate_input_index + AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); + // add implementation of fn assign_output_data + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); + const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" + << " let input_index = calculate_input_index(indices_axis);\n" + << " if (input_index != 0u) {\n" + << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" + << " }\n" + " assign_output_data(global_idx, input_index, indices);\n"; + return Status::OK(); +} + +Status Concat::ComputeInternal(ComputeContext& context) const { + int input_count = context.InputCount(); + InlinedTensorsVector input_tensors; + input_tensors.reserve(input_count); + for (int i = 0; i < input_count; ++i) { + input_tensors.push_back(context.Input(i)); + } + + Prepare prepare; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), input_tensors, prepare)); + if (prepare.output_num_elements == 0) { + return Status::OK(); + } + + uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); + + ConcatProgram program{prepare.axis}; + + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(input_count); + uint32_t sum = 0; + for (int i = 0; i < input_count; ++i) { + const auto& input = prepare.inputs[i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + + auto axis_size = input.tensor->Shape()[prepare.axis]; + sum += static_cast(axis_size); + sizes_in_concat_axis.push_back(sum); + } + + size_t non_empty_input_count = sizes_in_concat_axis.size(); + + if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { + // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", + input_count, ", output=1) exceeds the limit (", + context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + } + + program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), + output_size}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h new file mode 100644 index 0000000000000..0f6e6dd327e33 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/concatbase.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class ConcatProgram final : public Program { + public: + ConcatProgram(size_t axis) : Program{"Concat"}, axis_{axis} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + size_t axis_; +}; + +class Concat final : public WebGpuKernel, public ConcatBase { + public: + Concat(const OpKernelInfo& info) : WebGpuKernel(info), ConcatBase(info) { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc new file mode 100644 index 0000000000000..809616660aa9e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" + +#include "core/providers/webgpu/tensor/expand.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"); + if (input.NumComponents() != output.NumComponents()) { + const auto& output_indices = shader.AddIndices("output_indices"); + shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n " + << " let value = vec4(" << input.GetByOffset("input_offset") << ");\n" + << output.SetByOffset("global_idx", "value"); + } else { + shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + } + return Status::OK(); +} + +Status Expand::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const auto* input_shape_tensor = context.Input(1); + + auto output_dims = input_shape_tensor->DataAsSpan(); + TensorShape output_shape{}; + TensorShape input_shape = input_tensor->Shape(); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_shape, output_dims, output_shape)); + + auto* output_tensor = context.Output(0, output_shape); + const int components_i = input_shape.IsScalar() ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); + + ExpandProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {data_size}, + }); + if (components_i != components_o) { + program.AddIndices(output_shape); + } + return context.RunProgram(program); +} + +#define WEBGPU_EXPAND_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +#define WEBGPU_EXPAND_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedNumberTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h new file mode 100644 index 0000000000000..046520b479257 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class ExpandProgram final : public Program { + public: + ExpandProgram() : Program{"Expand"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); +}; + +class Expand final : public WebGpuKernel { + public: + Expand(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.cc b/onnxruntime/core/providers/webgpu/tensor/flatten.cc new file mode 100644 index 0000000000000..11ded865b6be2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/flatten.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 1, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 9, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Flatten, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +ONNX_OPERATOR_KERNEL_EX( + Flatten, + kOnnxDomain, + 21, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), + Flatten); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.h b/onnxruntime/core/providers/webgpu/tensor/flatten.h new file mode 100644 index 0000000000000..5fc49a844b404 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/flatten.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Flatten final : public OpKernel { + public: + explicit Flatten(const OpKernelInfo& info) : OpKernel{info} { + axis_ = info.GetAttrOrDefault("axis", 1); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* input_tensor = context->Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int64_t input_rank = input_shape.NumDimensions(); + + // Handle negative axis + int64_t axis = axis_; + if (axis < 0) { + axis += input_rank; + } + + if (axis > input_rank) { + return Status(common::ONNXRUNTIME, common::FAIL, "Invalid value for axis, must be less than or equal to input_rank"); + } + + int64_t first_dim = 1; + for (int64_t i = 0; i < axis; i++) { + first_dim *= input_shape[i]; + } + + int64_t second_dim = 1; + for (int64_t i = axis; i < input_rank; i++) { + second_dim *= input_shape[i]; + } + + TensorShape output_shape({first_dim, second_dim}); + Tensor* output_tensor = context->Output(0, output_shape); + + const void* source = input_tensor->DataRaw(); + void* target = output_tensor->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output_tensor)); + } + + return Status::OK(); + } + + private: + int64_t axis_; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc new file mode 100644 index 0000000000000..9f6e5f2420d86 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/gather.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var indices_indices = input_indices_indices_t(0);\n"; + for (int i = 0; i < indices.Rank(); i++) { + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; + } + shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(" << data.IndicesGet("uniforms.data_shape", axis_) << ");\n" + << " }\n" + << " var data_indices : data_indices_t;\n"; + for (int i = 0, j = 0; i < data.Rank(); i++) { + if (static_cast(i) == axis_) { + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + j += indices.Rank(); + } else { + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; + j++; + } + } + + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); + + return Status::OK(); +} + +Status Gather::ComputeInternal(ComputeContext& context) const { + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + uint32_t data_size = gsl::narrow(p.output_tensor->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + uint32_t axis = static_cast(p.axis); + GatherProgram program{axis}; + program + .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .CacheHint(std::to_string(axis)) + .AddUniformVariables({{data_size}}); + return context.RunProgram(program); +} + +#define WEBGPU_GATHER_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +#define WEBGPU_GATHER_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.h b/onnxruntime/core/providers/webgpu/tensor/gather.h new file mode 100644 index 0000000000000..bebe13519ce43 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/tensor/gatherbase.h" + +namespace onnxruntime { +namespace webgpu { + +class GatherProgram final : public Program { + public: + GatherProgram(const uint32_t axis) : Program{"Gather"}, axis_{axis} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t axis_; +}; + +class Gather final : public WebGpuKernel, public GatherBase { + public: + Gather(const OpKernelInfo& info) : WebGpuKernel(info), GatherBase(info) {} + + protected: + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.cc b/onnxruntime/core/providers/webgpu/tensor/reshape.cc new file mode 100644 index 0000000000000..9ede015a0c99c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.cc @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/reshape.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Reshape, + kOnnxDomain, + 21, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 14, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 13, 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 5, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.h b/onnxruntime/core/providers/webgpu/tensor/reshape.h new file mode 100644 index 0000000000000..4629598d068f7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/data_transfer_manager.h" +#include "core/providers/cpu/tensor/reshape_helper.h" + +namespace onnxruntime { +namespace webgpu { + +class Reshape final : public OpKernel { + public: + Reshape(const OpKernelInfo& info) + : OpKernel{info}, + allow_zero_(info.GetAttrOrDefault("allowzero", static_cast(0)) == 1) { + } + + Status Compute(OpKernelContext* context) const override { + // Copy the second input tensor into the shape vector + const Tensor* shapeTensor = context->Input(1); + if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (shapeTensor->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); + } + auto data_span = shapeTensor->template DataAsSpan(); + TensorShapeVector shape(data_span.begin(), data_span.end()); + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + ReshapeHelper helper(X_shape, shape, allow_zero_); + + Tensor* Y = context->Output(0, TensorShape(shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } + + private: + bool allow_zero_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc new file mode 100644 index 0000000000000..b211d48dab1c9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 13, 14, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 15, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_KERNEL_EX( + Shape, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc new file mode 100644 index 0000000000000..136a1ba9776a0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/squeeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Squeeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.h b/onnxruntime/core/providers/webgpu/tensor/squeeze.h new file mode 100644 index 0000000000000..bc80cb86d5e8e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/squeeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Squeeze final : public OpKernel, public SqueezeBase { + public: + explicit Squeeze(const OpKernelInfo& info) : OpKernel{info}, SqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + auto nDims = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + axes.assign(data, data + nDims); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.cc b/onnxruntime/core/providers/webgpu/tensor/tile.cc new file mode 100644 index 0000000000000..841c36724df30 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/tile.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Tile, + kOnnxDomain, + 6, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +ONNX_OPERATOR_KERNEL_EX( + Tile, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +Status TileProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "var input_indices: input_indices_t;\n"; + for (auto i = 0; i < input.Rank(); i++) { + std::string input_dim_i = absl::StrCat("input_dim_", i); + std::string input_dim_value = absl::StrCat("input_dim_", i, "_value"); + shader.MainFunctionBody() << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n" + << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n" + << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; + } + + shader.MainFunctionBody() << output.SetByOffset("global_idx", input.GetByIndices("input_indices")); + + return Status::OK(); +} + +Status Tile::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + + const auto* repeats_tensor = context.Input(1); + const auto* repeats_data = repeats_tensor->Data(); + std::vector repeats; + + for (size_t i = 0; i < static_cast(repeats_tensor->Shape().Size()); i++) { + repeats.push_back(static_cast(repeats_data[i])); + } + + auto output_dims = input_shape.AsShapeVector(); + for (size_t axis = 0; axis < input_rank; axis++) { + output_dims[axis] *= repeats[axis]; + } + + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_tensor->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + TileProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {repeats}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.h b/onnxruntime/core/providers/webgpu/tensor/tile.h new file mode 100644 index 0000000000000..9b6ab420b3252 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/tile.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class TileProgram final : public Program { + public: + TileProgram() : Program{"Tile"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"repeats", ProgramUniformVariableDataType::Uint32}); +}; + +class Tile final : public WebGpuKernel { + public: + Tile(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc new file mode 100644 index 0000000000000..c40ec43dd0009 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_KERNEL_EX( + Transpose, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; + +Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + + if (use_shared_) { + shader.AdditionalImplementation() << "var tile : array, tile_size>;\n"; + shader.MainFunctionBody() << " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + << " tile[local_id.y][local_id.x] = " << input.GetByIndices("a_indices_t(input_row, input_col)") << ";\n" + << " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n" + << " " << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n" + << " }"; + } else { + shader.AdditionalImplementation() << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; + for (size_t i = 0; i < perm_.size(); ++i) { + shader.AdditionalImplementation() << " a[" << perm_[i] << "] = i[" << i << "];\n"; + } + shader.AdditionalImplementation() << " return a;\n" + "}\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let indices = " << output.OffsetToIndices("global_idx") + << ";\n" + " let x_indices = perm(indices);\n" + " " + << output.SetByOffset("global_idx", input.GetByIndices("x_indices")); + } + return Status::OK(); +} + +Status Transpose::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const bool channels_last = new_perm == InlinedVector({2, 3, 1}); + const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; + auto new_input_shape = input_shape; + TensorShape new_output_shape(output_dims); + if (use_shared) { + new_input_shape = channels_last + ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) + : channels_first + ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) + : new_shape; + new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); + } + + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); + TransposeProgram program{*p_perm, use_shared}; + if (use_shared) { + program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); + } + + program + .CacheHint(absl::StrJoin(*p_perm, "-")) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + .AddUniformVariables({ + {static_cast(output_size)}, + }); + + use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h new file mode 100644 index 0000000000000..7cf5c1fe0865d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + Status ComputeInternal(ComputeContext& context) const override; + constexpr static uint32_t TILE_SIZE = 16; +}; + +class TransposeProgram final : public Program { + public: + TransposeProgram(const gsl::span& permutations, bool use_shared) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_CONSTANTS({"tile_size", Transpose::TILE_SIZE}); + + private: + InlinedVector perm_; + const bool use_shared_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc new file mode 100644 index 0000000000000..4bcef4fd79296 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/unsqueeze.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .Alias(0, 0), + Unsqueeze); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h new file mode 100644 index 0000000000000..0ae9d50f6d4e7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +class Unsqueeze final : public OpKernel, public UnsqueezeBase { + public: + explicit Unsqueeze(const OpKernelInfo& info) : OpKernel{info}, UnsqueezeBase(info) {} + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); + } + const TensorShape& X_shape = X->Shape(); + + TensorShapeVector axes; + size_t num_inputs = context->InputCount(); + if (num_inputs == 2) { // axes is an input + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); + auto data_span = axes_tensor->template DataAsSpan(); + axes.assign(data_span.begin(), data_span.end()); + } else { + axes.assign(axes_.begin(), axes_.end()); + } + + TensorShapeVector output_shape = ComputeOutputShape(X_shape, axes); + Tensor* Y = context->Output(0, TensorShape(output_shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc new file mode 100644 index 0000000000000..524dd07d5b710 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/where.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +// Compute where operator output shape based upon three way broad-casting. +Status ComputeOutputShape(const TensorShape& cond_shape, + const TensorShape& x_shape, const TensorShape& y_shape, TensorShape& output_shape) { + size_t cond_rank = cond_shape.NumDimensions(); + size_t x_rank = x_shape.NumDimensions(); + size_t y_rank = y_shape.NumDimensions(); + size_t output_rank = std::max(std::max(cond_rank, x_rank), y_rank); + + std::vector output_dims(output_rank, 0); + for (size_t i = 0; i < output_rank; ++i) { + int64_t cond_dim = 1; + if (i < cond_rank) + cond_dim = cond_shape[cond_rank - 1 - i]; + + int64_t x_dim = 1; + if (i < x_rank) + x_dim = x_shape[x_rank - 1 - i]; + + int64_t y_dim = 1; + if (i < y_rank) + y_dim = y_shape[y_rank - 1 - i]; + + int64_t output_dim = std::max({cond_dim, x_dim, y_dim}); + // special case to handle a dim of 0 which can be broadcast with a 1 + if (output_dim == 1) + output_dim = std::min({cond_dim, x_dim, y_dim}); + + const auto node_name = "Where"; + if (cond_dim != output_dim && cond_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": condition operand cannot broadcast on dim ", cond_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (x_dim != output_dim && x_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": X operand cannot broadcast on dim ", x_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + if (y_dim != output_dim && y_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": Y operand cannot broadcast on dim ", y_rank - 1 - i, + " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); + output_dims[output_rank - 1 - i] = output_dim; + } + + output_shape = TensorShape(output_dims); + return Status::OK(); +} + +Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& c_input = shader.AddInput("c_data", ShaderUsage::UseUniform); + const auto& a_input = shader.AddInput("a_data", ShaderUsage::UseUniform); + const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + const auto expression = [](std::string_view a, std::string_view b, std::string_view c) -> auto { + return absl::StrCat("select(", b, ", ", a, ", ", c, ")"); + }; + + if (!is_broadcast_) { + shader.MainFunctionBody() << output.SetByOffset( + "global_idx", + expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx"))); + + } else { + const auto& c_indices = shader.AddIndices("c_indices"); + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + const auto& output_indices = shader.AddIndices("output_indices"); + + const auto single_assignment = + [&expression, &shader, &output_indices, &a_indices, &b_indices, &c_indices]( + std::string_view rest_str, const std::string& x, std::string_view type_cast = "") + -> void { + const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; + const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; + const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; + + shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4 + " + x) << ";\n" + << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_b" << x << " = " << b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_c" << x << " = " << c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let index_a" << x << " = offset_a" << x << " / 4;\n" + << "let index_b" << x << " = offset_b" << x << " / 4;\n" + << "let index_c" << x << " = offset_c" << x << " / 4;\n" + << "let component_a" << x << " = offset_a" << x << " % 4;\n" + << "let component_b" << x << " = offset_b" << x << " % 4;\n" + << "let component_c" << x << " = offset_c" << x << " % 4;\n" + << rest_str << "[" << x << "] = " << type_cast << "(" << expression(a_expression, b_expression, c_expression) << ");\n"; + }; + + if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { + shader.MainFunctionBody() << "var data = vec4(0);\n"; + single_assignment("data", "0", "u32"); + single_assignment("data", "1", "u32"); + single_assignment("data", "2", "u32"); + single_assignment("data", "3", "u32"); + shader.MainFunctionBody() << "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; + } else { + single_assignment("output_data[global_idx]", "0"); + single_assignment("output_data[global_idx]", "1"); + single_assignment("output_data[global_idx]", "2"); + single_assignment("output_data[global_idx]", "3"); + } + } + + return Status::OK(); +} + +Status Where::ComputeInternal(ComputeContext& context) const { + const auto* cond_tensor = context.Input(0); + const auto* x_tensor = context.Input(1); + const auto* y_tensor = context.Input(2); + const auto& cond_shape = cond_tensor->Shape(); + const auto& x_shape = x_tensor->Shape(); + const auto& y_shape = y_tensor->Shape(); + + TensorShape output_shape; + ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); + auto* output_tensor = context.Output(0, output_shape); + constexpr int component = 4; + uint32_t vec_size = gsl::narrow_cast((output_shape.Size() + 3) / component); + const auto is_broadcast = !(x_shape == y_shape && + y_shape == cond_shape); + WhereProgram program{is_broadcast}; + program + .CacheHint(is_broadcast) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, {(cond_shape.Size() + 3) / 4}, 4}, + {x_tensor, ProgramTensorMetadataDependency::Rank, {(x_shape.Size() + 3) / 4}, 4}, + {y_tensor, ProgramTensorMetadataDependency::Rank, {(y_shape.Size() + 3) / 4}, 4}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddUniformVariables({ + {static_cast(vec_size)}, + }); + if (is_broadcast) { + program + .AddIndices(cond_shape) + .AddIndices(x_shape) + .AddIndices(y_shape) + .AddIndices(output_tensor->Shape()); + } + return context.RunProgram(program); +} + +namespace { +const std::vector& WhereOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Where, + kOnnxDomain, + 9, 15, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +ONNX_OPERATOR_KERNEL_EX( + Where, + kOnnxDomain, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WhereOpTypeConstraints()), + Where); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/where.h b/onnxruntime/core/providers/webgpu/tensor/where.h new file mode 100644 index 0000000000000..e46b24e9ba2e5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/where.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class WhereProgram final : public Program { + public: + WhereProgram(bool is_broadcast) : Program{"Where"}, is_broadcast_{is_broadcast} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool is_broadcast_; +}; + +class Where final : public WebGpuKernel { + public: + Where(const OpKernelInfo& info) : WebGpuKernel{info} { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc new file mode 100644 index 0000000000000..d66c2a79d28a8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -0,0 +1,689 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "dawn/dawn_proc.h" +#if !defined(USE_EXTERNAL_DAWN) +#include "dawn/native/DawnNative.h" +#endif + +#include "core/common/common.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/program_cache_key.h" +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/string_macros.h" + +namespace onnxruntime { +namespace webgpu { + +void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table) { + std::call_once(init_flag_, [this, &webgpu_ep_info, dawn_proc_table]() { + // Initialization.Step.1 - Create wgpu::Instance + if (instance_ == nullptr) { + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); +#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) + ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); +#else +#if !defined(USE_EXTERNAL_DAWN) + if (dawn_procs == nullptr) { + dawn_procs = &dawn::native::GetProcs(); + } +#else + ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); +#endif + dawnProcSetProcs(dawn_procs); +#endif + + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.features.timedWaitAnyEnable = true; + instance_ = wgpu::CreateInstance(&instance_desc); + + ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance."); + } + + // Initialization.Step.2 - Create wgpu::Adapter + if (adapter_ == nullptr) { + wgpu::RequestAdapterOptions req_adapter_options = {}; + wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; + req_adapter_options.nextInChain = &adapter_toggles_desc; + req_adapter_options.backendType = static_cast(webgpu_ep_info.backend_type); + req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance; + + auto enabled_adapter_toggles = GetEnabledAdapterToggles(); + adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size(); + adapter_toggles_desc.enabledToggles = enabled_adapter_toggles.data(); + + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter( + &req_adapter_options, + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) { + ORT_ENFORCE(status == wgpu::RequestAdapterStatus::Success, "Failed to get a WebGPU adapter: ", std::string_view{message}); + *ptr = adapter; + }, + &adapter_), + UINT64_MAX)); + ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter."); + } + + // Initialization.Step.3 - Create wgpu::Device + if (device_ == nullptr) { + wgpu::DeviceDescriptor device_desc = {}; + wgpu::DawnTogglesDescriptor device_toggles_desc = {}; + device_desc.nextInChain = &device_toggles_desc; + + auto enabled_device_toggles = GetEnabledDeviceToggles(); + device_toggles_desc.enabledToggleCount = enabled_device_toggles.size(); + device_toggles_desc.enabledToggles = enabled_device_toggles.data(); + + auto disabled_device_toggles = GetDisabledDeviceToggles(); + device_toggles_desc.disabledToggleCount = disabled_device_toggles.size(); + device_toggles_desc.disabledToggles = disabled_device_toggles.data(); + + std::vector required_features = GetAvailableRequiredFeatures(adapter_); + if (required_features.size() > 0) { + device_desc.requiredFeatures = required_features.data(); + device_desc.requiredFeatureCount = required_features.size(); + } + wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_); + device_desc.requiredLimits = &required_limits; + + // TODO: revise temporary error handling + device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, const char* message) { + LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << message; + }); + // TODO: revise temporary device lost handling + device_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device& /*device*/, wgpu::DeviceLostReason reason, const char* message) { + // cannot use ORT logger because it may be already destroyed + std::cerr << "WebGPU device lost (" << int(reason) << "): " << message; + }); + + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice( + &device_desc, + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) { + ORT_ENFORCE(status == wgpu::RequestDeviceStatus::Success, "Failed to get a WebGPU device: ", std::string_view{message}); + *ptr = device; + }, + &device_), + UINT64_MAX)); + ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); + } + + // cache adapter info + ORT_ENFORCE(Adapter().GetInfo(&adapter_info_)); + // cache device limits + wgpu::SupportedLimits device_supported_limits; + ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); + device_limits_ = device_supported_limits.limits; + + // create buffer manager + buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); + + // create program manager + program_mgr_ = std::make_unique(Device(), DeviceLimits()); + + // set query type + if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { + query_type_ = TimestampQueryType::InsidePasses; + } else if (device_.HasFeature(wgpu::FeatureName::TimestampQuery)) { + query_type_ = TimestampQueryType::AtPasses; + } else { + query_type_ = TimestampQueryType::None; + } + }); +} + +Status WebGpuContext::Wait(wgpu::Future f) { + auto status = instance_.WaitAny(f, UINT64_MAX); + if (status == wgpu::WaitStatus::Success) { + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); +} + +Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { + const auto& inputs = program.Inputs(); + const auto& outputs = program.Outputs(); + + if (outputs.size() == 0) { + return Status::OK(); + } + + if (ValidationMode() >= ValidationMode::Basic) { + ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { + const auto* tensor = input.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All inputs must be tensors on WebGPU buffers."); + + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { + const auto* tensor = output.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All outputs must be tensors on WebGPU buffers."); + } + + const ProgramMetadata& metadata = program.Metadata(); + + // validate program metadata + if (ValidationMode() >= ValidationMode::Basic) { + const auto& [constants, overridable_constants, uniform_variables] = metadata; + + // check overridable constants + ORT_RETURN_IF(program.OverridableConstants().size() != overridable_constants.size(), + "Size of overridable constants mismatch in program \"", program.Name(), + "\", Expected: ", overridable_constants.size(), + ", Actual: ", program.OverridableConstants().size()); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_overridable_constants = program.OverridableConstants().size(); + for (size_t i = 0; i < num_overridable_constants; ++i) { + const auto& override_value = program.OverridableConstants()[i]; + const auto& definition = overridable_constants[i]; + ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, + "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.type, + ", Actual: ", override_value.type); + ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, + "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), + "\""); + } + } + + // check uniform variables + ORT_RETURN_IF(program.UniformVariables().size() != uniform_variables.size(), + "Size of uniform_value variables mismatch in program \"", program.Name(), + "\", Expected: ", uniform_variables.size(), + ", Actual: ", program.UniformVariables().size()); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_uniform_variables = program.UniformVariables().size(); + for (size_t i = 0; i < num_uniform_variables; ++i) { + const auto& uniform_value = program.UniformVariables()[i]; + const auto& definition = uniform_variables[i]; + ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, + "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.data_type, + ", Actual: ", uniform_value.data_type); + } + } + } + + uint32_t x = program.DispatchGroupSizeX(); + uint32_t y = program.DispatchGroupSizeY(); + uint32_t z = program.DispatchGroupSizeZ(); + ORT_RETURN_IF_ERROR(program_mgr_->NormalizeDispatchGroupSize(x, y, z)); + + bool is_1d_dispatch = (y == 1 && z == 1); + + auto key = CalculateProgramCacheKey(program, is_1d_dispatch); + + if (is_profiling_) { + PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), + program.Name(), + key, + inputs, + outputs); + pending_kernels_.emplace_back(std::move(pending_kernel_info)); + } + + LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")"; + + const auto* program_artifact = program_mgr_->Get(key); + if (program_artifact == nullptr) { + wgpu::ComputePipeline compute_pipeline; + std::vector shape_uniform_ranks; + auto status = program_mgr_->Build(program, + metadata, +#ifndef NDEBUG // if debug build + key, +#endif + x, + y, + z, + compute_pipeline, + shape_uniform_ranks); + ORT_RETURN_IF_ERROR(status); + program_artifact = program_mgr_->Set(key, ProgramArtifact{program, + std::move(compute_pipeline), + std::move(shape_uniform_ranks)}); +#ifndef NDEBUG // if debug build + ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); +#endif + } + + // prepare shape uniforms for shader variables (if any) and user defined uniforms + std::vector shape_uniforms; + shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); + if (ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size() + program.Indices().size(), + "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), + ") does not match current program (input: ", inputs.size(), + ", output: ", outputs.size(), + ", indices: ", program.Indices().size(), ")"); + } + + auto append_shape_uniforms = [&shape_uniforms, program_artifact](size_t i, const TensorShape& shape) { + if (program_artifact->shape_uniform_ranks[i] > 0) { + size_t expected_rank = static_cast(program_artifact->shape_uniform_ranks[i]); + ORT_RETURN_IF(expected_rank != shape.NumDimensions(), + "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", expected_rank, + ", Actual: ", shape.NumDimensions()); + + std::vector dims(expected_rank); + std::vector stride(expected_rank - 1); + for (size_t j = 0; j < expected_rank; ++j) { + dims[j] = gsl::narrow(shape[j]); + if (j < expected_rank - 1) { + stride[j] = gsl::narrow(shape.SizeFromDimension(j + 1)); + } + } + + shape_uniforms.emplace_back(gsl::make_span(dims)); + if (expected_rank > 1) { + shape_uniforms.emplace_back(gsl::make_span(stride)); + } + } + return Status::OK(); + }; + + for (size_t i = 0; i < inputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i, + inputs[i].use_override_shape ? inputs[i].override_shape : inputs[i].tensor->Shape())); + } + for (size_t i = 0; i < outputs.size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size(), + outputs[i].use_override_shape ? outputs[i].override_shape : outputs[i].tensor->Shape())); + } + for (size_t i = 0; i < program.Indices().size(); i++) { + ORT_RETURN_IF_ERROR(append_shape_uniforms(i + inputs.size() + outputs.size(), program.Indices()[i])); + } + + const size_t uniform_count = shape_uniforms.size() + program.UniformVariables().size(); + size_t current_offset = 0; + std::vector> uniform_and_offsets; + uniform_and_offsets.reserve(uniform_count); + for (size_t i = 0; i < uniform_count; i++) { + const auto& uniform = i < shape_uniforms.size() ? shape_uniforms[i] + : program.UniformVariables()[i - shape_uniforms.size()]; + size_t length = uniform.length; + if (length == 0) { // skip zero-length uniform + continue; + } + + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // https://www.w3.org/TR/WGSL/#alignof + size_t base_alignment = is_f16 + ? (length > 4 ? 16 : length > 2 ? 8 + : length * element_size) + : (length > 2 ? 16 : length * element_size); + size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; + + current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + uniform_and_offsets.emplace_back(uniform, current_offset); + + // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). + // For float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). + size_t element_per_struct = is_f16 ? 8 : 4; + current_offset += + length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + } + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. + constexpr size_t max_alignment_of_field = 16; + const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; + + WGPUBuffer uniform_buffer = nullptr; + if (uniform_buffer_total_size > 0) { + std::vector uniform_data_buffer(uniform_buffer_total_size); + + for (auto const& [uniform, offset] : uniform_and_offsets) { + memcpy(uniform_data_buffer.data() + offset, uniform.data.data(), uniform.data.size()); + } + + uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + + WriteTimestamp(num_pending_dispatches_ * 2); + + uint32_t entry_index = 0; + std::vector bind_group_entries; + for (const auto& input : inputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(const_cast(input.tensor->DataRaw()))}); + } + for (const auto& output : outputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output.tensor->MutableDataRaw())}); + } + if (uniform_buffer) { + bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer}); + } + + wgpu::BindGroupDescriptor bind_group_desc{}; + bind_group_desc.layout = program_artifact->compute_pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = bind_group_entries.size(); + bind_group_desc.entries = bind_group_entries.data(); + bind_group_desc.label = program_artifact->name.c_str(); + + auto bind_group = Device().CreateBindGroup(&bind_group_desc); + + // TODO support graph capture + + compute_pass_encoder.SetPipeline(program_artifact->compute_pipeline); + compute_pass_encoder.SetBindGroup(0, bind_group); + compute_pass_encoder.DispatchWorkgroups(x, y, z); + + if (uniform_buffer) { + buffer_mgr_->Release(uniform_buffer); + } + + WriteTimestamp(num_pending_dispatches_ * 2 + 1); + + ++num_pending_dispatches_; + + if (num_pending_dispatches_ >= max_num_pending_dispatches_ || + (is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) { + EndComputePass(); + } + if (num_pending_dispatches_ >= max_num_pending_dispatches_) { + Flush(); + num_pending_dispatches_ = 0; + } + + return Status::OK(); +} + +std::vector WebGpuContext::GetEnabledAdapterToggles() const { + // See the description of all the toggles in toggles.cpp + // "use_dxc" for Shader Model 6+ features (e.g. float16) + // "allow_unsafe_apis" for chromium experimental features + constexpr const char* toggles[] = { + "use_dxc", + "allow_unsafe_apis", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetEnabledDeviceToggles() const { + // Enable / disable other toggles that may affect the performance. + // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" + constexpr const char* toggles[] = { + "skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled" + "disable_robustness", + "d3d_disable_ieee_strictness", + }; + return std::vector(ValidationMode() >= ValidationMode::WGPUOnly + ? std::begin(toggles) + 1 + : std::begin(toggles), + std::end(toggles)); +} + +std::vector WebGpuContext::GetDisabledDeviceToggles() const { + constexpr const char* toggles[] = { + "lazy_clear_resource_on_first_use", + "timestamp_quantization", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const { + wgpu::RequiredLimits required_limits{}; + wgpu::SupportedLimits adapter_limits; + ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); + + required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; + required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; + required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + +void WebGpuContext::WriteTimestamp(uint32_t query_index) { + if (!is_profiling_ || query_type_ != TimestampQueryType::InsidePasses) { + return; + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + compute_pass_encoder.WriteTimestamp(query_set_, query_index); +} + +void WebGpuContext::StartProfiling() { + if (query_type_ == TimestampQueryType::None) { + return; + } + + is_profiling_ = true; + + const uint32_t query_count = max_num_pending_dispatches_ * 2; + + if (!query_set_) { + // Create query set + wgpu::QuerySetDescriptor querySetDescriptor; + querySetDescriptor.count = query_count; + querySetDescriptor.type = wgpu::QueryType::Timestamp; + query_set_ = device_.CreateQuerySet(&querySetDescriptor); + } + + if (!query_resolve_buffer_) { + // Create resolve buffer + wgpu::BufferDescriptor bufferDescriptor; + bufferDescriptor.size = query_count * sizeof(uint64_t); + bufferDescriptor.usage = wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc | + wgpu::BufferUsage::CopyDst; + query_resolve_buffer_ = device_.CreateBuffer(&bufferDescriptor); + } +} + +void WebGpuContext::CollectProfilingData(profiling::Events& events) { + if (!pending_queries_.empty()) { + for (const auto& pending_query : pending_queries_) { + const auto& pending_kernels = pending_query.kernels; + const auto& query_read_buffer = pending_query.query_buffer; + + ORT_ENFORCE(Wait(query_read_buffer.MapAsync(wgpu::MapMode::Read, + 0, + query_read_buffer.GetSize(), + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + auto mapped_data = static_cast(query_read_buffer.GetConstMappedRange()); + + for (size_t i = 0; i < pending_kernels.size(); i++) { + const PendingKernelInfo& pending_kernel_info = pending_kernels[i]; + const auto& inputs = pending_kernel_info.inputs; + const auto& outputs = pending_kernel_info.outputs; + + SS(shapes, 128); + for (size_t s = 0; s < inputs.size(); s++) { + const auto& input = inputs[s]; + shapes << "inputs[" << s << "] = " << input.override_shape.ToString() << " "; + } + for (size_t s = 0; s < outputs.size(); s++) { + const auto& output = outputs[s]; + shapes << "outputs[" << s << "] = " << output.override_shape.ToString() << " "; + } + + if (gpu_timestamp_offset_ == 0) { + gpu_timestamp_offset_ = mapped_data[i * 2]; + // TODO: apply CPU-GPU time offset so that timestamps are aligned + } + uint64_t start_time = mapped_data[i * 2] - gpu_timestamp_offset_; + uint64_t end_time = mapped_data[i * 2 + 1] - gpu_timestamp_offset_; + + const std::unordered_map& event_args = { + {"shapes", SS_GET(shapes)}, + {"cache_key", pending_kernel_info.cache_key}, + }; + + profiling::EventRecord event(profiling::API_EVENT, + -1, + -1, + pending_kernel_info.name, + static_cast(std::round(start_time / 1000.0)), + static_cast(std::round((end_time - start_time) / 1000.0)), + event_args); + events.emplace_back(std::move(event)); + } + + query_read_buffer.Unmap(); + query_read_buffer.Destroy(); + } + + pending_queries_.clear(); + } + + is_profiling_ = false; +} + +void WebGpuContext::EndProfiling(TimePoint /* tp */, profiling::Events& events, profiling::Events& cached_events) { + // This function is called when no active inference is ongoing. + ORT_ENFORCE(!is_profiling_, "Profiling is ongoing in an inference run."); + + if (query_type_ != TimestampQueryType::None) { + // No pending kernels or queries should be present at this point. They should have been collected in CollectProfilingData. + ORT_ENFORCE(pending_kernels_.empty() && pending_queries_.empty(), "Pending kernels or queries are not empty."); + + events.insert(events.end(), + std::make_move_iterator(cached_events.begin()), + std::make_move_iterator(cached_events.end())); + + cached_events.clear(); + } else { + LOGS_DEFAULT(WARNING) << "TimestampQuery is not supported in this device."; + } +} + +void WebGpuContext::Flush() { + if (!current_command_encoder_) { + return; + } + + EndComputePass(); + + if (is_profiling_ && num_pending_dispatches_ > 0) { + uint32_t query_count = num_pending_dispatches_ * 2; + current_command_encoder_.ResolveQuerySet( + query_set_, + 0, + query_count, + query_resolve_buffer_, + 0); + + wgpu::BufferDescriptor bufferDescriptor; + bufferDescriptor.size = query_count * sizeof(uint64_t); + bufferDescriptor.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst; + wgpu::Buffer query_read_buffer = device_.CreateBuffer(&bufferDescriptor); + + current_command_encoder_.CopyBufferToBuffer( + query_resolve_buffer_, + 0, + query_read_buffer, + 0, + query_count * sizeof(uint64_t)); + + pending_queries_.emplace_back(std::move(pending_kernels_), query_read_buffer); + pending_kernels_.clear(); + } + + auto command_buffer = current_command_encoder_.Finish(); + Device().GetQueue().Submit(1, &command_buffer); + BufferManager().RefreshPendingBuffers(); + current_command_encoder_ = nullptr; + num_pending_dispatches_ = 0; +} + +std::unordered_map> WebGpuContextFactory::contexts_; +std::mutex WebGpuContextFactory::mutex_; + +WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode) { + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device."); + } else { + // for context ID > 0, user must provide custom WebGPU instance, adapter and device. + ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr, + "WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device."); + } + + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + if (it == contexts_.end()) { + GSL_SUPPRESS(r.11) + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, validation_mode)); + it = contexts_.emplace(context_id, std::move(context)).first; + } else if (context_id != 0) { + ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, + "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device."); + } + return *it->second; +} + +WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); + + return *it->second; +} + +void WebGpuContextFactory::Cleanup() { + std::lock_guard lock(mutex_); + contexts_.clear(); +} + +void CleanupWebGpuContexts() { + WebGpuContextFactory::Cleanup(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h new file mode 100644 index 0000000000000..be05b06523b9c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/program_manager.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { +class WebGpuContext; +class ComputeContext; +class ProgramBase; + +class WebGpuContextFactory { + public: + static WebGpuContext& CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode); + static WebGpuContext& GetContext(int context_id); + + static void Cleanup(); + + private: + WebGpuContextFactory() {} + + static std::unordered_map> contexts_; + static std::mutex mutex_; +}; + +// Class WebGpuContext includes all necessary resources for the context. +class WebGpuContext final { + public: + void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table); + + Status Wait(wgpu::Future f); + + const wgpu::Adapter& Adapter() const { return adapter_; } + const wgpu::Device& Device() const { return device_; } + + const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } + const wgpu::Limits& DeviceLimits() const { return device_limits_; } + + const wgpu::CommandEncoder& GetCommandEncoder() { + if (!current_command_encoder_) { + current_command_encoder_ = device_.CreateCommandEncoder(); + } + return current_command_encoder_; + } + + const wgpu::ComputePassEncoder& GetComputePassEncoder() { + if (!current_compute_pass_encoder_) { + auto& command_encoder = GetCommandEncoder(); + + wgpu::ComputePassDescriptor compute_pass_desc{}; + + if (is_profiling_ && query_type_ == TimestampQueryType::AtPasses) { + wgpu::ComputePassTimestampWrites timestampWrites = { + query_set_, num_pending_dispatches_ * 2, num_pending_dispatches_ * 2 + 1}; + compute_pass_desc.timestampWrites = ×tampWrites; + } + + current_compute_pass_encoder_ = command_encoder.BeginComputePass(&compute_pass_desc); + } + return current_compute_pass_encoder_; + } + + void EndComputePass() { + if (current_compute_pass_encoder_) { + current_compute_pass_encoder_.End(); + current_compute_pass_encoder_ = nullptr; + } + } + + void Flush(); + + webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } + + inline webgpu::ValidationMode ValidationMode() const { + return validation_mode_; + } + + void StartProfiling(); + void CollectProfilingData(profiling::Events& events); + void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); + + Status Run(ComputeContext& context, const ProgramBase& program); + + private: + enum class TimestampQueryType { + None = 0, + InsidePasses, + AtPasses + }; + + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode) + : instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + + std::vector GetEnabledAdapterToggles() const; + std::vector GetEnabledDeviceToggles() const; + std::vector GetDisabledDeviceToggles() const; + std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const; + wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) const; + void WriteTimestamp(uint32_t query_index); + + struct PendingKernelInfo { + PendingKernelInfo(std::string_view kernel_name, + std::string_view program_name, + std::string_view cache_key, + const std::vector& inputs, + const std::vector& outputs) + : name{absl::StrJoin({kernel_name, program_name}, "_")}, cache_key{cache_key}, inputs{inputs}, outputs{outputs} {} + + PendingKernelInfo(PendingKernelInfo&&) = default; + PendingKernelInfo& operator=(PendingKernelInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingKernelInfo); + + std::string name; + std::string cache_key; + std::vector inputs; + std::vector outputs; + }; + + struct PendingQueryInfo { + PendingQueryInfo(std::vector&& kernels, wgpu::Buffer query_buffer) + : kernels{std::move(kernels)}, query_buffer{query_buffer} {} + + PendingQueryInfo(PendingQueryInfo&&) = default; + PendingQueryInfo& operator=(PendingQueryInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(PendingQueryInfo); + + std::vector kernels; + wgpu::Buffer query_buffer; + }; + + friend class WebGpuContextFactory; + + std::once_flag init_flag_; + + wgpu::Instance instance_; + wgpu::Adapter adapter_; + wgpu::Device device_; + + webgpu::ValidationMode validation_mode_; + + wgpu::AdapterInfo adapter_info_; + wgpu::Limits device_limits_; + + wgpu::CommandEncoder current_command_encoder_; + wgpu::ComputePassEncoder current_compute_pass_encoder_; + + std::unique_ptr buffer_mgr_; + std::unique_ptr program_mgr_; + + uint32_t num_pending_dispatches_ = 0; + const uint32_t max_num_pending_dispatches_ = 16; + + // profiling + TimestampQueryType query_type_; + wgpu::QuerySet query_set_; + wgpu::Buffer query_resolve_buffer_; + + // info of kernels pending submission for a single batch + std::vector pending_kernels_; + // info of queries pending appending to profiling events + std::vector pending_queries_; + + uint64_t gpu_timestamp_offset_ = 0; + bool is_profiling_ = false; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 00ebdd5583d2e..66209adf6f1a9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -3,6 +3,9 @@ #include "core/providers/webgpu/webgpu_execution_provider.h" +#ifdef __EMSCRIPTEN__ +#include +#endif #include #include #include @@ -13,6 +16,7 @@ #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #endif +#include "allocator.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/fallback_cpu_capability.h" @@ -20,6 +24,10 @@ #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_profiler.h" + namespace onnxruntime { namespace webgpu { @@ -65,6 +73,330 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Memcpy); +#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, End, Op)> + +#define KERNEL_CREATE_INFO(Start, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, Op)> + +#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, type, Op)> + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Abs); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Abs); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Neg); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Neg); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Floor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Floor); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Ceil); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Ceil); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Reciprocal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Reciprocal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sqrt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sqrt); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Exp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Exp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Erf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Erf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, HardSigmoid); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Log); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Log); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Sin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Cos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Tan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Asin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Acos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Atan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Sinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Cosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Asinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Acosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Atanh); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, Not); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, Clip); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, Clip); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, Elu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Relu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, Gelu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Sub); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Mul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Div); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 11, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Pow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 10, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Equal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Greater); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Less); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LessOrEqual); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, 18, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Shape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 5, 12, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, Conv); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 3, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 4, 10, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 19, Resize); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, InstanceNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, float, Range); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, int32_t, Range); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, float, Einsum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -72,6 +404,322 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + + // element-wise operators + // unary - math + KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), + KERNEL_CREATE_INFO(13, Abs), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + KERNEL_CREATE_INFO(13, Neg), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + KERNEL_CREATE_INFO(13, Floor), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + KERNEL_CREATE_INFO(13, Ceil), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + KERNEL_CREATE_INFO(13, Reciprocal), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + KERNEL_CREATE_INFO(13, Sqrt), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + KERNEL_CREATE_INFO(13, Exp), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + KERNEL_CREATE_INFO(13, Erf), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + KERNEL_CREATE_INFO(13, Log), + + KERNEL_CREATE_INFO(7, Sin), + KERNEL_CREATE_INFO(7, Cos), + KERNEL_CREATE_INFO(7, Tan), + KERNEL_CREATE_INFO(7, Asin), + KERNEL_CREATE_INFO(7, Acos), + KERNEL_CREATE_INFO(7, Atan), + KERNEL_CREATE_INFO(9, Sinh), + KERNEL_CREATE_INFO(9, Cosh), + KERNEL_CREATE_INFO(9, Asinh), + KERNEL_CREATE_INFO(9, Acosh), + KERNEL_CREATE_INFO(9, Atanh), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), + KERNEL_CREATE_INFO(13, Tanh), + KERNEL_CREATE_INFO(1, Not), + + KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), + KERNEL_CREATE_INFO(19, Cast), + + // // activations + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + KERNEL_CREATE_INFO(6, Elu), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + KERNEL_CREATE_INFO(14, Relu), + KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + KERNEL_CREATE_INFO(16, LeakyRelu), + KERNEL_CREATE_INFO(10, ThresholdedRelu), + KERNEL_CREATE_INFO(20, Gelu), + + // // binary - math + KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + KERNEL_CREATE_INFO(14, Add), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + KERNEL_CREATE_INFO(14, Sub), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + KERNEL_CREATE_INFO(14, Mul), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + KERNEL_CREATE_INFO(14, Div), + KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + KERNEL_CREATE_INFO(15, Pow), + KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), + KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), + KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), + KERNEL_CREATE_INFO(19, Equal), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + KERNEL_CREATE_INFO(13, Greater), + KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + KERNEL_CREATE_INFO(16, GreaterOrEqual), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + KERNEL_CREATE_INFO(13, Less), + KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + KERNEL_CREATE_INFO(16, LessOrEqual), + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -93,8 +741,77 @@ std::unique_ptr RegisterKernels() { using namespace webgpu; -WebGpuExecutionProvider::WebGpuExecutionProvider() - : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)} {} +WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, + WebGpuContext& context, + WebGpuExecutionProviderInfo&& info) + : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, + context_id_{context_id}, + context_{context}, + preferred_data_layout_{info.data_layout}, + force_cpu_node_names_{std::move(info.force_cpu_node_names)}, + enable_graph_capture_{info.enable_graph_capture} { +} + +std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) { + return std::make_unique(context_); + }, + 0, false); + return std::vector{CreateAllocator(gpuBufferAllocatorCreationInfo)}; +} + +std::vector> WebGpuExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + InlinedVector candidates; + // `tenative_candidates` is a subset of `candidates`. + InlinedVector tenative_candidates; + for (auto& node_index : graph.GetNodesInTopologicalOrder()) { + const auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; + + const auto& node = *p_node; + if (!node.GetExecutionProviderType().empty()) { + // If the node was added by layout transformer, do not move it to CPU + if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) { + candidates.push_back(node.Index()); + } + continue; + } + + const KernelCreateInfo* webgpu_kernel_def = kernel_lookup.LookUpKernel(node); + // none of the provided registries has a webgpu kernel for this node + if (webgpu_kernel_def == nullptr) { + LOGS(*GetLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.OpType() << " node name: " << node.Name(); + continue; + } + + // TODO: currently this lookup is O(N). If the list becomes large we should optimize this. + if (std::find(force_cpu_node_names_.cbegin(), + force_cpu_node_names_.cend(), + node.Name()) != force_cpu_node_names_.cend()) { + LOGS(*GetLogger(), INFO) << "Force CPU execution for node: " << node.Name(); + continue; + } + candidates.push_back(node.Index()); + tenative_candidates.push_back(node.Index()); + } + + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); + std::vector> result; + for (auto& node_index : candidates) { + if (cpu_nodes.count(node_index) > 0) { + continue; + } + + auto sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node_index); + result.emplace_back(std::make_unique(std::move(sub_graph))); + } + return result; +} std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { static std::shared_ptr registry = webgpu::RegisterKernels(); @@ -102,7 +819,68 @@ std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() con return registry; } +std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { + return std::make_unique(context_); +} + WebGpuExecutionProvider::~WebGpuExecutionProvider() { } +std::unique_ptr WebGpuExecutionProvider::GetProfiler() { + auto profiler = std::make_unique(context_); + profiler_ = profiler.get(); + return profiler; +} + +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { + if (profiler_->Enabled()) { + context_.StartProfiling(); + } + + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + } + return Status::OK(); +} + +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + // is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + context_.Flush(); + + if (profiler_->Enabled()) { + context_.CollectProfilingData(profiler_->Events()); + } + + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool WebGpuExecutionProvider::IsGraphCaptured(int) const { + return is_graph_captured_; +} + +Status WebGpuExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); + ORT_ENFORCE(false); + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 537ecb9301f67..f9c43c6bfd7d0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -9,6 +9,7 @@ #include "core/graph/constants.h" #include "core/providers/providers.h" +struct pthreadpool; namespace onnxruntime { namespace webgpu { @@ -16,22 +17,80 @@ namespace webgpu { template KernelCreateInfo BuildKernelCreateInfo(); +class WebGpuContext; +enum class BufferCacheMode; +class WebGpuProfiler; } // namespace webgpu +struct WebGpuExecutionProviderInfo { + WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture) + : data_layout{data_layout}, + enable_graph_capture{enable_graph_capture}, + backend_type{}, + storage_buffer_cache_mode{}, + uniform_buffer_cache_mode{}, + query_resolve_buffer_cache_mode{}, + default_buffer_cache_mode{} {} + WebGpuExecutionProviderInfo(WebGpuExecutionProviderInfo&&) = default; + WebGpuExecutionProviderInfo& operator=(WebGpuExecutionProviderInfo&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderInfo); + + DataLayout data_layout; + bool enable_graph_capture; + int backend_type; + webgpu::BufferCacheMode storage_buffer_cache_mode; + webgpu::BufferCacheMode uniform_buffer_cache_mode; + webgpu::BufferCacheMode query_resolve_buffer_cache_mode; + webgpu::BufferCacheMode default_buffer_cache_mode; + std::vector force_cpu_node_names; +}; + class WebGpuExecutionProvider : public IExecutionProvider { public: - WebGpuExecutionProvider(); + WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& info); ~WebGpuExecutionProvider() override; + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const override; + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; - DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } // WebGPU EP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to // work, and concurrent run with async function may mess up the states and cause undefined behavior. bool ConcurrentRunSupported() const override { return false; } + + std::vector CreatePreferredAllocators() override; + + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + + // WebGPU EP reuses the Device ID as the key to get the WebGpuContext instance. + int GetDeviceId() const override { return context_id_; } + + std::unique_ptr GetProfiler() override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); + int context_id_; + webgpu::WebGpuContext& context_; + webgpu::WebGpuProfiler* profiler_ = nullptr; + DataLayout preferred_data_layout_; + std::vector force_cpu_node_names_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h new file mode 100644 index 0000000000000..d7682e751d9e4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/compute_context.h" + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +// ----------------------------------------------------------------------- +// Base class for WebGPU kernels +// ----------------------------------------------------------------------- +class WebGpuKernel : public OpKernel { + public: + explicit WebGpuKernel(const OpKernelInfo& info) + : OpKernel(info) { + } + + Status Compute(OpKernelContext* p_op_kernel_context) const override { + ComputeContext context{*p_op_kernel_context}; + + context.PushErrorScope(); + Status s = ComputeInternal(context); + ORT_RETURN_IF_ERROR(context.PopErrorScope()); + + return s; + } + + virtual Status ComputeInternal(ComputeContext& context) const = 0; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_profiler.cc b/onnxruntime/core/providers/webgpu/webgpu_profiler.cc new file mode 100644 index 0000000000000..ce973987e593a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_profiler.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_profiler.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +WebGpuProfiler::WebGpuProfiler(WebGpuContext& context) : context_{context} {} + +bool WebGpuProfiler::StartProfiling(TimePoint) { + enabled_ = true; + return true; +} + +void WebGpuProfiler::EndProfiling(TimePoint tp, onnxruntime::profiling::Events& events) { + context_.EndProfiling(tp, events, events_); + enabled_ = false; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_profiler.h b/onnxruntime/core/providers/webgpu/webgpu_profiler.h new file mode 100644 index 0000000000000..d826d295a3842 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_profiler.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/profiler_common.h" + +namespace onnxruntime { + +namespace webgpu { +class WebGpuContext; + +class WebGpuProfiler final : public onnxruntime::profiling::EpProfiler { + public: + WebGpuProfiler(WebGpuContext& context); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuProfiler); + ~WebGpuProfiler() {} + bool StartProfiling(TimePoint) override; + void EndProfiling(TimePoint, onnxruntime::profiling::Events&) override; + void Start(uint64_t) override { + } + void Stop(uint64_t) override { + } + inline bool Enabled() const { return enabled_; } + inline onnxruntime::profiling::Events& Events() { return events_; } + + private: + WebGpuContext& context_; + bool enabled_{false}; + onnxruntime::profiling::Events events_; // cached GPU events +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 1a1f1a438c750..6cfe9aac0b0e9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -4,21 +4,214 @@ #include #include "core/framework/error_code_helper.h" -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/ort_apis.h" + +#include "core/providers/webgpu/webgpu_provider_options.h" +using namespace onnxruntime::webgpu::options; namespace onnxruntime { struct WebGpuProviderFactory : IExecutionProviderFactory { - WebGpuProviderFactory() {} + WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& webgpu_ep_info) + : context_id_{context_id}, context_{context}, info_{std::move(webgpu_ep_info)} { + } std::unique_ptr CreateProvider() override { - return std::make_unique(); + return std::make_unique(context_id_, context_, std::move(info_)); } + + private: + int context_id_; + webgpu::WebGpuContext& context_; + WebGpuExecutionProviderInfo info_; }; -std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions&) { - return std::make_shared(); +std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { + // + // STEP.1 - prepare WebGpuExecutionProviderInfo + // + WebGpuExecutionProviderInfo webgpu_ep_info{ + // preferred layout is NHWC by default + DataLayout::NHWC, + // graph capture feature is disabled by default + false, + }; + + std::string preferred_layout_str; + if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (preferred_layout_str == kPreferredLayout_NHWC) { + webgpu_ep_info.data_layout = DataLayout::NHWC; + } else if (preferred_layout_str == kPreferredLayout_NCHW) { + webgpu_ep_info.data_layout = DataLayout::NCHW; + } else { + ORT_THROW("Invalid preferred layout: ", preferred_layout_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_info.data_layout) << " (parsed from \"" + << preferred_layout_str << "\")"; + + std::string enable_graph_capture_str; + if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (enable_graph_capture_str == kEnableGraphCapture_ON) { + webgpu_ep_info.enable_graph_capture = true; + } else if (enable_graph_capture_str == kEnableGraphCapture_OFF) { + webgpu_ep_info.enable_graph_capture = false; + } else { + ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; + + std::string backend_type_str; + if (config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { +#ifdef _WIN32 + // Setup Windows default backend type based on the build configuration +#if defined(onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_D3D12); +#elif defined(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_Vulkan); +#endif +#endif + if (backend_type_str == kDawnBackendType_D3D12) { + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_D3D12); + } else if (backend_type_str == kDawnBackendType_Vulkan) { + webgpu_ep_info.backend_type = static_cast(WGPUBackendType_Vulkan); + } else { + ORT_THROW("Invalid Dawn backend type: ", backend_type_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << webgpu_ep_info.backend_type; + + auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, + webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { + std::string buffer_cache_mode_str; + if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { + return webgpu::BufferCacheMode::Disabled; + } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { + return webgpu::BufferCacheMode::LazyRelease; + } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { + return webgpu::BufferCacheMode::Simple; + } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { + return webgpu::BufferCacheMode::Bucket; + } else { + ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + } + } else { + return default_value; + } + }; + + webgpu_ep_info.storage_buffer_cache_mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << webgpu_ep_info.storage_buffer_cache_mode; + + webgpu_ep_info.uniform_buffer_cache_mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::Simple); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << webgpu_ep_info.uniform_buffer_cache_mode; + + webgpu_ep_info.query_resolve_buffer_cache_mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << webgpu_ep_info.query_resolve_buffer_cache_mode; + + webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; + + webgpu::ValidationMode validation_mode = +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::Basic // for release build, enable basic validation by default +#endif // !NDEBUG + ; + std::string validation_mode_str; + if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (validation_mode_str == kValidationMode_Disabled) { + validation_mode = webgpu::ValidationMode::Disabled; + } else if (validation_mode_str == kValidationMode_wgpuOnly) { + validation_mode = webgpu::ValidationMode::WGPUOnly; + } else if (validation_mode_str == kValidationMode_basic) { + validation_mode = webgpu::ValidationMode::Basic; + } else if (validation_mode_str == kValidationMode_full) { + validation_mode = webgpu::ValidationMode::Full; + } else { + ORT_THROW("Invalid validation mode: ", validation_mode_str); + } + } + + // parse force CPU node names + // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. + // each line is a node name that will be forced to run on CPU. + std::string force_cpu_node_names_str; + if (config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { + std::vector force_cpu_node_names; + + // split the string by EOL (\n or \r\n) + std::istringstream ss(force_cpu_node_names_str); + std::string line; + while (std::getline(ss, line)) { + // skip empty lines + if (line.empty()) { + continue; + } + + force_cpu_node_names.push_back(line); + } + + webgpu_ep_info.force_cpu_node_names = std::move(force_cpu_node_names); + } + + // + // STEP.2 - prepare WebGpuContext + // + int context_id = 0; + std::string context_id_str; + if (config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); + } + + size_t webgpu_instance = 0; + std::string webgpu_instance_str; + if (config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); + } + + size_t webgpu_adapter = 0; + std::string webgpu_adapter_str; + if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { + static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec); + } + + size_t webgpu_device = 0; + std::string webgpu_device_str; + if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); + } + + size_t dawn_proc_table = 0; + std::string dawn_proc_table_str; + if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); + } + + auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, + reinterpret_cast(webgpu_instance), + reinterpret_cast(webgpu_adapter), + reinterpret_cast(webgpu_device), + validation_mode); + context.Initialize(webgpu_ep_info, reinterpret_cast(dawn_proc_table)); + + return std::make_shared(context_id, context, std::move(webgpu_ep_info)); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 6257a85d45760..e0030a3ec2a11 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -8,6 +8,8 @@ #include "core/framework/provider_options.h" #include "core/providers/providers.h" +#include "core/providers/webgpu/webgpu_provider_options.h" + namespace onnxruntime { struct ConfigOptions; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h new file mode 100644 index 0000000000000..12bb4b32e6a35 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webgpu { +namespace options { + +// The following are the options that can be set in the WebGPU provider options. + +constexpr const char* kPreferredLayout = "WebGPU:preferredLayout"; +constexpr const char* kEnableGraphCapture = "WebGPU:enableGraphCapture"; + +constexpr const char* kDawnProcTable = "WebGPU:dawnProcTable"; + +constexpr const char* kDawnBackendType = "WebGPU:dawnBackendType"; + +constexpr const char* kDeviceId = "WebGPU:deviceId"; +constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance"; +constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter"; +constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice"; + +constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode"; +constexpr const char* kUniformBufferCacheMode = "WebGPU:uniformBufferCacheMode"; +constexpr const char* kQueryResolveBufferCacheMode = "WebGPU:queryResolveBufferCacheMode"; +constexpr const char* kDefaultBufferCacheMode = "WebGPU:defaultBufferCacheMode"; + +constexpr const char* kValidationMode = "WebGPU:validationMode"; + +constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames"; + +// The following are the possible values for the provider options. + +constexpr const char* kDawnBackendType_D3D12 = "D3D12"; +constexpr const char* kDawnBackendType_Vulkan = "Vulkan"; + +constexpr const char* kPreferredLayout_NCHW = "NCHW"; +constexpr const char* kPreferredLayout_NHWC = "NHWC"; + +constexpr const char* kEnableGraphCapture_ON = "1"; +constexpr const char* kEnableGraphCapture_OFF = "0"; + +constexpr const char* kBufferCacheMode_Disabled = "disabled"; +constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; +constexpr const char* kBufferCacheMode_Simple = "simple"; +constexpr const char* kBufferCacheMode_Bucket = "bucket"; + +constexpr const char* kValidationMode_Disabled = "disabled"; +constexpr const char* kValidationMode_wgpuOnly = "wgpuOnly"; +constexpr const char* kValidationMode_basic = "basic"; +constexpr const char* kValidationMode_full = "full"; + +} // namespace options +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h new file mode 100644 index 0000000000000..ff66cd535399e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +using SupportedNumberTypes = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t>; + +using SupportedFloats = + TypeList< + float, + MLFloat16>; + +inline const std::vector& WebGpuSupportedNumberTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +inline const std::vector& WebGpuSupportedFloatTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 4b39e03ffc788..45a87960126cd 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,7 +69,8 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input) { const auto& node_arg_name = node_arg.Name(); const auto* shape_proto = node_arg.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. @@ -89,6 +90,10 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n << "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name; return false; } + if (dim.dim_value() == 0 && !allow_empty_input) { + LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN"; + return false; + } } return true; @@ -100,18 +105,6 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; - - for (const auto* input : graph_viewer.GetInputs()) { - if (!IsTensorShapeSupported(*input, "graph", logger)) { - return supported_node_groups; - } - } - for (const auto* output : graph_viewer.GetOutputs()) { - if (!IsTensorShapeSupported(*output, "graph", logger)) { - return supported_node_groups; - } - } - std::vector supported_node_group; const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); @@ -121,7 +114,6 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v bool supported = false; // Firstly check if platform supports the WebNN op. if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { - LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); } @@ -186,14 +178,31 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) return false; + return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits, + webnn_input_output_name, onnx_input_output_name, logger); +} + +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + if (wnn_limits[webnn_op_type].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now"; + return false; + } + if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter [" + << webnn_input_output_name << "]"; + return false; + } if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { - LOGS(logger, VERBOSE) << "[" << onnx_op_type - << "] " << onnx_input_output_name - << " type: [" << onnx_data_type - << "] is not supported for now"; + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: [" + << onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now"; return false; } - return true; } @@ -270,5 +279,67 @@ bool IsMLTensorSupported() { return is_supported; } +// Convert int8 to uint4/int4 (stored as uint8) +uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type) { + uint8_t result = 0; + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + if (value < 0 || value > 15) { + ORT_THROW("Value cannot be safely converted to uint4."); + } + result |= (static_cast(value) << 4); + } else { + if (value < -8 || value > 7) { + ORT_THROW("Value cannot be safely converted to int4."); + } + result |= (value << 4); + } + + return result; +} + +// Convert float32 to float16 (stored as uint16) +uint16_t PackFloat32ToUint16AsFloat16(float value) { + uint32_t float32_bits; + + // Safely copy the float bits into an integer + std::memcpy(&float32_bits, &value, sizeof(float)); + + // Extract the sign, exponent, and mantissa from the float32 bits + uint32_t sign = (float32_bits >> 31) & 0x1; + uint32_t exponent = (float32_bits >> 23) & 0xFF; + uint32_t mantissa = float32_bits & 0x7FFFFF; + + // Shift the sign for float16 + uint16_t sign_float16 = sign << 15; + + // Handle special cases: Infinity and NaN + if (exponent == 255) { + return sign_float16 | (0x1F << 10) | (mantissa ? 0x200 : 0); + } + // Handle zero and subnormal numbers in float32 + if (exponent == 0) { + return sign_float16 | (mantissa >> 13); + } + + // Adjust the exponent for float16 (subtract bias difference: 127 - 15 = 112) + int exponent_float16 = exponent - 112; + + // Handle exponent overflow (larger than float16 can represent) + if (exponent_float16 >= 0x1F) { + return sign_float16 | (0x1F << 10); + } + // Handle exponent underflow (smaller than float16 can represent) + if (exponent_float16 <= 0) { + mantissa = (mantissa | 0x800000) >> (1 - exponent_float16); + return sign_float16 | (mantissa >> 13); + } + + // Adjust the mantissa by shifting it to fit float16 format (round to nearest even) + uint16_t mantissa_float16 = (mantissa + 0x1000) >> 13; + + // Combine sign, exponent, and mantissa into the final float16 representation + return sign_float16 | (exponent_float16 << 10) | mantissa_float16; +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index aa3613551d8e1..a06f46f1bdf0a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include "core/common/inlined_containers.h" #include @@ -81,7 +82,7 @@ inline std::string GetTensorName(const ConstPointerContainer index) ? std::string(input_defs[index]->Name()) : ""; } -inline std::vector GetVecUint32FromVecInt64(const std::vector& int64_vec) { +inline std::vector GetVecUint32FromVecInt64(gsl::span int64_vec) { std::vector uint32_vec; uint32_vec.reserve(int64_vec.size()); std::transform(int64_vec.begin(), int64_vec.end(), @@ -180,7 +181,8 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } -bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, + const logging::Logger& logger, bool allow_empty_input = false); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, @@ -191,6 +193,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v static const InlinedHashMap op_map = { {"Abs", "abs"}, {"Add", "add"}, + {"And", "logicalAnd"}, {"ArgMax", "argMax"}, {"ArgMin", "argMin"}, {"AveragePool", "averagePool2d"}, @@ -203,10 +206,12 @@ static const InlinedHashMap op_map = { {"ConvInteger", "conv2dInteger"}, {"ConvTranspose", "convTranspose2d"}, {"Cos", "cos"}, + {"CumSum", "cumulativeSum"}, {"Div", "div"}, {"DequantizeLinear", "dequantizeLinear"}, {"Dropout", "identity"}, {"DynamicQuantizeLinear", "dynamicQuantizeLinear"}, + {"Einsum", "matmul"}, {"Elu", "elu"}, {"Equal", "equal"}, {"Erf", "erf"}, @@ -215,6 +220,8 @@ static const InlinedHashMap op_map = { {"Flatten", "reshape"}, {"Floor", "floor"}, {"Gather", "gather"}, + {"GatherElements", "gatherElements"}, + {"GatherND", "gatherND"}, {"Gelu", "gelu"}, {"Gemm", "gemm"}, {"GlobalAveragePool", "averagePool2d"}, @@ -222,7 +229,7 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, {"GreaterOrEqual", "greaterOrEqual"}, - {"Gru", "gru"}, + {"GRU", "gru"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, {"Identity", "identity"}, @@ -234,6 +241,7 @@ static const InlinedHashMap op_map = { {"Log", "log"}, {"LpPool", "l2Pool2d"}, {"LSTM", "lstm"}, + {"LRN", "averagePool2d"}, {"MatMul", "matmul"}, {"MatMulInteger", "matmulInteger"}, {"Max", "max"}, @@ -242,6 +250,7 @@ static const InlinedHashMap op_map = { {"Mul", "mul"}, {"Neg", "neg"}, {"Not", "logicalNot"}, + {"Or", "logicalOr"}, {"Pad", "pad"}, {"Pow", "pow"}, {"PRelu", "prelu"}, @@ -260,8 +269,12 @@ static const InlinedHashMap op_map = { {"Relu", "relu"}, {"Reshape", "reshape"}, {"Resize", "resample2d"}, + {"ScatterElements", "scatterElements"}, + {"ScatterND", "scatterND"}, {"Shape", "slice"}, {"Sigmoid", "sigmoid"}, + {"Sign", "sign"}, + {"SimplifiedLayerNormalization", "layerNormalization"}, {"Softplus", "softplus"}, {"Softsign", "softsign"}, {"Sin", "sin"}, @@ -278,6 +291,7 @@ static const InlinedHashMap op_map = { {"Trilu", "triangular"}, {"Unsqueeze", "reshape"}, {"Where", "where"}, + {"Xor", "logicalXor"}, }; inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder, @@ -326,6 +340,13 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, const std::string& webnn_input_output_name, const std::string& onnx_input_output_name, const logging::Logger& logger); +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, @@ -335,5 +356,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); bool IsMLTensorSupported(); +uint8_t PackInt8ToUint8AsNibble(int8_t value, const int32_t& data_type); +uint16_t PackFloat32ToUint16AsFloat16(float value); + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 1e641017f36b6..290d16a48dd83 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -21,8 +21,6 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& model_builder.GetOpSupportLimits(), logger), "Unsupported operator ", node.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); - LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() - << "] type: [" << node.OpType() << "] was added"; return Status::OK(); } @@ -31,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const emscripten::val& wnn_limits, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, wnn_limits, logger)) + if (!HasSupportedInputs(initializers, node, wnn_limits, logger)) return false; if (!HasSupportedOutputs(node, wnn_limits, logger)) @@ -43,19 +41,19 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, +bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsTensorShapeSupported(*input, node_name, logger)) { + if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) { return false; } } - return HasSupportedInputsImpl(node, wnn_limits, logger); + return HasSupportedInputsImpl(initializers, node, wnn_limits, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, +bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index a632876dab2b9..0a4367a71add4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -22,6 +22,9 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; protected: + explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false) + : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { + } virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const ORT_MUST_USE_RESULT = 0; @@ -37,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; @@ -53,8 +56,10 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + + const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index af82a01b14de5..e14507e8f5aea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,8 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -86,8 +86,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 70ebe18c85b86..4b2f04bed0eb1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -86,8 +86,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 48dd6f3beb020..bac528300e077 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -21,8 +21,8 @@ class ConcatOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -42,7 +42,6 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector inputs; for (const auto* input : input_defs) { - LOGS(logger, VERBOSE) << "input name " << input->Name(); inputs.push_back(model_builder.GetOperand(input->Name())); } @@ -56,8 +55,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 329db75316e82..81e688ea4f8ea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,8 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + x_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } if (input_defs.size() >= 4) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + w_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } output = model_builder.GetBuilder().call("conv2dInteger", input, x_zero_point, filter, w_zero_point, options); @@ -397,8 +397,8 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc new file mode 100644 index 0000000000000..be30c5520d62e --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class CumSumOpBuilder : public BaseOpBuilder { + // Add operator related. + + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +// Add operator related. + +void CumSumOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip axis. + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); +} + +Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + const auto input_rank = input_shape.size(); + + const auto& initializers = model_builder.GetInitializerTensors(); + const std::string axis_name = GetTensorName(input_defs, 1); + const auto axis_tensor = *initializers.at(axis_name); + emscripten::val axis = emscripten::val::undefined(); + ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value"); + int64_t webnn_axis = HandleNegativeAxis(axis.as(), input_rank); + + NodeAttrHelper helper(node); + const auto exclusive = helper.Get("exclusive", 0); + const auto reverse = helper.Get("reverse", 0); + + emscripten::val options = emscripten::val::object(); + options.set("exclusive", exclusive == 1); + options.set("reversed", reverse == 1); + options.set("label", node.Name()); + + emscripten::val output = emscripten::val::object(); + output = model_builder.GetBuilder().call("cumulativeSum", input, gsl::narrow(webnn_axis), + options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. +bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const std::string axis_name = GetTensorName(input_defs, 1); + // Inputs contain optional 'axis' input. + if (!Contains(initializers, axis_name)) { + LOGS(logger, VERBOSE) << "The axis must be a constant initializer."; + return false; + } + + return true; +} + +void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index 5434194a214ac..9bb930c63b009 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -59,22 +59,14 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector mask_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape"); std::vector dims = GetVecUint32FromVecInt64(mask_shape); - - emscripten::val desc = emscripten::val::object(); - desc.set("dataType", "uint8"); - desc.set("dimensions", emscripten::val::array(dims)); - desc.set("shape", emscripten::val::array(dims)); - const auto num_elements = narrow(Product(mask_shape)); - emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements); - ones_buffer.call("fill", 1); - - emscripten::val mask_output = model_builder.GetBuilder().call("constant", desc, ones_buffer); + emscripten::val one_constant = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_BOOL, 1, dims); emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, // beacuse WebNN does not support a constant operand as output. - mask_output = model_builder.GetBuilder().call("identity", mask_output, options); + emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc new file mode 100644 index 0000000000000..ef713f48b8135 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -0,0 +1,793 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/cpu/tensor/reshape_helper.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class EinsumOpBuilder : public BaseOpBuilder { + // Add operator related. + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Helper functions, thanks for DML EP's OperatorHelper. +enum class RecognizedOperatorType { + None, + Identity, + ReduceSum, + Transpose, + Diagonal, + Multiply, + Pairwise, + Total, +}; + +struct RecognizedOperatorInfo { + RecognizedOperatorType recognized_operator_type; + std::initializer_list component_ranks; + std::initializer_list label_indices; +}; + +struct Component { + uint32_t label_index_begin; + uint32_t label_index_end; + + uint32_t GetDimensionCount() const noexcept { + return label_index_end - label_index_begin; + } + gsl::span GetLabels(gsl::span labels) const { + return labels.subspan(label_index_begin, label_index_end - label_index_begin); + } +}; + +bool ParseEquationComponents(const Node& node, + const std::string_view equation, + std::vector& label_indices, + std::vector& components, + std::vector& output_dimensions, + uint32_t& num_labels, + const logging::Logger& logger) { + // Parse an equation like 'ij,jk->ik' into components {ij, jk, ik} mapping letters to + // numeric indices {(0,1}, {1,2}, {0,2}}. The last component is the output. + // Read first to last character in equation, looking for letters, commas, and one arrow. + // The ellipsis is not supported. + std::map label_maps; + std::set repeated_labels; + + num_labels = 0; + Component current_component = {}; + bool at_output = false; + bool end_flag = false; + + for (const char* it = equation.data(); !end_flag; ++it) { + // std::string.data() promises the end of the string is '\0' + char ch = *it; + + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) { + const auto [i, inserted] = label_maps.insert({ch, num_labels}); + if (inserted) { + if (at_output) { + LOGS(logger, VERBOSE) << "Found label in equation output not matching any label from inputs."; + return false; + } + ++num_labels; + } else if (!at_output) { + repeated_labels.insert(ch); + } + label_indices.push_back(i->second); + } else if (ch == ' ') { + continue; + } else { + current_component.label_index_end = static_cast(label_indices.size()); + components.push_back(current_component); + current_component.label_index_begin = current_component.label_index_end; + + switch (ch) { + case ',': + break; + + case '-': + ++it; + if (*it != '>') { + LOGS(logger, VERBOSE) << "Expected '->' for output."; + return false; + } + if (at_output) { + LOGS(logger, VERBOSE) << "Only one output arrow '->' is valid."; + return false; + } + at_output = true; + break; + + case '.': + // Ellipsis is unsupported + LOGS(logger, VERBOSE) << "Ellipsis is unsupported."; + return false; + + case '\0': + end_flag = true; + break; // End of string. + + default: + LOGS(logger, VERBOSE) << "Unsupported character in equation string."; + return false; + } + } + } + + // If no explicit output was given, generate an implicit output by ordering all the + // labels in alphabetic order (by ASCII value consistent with numpy, so Z < a). + // Exclude any labels that occurred more than once, as these cancel out. + if (!at_output) { + for (auto i : label_maps) { + if (repeated_labels.count(i.first) == 0) { + label_indices.push_back(i.second); + } + } + + current_component.label_index_end = static_cast(label_indices.size()); + components.push_back(current_component); + } + return true; +} + +// For two inputs A,B and one output C +Status PairwiseOperandProcess(ModelBuilder& model_builder, + const Node& node, + const std::vector& label_indices, + const std::vector& components, + const std::vector& output_dimensions, + uint32_t num_labels, + emscripten::val& output, + const logging::Logger& logger) { + auto input_a_labels = components[0].GetLabels(label_indices); + auto input_b_labels = components[1].GetLabels(label_indices); + auto output_labels = components[2].GetLabels(label_indices); + + /* + Step 1. Transpose and Reshape + + (0/1,0/1,0/1) means dim i whether appears in (A,B,C) + For new A, it has three segments [...a_1..., a_2, a_3], a_1 has multiple dims, a_2 and a_3 only have one dim respectively + For new B, it has three segments [...b_1..., b_2, b_3], b_1 has multiple dims, b_2 and b_3 only have one dim respectively + a_1 and b_1 are batch dims, and [a_2,a_3], [b_2,b_3] are for matmul + + case (1,0,0) and (0,1,0): reduce, here we treat it as batch dimension, and reduceSum at the end. + add additional dim for B/A + case (1,1,1): batch dimension, put it in the front. + case (1,0,1): gemm dim for A, put it in a_2 + case (0,1,1): gemm dim for B, put it in b_3 + case (1,1,0): summation dim / gemm dim for both A and B, put it in a_3 and b_2 + + Notes: + # of (1,1,0) maybe > 1, flatten / reshape a_3 and b_2 + # of (1,1,0) maybe = 0, add one additional dim for a_3 and b_2 + */ + + // The index in input/output of the dim index + std::map input_a_axes_map, input_b_axes_map, output_axes_map; + + for (uint32_t i = 0; i <= num_labels + 1; ++i) { + input_a_axes_map[i] = input_b_axes_map[i] = output_axes_map[i] = -1; + } + int32_t index = 0; + for (auto axis : input_a_labels) { + input_a_axes_map[axis] = index++; + } + index = 0; + for (auto axis : input_b_labels) { + input_b_axes_map[axis] = index++; + } + index = 0; + for (auto axis : output_labels) { + output_axes_map[axis] = index++; + } + + // Inputs Reshape + // a_0 = [a_1,a_2,a_3], b_0 = [b_1,b_2,b_3] + std::vector a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3; + uint32_t a_idx = input_a_labels.size(); + uint32_t b_idx = input_b_labels.size(); + bool a_flag = false; // whether a_2 has element + bool b_flag = false; // whether b_3 has element + + for (uint32_t i = 0; i < num_labels; ++i) { + if (input_a_axes_map[i] != -1) { + if (input_b_axes_map[i] != -1) { + if (output_axes_map[i] != -1) { + // The index in input/output of the dim index + a_1.push_back(i); + b_1.push_back(i); + } else { + // (1,1,0) push back in the middle for b and end for a + a_3.push_back(i); + b_2.push_back(i); + } + } else { + // (1,0,x) push back in the middle for a. If more than one, push back in the front for a, b. + if (a_flag) { + a_1.push_back(i); + b_1.push_back(i); + input_b_axes_map[i] = b_idx++; + } else { + a_2.push_back(i); + a_flag = true; + } + } + } else { + // (0,1,x) push back in the end for b. If more than one, push back in the front for a, b. + if (input_b_axes_map[i] != -1) { + if (b_flag) { + a_1.push_back(i); + b_1.push_back(i); + input_a_axes_map[i] = a_idx++; + } else { + b_3.push_back(i); + b_flag = true; + } + } + } + } + + // Matrix multiplication can be formatted in (...,i,j) * (...,j,k) ==> (...,i,k) + // Even inner and outer product can be reformatted as this. + // Inner product (1,i) * (i,1) ==> (1,1) + // Outer product (i,1) * (1,j) ==> (i,j) + // i.e., in our expression, (a_2,a_3) * (b_2,b_3) ==> (a_2,b_3) + + if (!a_flag) { + // Lack of a_2 element, add a new a_2, whose dim value = 1 + a_2.push_back(num_labels + 1); + input_a_axes_map[num_labels + 1] = a_idx++; + } + if (!b_flag) { + // Lack of b_3 element, add a new b_3, whose dim value = 1 + b_3.push_back(num_labels + 2); + input_b_axes_map[num_labels + 2] = b_idx++; + b_idx++; + } + + if (a_3.empty()) { + // Lack of a_3 and b_2 elements, add a new a_3 for A and a new b_2 for B, whose dim value = 1 + a_3.push_back(num_labels); + b_2.push_back(num_labels); + input_a_axes_map[num_labels] = a_idx; + input_b_axes_map[num_labels] = b_idx; + } + + a_0 = a_1; + b_0 = b_1; + a_0.insert(a_0.end(), a_2.begin(), a_2.end()); + a_0.insert(a_0.end(), a_3.begin(), a_3.end()); + b_0.insert(b_0.end(), b_2.begin(), b_2.end()); + b_0.insert(b_0.end(), b_3.begin(), b_3.end()); + + std::vector permutation_a, permutation_b; + for (uint32_t i = 0; i < a_0.size(); ++i) { + permutation_a.push_back(static_cast(input_a_axes_map[a_0[i]])); + permutation_b.push_back(static_cast(input_b_axes_map[b_0[i]])); + } + + const auto& input_defs = node.InputDefs(); + emscripten::val input_a = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val input_b = model_builder.GetOperand(input_defs[1]->Name()); + std::vector new_a_shape, new_b_shape; + if (input_a_labels.size() < a_0.size()) { + std::vector input_a_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_a_shape, logger), "Cannot get shape"); + std::transform(input_a_shape.begin(), input_a_shape.end(), std::back_inserter(new_a_shape), + [](int64_t i) { return static_cast(i); }); + for (uint32_t i = 0; i < a_0.size() - input_a_labels.size(); ++i) { + new_a_shape.push_back(SafeInt(1)); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_reshape"); + input_a = model_builder.GetBuilder().call("reshape", + input_a, + emscripten::val::array(new_a_shape), + options); + } + if (input_b_labels.size() < b_0.size()) { + std::vector input_b_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], input_b_shape, logger), "Cannot get shape"); + std::transform(input_b_shape.begin(), input_b_shape.end(), std::back_inserter(new_b_shape), + [](int64_t i) { return static_cast(i); }); + for (uint32_t i = 0; i < b_0.size() - input_b_labels.size(); ++i) { + new_b_shape.push_back(SafeInt(1)); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_reshape"); + input_b = model_builder.GetBuilder().call("reshape", + input_b, + emscripten::val::array(new_b_shape), + options); + } + + // Inputs Transpose + std::vector sequence(permutation_a.size()); + std::iota(sequence.begin(), sequence.end(), 0); + if (permutation_a != sequence) { + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation_a)); + options.set("label", node.Name() + "_transpose"); + input_a = model_builder.GetBuilder().call("transpose", input_a, options); + } + if (permutation_b != sequence) { + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation_b)); + options.set("label", node.Name() + "_transpose"); + input_b = model_builder.GetBuilder().call("transpose", input_b, options); + } + + // Input Reshape: if the number of (1,1,0) > 1, flatten the b_2 and a_3 dims. + if (a_3.size() > 1) { + if (new_a_shape.empty()) { + std::vector input_a_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_a_shape, logger), "Cannot get shape"); + std::transform(input_a_shape.begin(), input_a_shape.end(), std::back_inserter(new_a_shape), + [](int64_t i) { return static_cast(i); }); + } + if (new_b_shape.empty()) { + std::vector input_b_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], input_b_shape, logger), "Cannot get shape"); + std::transform(input_b_shape.begin(), input_b_shape.end(), std::back_inserter(new_b_shape), + [](int64_t i) { return static_cast(i); }); + } + std::vector new_new_a_shape, new_new_b_shape; + uint32_t a_dim = 1, b_dim = 1; + for (auto idx : a_1) { + new_new_a_shape.push_back(new_a_shape[idx]); + } + for (auto idx : a_2) { + new_new_a_shape.push_back(new_a_shape[idx]); + } + for (auto idx : a_3) { + a_dim *= new_a_shape[idx]; + } + new_new_a_shape.push_back(a_dim); + for (auto idx : b_1) { + new_new_b_shape.push_back(new_b_shape[idx]); + } + for (auto idx : b_2) { + b_dim *= new_b_shape[idx]; + } + new_new_b_shape.push_back(b_dim); + for (auto idx : b_3) { + new_new_b_shape.push_back(new_b_shape[idx]); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_reshape"); + input_a = model_builder.GetBuilder().call("reshape", + input_a, + emscripten::val::array(new_new_a_shape), + options); + input_b = model_builder.GetBuilder().call("reshape", + input_b, + emscripten::val::array(new_b_shape), + options); + } + + // Step 2. Matmul + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_matmul"); + output = model_builder.GetBuilder().call("matmul", input_a, input_b, options); + std::vector output_indices = a_1; + output_indices.push_back(a_2.back()); + output_indices.push_back(b_3.back()); + + /* + Step 3. Output Transpose: + Use the following fast permutation calculation algorithm + to calculate the permutation of transpose. + sequence x[] -> sequence y[] : permutation p[] + x[s[i]] = i, y[t[i]] = i, p[t[i]] = s[i] + output_indices is x and target_output_indices is y + */ + std::vector target_output_indices(output_labels.begin(), output_labels.end()); + + // map output dim labels to 0 ~ n-1 + std::vector output_indices_sorted(output_indices.begin(), output_indices.end()); + std::map mapping; + std::sort(output_indices_sorted.begin(), output_indices_sorted.end()); + for (size_t i = 0; i < output_indices_sorted.size(); i++) { + mapping[output_indices_sorted[i]] = i; + } + + for (size_t i = 0; i < output_indices.size(); i++) { + output_indices[i] = mapping[output_indices[i]]; + if (i < target_output_indices.size()) { + target_output_indices[i] = mapping[target_output_indices[i]]; + } + } + + uint32_t pad = target_output_indices.size(); + std::vector s(output_indices.size(), -1); + std::vector t(output_indices.size(), -1); + std::vector p(output_indices.size(), 0); + for (uint32_t i = 0; i < output_indices.size(); ++i) { + s[output_indices[i]] = i; + if (i < target_output_indices.size()) { + t[target_output_indices[i]] = i; + } + } + for (uint32_t i = 0; i < output_indices.size(); ++i) { + if (t[i] == -1) { + t[i] = pad++; + } + p[static_cast(t[i])] = static_cast(s[i]); + } + + std::vector sequence_o(output_indices.size()); + std::iota(sequence_o.begin(), sequence_o.end(), 0); + if (p != sequence_o) { + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(p)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", output, options); + } + + // Step 4. Output ReduceSum + if (output_labels.size() < output_indices.size()) { + std::vector axes_data; + for (uint32_t i = output_labels.size(); i < output_indices.size(); ++i) { + axes_data.push_back(SafeInt(i)); + } + emscripten::val options_reduce = emscripten::val::object(); + options_reduce.set("axes", emscripten::val::array(axes_data)); + options_reduce.set("label", node.Name() + "_reduceSum"); + output = model_builder.GetBuilder().call("reduceSum", output, options_reduce); + } + return Status::OK(); +} + +RecognizedOperatorType DetermineRecognizedOperatorType(const std::vector& label_indices, + const std::vector& components, + const std::vector& output_dimensions) { + if (components.empty()) return RecognizedOperatorType::None; + + auto equals = [](gsl::span a, gsl::span b) { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); + }; + + std::array component_ranks; + if (components.size() > component_ranks.size()) { + // So far, not support for more than two inputs and one output. + return RecognizedOperatorType::None; + } else if (components.size() == 2) { // one input + auto input_labels = components[0].GetLabels(label_indices); + auto output_labels = components[1].GetLabels(label_indices); + if (input_labels.size() == output_labels.size()) { + if (equals(input_labels, output_labels)) { + // Identity: input labels = output labels + return RecognizedOperatorType::Identity; + } else { + return RecognizedOperatorType::Transpose; + } + } else if (input_labels.size() == input_labels.back() + 1) { + // ReduceSum: There is no repeated character in input. + return RecognizedOperatorType::ReduceSum; + } else if (input_labels.size() == input_labels.back() + 2) { + // Diagonal: One repeated character in input, ii->i / iij->ij / iijk -> ijk. + return RecognizedOperatorType::Diagonal; + } else { + return RecognizedOperatorType::None; + } + } else if (components.size() == 3) { // two inputs + auto input_A_labels = components[0].GetLabels(label_indices); + auto input_B_labels = components[1].GetLabels(label_indices); + auto output_labels = components[2].GetLabels(label_indices); + if (equals(input_A_labels, output_labels) && equals(input_B_labels, output_labels)) { // element-wise product + return RecognizedOperatorType::Multiply; + } + } + + return RecognizedOperatorType::Pairwise; +} + +// Add operator related. + +Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val output = emscripten::val::object(); + + NodeAttrHelper helper(node); + const auto equation = helper.Get("equation", std::string(" ")); + + std::vector label_indices; + std::vector components; + std::vector output_dimensions; + uint32_t num_labels; + ORT_RETURN_IF_NOT(ParseEquationComponents(node, equation, label_indices, components, output_dimensions, + num_labels, logger), + "Error parsing equation components."); + + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + + switch (recognized_operator_type) { + case RecognizedOperatorType::Multiply: { + emscripten::val a = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val b = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name() + "_mul"); + output = model_builder.GetBuilder().call("mul", a, b, options); + } break; + case RecognizedOperatorType::ReduceSum: { + auto kept_axes = components.back().GetLabels(label_indices); + std::vector reduced_axes; + uint32_t kept_axes_mask = 0; + for (auto axis : kept_axes) { + kept_axes_mask |= (1 << axis); + } + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + for (uint32_t axis = 0, axis_count = static_cast(input_shape.size()); axis < axis_count; ++axis) { + if (~kept_axes_mask & (1 << axis)) { + reduced_axes.push_back(axis); + } + } + + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("keepDimensions", false); + options.set("axes", emscripten::val::array(reduced_axes)); + options.set("label", node.Name() + "_reduceSum"); + + output = model_builder.GetBuilder().call("reduceSum", input, options); + + // transpose output + std::vector output_labels_sorted(kept_axes.begin(), kept_axes.end()); + std::map mapping; + std::sort(output_labels_sorted.begin(), output_labels_sorted.end()); + + auto equals = [](std::vector a, gsl::span b) { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); + }; + if (equals(output_labels_sorted, kept_axes)) { + break; + } + + for (size_t i = 0; i < output_labels_sorted.size(); i++) { + mapping[output_labels_sorted[i]] = i; + } + std::vector permutation; + for (auto idx : kept_axes) { + permutation.push_back(mapping[idx]); + } + emscripten::val options_transpose = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", output, options); + } break; + case RecognizedOperatorType::Diagonal: { + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + auto input_labels = components[0].GetLabels(label_indices); + auto output_labels = components[1].GetLabels(label_indices); + uint32_t diagonal_idx_1, diagonal_idx_2; + uint32_t permutation_idx = 0; + for (uint32_t idx = 0; idx < input_labels.size(); idx++) { + if (idx != input_labels[idx]) { + diagonal_idx_1 = input_labels[idx]; + diagonal_idx_2 = idx; + break; + } + } + + // tranpose input + std::vector permutation(input_labels.size()); + for (uint32_t idx = 0; idx < input_labels.size(); idx++) { + if (idx != diagonal_idx_1 && idx != diagonal_idx_2) { + permutation[permutation_idx++] = idx; + } + } + permutation[permutation_idx++] = diagonal_idx_1; + permutation[permutation_idx] = diagonal_idx_2; + + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", input, options); + + // triu + tril = diagonal + emscripten::val options_trilu = emscripten::val::object(); + options_trilu.set("label", node.Name() + "_triangular"); + output = model_builder.GetBuilder().call("triangular", output, options_trilu); // triu + options_trilu.set("upper", false); + output = model_builder.GetBuilder().call("triangular", output, options_trilu); // tril + + // reducesum to achieve the diagonal values + std::vector input_shape; + std::vector reduced_axes; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + if (input_shape[diagonal_idx_1] > input_shape[diagonal_idx_2]) { + reduced_axes.push_back(input_labels.size() - 2); + } else { + reduced_axes.push_back(input_labels.size() - 1); + } + emscripten::val options_reduce = emscripten::val::object(); + options_reduce.set("keepDimensions", false); + options_reduce.set("axes", emscripten::val::array(reduced_axes)); + options_reduce.set("label", node.Name() + "_reduceSum"); + output = model_builder.GetBuilder().call("reduceSum", output, options_reduce); // triu + + // transpose output + std::vector target_output_indices(output_labels.begin(), output_labels.end()); + std::vector output_indices(permutation.begin(), permutation.end() - 1); + + // Use the fast permutation calculation algorithm mentioned above + std::vector s(output_indices.size(), -1); + std::vector t(output_indices.size(), -1); + std::vector p(output_indices.size(), 0); + for (uint32_t i = 0; i < output_indices.size(); ++i) { + s[output_indices[i]] = i; + t[target_output_indices[i]] = i; + } + for (uint32_t i = 0; i < output_indices.size(); ++i) { + p[static_cast(t[i])] = static_cast(s[i]); + } + + std::vector sequence_o(output_indices.size()); + std::iota(sequence_o.begin(), sequence_o.end(), 0); + if (p != sequence_o) { + emscripten::val options_transpose = emscripten::val::object(); + options.set("permutation", emscripten::val::array(p)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", output, options); + } + } break; + + case RecognizedOperatorType::Transpose: { + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + assert(components.front().GetDimensionCount() == components.back().GetDimensionCount()); + // Remap transposed strides using the component labels from input to output. + auto output_labels = components.back().GetLabels(label_indices); + + std::vector permutation{output_labels.begin(), output_labels.end()}; + emscripten::val options = emscripten::val::object(); + options.set("permutation", emscripten::val::array(permutation)); + options.set("label", node.Name() + "_transpose"); + output = model_builder.GetBuilder().call("transpose", input, options); + } break; + + case RecognizedOperatorType::Identity: { + emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); + output = input; + } break; + + case RecognizedOperatorType::Pairwise: { + ORT_RETURN_IF_ERROR(PairwiseOperandProcess(model_builder, node, label_indices, components, + output_dimensions, num_labels, output, logger)); + } break; + + default: + break; + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, + const Node& node, + const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + if (input_defs.size() > 2) { + // TODO: Support more than two inputs. + LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; + return false; + } + + NodeAttrHelper helper(node); + const auto equation = helper.Get("equation", std::string(" ")); + std::vector label_indices; + std::vector components; + std::vector output_dimensions; + uint32_t num_labels; + + if (!ParseEquationComponents(node, equation, label_indices, components, + output_dimensions, num_labels, logger)) { + LOGS(logger, VERBOSE) << "EinSum input equation is illegal."; + return false; + } + + if (static_cast(input_defs.size()) + 1 != components.size()) { + LOGS(logger, VERBOSE) << "EinSum input tensor count is inconsistent with the equation component count."; + return false; + } + + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + if (recognized_operator_type == RecognizedOperatorType::None) { + LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; + return false; + } + + return true; +} + +bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + const auto& op_type = node.OpType(); + int32_t input0_type; + int32_t input1_type; + bool has_input1 = input_defs.size() > 1 && input_defs[1]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + (has_input1 && !GetType(*input_defs[1], input1_type, logger))) { + return false; + } + + if (has_input1 && input0_type != input1_type) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + NodeAttrHelper helper(node); + const auto equation = helper.Get("equation", std::string(" ")); + std::vector label_indices; + std::vector components; + std::vector output_dimensions; + uint32_t num_labels; + + if (!ParseEquationComponents(node, equation, label_indices, + components, output_dimensions, num_labels, logger)) { + LOGS(logger, VERBOSE) << "EinSum input equation is illegal."; + return false; + } + + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + + if (recognized_operator_type == RecognizedOperatorType::None) { + LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; + return false; + } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { + // Map to WebNN's gemm or matmul + return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); + } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { + return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); + } else { + return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); + } +} + +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index 5e99551fe6e7d..f5e1f59602c5d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -88,6 +88,10 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers LOGS(logger, VERBOSE) << "Cannot get shape."; return false; } + if (std::any_of(new_shape.begin(), new_shape.end(), [](int64_t dimension) { return dimension == 0; })) { + LOGS(logger, VERBOSE) << "WebNN expand does not support new shape with 0 dimension."; + return false; + } std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc new file mode 100644 index 0000000000000..cb7b7de74e121 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GatherElementsOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + const size_t rank = input_shape.size(); + NodeAttrHelper helper(node); + const uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 0), rank)); + options.set("axis", axis); + + emscripten::val output = model_builder.GetBuilder().call("gatherElements", data, indices, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc new file mode 100644 index 0000000000000..002a1a6a63026 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GatherNDOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status GatherNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("gatherND", data, indices, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("batch_dims", 0) != 0) { + LOGS(logger, VERBOSE) << "GatherND: WebNN only supports batch_dims 0 (default)"; + return false; + } + + return true; +} + +bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index ae9fe3e3f3bd1..88d22f103cadc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,8 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -69,8 +69,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 1477530ce1894..5f4e6de8fda98 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,8 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + a_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } if (input_defs.size() >= 4) { b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + b_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); } output = model_builder.GetBuilder().call("matmulInteger", a, @@ -215,8 +215,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // A data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index c92fe7366d494..b240e30d38b22 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,8 +26,10 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -185,44 +187,68 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - int32_t input0_type = 0; // input data type - int32_t input1_type = 0; // weight data type - int32_t input2_type = 0; // recurrentWeight data type - int32_t input3_type = 0; // bias data type - int32_t input4_type = 0; // recurrentBias data type - int32_t input5_type = 0; // initialHiddenState data type - bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); - bool has_input4 = input_defs.size() > 4 && input_defs[4]->Exists(); - bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); - - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger) || - !GetType(*input_defs[2], input2_type, logger) || - (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || - (has_input4 && !GetType(*input_defs[4], input4_type, logger)) || - (has_input5 && !GetType(*input_defs[5], input5_type, logger))) { + int32_t input_X_type = 0; // input data type + int32_t input_W_type = 0; // weight data type + int32_t input_R_type = 0; // recurrent weight data type + int32_t input_B_type = 0; // bias data type + int32_t input_initial_h_type = 0; // initial hidden state data type + bool has_input_B = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input_initial_h = input_defs.size() > 5 && input_defs[5]->Exists(); + + if (!GetType(*input_defs[0], input_X_type, logger) || + !GetType(*input_defs[1], input_W_type, logger) || + !GetType(*input_defs[2], input_R_type, logger) || + (has_input_B && !GetType(*input_defs[3], input_B_type, logger)) || + // input_defs[4] refers to sequence_lens and is a fixed data type of int32. + (has_input_initial_h && !GetType(*input_defs[5], input_initial_h_type, logger))) { return false; } - InlinedVector input_types = {input0_type, input1_type, input2_type}; - if (has_input3) { - input_types.push_back(input3_type); + InlinedVector input_types = {input_X_type, input_W_type, input_R_type}; + if (has_input_B) { + input_types.push_back(input_B_type); } - if (has_input4) { - input_types.push_back(input4_type); - } - if (has_input5) { - input_types.push_back(input5_type); + if (has_input_initial_h) { + input_types.push_back(input_initial_h_type); } if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); +} + +bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output_defs = node.OutputDefs(); + const auto& op_type = node.OpType(); + int32_t Y_type = 0; + int32_t Y_h_type = 0; + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + + bool Y_supported = has_Y && GetType(*output_defs[0], Y_type, logger); + bool Y_h_supported = has_Y_h && GetType(*output_defs[1], Y_h_type, logger); + + if (Y_supported && !Y_h_supported) { + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + } else if (!Y_supported && Y_h_supported) { + return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger); + } else if (Y_supported && Y_h_supported) { + if (Y_type != Y_h_type) { + LOGS(logger, VERBOSE) << "[GRU] Output data types must be the same."; + return false; + } + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + } else { + LOGS(logger, VERBOSE) << "[GRU] No output found."; + return false; + } } void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index ea7f70b4598e6..91910f55f37c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,8 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -33,28 +33,20 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons const auto& op_type = node.OpType(); emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val input1 = emscripten::val::undefined(); - if (input_defs.size() > 1) { - input1 = model_builder.GetOperand(input_defs[1]->Name()); - } emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - if (op_type == "Equal") { - output = model_builder.GetBuilder().call("equal", input0, input1, options); - } else if (op_type == "Greater") { - output = model_builder.GetBuilder().call("greater", input0, input1, options); - } else if (op_type == "GreaterOrEqual") { - output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1, options); - } else if (op_type == "Less") { - output = model_builder.GetBuilder().call("lesser", input0, input1, options); - } else if (op_type == "LessOrEqual") { - output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); - } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input0, options); + + std::string webnn_op_type; + ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type"); + + if (input_defs.size() == 1) { + // Not + output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input0, options); } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + input1 = model_builder.GetOperand(input_defs[1]->Name()); + output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input0, input1, options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); @@ -68,16 +60,19 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2 && op_type != "Not") { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " - << input_defs.size(); + + size_t expected_input_count = (op_type == "Not") ? 1 : 2; + if (input_defs.size() != expected_input_count) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] expected input count: " + << expected_input_count << ", actual: " << input_defs.size(); return false; } + return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LogicalOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; @@ -105,12 +100,15 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& static std::vector op_types = { + "And", "Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual", "Not", + "Or", + "Xor", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc new file mode 100644 index 0000000000000..19f6d6aff8f97 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class LRNOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + int32_t input_data_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type"); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + const auto node_name = node.Name(); + emscripten::val wnn_builder = model_builder.GetBuilder(); + + NodeAttrHelper helper(node); + const float alpha = helper.Get("alpha", 0.0001f); + const float beta = helper.Get("beta", 0.75f); + const float bias = helper.Get("bias", 1.0f); + const uint32_t size = helper.Get("size", 1); + + // Prepare WebNN constants for alpha, beta, bias attributes. + // Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'. + emscripten::val alpha_constant = model_builder.CreateOrGetConstant(input_data_type, alpha); + emscripten::val beta_constant = model_builder.CreateOrGetConstant(input_data_type, beta); + emscripten::val bias_constant = model_builder.CreateOrGetConstant(input_data_type, bias); + emscripten::val pow1_constant = model_builder.CreateOrGetConstant(input_data_type, 2); + + /** + WebNN doesn't support LRN. So decompose it into a series of ops: + X --> Pow --> (Transpose)--> Pad --> AveragePool--> (Transpose) --> Mul --> Add --> Pow --> Div + ^ ^ ^ ^ ^ ^ ^ ^ + | | | | | | | | + Y:2 (0,2,3,1) Kernel:(1,size) (0,3,1,2) B:alpha B:bias B:beta A:input + */ + // + // pow(input, 2) + emscripten::val label_options = emscripten::val::object(); + label_options.set("label", node_name + "_pow1"); + emscripten::val pow1_output = wnn_builder.call("pow", input, pow1_constant, label_options); + + // transpose(pow1_output, permutation=[0, 2, 3, 1]) + // LRN is one of NHWC layout sensitive ops. When preferred layout is NCHW, move dimension 1 to dimension 3 (rightmost). + if (model_builder.GetPreferredLayout() == DataLayout::NCHW) { + std::vector perm{0, 2, 3, 1}; + emscripten::val transpose_options = emscripten::val::object(); + transpose_options.set("label", node_name + "_transpose_rightmost"); + transpose_options.set("permutation", emscripten::val::array(perm)); + pow1_output = + wnn_builder.call("transpose", pow1_output, transpose_options); + } + + // pad(pow1_output, beginning_padding = {0, 0, 0, leading_padding}, ending_padding = {0, 0, 0, trailing_padding}) + // Adding a Pad before averagePool2d and calling AveragePool with pads as 0's. + const uint32_t leading_padding = floor((size - 1) / 2); + const uint32_t trailing_padding = ceil((size - 1) / 2); + std::vector beginning_padding{0, 0, 0, leading_padding}; + std::vector ending_padding{0, 0, 0, trailing_padding}; + emscripten::val pad_options = emscripten::val::object(); + pad_options.set("label", node_name + "_pad"); + emscripten::val pad_output = + wnn_builder.call("pad", pow1_output, emscripten::val::array(beginning_padding), + emscripten::val::array(ending_padding), pad_options); + + // averagePool2d(pad_output, pool_options) + const std::vector kernel_shape = {1, size}; + emscripten::val pool_options = emscripten::val::object(); + pool_options.set("label", node_name + "_averagePool2d"); + pool_options.set("windowDimensions", emscripten::val::array(kernel_shape)); + emscripten::val pool_output = wnn_builder.call("averagePool2d", pad_output, pool_options); + + // transpose(pool_output, permutation=[0, 3, 1, 2]) + // Move dimension 3 back to dimension 1. + if (model_builder.GetPreferredLayout() == DataLayout::NCHW) { + std::vector perm{0, 3, 1, 2}; + emscripten::val transpose_options = emscripten::val::object(); + transpose_options.set("label", node_name + "_transpose_inverse"); + transpose_options.set("permutation", emscripten::val::array(perm)); + pool_output = + wnn_builder.call("transpose", pool_output, transpose_options); + } + + // mul(pool_output, alpha_constant) + label_options.set("label", node_name + "_mul"); + emscripten::val mul_output = + wnn_builder.call("mul", pool_output, alpha_constant, label_options); + + // add(mul_output, bias_constant) + label_options.set("label", node_name + "_add"); + emscripten::val add_output = wnn_builder.call("add", mul_output, bias_constant, label_options); + + // pow(add_output, beta_constant) + label_options.set("label", node_name + "_pow2"); + emscripten::val pow2_output = wnn_builder.call("pow", add_output, beta_constant, label_options); + + // div(input, pow2_output) + label_options.set("label", node_name + "_div"); + emscripten::val div_output = wnn_builder.call("div", input, pow2_output, label_options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(div_output)); + return Status::OK(); +} + +// Operator support related. +bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << "LRN only supports 4D input shape, input is " + << input_size << "D shape"; + return false; + } + + return true; +} + +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 6213b039fb2f9..33ba22ac3fb5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -25,8 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -198,8 +198,8 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LstmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index e111ca412c6e9..40f94186e9ed6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,8 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -87,8 +87,8 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool MaxMinOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index a3c6b8fdcea9b..50e49884bdfa9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,8 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -72,7 +72,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } NodeAttrHelper helper(node); - options.set("epsilon", helper.Get("epsilon", 1e-05f)); + const auto epsilon = helper.Get("epsilon", 1e-05f); + options.set("epsilon", epsilon); emscripten::val output = emscripten::val::undefined(); if (op_type == "BatchNormalization") { @@ -84,14 +85,59 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); - } else if (op_type == "LayerNormalization") { + } else if (op_type == "LayerNormalization" || op_type == "SimplifiedLayerNormalization") { int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); std::iota(axes.begin(), axes.end(), axis); - options.set("axes", emscripten::val::array(axes)); - output = model_builder.GetBuilder().call("layerNormalization", input, options); + if (op_type == "LayerNormalization") { + options.set("axes", emscripten::val::array(axes)); + output = model_builder.GetBuilder().call("layerNormalization", input, options); + } else { // SimplifiedLayerNormalization + /** + WebNN doesn't support SimplifiedLayerNormalization. So decompose it into a series of ops: + X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul + ^ ^ ^ ^ ^ + | | | | | + Y:2 axis B:epsilon A:X A:scale + */ + + int32_t input_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type"); + emscripten::val common_options = emscripten::val::object(); + + // Pow + emscripten::val pow_constant = model_builder.CreateOrGetConstant(input_type, 2); + common_options.set("label", node.Name() + "_pow"); + emscripten::val pow = + model_builder.GetBuilder().call("pow", input, pow_constant, common_options); + + // ReduceMean + emscripten::val reduce_options = emscripten::val::object(); + reduce_options.set("axes", emscripten::val::array(axes)); + reduce_options.set("keepDimensions", true); + reduce_options.set("label", node.Name() + "_reduceMean"); + emscripten::val reduce_mean = model_builder.GetBuilder().call("reduceMean", pow, reduce_options); + + // Add + emscripten::val add_constant = model_builder.CreateOrGetConstant(input_type, epsilon); + common_options.set("label", node.Name() + "_add"); + emscripten::val add = + model_builder.GetBuilder().call("add", reduce_mean, add_constant, common_options); + + // Sqrt + common_options.set("label", node.Name() + "_sqrt"); + emscripten::val sqrt = model_builder.GetBuilder().call("sqrt", add, common_options); + + // Div + common_options.set("label", node.Name() + "_div"); + emscripten::val div = model_builder.GetBuilder().call("div", input, sqrt, common_options); + + // Mul + common_options.set("label", node.Name() + "_mul"); + output = model_builder.GetBuilder().call("mul", scale, div, common_options); + } } else if (op_type == "InstanceNormalization") { // WebNN spec only supports 4D input for instanceNormalization. // Supports 3D input by prepending 1 size dimension. @@ -182,7 +228,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -229,6 +276,7 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat "BatchNormalization", "InstanceNormalization", "LayerNormalization", + "SimplifiedLayerNormalization", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index 13dee667f6fd9..b71507a871bf6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -22,8 +22,8 @@ class QDQOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -35,85 +35,91 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector input_shape; std::vector scale_shape; + std::vector zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); int32_t input_type = 0; int32_t output_type = 0; int32_t zero_point_type = 0; + bool has_zero_point = false; ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input data type"); ORT_RETURN_IF_NOT(GetType(*output_defs[0], output_type, logger), "Cannot get output data type"); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val zero_point = emscripten::val::null(); + if (input_defs.size() == 3 && input_defs[2]->Exists()) { zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + has_zero_point = true; } else { // DequantizeLinear: x_zero_point's data type equals to input data type // QuantizeLinear: x_zero_point's data type equals to output data type zero_point_type = op_type == "DequantizeLinear" ? input_type : output_type; - zero_point = model_builder.GetZeroConstant(zero_point_type); } - emscripten::val output; + const auto input_rank = input_shape.size(); NodeAttrHelper helper(node); - int32_t axis = helper.Get("axis", 1); int32_t block_size = helper.Get("block_size", 0); - // axis is valid for input shape greater than 1D. - if (input_shape.size() > 1) { - axis = static_cast(HandleNegativeAxis(axis, input_shape.size())); + int32_t axis = helper.Get("axis", 1); + if (axis < 0) { + axis = SafeInt(HandleNegativeAxis(axis, input_rank)); } - // Insert ones before and after the axis dimension for broadcasting of 1D scale tensor. - if (1 == scale_shape.size() && 1 < input_shape.size()) { - std::vector target_shape{static_cast(input_shape[axis])}; + + // For per-axis quantization/dequantization and axis is not equal to input_rank - 1, + // we need to reshape the scale and zero_point tensors to make them broadcastable with the input tensor. + if (scale_shape.size() == 1 && input_rank > 1 && + block_size == 0 && axis != static_cast(input_rank - 1)) { + // Insert ones before and after the axis dimension for broadcasting of scale tensor. + std::vector target_shape{SafeInt(input_shape[axis])}; target_shape.insert(target_shape.begin(), axis, 1); - target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); + target_shape.insert(target_shape.end(), input_rank - axis - 1, 1); + // zero_point has the same shape as the scale tensor. + zero_point_shape = target_shape; emscripten::val reshape_scale_options = emscripten::val::object(); reshape_scale_options.set("label", node.Name() + "_reshape_scale"); scale = model_builder.GetBuilder().call("reshape", scale, emscripten::val::array(target_shape), reshape_scale_options); - emscripten::val reshape_zero_point_options = emscripten::val::object(); - reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); - zero_point = model_builder.GetBuilder().call("reshape", - zero_point, - emscripten::val::array(target_shape), - reshape_zero_point_options); - } - // If block_size is specified, we need to expand the scale and zero_point tensors. - if (block_size > 1) { - emscripten::val concat_scale_inputs = emscripten::val::array(); - emscripten::val concat_zero_point_inputs = emscripten::val::array(); - for (int i = 0; i < block_size; i++) { - concat_scale_inputs.call("push", scale); - concat_zero_point_inputs.call("push", zero_point); + if (has_zero_point) { + // Reshape the zero_point tensor too. + emscripten::val reshape_zero_point_options = emscripten::val::object(); + reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); + zero_point = model_builder.GetBuilder().call("reshape", + zero_point, + emscripten::val::array(target_shape), + reshape_zero_point_options); } + } - emscripten::val concat_scale_options = emscripten::val::object(); - concat_scale_options.set("label", node.Name() + "_concat_scale"); - scale = model_builder.GetBuilder().call("concat", concat_scale_inputs, axis, concat_scale_options); - - emscripten::val concat_zero_point_options = emscripten::val::object(); - concat_zero_point_options.set("label", node.Name() + "_concat_zero_point"); - zero_point = model_builder.GetBuilder().call( - "concat", concat_zero_point_inputs, axis, concat_zero_point_options); + // If zero_point is not provided, create a zero constant with the same shape as the scale tensor. + if (!has_zero_point) { + if (zero_point_shape.empty()) { + // zero_point has the same shape as the scale tensor. + zero_point_shape = GetVecUint32FromVecInt64(scale_shape); + } + // Create a zero constant with the same shape as the scale tensor. + // The zero value has been pre-processed in the CreateOrGetConstant function, + // so the type of T is not relevant here. + zero_point = model_builder.CreateOrGetConstant(zero_point_type, 0, zero_point_shape); } emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); std::string webnn_op_type; ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type"); - output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input, scale, zero_point, options); + emscripten::val output = + model_builder.GetBuilder().call(webnn_op_type.c_str(), input, scale, zero_point, options); model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); return Status::OK(); } -bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 3442afbc2b3cd..00f8cff25ccf5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -21,6 +21,8 @@ namespace webnn { class ResizeOpBuilder : public BaseOpBuilder { // Add operator related. public: + // Allow roi and scales potentially being empty inputs that are ignored during processing. + ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {} void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; private: @@ -267,15 +269,9 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } - // coordinate_transformation_mode - // Spec issue for supporting more coordinate transformation modes: - // https://github.com/webmachinelearning/webnn/issues/270 - const std::string coordinate_transformation_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); - if (coordinate_transformation_mode != "half_pixel") { - LOGS(logger, VERBOSE) << "Resize does not support coordinate_transformation_mode: " - << coordinate_transformation_mode; - return false; - } + // Ignore coordinate_transformation_mode because WebNN only supports half_pixel mode. + // TODO: Validate coordinate_transformation_mode. Related spec issue for supporting attribute coordinate + // transformation modes: https://github.com/webmachinelearning/webnn/issues/270 // exclude_outside const auto exclude_outside = helper.Get("exclude_outside", 0); diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc new file mode 100644 index 0000000000000..8c70525835059 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ScatterElementsOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status ScatterElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + const size_t rank = input_shape.size(); + NodeAttrHelper helper(node); + const uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 0), rank)); + options.set("axis", axis); + + emscripten::val output = + model_builder.GetBuilder().call("scatterElements", data, indices, updates, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("reduction", "none") != "none") { + LOGS(logger, VERBOSE) << "ScatterElements: WebNN only supports reduction type none (default)"; + return false; + } + + return true; +} + +bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& updates = *node.InputDefs()[2]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + int32_t updates_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) || + !GetType(updates, updates_type, logger)) { + return false; + } + + if (data_type != updates_type) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc new file mode 100644 index 0000000000000..8089b9706886f --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ScatterNDOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status ScatterNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = + model_builder.GetBuilder().call("scatterND", data, indices, updates, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("reduction", "none") != "none") { + LOGS(logger, VERBOSE) << "ScatterND: WebNN only supports reduction type none (default)"; + return false; + } + + return true; +} + +bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& updates = *node.InputDefs()[2]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + int32_t updates_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) || + !GetType(updates, updates_type, logger)) { + return false; + } + + if (data_type != updates_type) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 3f0d633ac888b..41c66038c2694 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -27,6 +27,8 @@ class SliceOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } }; @@ -40,8 +42,7 @@ void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const No } } -Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; @@ -49,9 +50,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto rank = input_shape.size(); NodeAttrHelper helper(node); - emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name()); - std::vector starts(rank); - std::vector sizes(rank); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); // Copy the data from the starts/ends/axes/steps initializers. std::vector input_starts; @@ -75,8 +74,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& initializers(model_builder.GetInitializerTensors()); const auto& tensor = *initializers.at(input_name); if (!ReadIntArrayFrom1DTensor(tensor, data, logger)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Data type for starts and ends inputs is not supported in this build."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type for starts and ends inputs is not supported in this build."); } return Status::OK(); @@ -88,28 +86,55 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_ERROR( SliceOp::PrepareForComputeHelper(input_starts, input_ends, input_axes, input_steps, compute_metadata)); - std::transform(compute_metadata.starts_.cbegin(), compute_metadata.starts_.cend(), - starts.begin(), - [](int64_t i) { return SafeInt(i); }); - std::transform(compute_metadata.ends_.cbegin(), compute_metadata.ends_.cend(), compute_metadata.starts_.cbegin(), - sizes.begin(), - [](int64_t i, int64_t j) { return SafeInt(i - j); }); + // Check if reverse op is needed. + std::vector reverse_axes; + emscripten::val reverse_output = input; + for (size_t i = 0; i < rank; ++i) { + if (compute_metadata.steps_[i] < 0) { + reverse_axes.push_back(SafeInt(i)); + compute_metadata.steps_[i] = -compute_metadata.steps_[i]; + compute_metadata.starts_[i] = input_shape[i] - 1 - compute_metadata.starts_[i]; + compute_metadata.ends_[i] = input_shape[i] - 1 - compute_metadata.ends_[i]; + } + } + if (!reverse_axes.empty()) { + emscripten::val reverse_options = emscripten::val::object(); + reverse_options.set("axes", emscripten::val::array(reverse_axes)); + reverse_options.set("label", node.Name() + "_reverse"); + reverse_output = model_builder.GetBuilder().call("reverse", input, reverse_options); + } - emscripten::val options = emscripten::val::object(); - options.set("label", node.Name()); - emscripten::val output = model_builder.GetBuilder().call("slice", inputs, - emscripten::val::array(starts), - emscripten::val::array(sizes), - options); + // Check if slice op is needed. + bool is_slice_required = false; + for (size_t i = 0; i < rank; ++i) { + if (compute_metadata.steps_[i] != 1 || compute_metadata.starts_[i] != 0 || + compute_metadata.ends_[i] != input_shape[i]) { + is_slice_required = true; + break; + } + } + + emscripten::val output = reverse_output; + if (is_slice_required) { + std::vector starts = GetVecUint32FromVecInt64(compute_metadata.starts_); + std::vector steps = GetVecUint32FromVecInt64(compute_metadata.steps_); + std::vector sizes(rank); + std::transform(compute_metadata.ends_.cbegin(), compute_metadata.ends_.cend(), compute_metadata.starts_.cbegin(), + sizes.begin(), [](int64_t i, int64_t j) { return SafeInt(i - j); }); + + emscripten::val options = emscripten::val::object(); + options.set("strides", emscripten::val::array(steps)); + options.set("label", node.Name()); + output = model_builder.GetBuilder().call("slice", reverse_output, emscripten::val::array(starts), + emscripten::val::array(sizes), options); + } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } -bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { +bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -129,39 +154,37 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, // Optional tensors (axes, steps) can be indicated by an empty name, just ignore it. const std::string input_name = GetTensorName(input_defs, i); if (!input_name.empty() && !Contains(initializers, input_name)) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of " << op_type - << " [" << name << "] must be known as initializer"; + LOGS(logger, VERBOSE) << "Input [" << input_name << "] of " << op_type << " [" << name + << "] must be known as initializer"; return false; } } - if (input_defs.size() == 5) { // Check steps. - const auto& steps_tensor = *initializers.at(input_defs[4]->Name()); - std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(steps_tensor, unpacked_tensor); - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Error while unpacking steps_tensor: " << status.ErrorMessage(); + return true; +} + +bool SliceOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& input = *input_defs[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + // If there is step < 0, check data type support of reverse. + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + std::vector steps; + if (!ReadIntArrayFrom1DTensor(*initializers.at(input_defs[4]->Name()), steps, logger)) return false; - } - const auto data_type = steps_tensor.data_type(); - // WebNN doesn't support steps other than 1. - if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { - if (!std::all_of(reinterpret_cast(unpacked_tensor.data()), - reinterpret_cast(unpacked_tensor.data() + unpacked_tensor.size()), - [](int64_t i) { return i == 1; })) { - return false; - } - } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) { - if (!std::all_of(reinterpret_cast(unpacked_tensor.data()), - reinterpret_cast(unpacked_tensor.data()) + - unpacked_tensor.size() / sizeof(int32_t), - [](int32_t i) { return i == 1; })) { + if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { return false; } } } - return true; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); } void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index 4c59b694d690a..db10720f72762 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -28,6 +28,8 @@ class SplitOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -73,8 +75,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Check that the splits evenly divide. if (split_count > 0 && splits.empty() && input_shape[axis] % split_count != 0) { // Divide inputs into variable size outputs: - splits.insert(splits.end(), split_count - 1, gsl::narrow(input_shape[axis]) / split_count); - splits.insert(splits.end(), gsl::narrow(input_shape[axis]) % split_count); + splits.insert(splits.end(), split_count - 1, narrow(input_shape[axis]) / split_count); + splits.insert(splits.end(), narrow(input_shape[axis]) % split_count); } if (splits.empty()) { @@ -163,6 +165,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool SplitOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output_defs = node.OutputDefs(); + const auto& op_type = node.OpType(); + int32_t output_type = 0; + + if (GetType(*output_defs[0], output_type, logger)) { + // Chromium has changed the output name of split from 'output' to 'outputs', + // to avoid breaking the existing API, we need to check both names. + std::string wnn_output_name = wnn_limits["split"]["output"].isUndefined() ? "outputs" : "output"; + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, wnn_output_name, "outputs", logger); + } + + return false; +} + void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 4b6cf312074ba..c7b3129c0c85b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,8 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -46,8 +46,8 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool TernaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // condition data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 8e64e98445f03..91af452c64efd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -51,6 +51,8 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const output = model_builder.GetBuilder().call("neg", input, options); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input, options); + } else if (op_type == "Sign") { + output = model_builder.GetBuilder().call("sign", input, options); } else if (op_type == "Sin") { output = model_builder.GetBuilder().call("sin", input, options); } else if (op_type == "Sqrt") { @@ -82,6 +84,7 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Log", "Neg", "Reciprocal", + "Sign", "Sin", "Sqrt", "Tan", diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 84f8cc4b14665..e8f116d390199 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -88,11 +88,15 @@ Status ModelBuilder::RegisterInitializers() { for (const auto& pair : GetInitializerTensors()) { const auto& tensor = *pair.second; const auto& name = tensor.name(); - // Optional tensors can be indicated by an empty name, just ignore it. - if (name.empty() || Contains(skipped_initializers_, name)) + const auto& shape = tensor.dims(); + + // Ignore the following tensors: + // 1. Empty tensors: optional tensors can be indicated by an empty name. + // 2. Tensors in skipped_initializers_: These are tensors that are not used as WebNN Constants. + // Note: Scalar tensors are excluded because ONNX Runtime will optimize same scalar initializers into one. + if (name.empty() || (Contains(skipped_initializers_, name) && !shape.empty())) continue; - const auto& shape = tensor.dims(); std::vector dims; // When the shape is empty, it is scalar initializer that dims = {}; std::transform(shape.cbegin(), shape.cend(), @@ -380,62 +384,6 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op wnn_operands_.insert(std::make_pair(name, operand)); } -// Get the zero scalar constant. -// Workaround for builer.constant(value, type) method since it has not been implemented now. -// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-value-type -// BTW, the spec is discussing if the builer.constant(value, type) should be dropped at -// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision. -const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) { - std::string name = "webnn_zero_constant_" + std::to_string(data_type); - // If the operand does not exist, create it. - if (wnn_operands_.find(name) == wnn_operands_.end()) { - emscripten::val desc = emscripten::val::object(); - emscripten::val dims = emscripten::val::array(); - desc.set("dimensions", dims); - desc.set("shape", dims); - emscripten::val zero_buffer = emscripten::val::undefined(); - if (!SetWebnnDataType(desc, data_type)) { - ORT_THROW("Unsupported data type: " + std::to_string(data_type)); - } - - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_INT4: - case ONNX_NAMESPACE::TensorProto_DataType_UINT4: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - zero_buffer = emscripten::val::global("Uint8Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - zero_buffer = emscripten::val::global("Int8Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - zero_buffer = emscripten::val::global("Uint16Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - zero_buffer = emscripten::val::global("Float32Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - zero_buffer = emscripten::val::global("Int32Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - zero_buffer = emscripten::val::global("BigInt64Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - zero_buffer = emscripten::val::global("Uint32Array").new_(1); - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - zero_buffer = emscripten::val::global("BigUint64Array").new_(1); - break; - default: - break; - } - - emscripten::val zero_constant = wnn_builder_.call("constant", desc, zero_buffer); - wnn_operands_.insert(std::make_pair(name, zero_constant)); - } - return wnn_operands_.at(name); -} - void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { skipped_initializers_.insert(tensor_name); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 13937933a0a9c..0fc2fa20670c7 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -11,6 +11,7 @@ #include "core/framework/execution_provider.h" #include "core/providers/webnn/builders/helper.h" +#include #include #include @@ -38,7 +39,11 @@ class ModelBuilder { const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } void AddOperand(const std::string& name, const emscripten::val& operand); - const emscripten::val& GetZeroConstant(const int32_t& data_type); + + template + const emscripten::val& CreateOrGetConstant(const int32_t& data_type, T value, + const std::vector& shape = {}); + // Use the buffers to persist WebNN allocated data like transposed weight. // It ensures the validity during inference session. std::vector> mem_persist_buffers_; @@ -98,5 +103,120 @@ class ModelBuilder { static const IOpBuilder* GetOpBuilder(const Node& node); }; +// Create or retrieve one of the following: +// - A WebNN constant MLOperand filled with the specified value, data type, and shape. +// - A WebNN scalar constant MLOperand with the specified value and data type. +// For scalar constant, it is workaround for builer.constant(type, value) method since +// it has not been implemented now. +// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value +// +// This function enforces a mapping between the data_type and the value types: +// - TensorProto_DataType_INT4 <-> int8_t +// - TensorProto_DataType_UINT4 <-> int8_t +// - TensorProto_DataType_BOOL <-> bool +// - TensorProto_DataType_UINT8 <-> uint8_t +// - TensorProto_DataType_INT8 <-> int8_t +// - TensorProto_DataType_FLOAT16 <-> float +// - TensorProto_DataType_FLOAT <-> float +// - TensorProto_DataType_INT32 <-> int32_t +// - TensorProto_DataType_INT64 <-> int64_t +// - TensorProto_DataType_UINT32 <-> uint32_t +// - TensorProto_DataType_UINT64 <-> uint64_t +template +const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_type, T value, + const std::vector& shape) { + std::string name = "webnn_constant_" + std::to_string(data_type) + "_" + std::to_string(value); + emscripten::val dims = emscripten::val::array(); + if (!shape.empty()) { + dims = emscripten::val::array(shape); + std::ostringstream name_stream; + name_stream << name; + for (const auto& dim : shape) { + name_stream << "_" << dim; + } + name = name_stream.str(); + } + + // If the operand does not exist, create it. + if (wnn_operands_.find(name) == wnn_operands_.end()) { + emscripten::val desc = emscripten::val::object(); + desc.set("shape", dims); + desc.set("dimensions", dims); + emscripten::val buffer = emscripten::val::undefined(); + if (!SetWebnnDataType(desc, data_type)) { + ORT_THROW("Unsupported data type: " + std::to_string(data_type)); + } + auto num_elements = Product(shape); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + // For WebNN int4 and uint4 tensors are stored in Uint8Array, + // so we need to adjust the number of elements. + num_elements = (num_elements + 1) / 2; + buffer = emscripten::val::global("Uint8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(PackInt8ToUint8AsNibble(value, data_type))); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + buffer = emscripten::val::global("Uint8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + buffer = emscripten::val::global("Int8Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + buffer = emscripten::val::global("Uint16Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value))); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + buffer = emscripten::val::global("Float32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + buffer = emscripten::val::global("Int32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + buffer = emscripten::val::global("Uint32Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + buffer = emscripten::val::global("BigInt64Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val::global("BigInt")(value)); + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + buffer = emscripten::val::global("BigUint64Array").new_(num_elements); + if (value) { + buffer.call("fill", emscripten::val::global("BigInt")(value)); + } + break; + default: + break; + } + + const emscripten::val constant = wnn_builder_.call("constant", desc, buffer); + wnn_operands_.insert(std::make_pair(name, constant)); + } + + return wnn_operands_.at(name); +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 8baa4790247ec..6d1c572128b93 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -26,6 +26,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Sign", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); CreateUnaryOpBuilder("Tan", op_registrations); @@ -80,6 +81,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateConcatOpBuilder("Concat", op_registrations); } + { // CumSum + CreateCumSumOpBuilder("CumSum", op_registrations); + } + { // Dropout CreateDropoutOpBuilder("Dropout", op_registrations); } @@ -90,6 +95,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations); } + { // Einsum + CreateEinsumOpBuilder("Einsum", op_registrations); + } + { // Expand CreateExpandOpBuilder("Expand", op_registrations); } @@ -98,6 +107,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGatherOpBuilder("Gather", op_registrations); } + { // GatherElements + CreateGatherElementsOpBuilder("GatherElements", op_registrations); + } + + { // GatherND + CreateGatherNDOpBuilder("GatherND", op_registrations); + } + { // Flatten CreateFlattenOpBuilder("Flatten", op_registrations); } @@ -113,12 +130,19 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { } { // Logical + CreateLogicalOpBuilder("And", op_registrations); CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); CreateLogicalOpBuilder("LessOrEqual", op_registrations); CreateLogicalOpBuilder("Not", op_registrations); + CreateLogicalOpBuilder("Or", op_registrations); + CreateLogicalOpBuilder("Xor", op_registrations); + } + + { // LRN + CreateLRNOpBuilder("LRN", op_registrations); } { // LSTM @@ -134,6 +158,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateNormalizationOpBuilder("BatchNormalization", op_registrations); CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); CreateNormalizationOpBuilder("LayerNormalization", op_registrations); + CreateNormalizationOpBuilder("SimplifiedLayerNormalization", op_registrations); } { // Pad @@ -170,6 +195,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateResizeOpBuilder("Resize", op_registrations); } + { // ScatterElements + CreateScatterElementsOpBuilder("ScatterElements", op_registrations); + } + + { // ScatterND + CreateScatterNDOpBuilder("ScatterND", op_registrations); + } + { // Shape CreateShapeOpBuilder("Shape", op_registrations); } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 990be04d42107..22bd6cd0cfa9f 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -26,14 +26,19 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -43,6 +48,8 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 4eef14dddecd3..2adf8339b4b66 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "core/common/common.h" #include "core/common/safeint.h" @@ -239,8 +240,8 @@ std::unique_ptr FuseActivation(const NodeUnit& node_un def.attributes = node_unit.GetNode().GetAttributes(); // use infinity as the default as that's what xnnpack uses if min/max are not set - float min = -INFINITY; - float max = INFINITY; + float min = -std::numeric_limits::infinity(); + float max = std::numeric_limits::infinity(); const auto& activation_type = activation.OpType(); if (activation_type == "Clip") { diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index 68b55030c7363..a3ff3b585ae45 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include "gemm.h" + +#include + #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" #include "core/providers/xnnpack/xnnpack_init.h" @@ -117,7 +120,6 @@ Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*ena } Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights*) { is_packed = false; @@ -141,8 +143,8 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, auto weights_cache = GetWeightsCache(); xnn_status status = xnn_status::xnn_status_uninitialized; struct xnn_operator* p = nullptr; - float foutput_min = clip_min_max_ ? clip_min_max_->first : -INFINITY; - float foutput_max = clip_min_max_ ? clip_min_max_->second : INFINITY; + float foutput_min = clip_min_max_ ? clip_min_max_->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max_ ? clip_min_max_->second : std::numeric_limits::infinity(); if (op_compute_type_ == OpComputeType::op_compute_type_fp32) { const float* bias_data = nullptr; if (C_matrix_exists_) { diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.h b/onnxruntime/core/providers/xnnpack/math/gemm.h index d632eef015f9a..954aab0698b9c 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.h +++ b/onnxruntime/core/providers/xnnpack/math/gemm.h @@ -23,7 +23,6 @@ class Gemm : protected GemmBase, public XnnpackKernel { static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index 71a11cb05d9af..f574238195ffd 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "matmul.h" +#include #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/xnnpack/xnnpack_init.h" @@ -78,7 +79,6 @@ MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ } Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*Not used*/) { is_packed = false; @@ -110,8 +110,8 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, xnn_weights_cache_t weight_cache = nullptr; #endif - float foutput_min = -INFINITY; - float foutput_max = INFINITY; + float foutput_min = -std::numeric_limits::infinity(); + float foutput_max = std::numeric_limits::infinity(); if (op_type_ == OpComputeType::op_compute_type_fp32) { status = xnn_create_fully_connected_nc_f32( shape_broadcast[0], // size_t input_channels, diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.h b/onnxruntime/core/providers/xnnpack/math/matmul.h index 31a8c36ad418b..188cc73189af5 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.h +++ b/onnxruntime/core/providers/xnnpack/math/matmul.h @@ -23,7 +23,6 @@ class MatMul : public XnnpackKernel { // Required for checking XNNpack restrictions on ORT side static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc index 1c8ed556e90d7..1fc941d9f52f6 100644 --- a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc @@ -33,8 +33,8 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, if (pool_attrs.auto_pad == AutoPadType::SAME_UPPER) { flags |= XNN_FLAG_TENSORFLOW_SAME_PADDING; } - float foutput_min = clip_min_max ? clip_min_max->first : -INFINITY; - float foutput_max = clip_min_max ? clip_min_max->second : INFINITY; + float foutput_min = clip_min_max ? clip_min_max->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max ? clip_min_max->second : std::numeric_limits::infinity(); xnn_status status = xnn_status_unsupported_parameter; if (avgpool_type == OpComputeType::op_compute_type_fp32) { status = xnn_create_average_pooling2d_nhwc_f32(input_padding_top, input_padding_right, diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index f2e697df475da..4e6b308e28ae5 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -18,7 +18,6 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.h b/onnxruntime/core/providers/xnnpack/nn/conv.h index 762b68c8bd49a..3630aae208d49 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv.h @@ -19,7 +19,6 @@ class Conv : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index e0723c0e7690e..458e6000c8d70 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -54,8 +54,8 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, xnn_status status = xnn_status::xnn_status_uninitialized; p = nullptr; - float foutput_min = clip_min_max ? clip_min_max->first : -INFINITY; - float foutput_max = clip_min_max ? clip_min_max->second : INFINITY; + float foutput_min = clip_min_max ? clip_min_max->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max ? clip_min_max->second : std::numeric_limits::infinity(); // with the following IC and OC number, we can cover depthwise and regular conv at the same time // the equation 'IC (group_input_channels) == C ' set up when group_count==1 (regular convolution) // and OC (group_output_channels) follows the same rule. diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index 5729565b2feb9..b6930a5fc92d1 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -15,7 +15,6 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h index 0313515d10fa1..866b9b6b98365 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h @@ -18,7 +18,6 @@ class ConvTranspose : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index 6742e51e55082..c828ae9400174 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -3,6 +3,8 @@ #include "max_pool.h" +#include + #include "core/graph/graph.h" #include "core/providers/utils.h" #include "core/providers/xnnpack/xnnpack_init.h" @@ -168,8 +170,8 @@ MaxPool::MaxPool(const OpKernelInfo& info) auto input_dtype = X_arg.TypeAsProto()->tensor_type().elem_type(); xnn_status status = xnn_status_invalid_state; struct xnn_operator* p = nullptr; - float foutput_min = clip_min_max_ ? clip_min_max_->first : -INFINITY; - float foutput_max = clip_min_max_ ? clip_min_max_->second : INFINITY; + float foutput_min = clip_min_max_ ? clip_min_max_->first : -std::numeric_limits::infinity(); + float foutput_max = clip_min_max_ ? clip_min_max_->second : std::numeric_limits::infinity(); if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { maxpool_type_ = OpComputeType::op_compute_type_fp32; status = xnn_create_max_pooling2d_nhwc_f32(input_padding_top, input_padding_right, diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 12e567e7080b3..ee4e7be0f1f49 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -258,6 +258,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> capabilities; std::shared_ptr registry = GetKernelRegistry(); @@ -268,7 +269,7 @@ std::vector> XnnpackExecutionProvider::GetCap // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit being checked for multiple times diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 5f929d3760a95..48213e3e3894a 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -117,8 +117,9 @@ Status Environment::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, co } // determine if arena should be used - const bool is_arena_requested = mem_info.alloc_type == OrtArenaAllocator; - const bool create_arena = ShouldCpuAllocatorUseArena(is_arena_requested); + const bool create_arena = DoesCpuAllocatorSupportArenaUsage() + ? (mem_info.alloc_type == OrtArenaAllocator) + : false; AllocatorPtr allocator_ptr; // create appropriate DeviceAllocatorRegistrationInfo and allocator based on create_arena diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e6aafaa1f2283..a60ee500a9898 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -370,86 +370,12 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, // a monotonically increasing session id for use in telemetry session_id_ = global_session_id_.fetch_add(1); -#ifdef _WIN32 - std::lock_guard lock(active_sessions_mutex_); - active_sessions_[global_session_id_++] = this; - - // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider - callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( - [this](LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - (void)SourceId; - (void)Level; - (void)MatchAnyKeyword; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - - // Check if this callback is for capturing state - if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && - ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { - LogAllSessions(); - } - }); - WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); - - // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this](LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - (void)SourceId; - (void)Level; - (void)MatchAnyKeyword; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - - if (logging_manager_ != nullptr) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && - IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - logging_manager_->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - } - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; - logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); - LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; - } - } - }); - - // Register callback for ETW capture state (rundown) - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); - -#endif - SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options, false); + TraceSessionOptions(session_options, false, *session_logger_); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -575,14 +501,97 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } telemetry_ = {}; + +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[session_id_] = this; + + // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider + callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( + [](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + InferenceSession::LogAllSessions(); + } + }); + WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); + + // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (logging_manager_ != nullptr) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && + IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + logging_manager_->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + } + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; + logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); + LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; + } + } + }); + + // Register callback for ETW capture state (rundown) + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); + +#endif } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger) { ORT_UNUSED_PARAMETER(captureState); // Otherwise Linux build error - LOGS(*session_logger_, INFO) << session_options; + LOGS(logger, INFO) << session_options; #ifdef _WIN32 + std::string optimized_model_filepath = ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath); + std::string profile_file_prefix = ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.profile_file_prefix); + TraceLoggingWrite(telemetry_provider_handle, "SessionOptions", TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), @@ -590,11 +599,11 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.execution_mode), "execution_mode"), TraceLoggingUInt8(static_cast(session_options.execution_order), "execution_order"), TraceLoggingBoolean(session_options.enable_profiling, "enable_profiling"), - TraceLoggingString(ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath).c_str(), "optimized_model_filepath"), + TraceLoggingString(optimized_model_filepath.c_str(), "optimized_model_filepath"), TraceLoggingBoolean(session_options.enable_mem_pattern, "enable_mem_pattern"), TraceLoggingBoolean(session_options.enable_mem_reuse, "enable_mem_reuse"), TraceLoggingBoolean(session_options.enable_cpu_mem_arena, "enable_cpu_mem_arena"), - TraceLoggingString(ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.profile_file_prefix).c_str(), "profile_file_prefix"), + TraceLoggingString(profile_file_prefix.c_str(), "profile_file_prefix"), TraceLoggingString(session_options.session_logid.c_str(), "session_logid"), TraceLoggingInt8(static_cast(session_options.session_log_severity_level), "session_log_severity_level"), TraceLoggingInt8(static_cast(session_options.session_log_verbosity_level), "session_log_verbosity_level"), @@ -726,10 +735,14 @@ InferenceSession::~InferenceSession() { // Unregister the session and ETW callbacks #ifdef _WIN32 std::lock_guard lock(active_sessions_mutex_); - WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_); - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + if (callback_ML_ORT_provider_ != nullptr) { + WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_); + } + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif - active_sessions_.erase(global_session_id_); + active_sessions_.erase(session_id_); #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) @@ -1631,7 +1644,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( level <= static_cast(session_options.graph_optimization_level); ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( - static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, + static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, logger, optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { @@ -1653,6 +1666,23 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) { } } } + +// This function is called when the session is being initialized. +// For now, this function only checks for invalid combination of DML EP with other EPs. +// TODO: extend this function to check for other invalid combinations of EPs. +common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const { + // DML EP is only allowed with CPU EP + bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr; + if (has_dml_ep) { + const auto& ep_list = execution_providers_.GetIds(); + for (const auto& ep : ep_list) { + if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue; + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP."); + } + } + return Status::OK(); +} + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) // VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong. @@ -1710,6 +1740,11 @@ common::Status InferenceSession::Initialize() { execution_providers_.SetCpuProviderWasImplicitlyAdded(true); } + // Check for the presence of an invalid combination of execution providers in the session + // For e.g. we don't support DML EP and other GPU EPs to be present in the same session + // This check is placed here because it serves as a common place for all language bindings. + ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders()); + // re-acquire mutex std::lock_guard l(session_mutex_); @@ -1805,7 +1840,8 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_, session_options_.graph_optimization_level, minimal_build_optimization_handling, - record_runtime_optimization_produced_op_schema)); + record_runtime_optimization_produced_op_schema, + *session_logger_)); #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); @@ -2027,11 +2063,9 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } - SessionState::PrePackInitializers pre_packed_initializers; ORT_RETURN_IF_ERROR_SESSIONID_( session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, // need to keep the initializers if saving the optimized model - pre_packed_initializers, !saving_model, saving_ort_format)); @@ -2067,47 +2101,11 @@ common::Status InferenceSession::Initialize() { kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "1024")); Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; - bool save_prepacked_constant_initializers = - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsSavePrePackedConstantInitializers, "0") == "1" ? true : false; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - if (save_prepacked_constant_initializers) { - LOGS(*session_logger_, WARNING) << "Serialize prepacked initializers option has been turn on." - << "Use this option only when run model inference on PC with CPU." - << "Make sure to save and load model in same device as prepack is device specific." - << "Note: this feature in only work with ONNX model format." - << "Process of use this option is like below:" - << "1. Optimize model with external data file with save_prepacked_constant_initializers on:" - << " sample: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', ' 1 ')" - << " With save_prepacked_constant_initializers option, prepacked initializer will be serialized into data file." - << "2. Load optimized model and external data file in same device, no prepack is need." - << "3. Run inference with optimized model."; - - if (fbs::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)) { - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Unable to serialize prepacked external constant initializer for ORT format model." - "Please use ONNX format model with save_prepacked_constant_initializers.")); - } - - // convert pre_packed_initializers to tensorproto format and save to external data file - for (const auto& name_item_pair : pre_packed_initializers.pre_packed_initializers_to_save) { - auto initializer_name = name_item_pair.first; - - for (const auto& kernel_name_initializer_item_pair : name_item_pair.second) { - auto kernel_name = kernel_name_initializer_item_pair.first; - auto prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer_name, kernel_name); - - pre_packed_initializers_tensor_proto[initializer_name][kernel_name] = utils::TensorToTensorProto(kernel_name_initializer_item_pair.second, prepacked_initializer_name); - } - } - } ORT_RETURN_IF_ERROR_SESSIONID_(Model::SaveWithExternalInitializers(*model_, session_options_.optimized_model_filepath, optimized_model_external_initializers_file_name, optimized_model_external_initializers_min_size_in_bytes, - align_info, - save_prepacked_constant_initializers, - pre_packed_initializers_tensor_proto)); + align_info)); } } } @@ -2115,7 +2113,7 @@ common::Status InferenceSession::Initialize() { std::vector tuning_results; bool found_tuning_results = false; ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( - model_metadata_, tuning_results, found_tuning_results)); + model_metadata_, tuning_results, found_tuning_results, *session_logger_)); if (found_tuning_results) { ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results, /*error_on_invalid*/ false, /*auto_enable*/ true)); } @@ -3236,7 +3234,8 @@ common::Status InferenceSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); @@ -3248,7 +3247,7 @@ common::Status InferenceSession::AddPredefinedTransformers( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; if (use_full_build_optimizations) { - return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, + return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); @@ -3260,6 +3259,7 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, + logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); @@ -3313,14 +3313,21 @@ void InferenceSession::LogAllSessions() { for (const auto& session_pair : active_sessions_) { InferenceSession* session = session_pair.second; - onnxruntime::Graph& graph = model_->MainGraph(); - bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); - env.GetTelemetryProvider().LogSessionCreation( - session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), - graph.DomainToVersionMap(), graph.Name(), model_->MetaData(), - telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs, true); + if (!session) { + continue; + } + + auto model = session->model_; + if (nullptr != model) { + onnxruntime::Graph& graph = model->MainGraph(); + bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); + env.GetTelemetryProvider().LogSessionCreation( + session->session_id_, model->IrVersion(), model->ProducerName(), model->ProducerVersion(), model->Domain(), + graph.DomainToVersionMap(), graph.Name(), model->MetaData(), + session->telemetry_.event_name_, session->execution_providers_.GetIds(), model_has_fp16_inputs, true); + } - TraceSessionOptions(session->session_options_, true); + InferenceSession::TraceSessionOptions(session->session_options_, true, *session->session_logger_); } } #endif diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 424248da793f1..e28ff75345785 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -620,7 +620,7 @@ class InferenceSession { const Environment& session_env); void ConstructorCommon(const SessionOptions& session_options, const Environment& session_env); - + [[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const; [[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model); #if !defined(ORT_MINIMAL_BUILD) @@ -663,7 +663,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options, bool captureState); + static void TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; @@ -690,8 +690,9 @@ class InferenceSession { * If we encounter an invalid request, we return an error * back to the user. */ - [[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list, - /*out*/ InlinedVector& arenas_to_shrink) const; + [[nodiscard]] common::Status ValidateAndParseShrinkArenaString( + const std::string& ort_device_list, + /*out*/ InlinedVector& arenas_to_shrink) const; /* * Performs the shrinkage of arenas requested to be shrunk by the user @@ -700,7 +701,7 @@ class InferenceSession { void ShrinkMemoryArenas(gsl::span arenas_to_shrink); #ifdef _WIN32 - void LogAllSessions(); + static void LogAllSessions(); #endif #if !defined(ORT_MINIMAL_BUILD) @@ -708,7 +709,8 @@ class InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const; common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format); diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index 3436eebda3819..8b9de0c604441 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -236,7 +236,8 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, std::vector& results, - bool& key_found) { + bool& key_found, + const logging::Logger& logger) { results.clear(); key_found = false; auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); @@ -245,7 +246,7 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met } key_found = true; - LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model"; + LOGS(logger, INFO) << "Found tuning results in the model file to be used while loading the model"; Status status; ORT_TRY { diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index a0bcdb9013bf0..f297d928f8a0d 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -19,7 +19,9 @@ using json = nlohmann::json; #endif namespace onnxruntime { - +namespace logging { +class Logger; +} namespace inference_session_utils { // need this value to be accessible in all builds in order to report error for attempted usage in a minimal build @@ -60,7 +62,8 @@ class JsonConfigParser { Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, /*out*/ std::vector& results, - /*out*/ bool& key_found); + /*out*/ bool& key_found, + const logging::Logger& logger); #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc index a095027a1d417..466edce187a56 100644 --- a/onnxruntime/core/session/lora_adapters.cc +++ b/onnxruntime/core/session/lora_adapters.cc @@ -4,9 +4,10 @@ #include "core/session/lora_adapters.h" #include "lora/adapter_format_utils.h" +#include + #include "core/framework/data_transfer.h" #include "core/framework/error_code_helper.h" -#include "core/framework/execution_provider.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/allocator_adapters.h" #include "core/session/ort_apis.h" @@ -15,15 +16,6 @@ #include "core/providers/cuda/cuda_provider_factory.h" #endif -#ifdef USE_DML -#include "core/session/abi_session_options_impl.h" -#include "core/providers/dml/dml_provider_factory_creator.h" -#include "core/providers/dml/dml_provider_factory.h" -#endif - -#include -#include - namespace onnxruntime { #ifdef USE_CUDA @@ -58,56 +50,28 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) { InitializeParamsValues(); } -namespace { -struct DataTransfer { - std::unique_ptr ep; +static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_info) { std::unique_ptr data_transfer; - Status CopyTensor(const Tensor& src, Tensor& dst) const { - return data_transfer->CopyTensor(src, dst); - } - Status Sync() const { -#if USE_DML - return ep->Sync(); -#else - return Status::OK(); -#endif - } -}; -} // namespace -static Status GetDataTransfer(const OrtMemoryInfo& mem_info, [[maybe_unused]] DataTransfer& dt) { - ORT_RETURN_IF(strcmp(mem_info.name, onnxruntime::CPU) == 0, "Expecting on device allocator for LoraAdapter"); + if (strcmp(mem_info.name, onnxruntime::CPU) == 0) { + return data_transfer; + } - Status status; if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) { #ifdef USE_CUDA auto* cuda_provider_info = TryGetProviderInfo_CUDA(); if (cuda_provider_info != nullptr) { - dt.data_transfer = cuda_provider_info->CreateGPUDataTransfer(); - } else { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider could not be loaded"); + data_transfer = cuda_provider_info->CreateGPUDataTransfer(); } -#else - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider is not enabled in this build"); -#endif - } else if (strcmp(mem_info.name, onnxruntime::DML) == 0) { -#ifdef USE_DML - auto ep_factory = onnxruntime::DMLProviderFactoryCreator::Create(ConfigOptions{}, 0, false, false, false); - dt.ep = ep_factory->CreateProvider(); - dt.data_transfer = dt.ep->GetDataTransfer(); -#else - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DML provider is not enabled in this build"); #endif - } else { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported device allocator"); } - return status; + return data_transfer; } static Status CreateOrtValueOnDevice(const OrtValue& ort_value_mapped, const AllocatorPtr& device_allocator, - const DataTransfer& data_transfer, + const IDataTransfer& data_transfer, OrtValue& out) { OrtValue result; const auto& src = ort_value_mapped.Get(); @@ -123,9 +87,12 @@ void LoraAdapter::InitializeParamsValues() { ORT_THROW("Adapter is not loaded yet."); } - DataTransfer data_transfer; + std::unique_ptr data_transfer; if (device_allocator_) { - ORT_THROW_IF_ERROR(GetDataTransfer(device_allocator_->Info(), data_transfer)); + data_transfer = GetDataTransfer(device_allocator_->Info()); + if (data_transfer == nullptr) { + ORT_THROW("Data transfer is not available for the specified device allocator, it also must not be a CPU allocator"); + } } const auto* params = adapter_->parameters(); @@ -133,12 +100,12 @@ void LoraAdapter::InitializeParamsValues() { std::unordered_map params_values; params_values.reserve(params->size()); // Re-work in two separate loops due to compiler issues - if (device_allocator_) { + if (data_transfer) { for (const auto* param : *params) { auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param); OrtValue ort_value_ondevice; ORT_THROW_IF_ERROR(CreateOrtValueOnDevice(ort_value, device_allocator_, - data_transfer, ort_value_ondevice)); + *data_transfer, ort_value_ondevice)); Param lora_param(std::move(ort_value), std::move(ort_value_ondevice)); params_values.emplace(std::move(name), std::move(lora_param)); } @@ -150,10 +117,6 @@ void LoraAdapter::InitializeParamsValues() { } } - if (device_allocator_) { - ORT_THROW_IF_ERROR(data_transfer.Sync()); - } - params_values_.swap(params_values); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 109445c877786..ca6950af0227a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2803,12 +2803,15 @@ static constexpr OrtApi ort_api_1_to_21 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + // End of Version 19 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::CreateLoraAdapter, &OrtApis::CreateLoraAdapterFromArray, &OrtApis::ReleaseLoraAdapter, &OrtApis::RunOptionsAddActiveLoraAdapter, &OrtApis::SetEpDynamicOptions, + // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2840,6 +2843,8 @@ static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_OpenVINO_V2) / sizeof(void*) == 275, "Size of version 17 API cannot change"); static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeof(void*) == 279, "Size of version 18 API cannot change"); +// no additions in version 19 +static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.21.0", diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index ef84875df18a3..335ebbf203e7c 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -17,6 +17,14 @@ using namespace onnxruntime; using namespace onnxruntime::logging; +#ifdef USE_WEBGPU +namespace onnxruntime { +namespace webgpu { +void CleanupWebGpuContexts(); +} // namespace webgpu +} // namespace onnxruntime +#endif + std::unique_ptr OrtEnv::p_instance_; int OrtEnv::ref_count_ = 0; std::mutex OrtEnv::m_; @@ -26,6 +34,10 @@ OrtEnv::OrtEnv(std::unique_ptr value1) } OrtEnv::~OrtEnv() { +#ifdef USE_WEBGPU + webgpu::CleanupWebGpuContexts(); +#endif + // We don't support any shared providers in the minimal build yet #if !defined(ORT_MINIMAL_BUILD) UnloadSharedProviders(); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2c4bffa4fb79f..1444c1976d447 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -279,8 +279,9 @@ struct ProviderHostImpl : ProviderHost { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) override { - return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) override { + return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } @@ -496,6 +497,7 @@ struct ProviderHostImpl : ProviderHost { void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_name(value); } void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) override { return p->set_type(value); } ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) override { return p->add_tensors(); } + std::string* AttributeProto__release_s(ONNX_NAMESPACE::AttributeProto* p) override { return p->release_s(); } // GraphProto (wrapped) std::unique_ptr GraphProto__construct() override { return std::make_unique(); } @@ -1057,8 +1059,8 @@ struct ProviderHostImpl : ProviderHost { } std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) override { - return QDQ::GetAllNodeUnits(*graph_viewer); + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) override { + return QDQ::GetAllNodeUnits(*graph_viewer, logger); } // Model (wrapped) @@ -1156,8 +1158,8 @@ struct ProviderHostImpl : ProviderHost { // GraphViewer (wrapped) void GraphViewer__operator_delete(GraphViewer* p) override { delete p; } - std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger) override { - return std::make_unique(graph_viewer->Name(), true, ModelMetaData(), PathString(), + std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) override { + return std::make_unique(graph_viewer->Name(), true, metadata, PathString(), #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), #else @@ -1212,6 +1214,7 @@ struct ProviderHostImpl : ProviderHost { GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order)); } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } + IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const override { return p->GetSchemaRegistry(); } // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } @@ -2240,7 +2243,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, new_tensorrt_options.trt_ep_context_file_path = (context_cache_path.size() == 0) ? nullptr : context_cache_path.c_str(); LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path; - embed_mode = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + embed_mode = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0"); if ("1" == embed_mode) { new_tensorrt_options.trt_ep_context_embed_mode = 1; } else if ("0" == embed_mode) { diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 8c512c561ea8c..7fb518cdc05ca 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -155,11 +155,21 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, status = create_not_supported_status(); #endif } else if (strcmp(provider_name, "VitisAI") == 0) { +#ifdef USE_VITISAI status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys); +#else + status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "CoreML") == 0) { +#if defined(USE_COREML) + options->provider_factories.push_back(CoreMLProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); +#endif } else { ORT_UNUSED_PARAMETER(options); status = OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' and 'AZURE'"); + "Unknown provider name. Currently supported values are 'OPENVINO', 'SNPE', 'XNNPACK', 'QNN', 'WEBNN' ,'CoreML', and 'AZURE'"); } return status; @@ -205,15 +215,6 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, } #endif -#ifndef USE_TVM -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tvm, - _In_ OrtSessionOptions* options, _In_ const char* settings) { - ORT_UNUSED_PARAMETER(options); - ORT_UNUSED_PARAMETER(settings); - return CreateNotEnabledStatus("Tvm"); -} -#endif - #ifdef __cplusplus } #endif diff --git a/onnxruntime/core/session/standalone_op_invoker.cc b/onnxruntime/core/session/standalone_op_invoker.cc index 9cbf01946e92b..2706448d831cc 100644 --- a/onnxruntime/core/session/standalone_op_invoker.cc +++ b/onnxruntime/core/session/standalone_op_invoker.cc @@ -314,7 +314,8 @@ class StandAloneKernelContext : public OpKernelContext { AllocatorPtr allocator_; }; // StandAloneKernelContext -onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, OrtOpAttr** op_attr) { +onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, + OrtOpAttr** op_attr) { auto attr = std::make_unique(); onnxruntime::Status status = onnxruntime::Status::OK(); attr->set_name(std::string{name}); @@ -410,7 +411,9 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info, node_ptr->SetSinceVersion(version); - auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, &kernel_create_info); + auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, + logging::LoggingManager::DefaultLogger(), // no other logger available + &kernel_create_info); ORT_RETURN_IF_ERROR(status); auto& kernel_def = kernel_create_info->kernel_def; diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index f4f10dc4b4b97..d05fba192820a 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -26,6 +26,8 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice: return C.OrtDevice.cpu() elif device_type == "dml": return C.OrtDevice.dml() + elif device_type == "webgpu": + return C.OrtDevice.webgpu() elif device_type == "ort": return C.get_ort_device(device_index).device_type() else: diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index ebb1a54facbeb..92396bb09bd4c 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -226,7 +226,7 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get()); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_device_guid, dml_device.Get())); - context = wil::MakeOrThrow(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true); + context = wil::MakeOrThrow(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true, true); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, context.Get())); } @@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetDmlToHostMemCpyFunction() { static std::unordered_map map{ - {OrtDevice::GPU, DmlToCpuMemCpy}}; + {OrtDevice::DML, DmlToCpuMemCpy}}; return ↦ } diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 18785cd607eaa..6a57fc5f900ae 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -96,16 +96,22 @@ void addOrtValueMethods(pybind11::module& m) { // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); -#elif USE_DML - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML - CreateGenericMLValue( - nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); #else - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); +#endif + } else if (device.Type() == OrtDevice::DML) { +#if USE_DML + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML + CreateGenericMLValue( + nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); +#else + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); #endif } else if (device.Type() == OrtDevice::NPU) { #ifdef USE_CANN @@ -116,9 +122,9 @@ void addOrtValueMethods(pybind11::module& m) { CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCannMemCpy); #else - throw std::runtime_error( - "Can't allocate memory on the CANN device using this package of OnnxRuntime. " - "Please use the CANN package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the CANN device using this package of OnnxRuntime. " + "Please use the CANN package of OnnxRuntime to use this feature."); #endif } else { throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); @@ -160,19 +166,24 @@ void addOrtValueMethods(pybind11::module& m) { } onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToRocmMemCpy); -#elif USE_DML + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToRocmMemCpy); +#else + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); +#endif + } else if (device.Type() == OrtDevice::DML) { +#if USE_DML onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToDmlMemCpy); + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToDmlMemCpy); #else - throw std::runtime_error( - "Unsupported GPU device: Cannot find the supported GPU device."); + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); #endif } else { throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device"); diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 1319e8f6fe959..958da26f4faf0 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -69,11 +69,14 @@ void addGlobalSchemaFunctions(pybind11::module& m) { #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), #endif +#ifdef USE_VSINPU + onnxruntime::VSINPUProviderFactoryCreator::Create(), +#endif #ifdef USE_RKNPU onnxruntime::RknpuProviderFactoryCreator::Create(), #endif #ifdef USE_COREML - onnxruntime::CoreMLProviderFactoryCreator::Create(0), + onnxruntime::CoreMLProviderFactoryCreator::Create(ProviderOptions{}), #endif #ifdef USE_XNNPACK onnxruntime::XnnpackProviderFactoryCreator::Create(ProviderOptions{}, nullptr), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7af659851e4f8..9d544c0cee9ed 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) { case OrtDevice::CPU: return CPU; case OrtDevice::GPU: -#ifdef USE_DML - return DML; -#else return CUDA; -#endif + case OrtDevice::DML: + return DML; case OrtDevice::FPGA: return "FPGA"; case OrtDevice::NPU: @@ -1127,16 +1125,6 @@ std::unique_ptr CreateExecutionProviderInstance( LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met."; } } -#endif - } else if (type == kTvmExecutionProvider) { -#if USE_TVM - onnxruntime::tvm::TvmEPOptions info{}; - const auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - info = onnxruntime::tvm::TvmEPOptionsHelper::FromProviderOptions(it->second); - } - - return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider(); #endif } else if (type == kVitisAIExecutionProvider) { #ifdef USE_VITISAI @@ -1192,6 +1180,10 @@ std::unique_ptr CreateExecutionProviderInstance( const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry( kOrtSessionOptionsConfigNnapiEpPartitioningStopOps); return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider(); +#endif + } else if (type == kVSINPUExecutionProvider) { +#ifdef USE_VSINPU + return onnxruntime::VSINPUProviderFactoryCreator::Create()->CreateProvider(); #endif } else if (type == kRknpuExecutionProvider) { #ifdef USE_RKNPU @@ -1224,6 +1216,9 @@ std::unique_ptr CreateExecutionProviderInstance( if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) { coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM; } + } else { + // read from provider_options + return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); } } @@ -1579,7 +1574,8 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def_static("cann", []() { return OrtDevice::NPU; }) .def_static("fpga", []() { return OrtDevice::FPGA; }) .def_static("npu", []() { return OrtDevice::NPU; }) - .def_static("dml", []() { return OrtDevice::GPU; }) + .def_static("dml", []() { return OrtDevice::DML; }) + .def_static("webgpu", []() { return OrtDevice::GPU; }) .def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; }); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 225931533615d..995341b0f8dc0 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -24,7 +24,7 @@ struct OrtStatus { char msg[1]; // a null-terminated string }; -#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_OPENVINO BACKEND_TVM BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML BACKEND_CANN +#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_OPENVINO BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML BACKEND_CANN BACKEND_WEBGPU #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/providers.h" #include "core/providers/provider_factory_creators.h" @@ -75,12 +75,6 @@ struct OrtStatus { #define BACKEND_OPENVINO "" #endif -#ifdef USE_TVM -#define BACKEND_TVM "-TVM" -#else -#define BACKEND_TVM "" -#endif - #if USE_OPENBLAS #define BACKEND_OPENBLAS "-OPENBLAS" #else @@ -111,6 +105,12 @@ struct OrtStatus { #define BACKEND_CANN "" #endif +#if USE_WEBGPU +#define BACKEND_WEBGPU "-WEBGPU" +#else +#define BACKEND_WEBGPU "" +#endif + #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" @@ -135,9 +135,6 @@ extern std::string openvino_device_type; } } // namespace onnxruntime #endif -#ifdef USE_TVM -#include "core/providers/tvm/tvm_ep_options.h" -#endif #ifdef USE_ACL #include "core/providers/acl/acl_provider_factory.h" #endif @@ -438,15 +435,12 @@ std::shared_ptr CreateExecutionProviderFactory_MIGrap std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); std::shared_ptr CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Dnnl(const OrtDnnlProviderOptions* params); -#ifdef USE_TVM -std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); -std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); -#endif std::shared_ptr CreateExecutionProviderFactory_ACL(bool enable_fast_math); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); +std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_CoreML(uint32_t flags); constexpr const char* kDefaultExecutionProviderEntry = "GetProvider"; diff --git a/onnxruntime/python/providers/tvm/__init__.py b/onnxruntime/python/providers/tvm/__init__.py deleted file mode 100644 index 4bcbc0bfef586..0000000000000 --- a/onnxruntime/python/providers/tvm/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -JIT interface implementing packed functions that -import and compile frontend models -""" -from .ort import ANSOR_TYPE, AUTO_TVM_TYPE, onnx_compile # noqa: F401 diff --git a/onnxruntime/python/providers/tvm/extend_python_file.py b/onnxruntime/python/providers/tvm/extend_python_file.py deleted file mode 100644 index 65902619f8150..0000000000000 --- a/onnxruntime/python/providers/tvm/extend_python_file.py +++ /dev/null @@ -1,54 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import argparse -import textwrap - - -def rewrite_target_file(target): - with open(target, "a") as f: - f.write( - textwrap.dedent( - """ - import warnings - - try: - # This import is necessary in order to delegate the loading of libtvm.so to TVM. - import tvm - except ImportError as e: - warnings.warn( - f"WARNING: Failed to import TVM, libtvm.so was not loaded. More details: {e}" - ) - try: - # Working between the C++ and Python parts in TVM EP is done using the PackedFunc and - # Registry classes. In order to use a Python function in C++ code, it must be registered in - # the global table of functions. Registration is carried out through the JIT interface, - # so it is necessary to call special functions for registration. - # To do this, we need to make the following import. - import onnxruntime.providers.tvm - except ImportError as e: - warnings.warn( - f"WARNING: Failed to register python functions to work with TVM EP. More details: {e}" - ) - """ - ) - ) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--target_file", - type=str, - required=True, - help="Path to the file to be expanded.", - ) - args = parser.parse_args() - rewrite_target_file(args.target_file) - - -if __name__ == "__main__": - main() diff --git a/onnxruntime/python/providers/tvm/ort.py b/onnxruntime/python/providers/tvm/ort.py deleted file mode 100644 index be6d23f39c532..0000000000000 --- a/onnxruntime/python/providers/tvm/ort.py +++ /dev/null @@ -1,140 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import collections -import copy -import logging -import os - -import onnx -import tvm -from tvm import auto_scheduler, autotvm, relay -from tvm.contrib import graph_executor -from tvm.relay import vm - -log = logging.getLogger("tvm_ep") - -ANSOR_TYPE = "Ansor" -AUTO_TVM_TYPE = "AutoTVM" - - -@tvm.register_func("tvm_onnx_import_and_compile") -def onnx_compile( - model_string, - model_path, - executor, - target, - target_host, - opt_level, - opset, - freeze_params, - input_shapes, - nhwc=False, - tuning_logfile="", - tuning_type=AUTO_TVM_TYPE, -): - def get_tvm_executor(irmod, executor, target, params): - if executor == "vm": - log.info("Build TVM virtual machine") - lib = vm.compile( - copy.deepcopy(irmod), - target, - params=params, - ) - elif executor == "graph": - log.info("Build TVM graph executor") - lib = relay.build(irmod, target=target, params=params) - else: - log.error(f'Executor type {executor} is unsupported. Only "vm" and "graph" types are supported') - return None - return lib - - model = onnx.load_model_from_string(bytes(model_string)) - if model_path: - base_dir = os.path.dirname(os.path.abspath(model_path)) - onnx.load_external_data_for_model(model, base_dir) - - # Collect only feed input names from all input names - all_input_names = [node.name for node in model.graph.input] - all_initializer = [node.name for node in model.graph.initializer] - net_feed_input_names = list(set(all_input_names) - set(all_initializer)) - - # Match names and input shapes - all_input_mapping = [(name, shape) for (name, shape) in zip(all_input_names, input_shapes)] - # Using an ordereddict maintains input ordering. - shape_dict = collections.OrderedDict(all_input_mapping) - # Get only feed input pairs - feed_shape_dict = {} - for name in net_feed_input_names: - feed_shape_dict[name] = shape_dict[name] - - irmod, params = relay.frontend.from_onnx(model, feed_shape_dict, opset=opset, freeze_params=freeze_params) - irmod = relay.transform.DynamicToStatic()(irmod) - - # Tuning file can be set by client through ep options - if not tuning_logfile: - tuning_logfile = os.getenv("AUTOTVM_TUNING_LOG") - lib = None - tvm_target = tvm.target.Target(target, host=target_host) - if tuning_logfile: - if tuning_type == ANSOR_TYPE: - desired_layouts = { - "nn.conv2d": ["NHWC", "default"], - "nn.conv2d_transpose": ["NHWC", "default"], - "nn.upsampling": ["NHWC", "default"], - "vision.roi_align": ["NHWC", "default"], - } - log.info("Use tuning file from %s: %s", ANSOR_TYPE, tuning_logfile) - with auto_scheduler.ApplyHistoryBest(tuning_logfile): # noqa: SIM117 - with tvm.transform.PassContext( - opt_level=opt_level, - config={ - "relay.backend.use_auto_scheduler": True, - "relay.FuseOps.max_depth": 30, - }, - ): - if nhwc: - seq = tvm.transform.Sequential( - [ - relay.transform.InferType(), - relay.transform.ConvertLayout(desired_layouts), - relay.transform.EliminateCommonSubexpr(), - relay.transform.FoldConstant(), - ] - ) - irmod = seq(irmod) - lib = get_tvm_executor(irmod, executor, tvm_target, params) - elif tuning_type == AUTO_TVM_TYPE: - with relay.build_config(opt_level=opt_level): - log.info("Use tuning file from %s: %s", AUTO_TVM_TYPE, tuning_logfile) - with autotvm.apply_history_best(tuning_logfile): - lib = get_tvm_executor(irmod, executor, tvm_target, params) - else: - log.error( - f"Tuning log type {tuning_type} is unsupported. " - f"Only {ANSOR_TYPE} and {AUTO_TVM_TYPE} types are supported" - ) - return None - else: - with tvm.transform.PassContext(opt_level=opt_level): - lib = get_tvm_executor(irmod, executor, tvm_target, params) - - if lib is None: - return None - - ctx = tvm.device(target, 0) - if executor == "vm": - m = tvm.runtime.vm.VirtualMachine(lib, ctx) - elif executor == "graph": - m = graph_executor.GraphModule(lib["default"](ctx)) - else: - print( - f"ERROR: Executor type {executor} is unsupported. ", - 'Only "vm" and "graph" types are supported', - ) - return None - - return m.module diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 9d397499d45a4..712e15a6a1ca9 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -10,6 +10,7 @@ from .quantize import DynamicQuantConfig # noqa: F401 from .quantize import QuantizationMode # noqa: F401 from .quantize import StaticQuantConfig # noqa: F401 +from .quantize import get_qdq_config # noqa: F401 from .quantize import quantize # noqa: F401 from .quantize import quantize_dynamic # noqa: F401 from .quantize import quantize_static # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index b20af5137d206..6235db3234d49 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -19,9 +19,10 @@ from .calibrate import TensorData from .onnx_model import ONNXModel from .quant_utils import ( + DEQUANT_OP_NAME, ONNX_TYPE_TO_NP_TYPE, + QUANT_OP_NAME, TENSOR_NAME_QUANT_SUFFIX, - QuantType, find_by_name, model_has_infer_metadata, normalize_axis, @@ -40,18 +41,26 @@ def __init__(self, **data: Dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if not isinstance(v, (int, str, np.ndarray)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") + if k == "axis" and not isinstance(v, int) and v is not None: + raise TypeError(f"Axis value must be an int or None, not {type(v)}.") if k == "scale" and v.dtype not in (np.float32, np.float16): raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}") self.data[k] = v + def get(self, key, default_value=None): + return self.data.get(key, default_value) + def __iter__(self): yield from self.data def __getitem__(self, key): return self.data[key] + def __setitem__(self, key, value): + self.data[key] = value + def __len__(self): return len(self.data) @@ -88,9 +97,10 @@ def __init__( self.force_quantize_no_input_check = ( "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"] ) - self.is_weight_symmetric = self.extra_options.get( - "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) - ) + + # If user does not explicitly set "WeightSymmetric", then the weight's quantization type determines + # the symmetry (i.e., signed integer types will use symmetric quantization). See `def is_weight_symmetric()` + self._is_weight_symmetric: bool | None = self.extra_options.get("WeightSymmetric", None) self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False) self.min_real_range = self.extra_options.get("MinimumRealRange") @@ -131,6 +141,16 @@ def __init__( self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types() + def is_weight_symmetric(self, weight_quant_type: onnx.TensorProto.DataType) -> bool: + if self._is_weight_symmetric is not None: + return self._is_weight_symmetric # Return value explicitly set by user. + return weight_quant_type in ( + onnx.TensorProto.INT4, + onnx.TensorProto.INT8, + onnx.TensorProto.INT16, + onnx.TensorProto.FLOAT8E4M3FN, + ) + def quantize_model(self): raise NotImplementedError @@ -160,6 +180,9 @@ def should_quantize_node(self, node): if node.op_type not in self.op_types_to_quantize: return False + if node.op_type in (DEQUANT_OP_NAME, QUANT_OP_NAME): + return False + if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude: return False @@ -230,9 +253,19 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1 # TODO: This formula should be explained including why the scale is not estimated for the bias as well. bias_scale = input_scale * weight_scale * beta - quantized_data = (np.asarray(bias_data) / bias_scale).round() - quantized_data = np.clip(quantized_data, np.iinfo(np.int32).min, np.iinfo(np.int32).max) - quantized_data = quantized_data.astype(np.int32) + # Quantize by dividing by bias_scale + quantized_data = np.asarray(bias_data, dtype=np.float64) / np.asarray(bias_scale, dtype=np.float64) + quantized_data = quantized_data.round() + + # Clip quantized data to the range of a int32 + int32_min = np.float64(np.iinfo(np.int32).min) + int32_max = np.float64(np.iinfo(np.int32).max) + if np.any(quantized_data < int32_min) or np.any(quantized_data > int32_max): + logging.warning( + f"Quantized bias `{bias_name}` exceeds the range of a int32. The bias scale is too small." + ) + + quantized_data = np.clip(quantized_data, int32_min, int32_max).astype(np.int32) # update bias initializer bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims) @@ -282,6 +315,7 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa If keep_float_weight is False, quantize the weight, or don't quantize the weight. :return: quantized weight name, zero point name, scale name """ + # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there. q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight.name + "_zero_point" scale_name = weight.name + "_scale" @@ -303,10 +337,11 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}" else: - _, _, zero_point, scale, q_weight_data = quantize_data( + symmetric = self.is_weight_symmetric(qType) if qType == self.weight_qType else self.is_activation_symmetric + zero_point, scale, q_weight_data = quantize_data( weight_data.flatten(), qType, - quant_overrides.get("symmetric", self.is_weight_symmetric), + quant_overrides.get("symmetric", symmetric), reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range), min_real_range=self.min_real_range, rmin_override=quant_overrides.get("rmin"), @@ -371,6 +406,7 @@ def quantize_weight_per_channel_impl( reduce_range=True, keep_float_weight=False, ): + # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there. initializer = find_by_name(weight_name, self.model.initializer()) if initializer is None: raise ValueError("{} is not an initializer", weight_name) @@ -409,13 +445,7 @@ def quantize_weight_per_channel_impl( if "quant_type" in quant_overrides_for_channels[0]: weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806 - symmetric = quant_overrides_for_channels[0].get( - "symmetric", - ( - self.is_weight_symmetric - or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4) - ), - ) + symmetric = quant_overrides_for_channels[0].get("symmetric", self.is_weight_symmetric(weight_qType)) reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range) zero_point_list = [] scale_list = [] @@ -444,7 +474,7 @@ def quantize_weight_per_channel_impl( ), f"Unexpected type {type(quantized_per_channel_data)}" else: - _, _, zero_point, scale, quantized_per_channel_data = quantize_data( + zero_point, scale, quantized_per_channel_data = quantize_data( per_channel_data.flatten(), weight_qType, symmetric, @@ -529,4 +559,6 @@ def adjust_tensor_ranges(self): self.tensors_range[node.input[0]] = td # Adjust Softmax to range from 0.0 to 1.0 elif node.op_type == "Softmax": + if not self.should_quantize_node(node): + continue self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0)) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 174bf5fd1509c..43105550139de 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -296,6 +296,26 @@ def get_largest_node_name_suffix(self, node_name_prefix): return suffix + def get_largest_initializer_name_suffix(self, initializer_name_prefix): + """ + Gets the largest initializer name integer suffix for all initializer names that begin + with `initializer_name_prefix`. This can be used to create unique initializer names. + + Example: for initializer names 'my_weight_0' and 'my_weight_3', this method returns 3 if + `initializer_name_prefix` is 'my_weight_'. + """ + suffix = -1 + + for initializer in self.model.graph.initializer: + if initializer.name.startswith(initializer_name_prefix): + try: + index = int(initializer.name[len(initializer_name_prefix) :]) + suffix = max(index, suffix) + except ValueError: + continue + + return suffix + def find_nodes_by_initializer(self, graph, initializer): """ Find all nodes with given initializer as an input. diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index e1e4a4f724fdc..424f9b7e180a3 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -942,7 +942,7 @@ def _dequantize_value(self, value_name): self.model.model.producer_name == "onnx-quantizer" and scale_init is not None ): # axis is not specified so scale_init must be a scalar. - assert onnx.numpy_helper.to_array(scale_init).size == 1 + assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) diff --git a/onnxruntime/python/tools/quantization/operators/pad.py b/onnxruntime/python/tools/quantization/operators/pad.py index 5f3c1231e62d6..b3e9ddb5e6278 100644 --- a/onnxruntime/python/tools/quantization/operators/pad.py +++ b/onnxruntime/python/tools/quantization/operators/pad.py @@ -1,3 +1,12 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +from typing import Any + +import numpy as np import onnx from ..quant_utils import ( @@ -8,6 +17,7 @@ quantize_nparray, ) from .base_operator import QuantOperatorBase +from .qdq_base_operator import QDQOperatorBase class QPad(QuantOperatorBase): @@ -98,3 +108,65 @@ def quantize(self): node.input[0] = quantized_input_value.q_name node.output[0] = quantized_output_value.q_name self.quantizer.new_nodes += [node] + + +class QDQPad(QDQOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def _get_pad_const_val(self, attrs_dict: dict[str, Any]) -> np.ndarray | None: + """ + Returns the Pad's constant padding value. Returns `None` if the padding value is + not constant (i.e., comes from a dynamic input). + """ + const_val = None + onnx_tensor_type = self.quantizer.model.get_tensor_type(self.node.input[0]) + if onnx_tensor_type is None: + return None + + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type.elem_type) + if self.quantizer.opset_version < 11: + const_val = np.array(attrs_dict.get("value", 0), dtype=np_dtype) + elif len(self.node.input) >= 3 and self.node.input[2]: + const_val = self.quantizer.model.get_constant_value(self.node.input[2]) + else: + const_val = np.array(0, dtype=np_dtype) + + return const_val + + def _should_quantize_output_same_as_input(self) -> bool: + """ + Returns true if Pad's output should use the same quantization parameters as input[0] + """ + attrs_dict = {} + for attribute in self.node.attribute: + kv = attribute_to_kwarg(attribute) + attrs_dict.update(kv) + + pad_mode = attrs_dict.get("mode", b"constant") + if pad_mode in (b"reflect", b"edge", b"wrap"): + # These modes pad the output with a value that already exists in the input. + # So, we can quantize the output the same as the input. + return True + + # For 'constant' mode, if padding with 0, we can also quantize the output the same as the input + # because our quantization floating-point range always includes 0. + if pad_mode == b"constant": + pad_val = self._get_pad_const_val(attrs_dict) + if pad_val is not None and pad_val.dtype in (np.float32, np.float16): + return float(pad_val.item()) == 0 + + return False + + def quantize(self): + assert self.node.op_type == "Pad" + + for input_name in self.node.input: + if input_name: + self.quantizer.quantize_activation_tensor(input_name) + + if not self.disable_qdq_for_node_output: + if self._should_quantize_output_same_as_input(): + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) + else: + self.quantizer.quantize_activation_tensor(self.node.output[0]) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index b71f332252850..5552a4451c542 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -20,6 +20,7 @@ from .calibrate import TensorData from .quant_utils import ( DEQUANT_OP_NAME, + ONNX_TYPE_TO_NP_TYPE, QUANT_OP_NAME, QuantizedValue, QuantizedValueType, @@ -30,12 +31,14 @@ add_quant_input_suffix, add_quant_output_suffix, add_quant_suffix, + compute_data_quant_params, compute_scale_zp, compute_scale_zp_float8, find_by_name, get_qmin_qmax_for_qType, ms_domain, normalize_axis, + quantize_onnx_initializer, tensor_proto_to_array, ) from .registry import CreateQDQQuantizer @@ -86,6 +89,18 @@ class QDQTensorQuantParams: converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes. converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type. + def get_for_consumer(self, consumer_node_name) -> QuantizationParams: + if self.converted is None: # Quantized value is not converted, return original + return self.original + + if self.converted_recv_nodes is None: # All consumers receive the converted value + return self.converted + + # Check if consumer node name is in the list of nodes that + # receive the converted quantization value. If not, return the original value generated + # by the tensor's producer. + return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original + # Holds scale and zero_point initializer TensorProtos. @dataclass @@ -153,8 +168,8 @@ def __init__( op_types_to_quantize, extra_options, ) - self.tensors_to_quantize = {} - self.bias_to_quantize = {} + self.tensors_to_quantize: dict[str, QDQTensorQuantInfo] = {} + self.bias_to_quantize: dict[str, QDQBiasQuantInfo] = {} self.nodes_to_remove = [] @@ -180,7 +195,11 @@ def __init__( # The default behavior is that multiple nodes can share a QDQ pair as their inputs. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node. self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False) - self.tensor_to_its_receiving_nodes = {} + self.tensor_to_its_receiving_nodes: dict[str, list[onnx.NodeProto]] = {} + + # Maps a tensor to the DequantizeLinear node (in the original input model) that outputs the tensor. + # Populated for input models with some pre-quantized weights (typically via a different tool). + self.tensor_to_producing_dq: dict[str, onnx.NodeProto] = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {}) @@ -191,6 +210,9 @@ def __init__( # Used in the QDQRemovableActivation class. self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False) + # Let user disable adjustment of weight scales for bias inputs that are quantized to int32. + self.qdq_disable_weight_adjust_for_int32_bias = extra_options.get("QDQDisableWeightAdjustForInt32Bias", False) + # The ONNX spec did not support 16-bit Q/DQ ops before opset 21. # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types # are 16-bit or 4-bit integers. @@ -213,6 +235,7 @@ def __init__( self.qdq_op_domain = ms_domain self.quantization_params = self.calc_graph_quant_params() + self.initializer_quant_params: dict[str, QuantizationParams] = {} # Map of all original value names to quantized value names self.quantized_value_map = {} @@ -328,6 +351,18 @@ def quantize_weight_tensor_per_channel(self, tensor_name, axis): else: logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.") + def _dup_initializer(self, initializer: onnx.TensorProto) -> onnx.TensorProto: + """ + Duplicates an existing initializer and adds it to the model. Returns the new initializer. + """ + name_suffix: int = self.model.get_largest_initializer_name_suffix(initializer.name) + 1 + new_initializer_name = f"{initializer.name}{name_suffix}" + new_initializer = onnx.TensorProto() + new_initializer.CopyFrom(initializer) + new_initializer.name = new_initializer_name + self.model.add_initializer(new_initializer) + return new_initializer + def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0): """ Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that @@ -353,15 +388,160 @@ def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, be self.quantize_weight_tensor(bias_name) return - weight = find_by_name(bias_name, self.model.initializer()) - if weight is not None: - if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): - if bias_name not in self.bias_to_quantize: - self.bias_to_quantize[bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) - else: - logging.warning(f"Bias {bias_name} has already been marked for quantization") - else: - logging.warning(f"Expected {bias_name} to be a weight") + bias_initializer = find_by_name(bias_name, self.model.initializer()) + if bias_initializer is None: + logging.warning(f"Expected bias '{bias_name}' to be an initializer") + return + + if bias_initializer.data_type not in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): + logging.info(f"Expected bias '{bias_name}' to be an floating-point initializer") + return + + actual_bias_name = bias_name + if bias_name in self.bias_to_quantize: + # This bias input is consumed by two different nodes. We need to duplicate the bias so that + # each node has its own bias input. This is necessary because the bias's scale is computed + # from the node's other input scales. + new_bias_initializer = self._dup_initializer(bias_initializer) + actual_bias_name = new_bias_initializer.name + + # Replace this node's bias input + self.model.replace_input_of_nodes(bias_name, actual_bias_name, {node_name}) + logging.info(f"Created a copy of bias input '{bias_name}' called '{actual_bias_name}'") + + # Add this to our list of biases to quantize. + self.bias_to_quantize[actual_bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) + + def _adjust_weight_scale_for_int32_bias( + self, + input_scale: np.ndarray, + weight_scale: np.ndarray, + weight_name: str, + bias_tp: onnx.TensorProto, + is_per_channel: bool, + ) -> tuple[bool, np.ndarray | None]: + """ + Checks if the bias scale (input_scale * weight_scale) that we intend to use is too small. + A bias scale that is too small leads to quantized bias values that fall outside the range of a int32 and have to + be clipped, which decreases accuracy. If this function detects such a scenario, the weight_scale value will be + increased to prevent this from happening. + + Although the adjustment method and amount differs, the idea to adjust the weight's scale came from the following + reference: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/optimize/quantization_utils.cc#L252 + + :param input_scale: The input's scale. + :param weight_scale: The weight scale to potentially adjust. + :param weight_name: The weight initializer's name. Used for logging. + :param bias_tp: The bias ONNX initializer. + :param is_per_channel: True if the bias and weight are quantized per-channel. + :return: A tuple with a bool indicating if the weight's scale was adjusted and the new weight scale. + """ + if not weight_scale.size: + return False, None + + bias_float_data = tensor_proto_to_array(bias_tp) + + int32_info = np.iinfo(np.int32) + multiplicative_epsilon = 1.0001 + qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64) + weight_scale_dtype = weight_scale.dtype + updated_an_elem = False + + if not is_per_channel: + rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64)) + rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64)) + absmax = np.maximum(np.abs(rmin), np.abs(rmax)) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio. + ratio = bias_smallest_valid_scale / bias_candidate_scale + logging.info( + f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to " + f"ensure bias input `{bias_tp.name}` has a valid scale." + ) + new_scale = weight_scale_fp64 * ratio + weight_scale = new_scale.astype(weight_scale_dtype) + updated_an_elem = True + elif weight_scale.shape and len(weight_scale.shape) == 1: + # per-channel case + num_elems = weight_scale.shape[0] + + for i in range(num_elems): + bias_rmax = np.abs(bias_float_data[i]) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * bias_rmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale[i].item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio. + ratio = bias_smallest_valid_scale / bias_candidate_scale + logging.info( + f"Increased scale[{i}] for weight `{weight_name}` by ratio {ratio} " + f"to ensure bias input `{bias_tp.name}` has a valid scale." + ) + new_scale = weight_scale_fp64 * ratio + weight_scale[i] = new_scale.astype(weight_scale_dtype) + updated_an_elem = True + + return updated_an_elem, weight_scale + + def _adjust_weight_quant_params_for_bias_tensors(self): + """ + Iterates through all bias inputs that should be quantized to int32. If the intended + bias scale (equal to input_scale * weight_scale) is too small, this function will increase + the associated weight's scale to ensure the bias does not overflow the int32 range when quantized. + """ + + if self.qdq_disable_weight_adjust_for_int32_bias: + # User passed an extra_option to disable this adjustment. + return + + for bias_name, bias_info in self.bias_to_quantize.items(): + if ( + bias_info.input_name not in self.quantization_params + or bias_info.input_name not in self.tensors_to_quantize + or bias_info.weight_name not in self.initializer_quant_params + ): + continue + + # Get the associated input's scale. + input_qparams = self.quantization_params[bias_info.input_name].get_for_consumer(bias_info.node_name) + input_info = self.tensors_to_quantize[bias_info.input_name] + input_scale = np.asarray( + input_qparams["scale"], dtype=onnx.helper.tensor_dtype_to_np_dtype(input_info.data_type) + ) + + weight_quant_params = self.initializer_quant_params[bias_info.weight_name] + weight_quant_type = weight_quant_params["quant_type"] + if weight_quant_type not in (onnx.TensorProto.INT8, onnx.TensorProto.INT16): + continue + + weight_zero_point: np.ndarray = weight_quant_params["zero_point"] + if weight_zero_point.any(): + # Skip if zero_point(s) are not all zero (i.e., symmetric quant) + continue + + weight_scale: np.ndarray = weight_quant_params["scale"] + is_per_channel = weight_quant_params.get("axis", None) is not None + + # Get adjusted weight scales. + did_update_weight_scale, new_weight_scale = self._adjust_weight_scale_for_int32_bias( + input_scale, + weight_scale, + bias_info.weight_name, + find_by_name(bias_name, self.model.initializer()), + is_per_channel, + ) + + if did_update_weight_scale: + weight_quant_params["scale"] = new_weight_scale def remove_node(self, node): self.nodes_to_remove.append(node) @@ -379,7 +559,12 @@ def quantize_model(self): if tensor_name not in self.tensor_to_its_receiving_nodes: self.tensor_to_its_receiving_nodes[tensor_name] = [] self.tensor_to_its_receiving_nodes[tensor_name].append(node) + if node.op_type == DEQUANT_OP_NAME: + for tensor_name in node.output: + self.tensor_to_producing_dq[tensor_name] = node + self.initializer_quant_params = self._calc_initializer_quant_params() + self._adjust_weight_quant_params_for_bias_tensors() self._quantize_normal_tensors() self._quantize_sharing_param_tensors() if self.quantize_bias: @@ -475,38 +660,26 @@ def _create_qdq_nodes( ) self.model.add_nodes([qlinear_node, dequant_node]) - def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): + def _add_qdq_nodes_for_initializer(self, weight_proto: onnx.TensorProto): + """ + Adds Q/DQ nodes for an initializer. If `self.add_qdq_pair_to_weight` is true, creates + the sequence (weight_f32 -> Q -> DQ -> ). Otherwise, this function quantizes the initializer + and adds the sequence (weight_quant -> DQ ->). + """ weight_name = weight_proto.name - if axis is not None: - if self.opset_version < 13: - raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") - - qtype = self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType - if qtype == onnx.onnx_pb.TensorProto.UINT8: - qtype = onnx_proto.TensorProto.INT8 - - q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( - weight_name, - # Quantization type is forced to be TensorProto.INT8. - # when the expected value would be (see below) - # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. - # QLinearConv expects to have a unique value for all channels. - # This code does not enforce that but it is necessarily the case when the - # quantization is symmetric (as for INT8). - qtype, - axis, - keep_float_weight=self.add_qdq_pair_to_weight, - ) - else: - q_weight_name, zp_name, scale_name = self.quantize_initializer( - weight_proto, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, - keep_float_weight=self.add_qdq_pair_to_weight, - ) + if weight_name in self.quantized_value_map: + return + quant_params: QuantizationParams = self.initializer_quant_params[weight_name] + axis: int = quant_params.get("axis") + scale_zp_initializers = self._make_scale_zp_initializers(weight_name, quant_params) + q_weight_name: str | None = None weight_dequant_output = add_dequant_output_suffix(weight_name) self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output) + if self.add_qdq_pair_to_weight: + # Don't actually quantize the weight. Instead, keep floating-point weight and create the node + # sequence (weight_f32 -> Q -> DQ -> weight_dequant) weight_quant_output = add_quant_output_suffix(weight_name) self._create_qdq_nodes( @@ -516,14 +689,26 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): weight_quant_output, weight_dequant_output, add_dequant_suffix(weight_name), - scale_name, - zp_name, + scale_zp_initializers.scale.name, + scale_zp_initializers.zero_point.name, axis, ) else: + # Quantize the weight and create the node sequence: + # (weight_quantized -> DQ -> weight_dequant) + quant_weight = quantize_onnx_initializer( + weight_proto, + quant_params["quant_type"], + quant_params["zero_point"], + quant_params["scale"], + axis, + ) + self.model.add_initializer(quant_weight) + + q_weight_name = quant_weight.name dequant_node = onnx.helper.make_node( DEQUANT_OP_NAME, - [q_weight_name, scale_name, zp_name], + [quant_weight.name, scale_zp_initializers.scale.name, scale_zp_initializers.zero_point.name], [weight_dequant_output], add_dequant_suffix(weight_name), axis=axis, @@ -531,6 +716,17 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): ) self.model.add_node(dequant_node) + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_zp_initializers.scale.name, + scale_zp_initializers.zero_point.name, + QuantizedValueType.Initializer, + axis=axis, + ) + self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_type=None): if ( self.dedicated_qdq_pair @@ -767,8 +963,16 @@ def _quantize_normal_tensors(self): # Quantize the input initializer = find_by_name(tensor_name, self.model.initializer()) if initializer: - self._add_qdq_pair_for_initializer(initializer, tensor_info.tensor_type, tensor_info.axis) + self._add_qdq_nodes_for_initializer(initializer) else: + # Check if this tensor is already a dequantized value. If so, skip it. + # This happens if the original input model already has some pre-quantized weights + # generated by a different tool. + # Ex: (quantized_weight -> DequantizeLinear -> this_tensor) + if tensor_name in self.tensor_to_producing_dq: + del self.tensors_to_quantize[tensor_name] + continue + tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name) if not tensor_qparam_initializers: raise ValueError( @@ -820,6 +1024,12 @@ def _quantize_sharing_param_tensors(self): if self.is_input_a_initializer(tensor_name): raise ValueError("Quantization parameter shared mode is not supported for weight yet") + if tensor_name in self.tensor_to_producing_dq: + raise ValueError( + f"Quantization parameter sharing is invalid for tensor {tensor_name} " + "because it has already been quantized" + ) + # Need to check if this tensor's quant_type is converted for some consumers. # If so, create new scale/zp initializers for these consumers. converted_qparam_inits = None @@ -909,45 +1119,6 @@ def _quantize_bias_tensors(self): def is_tensor_quantized(self, tensor_name: str): return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize - def quantize_initializer( - self, - weight: onnx.TensorProto, - qType: onnx.TensorProto.DataType, - reduce_range: bool = False, - keep_float_weight: bool = False, - ) -> tuple[str, str, str]: - """ - :param weight: TensorProto initializer - :param qType: type to quantize to - :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. - If keep_float_weight is False, quantize the weight, or don't quantize the weight. - :return: quantized weight name, zero point name, scale name - """ - # Find if this input is already quantized - if weight.name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight.name].original - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - - q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( - weight, qType, reduce_range, keep_float_weight - ) - - # Log entry for this quantized weight - quantized_value = QuantizedValue( - weight.name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight.name] = QDQTensorQuantizedValue(quantized_value, None, None) - return q_weight_name, zp_name, scale_name - def is_tensor_per_channel( self, tensor_name: str, @@ -997,37 +1168,29 @@ def is_tensor_per_channel( return True, axis - def quantize_weight_per_channel( - self, - weight_name: str, - weight_qType: onnx.TensorProto.DataType, - channel_axis: int, - reduce_range: bool = True, - keep_float_weight: bool = False, - ) -> tuple[str, str, str]: - # Find if this input is already quantized - if weight_name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight_name].original - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) + def _get_tensor_quantization_scale(self, tensor_name: str, consumer_node_name: str) -> np.ndarray | None: + """ + Returns the quantization scale of a tensor that is consumed by the given node. + :parameter tensor_name: The name of the tensor. + :parameter consumer_node_name: The name of the node that consumes the tensor as input. Necessary in case + the quantization type of the tensor was converted. + Refer: QDQQuantizer::_add_qdq_ops_for_converted_activation. + :returns: The quantization scale or None. + """ + initializers = self.model.initializer() + scale_initializer: onnx.TensorProto | None = None - q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( - weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight - ) - quantized_value = QuantizedValue( - weight_name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + if tensor_name in self.quantized_value_map: + # Tensor was quantized by this tool, so get scale from initializer created by this tool run. + scale_name = self.quantized_value_map[tensor_name].get_for_consumer(consumer_node_name).scale_name + scale_initializer = find_by_name(scale_name, initializers) + else: + # Tensor was already quantized in original model, so get scale from DQ node that outputs the tensor. + dq_node = self.tensor_to_producing_dq.get(tensor_name, None) + if dq_node: + scale_initializer = find_by_name(dq_node.input[1], initializers) - return q_weight_name, zp_name, scale_name + return tensor_proto_to_array(scale_initializer) if scale_initializer is not None else None def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str: """ @@ -1038,17 +1201,21 @@ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> s if bias_name in self.quantized_value_map: return self.quantized_value_map[bias_name].original.q_name - # get scale for weight - weight_scale_name = self.quantized_value_map[bias_info.weight_name].original.scale_name - weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) - weight_scale = tensor_proto_to_array(weight_initializer) + # get scale for weight. + weight_scale = self._get_tensor_quantization_scale(bias_info.weight_name, bias_info.node_name) + if weight_scale is None: + raise ValueError( + f"Unable to get valid quantization scale for weight input '{bias_info.weight_name}' " + f"when quantizing bias '{bias_name}' to int32." + ) - # get scale for input - input_scale_name = ( - self.quantized_value_map[bias_info.input_name].get_for_consumer(bias_info.node_name).scale_name - ) - inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) - input_scale = tensor_proto_to_array(inputscale_initializer) + # get scale for input. + input_scale = self._get_tensor_quantization_scale(bias_info.input_name, bias_info.node_name) + if input_scale is None: + raise ValueError( + f"Unable to get valid quantization scale for input '{bias_info.input_name}' " + f"when quantizing bias '{bias_name}' to int32." + ) ( quantized_bias_name, @@ -1074,7 +1241,7 @@ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> s return quantized_bias_name def _make_scale_zp_initializers( - self, param_name: str, params: QuantizationParams, init_name_suffix: str = "" + self, param_name: str, quant_params: QuantizationParams, init_name_suffix: str = "" ) -> QDQScaleZpInitializers: """ Creates and returns scale and zero-point initializers for the given quantization params. The initializers are @@ -1082,31 +1249,31 @@ def _make_scale_zp_initializers( - {param_name}_zero_point{init_name_suffix} - {param_name}_scale{init_name_suffix} """ - zero_point_values = np.array([params["zero_point"]]) - if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): - raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") - scale_values = np.array([params["scale"]]) - assert scale_values.dtype != np.float64 - zero_point_type = params.data.get("quant_type", self.activation_qType) - - zero_point_shape = [] + zero_point = quant_params["zero_point"] + scale = quant_params["scale"] + zero_point_type = quant_params["quant_type"] + axis: int | None = quant_params.get("axis") + assert (axis is not None and len(scale.shape) == 1) or ( + axis is None and len(scale.shape) == 0 + ), "Wrong scale/zp shapes" + assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank" + zero_point_name = param_name + "_zero_point" + init_name_suffix - scale_shape = [] scale_name = param_name + "_scale" + init_name_suffix # Add initializers to model init_zp = onnx.helper.make_tensor( - zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + zero_point_name, zero_point_type, zero_point.shape, zero_point.ravel().tolist() ) self.model.add_initializer(init_zp) - if scale_values.dtype == np.float32: + if scale.dtype == np.float32: scale_type = onnx_proto.TensorProto.FLOAT - elif scale_values.dtype == np.float16: + elif scale.dtype == np.float16: scale_type = onnx_proto.TensorProto.FLOAT16 else: - raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") - init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + raise ValueError(f"Unexpected dtype={scale.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale.shape, scale.ravel().tolist()) self.model.add_initializer(init_scale) return QDQScaleZpInitializers(init_scale, init_zp) @@ -1155,7 +1322,7 @@ def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) - return QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + return QuantizationParams(zero_point=zero.squeeze(), scale=scale.squeeze(), quant_type=quant_type) def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: """ @@ -1185,3 +1352,127 @@ def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes) return quantization_params + + def _calc_initializer_quant_params(self) -> dict[str, QuantizationParams]: + """ + Returns quantization parameters (scale/zero_point/quant_type) for all initializers. + """ + + quantization_params: dict[str, QuantizationParams] = {} + for tensor_name, tensor_info in self.tensors_to_quantize.items(): + initializer = find_by_name(tensor_name, self.model.initializer()) + if not initializer: + continue + + initializer_data = tensor_proto_to_array(initializer) + initializer_rank = len(initializer_data.shape) + + # initializers for elementwise ops use the quant_type for activations. + is_weight = tensor_info.tensor_type is QDQQuantTensorType.WEIGHT + quant_type = self.weight_qType if is_weight else self.activation_qType + + # Try to get scale/zp directly from user's overrides and avoid computation. + if self.tensor_quant_overrides.overrides_scale_zp(tensor_name): + overrides = self.tensor_quant_overrides[tensor_name] + if "quant_type" in overrides[0]: + quant_type = overrides[0]["quant_type"].tensor_type + + zp_dtype = ONNX_TYPE_TO_NP_TYPE[quant_type] + is_per_channel = "axis" in overrides[0] + if not is_per_channel: + quantization_params[tensor_name] = QuantizationParams( + zero_point=np.array(overrides[0]["zero_point"], dtype=zp_dtype), + scale=np.array(overrides[0]["scale"], initializer_data.dtype), + quant_type=quant_type, + ) + else: + zero_points_list = [] + scales_list = [] + for chan_overrides in overrides: + zero_points_list.append(np.array(chan_overrides["zero_point"], zp_dtype)) + scales_list.append(np.array(chan_overrides["scale"], dtype=initializer_data.dtype)) + + channel_axis = overrides[0]["axis"] + is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank) + if not is_axis_valid: + raise ValueError( + f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is " + f"out-of-bounds for rank {initializer_rank}" + ) + + quantization_params[tensor_name] = QuantizationParams( + zero_point=np.array(zero_points_list), + scale=np.array(scales_list), + quant_type=quant_type, + axis=norm_channel_axis, + ) + + continue + + # Compute scale/zp normally. User's overrides may still override parameters + # used to compute the scale/zp (e.g., rmin, rmax, symmetric, etc.) + overrides = self.tensor_quant_overrides.get(tensor_name, [{}]) + if "quant_type" in overrides[0]: + quant_type = overrides[0]["quant_type"].tensor_type + + channel_axis = overrides[0].get("axis", tensor_info.axis) + is_per_channel = channel_axis is not None + + # Note: always quantize per-channel initializers as symmetric because QLinear* ops require the + # same zero-point in every channel, which is necessarily the case for symmetric quantization. + is_symmetric_default = is_per_channel or ( + self.is_weight_symmetric(quant_type) if is_weight else self.is_activation_symmetric + ) + is_symmetric = overrides[0].get("symmetric", is_symmetric_default) + reduce_range = overrides[0].get("reduce_range", self.reduce_range) + zero_point: np.ndarray | None = None + scale: np.ndarray | None = None + + if not is_per_channel: + zero_point, scale = compute_data_quant_params( + initializer_data.flatten(), + quant_type, + is_symmetric, + reduce_range=reduce_range, + min_real_range=self.min_real_range, + rmin_override=overrides[0].get("rmin"), + rmax_override=overrides[0].get("rmax"), + ) + else: + is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank) + if not is_axis_valid: + raise ValueError( + f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is " + f"out-of-bounds for rank {initializer_rank}" + ) + + channel_axis = norm_channel_axis + channel_count = initializer_data.shape[channel_axis] + zero_points_list = [] + scales_list = [] + for i in range(channel_count): + per_channel_data = initializer_data.take(i, channel_axis) + channel_overrides = overrides[i] if overrides and i < len(overrides) else {} + channel_zero_point, channel_scale = compute_data_quant_params( + per_channel_data.ravel(), + quant_type, + is_symmetric, + reduce_range=reduce_range, + min_real_range=self.min_real_range, + rmin_override=channel_overrides.get("rmin"), + rmax_override=channel_overrides.get("rmax"), + ) + zero_points_list.append(channel_zero_point) + scales_list.append(channel_scale) + + zero_point = np.asarray(zero_points_list) + scale = np.asarray(scales_list) + + quantization_params[tensor_name] = QuantizationParams( + zero_point=zero_point, + scale=scale, + quant_type=quant_type, + axis=channel_axis, + ) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 9228ad33130f2..2bf675745d093 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -33,6 +33,12 @@ int4 = None uint4 = None +try: + from onnx.reference.op_run import to_array_extended +except ImportError: + # old version of onnx. + to_array_extended = None + __producer__ = "onnx.quantize" __version__ = "0.1.0" @@ -43,6 +49,7 @@ DEQUANT_OP_NAME = "DequantizeLinear" DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output" TENSOR_NAME_QUANT_SUFFIX = "_quantized" +MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB FLOAT8_DISTRIBUTIONS = {} @@ -156,7 +163,9 @@ def from_string(format): } ONNX_INT_TYPE_SYMMETRIC_RANGE = { + onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(254, dtype=numpy.uint8)), onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)), + onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65534, dtype=numpy.uint16)), onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)), } @@ -229,7 +238,7 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): # which matches the python reference ONNX implementation of QuantizeLinear. # This data can be packed into 4-bit elements by using pack_bytes_to_4bit(). dtype = ONNX_TYPE_TO_NP_TYPE[qType] - (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) + qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False) cliplow = max(qmin, low) if low is not None else qmin cliphigh = min(qmax, high) if high is not None else qmax @@ -269,7 +278,7 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non # Ensure a minimum float-point range if specified. if min_real_range is not None: - rmax = max(rmax, rmin + min_real_range) + rmax = max(rmax, rmin + numpy.asarray(min_real_range, dtype=rmin.dtype)) if symmetric: absmax = numpy.maximum(numpy.abs(rmin), numpy.abs(rmax)) @@ -338,13 +347,75 @@ def compute_scale_zp_float8(element_type, std): return [zero, scale] +def compute_data_quant_params( + data: numpy.ndarray, + quant_type: onnx.TensorProto.DataType, + symmetric: bool, + reduce_range: bool = False, + min_real_range: float | None = None, + rmin_override: float | None = None, + rmax_override: float | None = None, +) -> tuple[numpy.ndarray, numpy.ndarray]: + """ + Returns the zero_point and scale for the given data. + + :param data: The data for which to compute quantization parameters. + :param quant_type: The quantization data type. + :param symmetric: whether symmetric quantization is used or not. + :parameter reduce_range: True if the quantization range should be reduced. Defaults to False. + :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None. + :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data). + :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data). + :return: zero point and scale + """ + if not isinstance(data, numpy.ndarray): + raise TypeError(f"Weight must be given as an array not {type(data)}.") + if rmin_override is not None: + rmin = rmin_override + else: + rmin = data.min() if len(data) else 0.0 + + if rmax_override is not None: + rmax = rmax_override + else: + rmax = data.max() if len(data) else 0.0 + + rmin = numpy.array(rmin, dtype=data.dtype) + rmax = numpy.array(rmax, dtype=data.dtype) + scale = numpy.array(1.0, dtype=data.dtype) + + if quant_type == TensorProto.FLOAT8E4M3FN: + if reduce_range: + raise RuntimeError("Unsupported option reduce_range=True for float 8.") + std = numpy.std(data) + zero_point, scale = compute_scale_zp_float8(quant_type, std) + return _check_type(zero_point, scale, zero_point_index=0) + + if quant_type in ( + TensorProto.INT8, + TensorProto.UINT8, + TensorProto.INT16, + TensorProto.UINT16, + TensorProto.INT4, + TensorProto.UINT4, + ): + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range, symmetric=symmetric) + if len(data): + zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) + else: + zero_point = numpy.array(0, dtype=qmin.dtype) + return _check_type(zero_point, scale, zero_point_index=0) + + raise ValueError(f"Unexpected value for quant_type={quant_type}.") + + def quantize_data( data, qType, symmetric, reduce_range=False, min_real_range=None, rmin_override=None, rmax_override=None -): +) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: """ :param data: data to quantize - :param qType: data type to quantize to. Supported types UINT8 and INT8 - :param symmetric: whether symmetric quantization is used or not. This is applied to INT8. + :param qType: data type to quantize to. + :param symmetric: whether symmetric quantization is used or not. :parameter reduce_range: True if the quantization range should be reduced. Defaults to False. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None. :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data). @@ -366,28 +437,16 @@ def quantize_data( - *S*: scale - *z*: zero point """ - if not isinstance(data, numpy.ndarray): - raise TypeError(f"Weight must be given as an array not {type(data)}.") - if rmin_override is not None: - rmin = rmin_override - else: - rmin = data.min() if len(data) else 0.0 - - if rmax_override is not None: - rmax = rmax_override - else: - rmax = data.max() if len(data) else 0.0 - - rmin = numpy.array(rmin, dtype=data.dtype) - rmax = numpy.array(rmax, dtype=data.dtype) - zero_point = 0 - scale = numpy.array(1.0, dtype=data.dtype) - + zero_point, scale = compute_data_quant_params( + data, + qType, + symmetric, + reduce_range, + min_real_range, + rmin_override, + rmax_override, + ) if qType == TensorProto.FLOAT8E4M3FN: - if reduce_range: - raise RuntimeError("Unsupported option reduce_range=True for float 8.") - std = numpy.std(data) - zero_point, scale = compute_scale_zp_float8(qType, std) quantized_data = quantize_nparray(qType, data, scale, zero_point) if any((quantized_data.astype(numpy.uint8).ravel() & 127) == 127): np_data = numpy.asarray(data) @@ -395,7 +454,7 @@ def quantize_data( f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], " f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]." ) - return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) + return zero_point, scale, quantized_data if qType in ( TensorProto.INT8, @@ -405,15 +464,91 @@ def quantize_data( TensorProto.INT4, TensorProto.UINT4, ): - if len(data): - qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) - zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) quantized_data = quantize_nparray(qType, data, scale, zero_point) - return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) + return zero_point, scale, quantized_data raise ValueError(f"Unexpected value for qType={qType}.") +def quantize_onnx_initializer( + weight: onnx.TensorProto, + quant_type: onnx.TensorProto.DataType, + zero_point: numpy.ndarray, + scale: numpy.ndarray, + axis: int | None = None, + quant_weight_name: str | None = None, +) -> onnx.TensorProto: + """ + Returns a quantized version of the given ONNX initializer. + + :param weight: The ONNX initializer to quantize. + :param quant_type: The final quantized data type. + :param zero_point: The zero-point value to use for quantization. + :param scale: The scale value to use for quantization. + :param axis: The quantization axis if quantizing per-channel. Defaults to None. + :param quant_weight_name: The name of the quantized initializer. + If not specified, the quantized name is generated. + :return: The quantized ONNX initializer. + """ + weight_data = tensor_proto_to_array(weight) + q_weight_data: numpy.ndarray | None = None + + if axis is None: # Per-tensor quantization + q_weight_data = quantize_nparray(quant_type, weight_data.ravel(), scale, zero_point) + else: # Per-channel quantization + channel_count = weight_data.shape[axis] + channel_dims = list(weight_data.shape) # deep copy + channel_dims[axis] = 1 # only one per channel for reshape + quantized_channel_data_list = [] + + for i in range(channel_count): + channel_data = weight_data.take(i, axis) + channel_scale = scale[i] + channel_zero_point = zero_point[i] + quantized_channel_data = quantize_nparray( + quant_type, channel_data.ravel(), channel_scale, channel_zero_point + ) + quantized_channel_data_list.append(numpy.asarray(quantized_channel_data).reshape(channel_dims)) + + q_weight_data = numpy.concatenate(quantized_channel_data_list, axis) + + q_weight_name = quant_weight_name if quant_weight_name else f"{weight.name}{TENSOR_NAME_QUANT_SUFFIX}" + + if quant_type == onnx.TensorProto.FLOAT8E4M3FN: + q_weight_initializer = onnx.TensorProto() + q_weight_initializer.data_type = quant_type + q_weight_initializer.dims.extend(weight.dims) + q_weight_initializer.name = q_weight_name + # Do not remove .flatten().copy() numpy is not clear about data persistence. + q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes() + if to_array_extended is not None: + # This test should not be needed but it helped catch some issues + # with data persistence and tobytes. + check = to_array_extended(q_weight_initializer) + if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes(): + raise RuntimeError( + f"The initializer of shape {weight_data.shape} could not be created, expecting " + f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}" + f"\nraw={str(q_weight_initializer)[:200]}." + ) + elif quant_type in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + if q_weight_data.dtype not in (numpy.int8, numpy.uint8): + raise RuntimeError(f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values.") + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 + q_weight_initializer = onnx.helper.make_tensor(q_weight_name, quant_type, weight.dims, packed_data, raw=True) + else: + quant_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(quant_type) + q_weight_data = numpy.asarray(q_weight_data, dtype=quant_np_dtype).reshape(weight.dims) + q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) + + return q_weight_initializer + + def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802 """ Return qmin and qmax, the minimum and maximum value representable by the given qType diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 745344dc01fcb..4ffd8b9872982 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -3,10 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + +import copy import logging import tempfile from pathlib import Path -from typing import Union +from typing import Any, Callable import onnx @@ -14,6 +17,7 @@ from .onnx_quantizer import ONNXQuantizer from .qdq_quantizer import QDQQuantizer from .quant_utils import ( + MODEL_SIZE_THRESHOLD, QuantFormat, QuantizationMode, QuantType, @@ -22,6 +26,7 @@ save_and_reload_model_with_shape_infer, ) from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry +from .tensor_quant_overrides import TensorQuantOverridesHelper class QuantConfig: @@ -192,6 +197,9 @@ def __init__( removed if activations are asymmetrically quantized. Keeping these activations is necessary if optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear operators from the model. + QDQDisableWeightAdjustForInt32Bias = True/False: + Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias + has a scale (input_scale * weight_scale) that is too small. execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc. Raises: ValueError: Raise ValueError if execution provider is unknown @@ -213,6 +221,167 @@ def __init__( self.extra_options = extra_options or {} +def get_qdq_config( + model_input: str | Path | onnx.ModelProto, + calibration_data_reader: CalibrationDataReader, + calibrate_method=CalibrationMethod.MinMax, + calibrate_args: dict[str, Any] | None = None, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + activation_symmetric: bool = False, + weight_symmetric: bool | None = None, + per_channel: bool = False, + reduce_range: bool = False, + keep_removable_activations: bool = False, + min_real_range: float | None = None, + tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None, + nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None, + extra_options: dict | None = None, +) -> StaticQuantConfig: + """ + Returns a configuration suitable that quantizes the entire model to integer precision. + + Params: + model_input: Path to the input model file or ModelProto. + calibration_data_reader: Calibration data reader. + calibrate_methode: The calibration method. Defaults to MinMax. + activation_type: The default activation quantization type. Defaults to QUInt8. + weight_type: The default weight quantization type. Defaults to QInt8. + activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default. + Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uint16, + the zero-point values are 127 and 32,767, respectively. + weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default. + Defaults to None. If set to None, weight_symmetric is assumed true if a weight's quant type is a signed int. + per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel. + Defaults to false. Alternatively, use the tensor-level `tensor_quant_overrides` to select individual operators + and their quantization axes. + reduce_range: quantize weights with 1 less bit of precision (e.g., 7 bits for QInt8). Defaults to false. + May improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode. + keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not + be removed, and will be explicitly represented in the QDQ model. If false, these activations + are automatically removed if activations are asymmetrically quantized. Keeping these activations + is necessary if optimizations or EP transformations will later remove + QuantizeLinear/DequantizeLinear operators from the model. + min_real_range: Default is None. If set to a floating-point value, the calculation of the quantization parameters + (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin) + is less than the specified minimum range, rmax will be set to rmin + min_real_range. + tensor_quant_overrides: tensor-level quantization overrides. Defaults to None. + The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list + contains a single dictionary. For per-channel quantization, the list contains either a dictionary for + each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis' + key must be present in the first dictionary for per-channel quantization. + + Each dictionary contains optional overrides with the following keys and values. + 'quant_type' = QuantType : The tensor's quantization data type. + 'axis' = Int : The per-channel axis. Must be present for per-channel weights. + 'scale' = Float : The scale value to use. Must also specify `zero_point` if set. + 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set. + 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also + set `scale` or `zero_point`. + 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also + set `scale` or `zero_point`. Only valid for initializers. + 'rmax' = Float : Override the maximum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'rmin' = Float : Override the minimum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'convert' = Dict : A nested dictionary with the same keys for an activation + tensor that should be converted to another quantization type. + 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, + other nodes get the original type. If not specified, + assume all consumer nodes get the converted type. + nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that + accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto + should be excluded from quantization. + extra_options: Additional options specified as string key/value pairs. Refer to the documentation for + `quantize_static` for valid keys and values. + + Returns: + A StaticQuantConfig object + """ + q16_types = {QuantType.QInt16, QuantType.QUInt16} + q4_types = {QuantType.QInt4, QuantType.QUInt4} + op_types_to_exclude = {"Cast", "DequantizeLinear", "QuantizeLinear"} + + model = ( + model_input + if isinstance(model_input, onnx.ModelProto) + else onnx.load_model(model_input, load_external_data=False) + ) + + op_types = set() + model_has_external_data = False + overrides_helper = TensorQuantOverridesHelper( + copy.deepcopy(tensor_quant_overrides) if tensor_quant_overrides else {} + ) + + # check if the model has external data. + for initializer in model.graph.initializer: + if onnx.external_data_helper.uses_external_data(initializer): + model_has_external_data = True + + final_nodes_to_exclude = [] + if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list): + final_nodes_to_exclude.extend(nodes_to_exclude) + + # Iterate through nodes to get all operator types in the model and + # call user's function to filter out nodes from quantization. + for node in model.graph.node: + op_types.add(node.op_type) + if nodes_to_exclude is not None and callable(nodes_to_exclude): + if nodes_to_exclude(model, node): + final_nodes_to_exclude.append(node.name) + + final_extra_options = { + "MinimumRealRange": min_real_range, + "QDQKeepRemovableActivations": keep_removable_activations, + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, + "ForceQuantizeNoInputCheck": True, + "TensorQuantOverrides": overrides_helper.get_dict(), + } + + # Pass along known calibration options + if calibrate_args: + calib_extra_options_keys = [ + ("symmetric", "CalibTensorRangeSymmetric"), + ("moving_average", "CalibMovingAverage"), + ("averaging_constant", "CalibMovingAverageConstant"), + ("max_intermediate_outputs", "CalibMaxIntermediateOutputs"), + ("percentile", "CalibPercentile"), + ] + calib_extra_options = { + key: calibrate_args.get(name) for (name, key) in calib_extra_options_keys if name in calibrate_args + } + final_extra_options.update(calib_extra_options) + + # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain + # on Q/DQ operators if using 16-bit or 4-bit quantization. + onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") + if onnx_opset.version < 21: + opset21_types = q16_types.union(q4_types) + overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types()) + if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types: + final_extra_options["UseQDQContribOps"] = True + + # Allow user's extra_options to override our final_extra_options. + if extra_options: + final_extra_options.update(extra_options) + + return StaticQuantConfig( + calibration_data_reader, + calibrate_method=calibrate_method, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + op_types_to_quantize=list(op_types.difference(op_types_to_exclude)), + nodes_to_exclude=final_nodes_to_exclude, + per_channel=per_channel, + reduce_range=reduce_range, + use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), + extra_options=final_extra_options, + ) + + class DynamicQuantConfig(QuantConfig): def __init__( self, @@ -290,8 +459,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua def quantize_static( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, calibration_data_reader: CalibrationDataReader, quant_format=QuantFormat.QDQ, op_types_to_quantize=None, @@ -438,6 +607,9 @@ def quantize_static( removed if activations are asymmetrically quantized. Keeping these activations is necessary if optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear operators from the model. + QDQDisableWeightAdjustForInt32Bias = True/False: + Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias + has a scale (input_scale * weight_scale) that is too small. """ if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN: if calibrate_method != CalibrationMethod.Distribution: @@ -473,6 +645,7 @@ def quantize_static( ("CalibMovingAverage", "moving_average"), ("CalibMovingAverageConstant", "averaging_constant"), ("CalibMaxIntermediateOutputs", "max_intermediate_outputs"), + ("CalibPercentile", "percentile"), ] calib_extra_options = { key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options @@ -590,8 +763,8 @@ def inc_dataloader(): def quantize_dynamic( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, op_types_to_quantize=None, per_channel=False, reduce_range=False, @@ -690,8 +863,8 @@ def quantize_dynamic( def quantize( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, quant_config: QuantConfig, ): """Quantize a model with QuantConfig. diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 160b056e1de17..fbeae39c39d21 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -14,7 +14,7 @@ from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul from .operators.maxpool import QDQMaxPool, QMaxPool from .operators.norm import QDQNormalization -from .operators.pad import QPad +from .operators.pad import QDQPad, QPad from .operators.pooling import QLinearPool from .operators.qdq_base_operator import QDQOperatorBase from .operators.resize import QDQResize, QResize @@ -76,6 +76,8 @@ "Resize": QDQResize, "MaxPool": QDQMaxPool, "AveragePool": QDQDirect8BitOp, + "Slice": QDQDirect8BitOp, + "Pad": QDQPad, "MatMul": QDQMatMul, "Split": QDQSplit, "Gather": QDQGather, diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py index 219d929d22fce..fbd0cc17f5d81 100644 --- a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -78,6 +78,10 @@ def has_per_channel_overrides(self, tensor_name: str) -> bool: overrides_list = self.overrides.get(tensor_name) return overrides_list and "axis" in overrides_list[0] + def overrides_scale_zp(self, tensor_name: str) -> bool: + overrides_list = self.overrides.get(tensor_name) + return overrides_list and ("scale" in overrides_list[0]) and ("zero_point" in overrides_list[0]) + def get_per_tensor_overrides( self, tensor_name: str, diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index 763d160fa56b5..3ebc33c02592d 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -17,8 +17,8 @@ TRT_DOCKER_FILES = { "8.6.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6", "8.6.cuda_12_3_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6", - "10.4.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10", - "10.4.cuda_12_5_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10", + "10.5.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10", + "10.5.cuda_12_5_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10", "BIN": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin", } diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 26f8987c76623..55ce8d752a9d6 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -348,7 +348,7 @@ def run_pytorch( else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length logger.debug(f"Model {model}") logger.debug(f"Number of parameters {model.num_parameters()}") @@ -500,7 +500,7 @@ def run_tensorflow( tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length for batch_size in batch_sizes: if batch_size <= 0: diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index 167fc8697ce06..ccf2497d61342 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -250,6 +250,7 @@ def generate_test_data( average_sequence_length: int, random_sequence_length: bool, mask_type: int, + dictionary_size: int = 10000, ): """Create given number of input data for testing @@ -270,7 +271,6 @@ def generate_test_data( List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary with input name as key and a tensor as value """ - dictionary_size = 10000 all_inputs = fake_test_data( batch_size, sequence_length, diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py index 0c5125e74c8a4..03bcc20d9a5de 100644 --- a/onnxruntime/python/tools/transformers/compare_bert_results.py +++ b/onnxruntime/python/tools/transformers/compare_bert_results.py @@ -85,6 +85,7 @@ def run_test( segment_ids_name, input_mask_name, mask_type, + dictionary_size: int = 1024, ): # Try deduce input names from optimized model. input_ids, segment_ids, input_mask = get_bert_inputs( @@ -105,6 +106,7 @@ def run_test( average_sequence_length, True, # random sequence length mask_type, + dictionary_size=dictionary_size, ) baseline_results, baseline_latency, output_names = run_model( diff --git a/onnxruntime/python/tools/transformers/dev_benchmark.cmd b/onnxruntime/python/tools/transformers/dev_benchmark.cmd index 82137de3c0f3b..4bef58621e8c0 100644 --- a/onnxruntime/python/tools/transformers/dev_benchmark.cmd +++ b/onnxruntime/python/tools/transformers/dev_benchmark.cmd @@ -3,9 +3,7 @@ REM Run benchmark in Windows for developing purpose. For official benchmark, please use run_benchmark.sh. REM Settings are different from run_benchmark.sh: no cli, batch and sequence, input counts, average over 100, no fp16, less models etc. -REM Please install PyTorch (see https://pytorch.org/) before running this benchmark. Like the following: -REM GPU: conda install pytorch torchvision cudatoolkit=10.1 -c pytorch -REM CPU: conda install pytorch torchvision cpuonly -c pytorch +REM Please install PyTorch (see https://pytorch.org/) before running this benchmark. REM When use_package=true, you need not copy other files to run benchmarks except this sh file. REM Otherwise, it will use python script (*.py) files in this directory. @@ -21,12 +19,12 @@ set run_torchscript=false REM Devices to test. REM Attention: You cannot run both CPU and GPU at the same time: gpu need onnxruntime-gpu, and CPU need onnxruntime. -set run_gpu_fp32=false -set run_gpu_fp16=false -set run_cpu_fp32=true -set run_cpu_int8=true +set run_gpu_fp32=true +set run_gpu_fp16=true +set run_cpu_fp32=false +set run_cpu_int8=false -set average_over=100 +set average_over=1000 REM Enable optimizer (use script instead of OnnxRuntime for graph optimization) set use_optimizer=true @@ -36,7 +34,7 @@ set sequence_length=8 128 REM Number of inputs (input_ids, token_type_ids, attention_mask) for ONNX model. REM Note that different input count might lead to different performance -set input_counts=1 +set input_counts=3 REM Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased set models_to_test=bert-base-cased @@ -57,7 +55,6 @@ if %run_cpu_int8% == true if %run_gpu_fp32% == true echo cannot test cpu and gpu if %run_cpu_int8% == true if %run_gpu_fp16% == true echo cannot test cpu and gpu at same time & goto :EOF if %run_install% == true ( - pip uninstall --yes ort_nightly pip uninstall --yes onnxruntime pip uninstall --yes onnxruntime-gpu if %run_cpu_fp32% == true ( @@ -70,7 +67,6 @@ if %run_install% == true ( ) ) - pip install --upgrade onnxconverter_common pip install --upgrade transformers ) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 2398bb9d6031b..74adc951c4aa3 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -132,6 +132,7 @@ def make_value_info_from_tensor(tensor): "Scaler", "TreeEnsembleClassifier", "TreeEnsembleRegressor", + "TreeEnsemble", "ZipMap", "NonMaxSuppression", "TopK", diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index a9ff623fb6967..030708783bb61 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -42,26 +42,26 @@ def get_first_mask(self): assert len(self.mask_indice) > 0 return next(iter(self.mask_indice)) - def process_mask(self, input: str) -> str: + def process_mask(self, mask_2d: str) -> Optional[str]: if self.mask_format == AttentionMaskFormat.NoMask: return None - if input in self.mask_indice: - return self.mask_indice[input] + if mask_2d in self.mask_indice: + return self.mask_indice[mask_2d] # Add cast to convert int64 to int32 - if self.model.find_graph_input(input): - casted, input_name = self.utils.cast_graph_input_to_int32(input) + if self.model.find_graph_input(mask_2d): + casted, input_name = self.utils.cast_graph_input_to_int32(mask_2d) else: - input_name, cast_node = self.utils.cast_input_to_int32(input) + input_name, _cast_node = self.utils.cast_input_to_int32(mask_2d) casted = True if casted: - self.mask_casted[input] = input_name + self.mask_casted[mask_2d] = input_name # Attention supports int32 attention mask (2D) since 1.4.0 if self.mask_format == AttentionMaskFormat.AttentionMask: - self.mask_indice[input] = input_name + self.mask_indice[mask_2d] = input_name return input_name # Add a mask processing node to convert attention mask to mask index (1D) @@ -97,7 +97,7 @@ def process_mask(self, input: str) -> str: self.model.add_node(mask_index_node) - self.mask_indice[input] = output_name + self.mask_indice[mask_2d] = output_name return output_name @@ -173,17 +173,20 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] Tuple[int, int]: num_heads and hidden_size """ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] - q_shape = self.model.get_initializer(reshape_q.input[1]) - if q_shape is None: + q_shape_value = self.model.get_constant_value(reshape_q.input[1]) + if q_shape_value is None: concat = self.model.get_parent(reshape_q, 1) if concat is not None and concat.op_type == "Concat": return self.get_num_heads_and_hidden_size_from_concat(concat) - logger.debug(f"{reshape_q.input[1]} is not initializer.") + logger.debug("%s is not initializer.", reshape_q.input[1]) return self.num_heads, self.hidden_size # Fall back to user specified value - q_shape_value = NumpyHelper.to_array(q_shape) - if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0): - logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].") + if ( + (not isinstance(q_shape_value, np.ndarray)) + or len(q_shape_value) != 4 + or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0) + ): + logger.debug("q_shape_value=%s. Expected value are like [0, 0, num_heads, head_size].", q_shape_value) return self.num_heads, self.hidden_size # Fall back to user specified value num_heads = q_shape_value[2] @@ -192,13 +195,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] if self.num_heads > 0 and num_heads != self.num_heads: if self.num_heads_warning: - logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + logger.warning( + "--num_heads is %d. Detected value is %d. Using detected value.", self.num_heads, num_heads + ) self.num_heads_warning = False # Do not show the warning more than once if self.hidden_size > 0 and hidden_size != self.hidden_size: if self.hidden_size_warning: logger.warning( - f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + "--hidden_size is %d. Detected value is %d. Using detected value.", self.hidden_size, hidden_size ) self.hidden_size_warning = False # Do not show the warning more than once @@ -216,11 +221,11 @@ def get_add_qk_str(self, add_qk: NodeProto): input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1]) if input_0_shape is None or input_1_shape is None: - logger.debug(f"one of the inputs of {add_qk} is None") + logger.debug("one of the inputs of %s is None", add_qk) return None if input_0_shape != input_1_shape: - logger.debug(f"the shape of two inputs of {add_qk} is not same") + logger.debug("the shape of two inputs of %s is not same", add_qk) return None return add_qk.input[1] @@ -305,55 +310,6 @@ def concat_kv(self, past_k: str, past_v: str) -> str: return kv_output_name - def reshape_kv(self, past_k: str, past_v: str) -> (str, str): - """Reshape past_k and past_v from 4D to 3D to use as inputs for multihead attention node. - - Args: - past_k (str): name of past K value of shape 4D - past_v (str): name of past V value of shape 4D - - Returns: - k_3d (str): name of past K value of shape 3D - v_3d (str): name of past V value of shape 3D - """ - # Reshape past_k and past_v from (B,N,P,H) to (B,P,N*H) - # B = batch size, N = num heads, P = past seq len, H = head size - - # Create initializer for reshaping past_k and past_v - new_dims_name = "kv_4d_to_3d" - new_dims = self.model.get_initializer(new_dims_name) - if new_dims is None: - new_dims = numpy_helper.from_array( - np.array([0, -1, self.model.hidden_size], dtype="int64"), name=new_dims_name - ) - self.model.add_initializer(new_dims, self.this_graph_name) - - reshape_k_name = self.model.create_node_name("Reshape") - reshape_v_name = self.model.create_node_name("Reshape") - k_3d_name = (past_k + "_3d").replace(".", "_") - v_3d_name = (past_v + "_3d").replace(".", "_") - - k_3d = helper.make_node( - "Reshape", - inputs=[past_k, new_dims_name], - outputs=[k_3d_name], - name=reshape_k_name, - ) - v_3d = helper.make_node( - "Reshape", - inputs=[past_v, new_dims_name], - outputs=[v_3d_name], - name=reshape_v_name, - ) - - # Add reshape nodes to graph - self.nodes_to_add.append(k_3d) - self.nodes_to_add.append(v_3d) - self.node_name_to_graph_name[reshape_k_name] = self.this_graph_name - self.node_name_to_graph_name[reshape_v_name] = self.this_graph_name - - return k_3d_name, v_3d_name - def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str): """Split kv_node containing present KV values into separate present K and present V values. @@ -476,8 +432,7 @@ def create_packed_qkv_matmul_node( q_add: NodeProto, k_add: Union[NodeProto, None], v_add: Union[NodeProto, None], - num_heads: int, - ) -> Union[NodeProto, None]: + ) -> Tuple[NodeProto, NodeProto, NodeProto]: """Create packed QKV MatMul node before MultiHeadAttention node. This is for the scenario where an Attention node should be created but cannot be created because past_key and past_value are separate inputs and not one concatenated input. @@ -489,10 +444,11 @@ def create_packed_qkv_matmul_node( q_add (NodeProto): name of Add from Q path k_add (NodeProto): name of Add from K path v_add (NodeProto): name of Add from V path - num_heads (int): number of heads Returns: - Union[NodeProto, None]: the node created or None if failed. + q_output (NodeProto): Slice node for Q + k_output (NodeProto): Slice node for K + v_output (NodeProto): Slice node for V """ matmul_node_name = self.model.create_node_name("MatMul") @@ -611,6 +567,7 @@ def create_packed_qkv_matmul_node( self.nodes_to_add.extend(qkv_nodes) return q_output, k_output, v_output + # This function is used in child classes for bart or conformer model. def create_multihead_attention_node( self, q_matmul: NodeProto, @@ -659,7 +616,7 @@ def create_multihead_attention_node( assert num_heads > 0 if hidden_size > 0 and (hidden_size % num_heads) != 0: - logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads) return None graph_input_names = set([node.name for node in self.model.graph().input]) @@ -669,17 +626,22 @@ def create_multihead_attention_node( mha_inputs = [] if packed_qkv: q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node( - q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, num_heads + q_matmul, + k_matmul, + v_matmul, + q_add, + k_add, + v_add, ) mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]]) - elif type(k_matmul) is NodeProto and type(v_matmul) is NodeProto: + elif isinstance(k_matmul, NodeProto) and isinstance(v_matmul, NodeProto): if self.disable_multi_head_attention_bias: mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]]) else: mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]]) elif ( - type(k_matmul) == str # noqa: E721 - and type(v_matmul) == str # noqa: E721 + isinstance(k_matmul, str) + and isinstance(v_matmul, str) and k_matmul in graph_input_names and v_matmul in graph_input_names ): @@ -724,7 +686,7 @@ def create_multihead_attention_node( def create_attention_node( self, - mask_index: str, + mask_index: Optional[str], q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto, @@ -733,7 +695,7 @@ def create_attention_node( v_add: NodeProto, num_heads: int, hidden_size: int, - input: str, + first_input: str, output: str, add_qk_str: str = "", past_k: str = "", @@ -746,7 +708,7 @@ def create_attention_node( """Create an Attention node. Args: - mask_index (str): mask input + mask_index (str | None): mask input q_matmul (NodeProto): MatMul node in fully connection for Q k_matmul (NodeProto): MatMul node in fully connection for K v_matmul (NodeProto): MatMul node in fully connection for V @@ -755,7 +717,7 @@ def create_attention_node( v_add (NodeProto): Add bias node in fully connection for V num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. - input (str): input name + first_input (str): first input name output (str): output name add_qk_str (str): name of Add node after Q x K' past_k (str): name of input for past K value @@ -771,7 +733,7 @@ def create_attention_node( assert num_heads > 0 if hidden_size > 0 and (hidden_size % num_heads) != 0: - logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads) return None has_bias = True @@ -813,8 +775,10 @@ def create_attention_node( if hidden_size > 0 and hidden_size != qw_in_size: logger.warning( - f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). " - "Please provide a correct input hidden size or pass in 0" + "Input hidden size (%d) is not same as weight matrix dimension of q,k,v (%d). " + "Please provide a correct input hidden size or pass in 0", + hidden_size, + qw_in_size, ) is_qkv_diff_dims = False @@ -836,6 +800,8 @@ def create_attention_node( qkv_weight = np.stack((qw, kw, vw), axis=1) qkv_weight_dim = 3 * qw_out_size + qkv_bias_dim = 0 + qkv_bias: Optional[np.ndarray] = None if has_bias: qb = NumpyHelper.to_array(q_bias) kb = NumpyHelper.to_array(k_bias) @@ -861,7 +827,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_weight", data_type=q_weight.data_type, - dims=[qw_in_size, qkv_weight_dim], + dims=[qw_in_size, int(qkv_weight_dim)], vals=qkv_weight, ) @@ -869,7 +835,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=q_bias.data_type, - dims=[qkv_bias_dim], + dims=[int(qkv_bias_dim)], vals=qkv_bias, ) @@ -897,7 +863,7 @@ def create_attention_node( ) else: attention_inputs = [ - input, + first_input, attention_node_name + "_qkv_weight", attention_node_name + "_qkv_bias" if has_bias else "", ] @@ -911,7 +877,7 @@ def create_attention_node( past_kv = self.concat_kv(past_k, past_v) attention_inputs.append(past_kv) - if add_qk_str is not None: + if add_qk_str: mask_output_name = self.reshape_add_qk(add_qk_str) # Add attention mask to attention node @@ -951,9 +917,10 @@ def create_attention_node( return attention_node - def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + def fuse(self, node, input_name_to_nodes, output_name_to_node): # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern + normalize_node = node start_node = normalize_node if normalize_node.op_type == "LayerNormalization": add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) @@ -982,25 +949,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return other_inputs = [] - for _i, input in enumerate(start_node.input): - if input not in output_name_to_node: + for _i, node_input in enumerate(start_node.input): + if node_input not in output_name_to_node: continue - if input == qkv_nodes[0].output[0]: + if node_input == qkv_nodes[0].output[0]: continue - other_inputs.append(input) + other_inputs.append(node_input) if len(other_inputs) != 1: return root_input = other_inputs[0] - """ - Match flaubert Mask - | - Mul --> LayerNormalization --> Attention --> MatMul --> Add - | | - | | - +--------------------------------------------------------- - """ + + # Match flaubert Mask + # | + # Mul --> LayerNormalization --> Attention --> MatMul --> Add + # | | + # | | + # +--------------------------------------------------------- mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0) if mul_before_layernorm is not None: mul_children = input_name_to_nodes[mul_before_layernorm.output[0]] @@ -1020,19 +986,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if child.op_type == "LayerNormalization": root_input = child.output[0] - """ - When Add before the LayerNormalization produces an output - that is consumed by some other nodes other than the LayerNormalization itself, - fused SkipLayerNormalization will have several outputs. - In this case we need to pick the one used in Attention - - For example, this is the case for ViT - - SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization - | | - | | - +---------------------------------------------------------------------+ - """ + # When Add before the LayerNormalization produces an output + # that is consumed by some other nodes other than the LayerNormalization itself, + # fused SkipLayerNormalization will have several outputs. + # In this case we need to pick the one used in Attention + # For example, this is the case for ViT + # SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization + # | | + # | | + # +---------------------------------------------------------------------+ parent_node = output_name_to_node[root_input] if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4: root_input = parent_node.output[0] @@ -1051,12 +1013,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): is_distill = False is_distill_add = False is_no_mask_attention = False + is_sdpa = False qk_paths = { "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]), "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]), "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]), "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]), "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]), + "sdpa": (["Softmax", "Add", "MatMul", "Mul", "Sqrt"], [0, 0, None, 0, 1]), } qk_nodes = None @@ -1066,10 +1030,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): continue if k == "path3": is_distill = True - if k == "path4": + elif k == "path4": is_distill_add = True - if k == "path5": + elif k == "path5": is_no_mask_attention = True + elif k == "sdpa": + is_sdpa = True break if qk_nodes is None: @@ -1079,19 +1045,23 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_qk = None matmul_qk = None where_qk = None + after_q = None if is_distill: (_, where_qk, matmul_qk, _) = qk_nodes elif is_distill_add: (_, add_qk, where_qk, matmul_qk) = qk_nodes elif is_no_mask_attention: (_, _, matmul_qk) = qk_nodes + elif is_sdpa: + (_, add_qk, matmul_qk, after_q, _) = qk_nodes else: (_, add_qk, _, matmul_qk) = qk_nodes - q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]) + after_q = after_q or matmul_qk + q_nodes = self.model.match_parent_path(after_q, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]) if q_nodes is None: q_nodes = self.model.match_parent_path( - matmul_qk, + after_q, ["Div", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None], ) @@ -1102,7 +1072,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_q = q_nodes[-2] matmul_q = q_nodes[-1] - k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]) + after_k = matmul_qk + if is_sdpa: + mul_k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Sqrt"], [1, None]) + if mul_k_nodes is None: + logger.debug("fuse_attention: failed to match mul sqrt q path") + return + (after_k, _) = mul_k_nodes + + k_nodes = self.model.match_parent_path( + after_k, ["Transpose", "Reshape", "Add", "MatMul"], [0 if is_sdpa else 1, 0, 0, None] + ) if k_nodes is None: k_nodes = self.model.match_parent_path( matmul_qk, @@ -1117,7 +1097,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Note that Cast might be removed by OnnxRuntime so we match two patterns here. mask_nodes = None - add_qk_str = None + add_qk_str = "" if is_distill: _, mask_nodes, _ = self.model.match_parent_paths( where_qk, @@ -1140,7 +1120,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if add_qk is not None: add_qk_str = self.get_add_qk_str(add_qk) if add_qk_str is None: - logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}") + logger.debug("fuse_attention: failed to verify shape inference of %s", add_qk) return elif is_no_mask_attention: pass @@ -1148,11 +1128,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): _, mask_nodes, _ = self.model.match_parent_paths( add_qk, [ - ( - ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], - [None, 0, 1, 0, 0], - ), + (["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]), (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]), + # The following two patterns are for SDPA. + (["Where", "Cast", "Sub", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0]), + (["Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0, 0]), ], output_name_to_node, ) @@ -1160,10 +1140,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: failed to match mask path") return - if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul": + if not is_no_mask_attention and len(mask_nodes) > 1: _, mul_val = self.model.get_constant_input(mask_nodes[0]) - if mul_val != -10000: - self.mask_filter_value = mul_val + # The mask value shall be a float scalar (usually is the lowest float value). + if ( + (mul_val is None) + or not (isinstance(mul_val, np.ndarray) and mul_val.size == 1) + or (float(mul_val) >= 0) + ): + return + if float(mul_val) != -10000: + self.mask_filter_value = float(mul_val) if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None @@ -1181,19 +1168,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately new_node = self.create_attention_node( - mask_index, - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - q_num_heads, - q_hidden_size, - root_input, - attention_last_node.output[0], - add_qk_str, + mask_index=mask_index, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=q_num_heads, + hidden_size=q_hidden_size, + first_input=root_input, + output=attention_last_node.output[0], + add_qk_str=add_qk_str, ) + if new_node is None: return @@ -1208,7 +1196,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): name="shape_modified_tensor" + unique_index, data_type=TensorProto.INT64, dims=[4], - vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]), + vals=[0, 0, q_num_heads, int(q_hidden_size / q_num_heads)], raw=False, ) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index b027957fcc725..16e2c36bfd092 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -239,9 +239,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): v_add=add_v, num_heads=num_heads, hidden_size=hidden_size, - input=root_input, + first_input=root_input, output=attention_last_node.output[0], - add_qk_str=None, + add_qk_str="", scale=None, causal=(add_mask is not None), ) diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index ebecc1db24792..8c334b83abfeb 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -564,15 +564,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # value whereas attention supports concatenated past key and past value. new_node = ( self.create_multihead_attention_node( - matmul_q, - matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, - matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, - add_q, - add_k if decoder_cross_attention or decoder_attention_with_past else None, - add_v if decoder_cross_attention or decoder_attention_with_past else None, - num_heads, - hidden_size, - attention_last_node.output[0], + q_matmul=matmul_q, + k_matmul=matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, + v_matmul=matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, + q_add=add_q, + k_add=add_k if decoder_cross_attention or decoder_attention_with_past else None, + v_add=add_v if decoder_cross_attention or decoder_attention_with_past else None, + num_heads=num_heads, + hidden_size=hidden_size, + output=attention_last_node.output[0], past_k=past_k if decoder_attention_with_past else "", past_v=past_v if decoder_attention_with_past else "", present_k=present_k, @@ -586,19 +586,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Temporarily set multihead attention flag to false use_multi_head_attention_ground_truth = self.use_multi_head_attention self.use_multi_head_attention = False + add_qk_str = mask_index if decoder_attention and mask_index else "" new_node = self.create_attention_node( - None, - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - num_heads, - hidden_size, - root_input, - attention_last_node.output[0], - add_qk_str=mask_index if decoder_attention else None, + mask_index=None, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=num_heads, + hidden_size=hidden_size, + first_input=root_input, + output=attention_last_node.output[0], + add_qk_str=add_qk_str, past_k=past_k, past_v=past_v, present_k=present_k, diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index 6bc681c57444e..f29d0a0ac9441 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -102,15 +102,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return new_node = self.create_multihead_attention_node( - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - num_heads, - hidden_size, - attention_last_node.output[0], + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=num_heads, + hidden_size=hidden_size, + output=attention_last_node.output[0], add_qk=add_qk.input[1], past_k=past_k, past_v=past_v, diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index dcfe4a28ad9af..4cd878a4656a7 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -13,155 +13,62 @@ "AutoModelForCausalLM", ] -# List of pretrained models: https://huggingface.co/transformers/pretrained_models.html # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type +# Some models like GPT, T5, Bart etc has its own convert_to_onnx.py in models sub-directory, and they are excluded here. MODELS = { # BERT - "bert-base-uncased": ( - ["input_ids", "attention_mask", "token_type_ids"], - 12, - False, - "bert", - ), - "bert-large-uncased": ( - ["input_ids", "attention_mask", "token_type_ids"], - 12, - False, - "bert", - ), - "bert-base-cased": ( - ["input_ids", "attention_mask", "token_type_ids"], - 12, - False, - "bert", - ), - # "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-chinese": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-german-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-large-uncased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-large-cased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask", - # "token_type_ids"], 12, False, "bert"), - # "bert-large-cased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask", - # "token_type_ids"], 12, False, "bert"), - # "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-german-dbmdz-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # "bert-base-german-dbmdz-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - # todo: more models to add - # GPT (no past state) - "openai-gpt": (["input_ids"], 11, False, "gpt2"), - # GPT-2 (no past state, use benchmark_gpt2.py for past_key_values) - "gpt2": (["input_ids"], 11, False, "gpt2"), - "gpt2-medium": (["input_ids"], 11, False, "gpt2"), - "gpt2-large": (["input_ids"], 11, True, "gpt2"), - "gpt2-xl": (["input_ids"], 11, True, "gpt2"), - "distilgpt2": (["input_ids"], 11, False, "gpt2"), - # Transformer-XL (Models uses Einsum, which need opset version 12 or later.) - "transfo-xl-wt103": (["input_ids", "mems"], 12, False, "bert"), + "bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"), + "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"), + # Transformer-XL (Models uses Einsum, which need opset version 16 or later.) + "transfo-xl-wt103": (["input_ids", "mems"], 16, False, "bert"), # XLNet - "xlnet-base-cased": (["input_ids"], 12, False, "bert"), - "xlnet-large-cased": (["input_ids"], 12, False, "bert"), + "xlnet-base-cased": (["input_ids"], 16, False, "bert"), + "xlnet-large-cased": (["input_ids"], 16, False, "bert"), # XLM - "xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"), - "xlm-mlm-ende-1024": (["input_ids"], 11, False, "bert"), - "xlm-mlm-enfr-1024": (["input_ids"], 11, False, "bert"), + "xlm-mlm-en-2048": (["input_ids"], 16, True, "bert"), + "xlm-mlm-ende-1024": (["input_ids"], 16, False, "bert"), + "xlm-mlm-enfr-1024": (["input_ids"], 16, False, "bert"), # RoBERTa - "roberta-base": (["input_ids", "attention_mask"], 12, False, "bert"), - "roberta-large": (["input_ids", "attention_mask"], 12, False, "bert"), - "roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"), - "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"), - "distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"), + "roberta-base": (["input_ids", "attention_mask"], 16, False, "bert"), + "roberta-large": (["input_ids", "attention_mask"], 16, False, "bert"), + "roberta-large-mnli": (["input_ids", "attention_mask"], 16, False, "bert"), + "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 16, False, "bert"), + "distilroberta-base": (["input_ids", "attention_mask"], 16, False, "bert"), # DistilBERT - "distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"), - "distilbert-base-uncased-distilled-squad": ( - ["input_ids", "attention_mask"], - 11, - False, - "bert", - ), + "distilbert-base-uncased": (["input_ids", "attention_mask"], 16, False, "bert"), + "distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 16, False, "bert"), # CTRL - "ctrl": (["input_ids"], 11, True, "bert"), + "ctrl": (["input_ids"], 16, True, "bert"), # CamemBERT - "camembert-base": (["input_ids"], 11, False, "bert"), + "camembert-base": (["input_ids"], 16, False, "bert"), # ALBERT - "albert-base-v1": (["input_ids"], 12, False, "bert"), - "albert-large-v1": (["input_ids"], 12, False, "bert"), - "albert-xlarge-v1": (["input_ids"], 12, True, "bert"), - # "albert-xxlarge-v1": (["input_ids"], 12, True, "bert"), - "albert-base-v2": (["input_ids"], 12, False, "bert"), - "albert-large-v2": (["input_ids"], 12, False, "bert"), - "albert-xlarge-v2": (["input_ids"], 12, True, "bert"), - # "albert-xxlarge-v2": (["input_ids"], 12, True, "bert"), - # T5 (use benchmark_t5.py instead) - # "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"), - # "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"), - # "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"), - # "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), - # "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), - # "valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"), + "albert-base-v1": (["input_ids"], 16, False, "bert"), + "albert-large-v1": (["input_ids"], 16, False, "bert"), + "albert-xlarge-v1": (["input_ids"], 16, True, "bert"), + # "albert-xxlarge-v1": (["input_ids"], 16, True, "bert"), + "albert-base-v2": (["input_ids"], 16, False, "bert"), + "albert-large-v2": (["input_ids"], 16, False, "bert"), + "albert-xlarge-v2": (["input_ids"], 16, True, "bert"), + # "albert-xxlarge-v2": (["input_ids"], 16, True, "bert"), # XLM-RoBERTa - "xlm-roberta-base": (["input_ids"], 11, False, "bert"), - "xlm-roberta-large": (["input_ids"], 11, True, "bert"), + "xlm-roberta-base": (["input_ids"], 16, False, "bert"), + "xlm-roberta-large": (["input_ids"], 16, True, "bert"), # FlauBERT - "flaubert/flaubert_small_cased": (["input_ids"], 11, False, "bert"), - # "flaubert/flaubert_base_uncased": (["input_ids"], 11, False, "bert"), - "flaubert/flaubert_base_cased": (["input_ids"], 11, False, "bert"), - # "flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"), - # Bart - "facebook/bart-large": (["input_ids", "attention_mask"], 11, False, "bart"), - "facebook/bart-base": (["input_ids", "attention_mask"], 11, False, "bart"), - "facebook/bart-large-mnli": (["input_ids", "attention_mask"], 11, False, "bart"), - "facebook/bart-large-cnn": (["input_ids", "attention_mask"], 11, False, "bart"), - # DialoGPT - "microsoft/DialoGPT-small": (["input_ids"], 11, False, "gpt2"), - "microsoft/DialoGPT-medium": (["input_ids"], 11, False, "gpt2"), - # "microsoft/DialoGPT-large": (["input_ids"], 11, True, "gpt2"), - # Reformer - # "google/reformer-enwik8": (["input_ids"], 11, False, "bert"), - # "google/reformer-crime-and-punishment": (["input_ids"], 11, False, "bert"), - # MarianMT - # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"), - # Longformer (use benchmark_longformer.py instead) - # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"), - # "allenai/longformer-large-4096": (["input_ids"], 12, False, "bert"), - # MBart - "facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"), - "facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"), - # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"), - # # Longformer - # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"), - # "allenai/longformer-large-4096": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/small": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/small-base": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/medium": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/medium-base": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/intermediate": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/intermediate-base": (["input_ids"], 12, False, "bert"), - # "funnel-transformer/large": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/large-base": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/xlarge": (["input_ids"], 12, True, "bert"), - # "funnel-transformer/xlarge-base": (["input_ids"], 12, True, "bert"), + "flaubert/flaubert_small_cased": (["input_ids"], 16, False, "bert"), + "flaubert/flaubert_base_cased": (["input_ids"], 16, False, "bert"), + # "flaubert/flaubert_large_cased": (["input_ids"], 16, False, "bert"), # Layoutlm - "microsoft/layoutlm-base-uncased": (["input_ids"], 11, False, "bert"), - "microsoft/layoutlm-large-uncased": (["input_ids"], 11, False, "bert"), + "microsoft/layoutlm-base-uncased": (["input_ids"], 16, False, "bert"), + "microsoft/layoutlm-large-uncased": (["input_ids"], 16, False, "bert"), # Squeezebert - "squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"), - "squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"), - "squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"), - "unc-nlp/lxmert-base-uncased": ( - ["input_ids", "visual_feats", "visual_pos"], - 11, - False, - "bert", - ), - # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"), - # "google/pegasus-large": (["input_ids"], 11, False, "bert"), + "squeezebert/squeezebert-uncased": (["input_ids"], 16, False, "bert"), + "squeezebert/squeezebert-mnli": (["input_ids"], 16, False, "bert"), + "squeezebert/squeezebert-mnli-headless": (["input_ids"], 16, False, "bert"), + "unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 16, False, "bert"), # ViT - "google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"), + "google/vit-base-patch16-224": (["pixel_values"], 16, False, "vit"), # Swin - "microsoft/swin-base-patch4-window7-224": (["pixel_values"], 12, False, "swin"), - "microsoft/swin-small-patch4-window7-224": (["pixel_values"], 12, False, "swin"), - "microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 12, False, "swin"), + "microsoft/swin-base-patch4-window7-224": (["pixel_values"], 16, False, "swin"), + "microsoft/swin-small-patch4-window7-224": (["pixel_values"], 16, False, "swin"), + "microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 16, False, "swin"), } diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh index e6da988f5c0df..9e97867657ab9 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -191,7 +191,6 @@ build_onnxruntime_gpu_for_profiling() { --build_wheel --skip_tests \ --cmake_generator Ninja \ --compile_no_warning_as_error \ - --enable_cuda_nhwc_ops \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \ --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ --enable_cuda_line_info diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 3879e25386d53..0708d57f040f8 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -51,6 +51,10 @@ def example_prompts(): return prompts, negative_prompt +def warmup_prompts(): + return "warm up", "bad" + + def measure_gpu_memory(monitor_type, func, start_memory=None): return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory) @@ -136,7 +140,14 @@ def run_ort_pipeline( prompts, negative_prompt = example_prompts() def warmup(): - pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) + prompt, negative = warmup_prompts() + pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative] * batch_size, + ) # Run warm up, and measure GPU memory of two runs # cuDNN/MIOpen The first run has algo search so it might need more memory) @@ -149,22 +160,20 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - images = pipe( - [prompt] * batch_size, - height, - width, - num_inference_steps=steps, - negative_prompt=[negative_prompt] * batch_size, - guidance_scale=7.5, - ).images - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + inference_start = time.time() + images = pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative_prompt] * batch_size, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") from onnxruntime import __version__ as ort_version @@ -200,7 +209,14 @@ def run_torch_pipeline( # total 2 runs of warm up, and measure GPU memory for CUDA EP def warmup(): - pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) + prompt, negative = warmup_prompts() + pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative] * batch_size, + ) # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory) first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) @@ -215,25 +231,23 @@ def warmup(): if i >= num_prompts: break torch.cuda.synchronize() - for j in range(batch_count): - inference_start = time.time() - images = pipe( - prompt=[prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, - guidance_scale=7.5, - negative_prompt=[negative_prompt] * batch_size, - generator=None, # torch.Generator - ).images + inference_start = time.time() + images = pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative_prompt] * batch_size, + generator=None, # torch.Generator + ).images - torch.cuda.synchronize() - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + torch.cuda.synchronize() + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") return { "engine": "torch", @@ -306,6 +320,7 @@ def get_optimum_ort_pipeline( directory: str, provider="CUDAExecutionProvider", disable_safety_checker: bool = True, + use_io_binding: bool = False, ): from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline @@ -321,7 +336,7 @@ def get_optimum_ort_pipeline( pipeline = ORTStableDiffusionPipeline.from_pretrained( directory, provider=provider, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. + use_io_binding=use_io_binding, ) elif "xl" in model_name: pipeline = ORTStableDiffusionXLPipeline.from_pretrained( @@ -337,7 +352,7 @@ def get_optimum_ort_pipeline( model_name, export=True, provider=provider, - use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification. + use_io_binding=use_io_binding, ) pipeline.save_pretrained(directory) @@ -359,15 +374,33 @@ def run_optimum_ort_pipeline( batch_count, start_memory, memory_monitor_type, + use_num_images_per_prompt=False, ): from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline assert isinstance(pipe, (ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline)) - prompts = example_prompts() + prompts, negative_prompt = example_prompts() def warmup(): - pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) + prompt, negative = warmup_prompts() + if use_num_images_per_prompt: + pipe( + prompt=prompt, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=negative, + num_images_per_prompt=batch_count, + ) + else: + pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative] * batch_size, + ) # Run warm up, and measure GPU memory of two runs. # The first run has algo search for cuDNN/MIOpen, so it might need more memory. @@ -380,23 +413,30 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() + inference_start = time.time() + if use_num_images_per_prompt: images = pipe( - prompt, - height, - width, + prompt=prompt, + height=height, + width=width, num_inference_steps=steps, - negative_prompt=None, - guidance_scale=0.0, # 7.5 + negative_prompt=negative_prompt, num_images_per_prompt=batch_size, ).images - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + else: + images = pipe( + prompt=[prompt] * batch_size, + height=height, + width=width, + num_inference_steps=steps, + negative_prompt=[negative_prompt] * batch_size, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") from onnxruntime import __version__ as ort_version @@ -429,9 +469,12 @@ def run_optimum_ort( batch_count: int, start_memory, memory_monitor_type, + use_io_binding: bool = False, ): load_start = time.time() - pipe = get_optimum_ort_pipeline(model_name, directory, provider, disable_safety_checker) + pipe = get_optimum_ort_pipeline( + model_name, directory, provider, disable_safety_checker, use_io_binding=use_io_binding + ) load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") @@ -530,9 +573,8 @@ def run_ort_trt_static( pipeline.load_resources(height, width, batch_size) def warmup(): - pipeline.run( - ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True - ) + prompt, negative = warmup_prompts() + pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -548,24 +590,23 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - images, pipeline_time = pipeline.run( - [prompt] * batch_size, - [negative_prompt] * batch_size, - height, - width, - denoising_steps=steps, - guidance=7.5, - seed=123, - ) - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + ) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") pipeline.teardown() @@ -671,9 +712,8 @@ def run_tensorrt_static( pipeline.load_resources(height, width, batch_size) def warmup(): - pipeline.run( - ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True - ) + prompt, negative = warmup_prompts() + pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -689,24 +729,22 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - images, pipeline_time = pipeline.run( - [prompt] * batch_size, - [negative_prompt] * batch_size, - height, - width, - denoising_steps=steps, - guidance=7.5, - seed=123, - ) - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + seed=123, + ) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.jpg") pipeline.teardown() @@ -828,7 +866,8 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None): ) def warmup(): - run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size) + prompt, negative = warmup_prompts() + run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -845,20 +884,15 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - if nvtx_profile: - cudart.cudaProfilerStart() - images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) - if nvtx_profile: - cudart.cudaProfilerStop() - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - image.save(f"{image_filename_prefix}_{i}_{j}_{k}.png") + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{k}.png") pipeline.teardown() @@ -911,8 +945,6 @@ def run_ort_trt_xl( opt_batch_size=batch_size, ) - from cuda import cudart - assert batch_size <= max_batch_size pipeline.load_resources(height, width, batch_size) @@ -929,7 +961,8 @@ def run_sd_xl_inference(prompt, negative_prompt, seed=None): ) def warmup(): - run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size) + prompt, negative = warmup_prompts() + run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -946,22 +979,17 @@ def warmup(): for i, prompt in enumerate(prompts): if i >= num_prompts: break - for j in range(batch_count): - inference_start = time.time() - # Use warmup mode here since non-warmup mode will save image to disk. - if nvtx_profile: - cudart.cudaProfilerStart() - images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) - if nvtx_profile: - cudart.cudaProfilerStop() - inference_end = time.time() - latency = inference_end - inference_start - latency_list.append(latency) - print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") - for k, image in enumerate(images): - filename = f"{image_filename_prefix}_{i}_{j}_{k}.png" - image.save(filename) - print("Image saved to", filename) + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123) + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}") + for k, image in enumerate(images): + filename = f"{image_filename_prefix}_{i}_{k}.png" + image.save(filename) + print("Image saved to", filename) pipeline.teardown() @@ -1137,6 +1165,14 @@ def parse_arguments(): ) parser.set_defaults(use_xformers=False) + parser.add_argument( + "--use_io_binding", + required=False, + action="store_true", + help="Use I/O Binding for Optimum.", + ) + parser.set_defaults(use_io_binding=False) + parser.add_argument( "-b", "--batch_size", @@ -1176,8 +1212,8 @@ def parse_arguments(): "--num_prompts", required=False, type=int, - default=1, - help="Number of prompts. Default is 1.", + default=10, + help="Number of prompts. Default is 10.", ) parser.add_argument( @@ -1312,6 +1348,7 @@ def main(): batch_count=args.batch_count, start_memory=start_memory, memory_monitor_type=memory_monitor_type, + use_io_binding=args.use_io_binding, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt index 8ff5990b7815a..5bdd422a11750 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt @@ -10,6 +10,10 @@ packaging protobuf==3.20.3 psutil sympy +nvtx==0.2.5 +torchvision==0.15.2 +tensorrt==8.5.1.7 +mediapipe controlnet_aux==0.0.9 # The following are for SDXL optimum==1.20.0 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt index e51ffb395c643..1938f59208ae7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt @@ -2,3 +2,4 @@ git+https://github.com/openai/CLIP.git open_clip_torch sentence_transformers pillow +numpy==1.22.2 diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 979f872ac4c5e..408b5b6c3a728 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -7,7 +7,7 @@ soundfile librosa optimum<=1.21.2 onnxruntime-extensions>=0.9.0 -onnx==1.16.1 +onnx==1.17.0 protobuf==3.20.2 numpy==1.23.3 psutil diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 3967a7875f3a7..c3ccde50dac85 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -392,11 +392,13 @@ def validate_and_optimize_onnx( False, output_names, ) - if optimize_info == OptimizerInfo.NOOPT: + if optimize_info.name == OptimizerInfo.NOOPT.name: return onnx_model_path, is_valid_onnx_model, config.vocab_size if ( - optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8 + optimize_info.name == OptimizerInfo.BYSCRIPT.name + or precision == Precision.FLOAT16 + or precision == Precision.INT8 ): # Use script (optimizer.py) to optimize optimized_model_path = get_onnx_file_path( onnx_dir, @@ -439,7 +441,7 @@ def validate_and_optimize_onnx( QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format) logger.info(f"Finished quantizing model: {onnx_model_path}") - if optimize_info == OptimizerInfo.BYORT: # Use OnnxRuntime to optimize + if optimize_info.name == OptimizerInfo.BYORT.name: # Use OnnxRuntime to optimize if is_valid_onnx_model: ort_model_path = add_filename_suffix(onnx_model_path, "_ort") optimize_onnx_model_by_ort( @@ -492,7 +494,7 @@ def export_onnx_model_from_pt( example_inputs = image_processor(data, return_tensors="pt") else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") example_inputs = filter_inputs(example_inputs, input_names) @@ -596,7 +598,7 @@ def export_onnx_model_from_tf( # Fix "Using pad_token, but it is not set yet" error. if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) + max_input_size = tokenizer.model_max_length config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier) model.resize_token_embeddings(len(tokenizer)) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index c781a91c9e493..efcd92129597a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -178,18 +178,17 @@ def fuse_attention(self): mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) logger.debug("Create an Attention node.") attention_node = self.attention_fusion.create_attention_node( - mask_index, - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - self.num_heads, - self.hidden_size, - parent.output[0], - reshape_qkv.output[0], - None, + mask_index=mask_index, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=self.num_heads, + hidden_size=self.hidden_size, + first_input=parent.output[0], + output=reshape_qkv.output[0], ) if attention_node is None: continue diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py index b7891223e1dc2..a89b6c9e9395d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py @@ -480,18 +480,17 @@ def fuse_attention(self): # For tf models, q and v are flipped. attention_node = self.attention_fusion.create_attention_node( - mask_index, - matmul_k, - matmul_q, - matmul_v, - add_k, - add_q, - add_v, - self.num_heads, - self.hidden_size, - parent.output[0], - qkv_nodes[2].output[0], - None, + mask_index=mask_index, + q_matmul=matmul_k, + k_matmul=matmul_q, + v_matmul=matmul_v, + q_add=add_k, + k_add=add_q, + v_add=add_v, + num_heads=self.num_heads, + hidden_size=self.hidden_size, + first_input=parent.output[0], + output=qkv_nodes[2].output[0], ) if attention_node is None: continue diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh index ddc8b781a5b52..25997f40d348f 100755 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -5,10 +5,7 @@ # license information. # -------------------------------------------------------------------------- # This measures the performance of OnnxRuntime, PyTorch and TorchScript on transformer models. -# Please install PyTorch (see https://pytorch.org/) before running this benchmark. Like the following: -# GPU: conda install pytorch torchvision cudatoolkit=11.0 -c pytorch -# CPU: conda install pytorch torchvision cpuonly -c pytorch -# To use torch2, please install the nightly PyTorch by replacing pytorch with pytorch-nightly. +# Please install PyTorch (see https://pytorch.org/) before running this benchmark. # When use_package=true, you need not copy other files to run benchmarks except this sh file. # Otherwise, it will use python script (*.py) files in this directory. @@ -60,7 +57,6 @@ sequence_lengths="8 16 32 64 128 256 512 1024" # Here we only test one input (input_ids) for fair comparison with PyTorch. input_counts=1 -# Pretrained transformers models can be a subset of: bert-base-cased roberta-base gpt2 distilgpt2 distilbert-base-uncased models_to_test="bert-base-cased roberta-base distilbert-base-uncased" # If you have multiple GPUs, you can choose one GPU for test. Here is an example to use the second GPU: @@ -99,7 +95,7 @@ if [ "$run_install" = true ] ; then else pip install onnxruntime-gpu fi - pip install --upgrade onnx coloredlogs packaging psutil py3nvml onnxconverter_common numpy transformers sympy + pip install --upgrade onnx coloredlogs packaging psutil py3nvml numpy transformers sympy fi if [ "$use_package" = true ] ; then diff --git a/onnxruntime/test/common/cuda_op_test_utils.h b/onnxruntime/test/common/cuda_op_test_utils.h index 6f3e460628566..d3e069237217e 100644 --- a/onnxruntime/test/common/cuda_op_test_utils.h +++ b/onnxruntime/test/common/cuda_op_test_utils.h @@ -5,6 +5,11 @@ #include "test/util/include/default_providers.h" +#define SKIP_CUDA_TEST_WITH_DML \ + if (DefaultCudaExecutionProvider() == nullptr) { \ + GTEST_SKIP() << "CUDA Tests are not supported while DML is enabled"; \ + } + namespace onnxruntime { namespace test { @@ -13,6 +18,10 @@ namespace test { int GetCudaArchitecture(); inline bool HasCudaEnvironment(int min_cuda_architecture) { + if (DefaultCudaExecutionProvider() == nullptr) { + return false; + } + if (DefaultCudaExecutionProvider().get() == nullptr) { return false; } diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index e0891c7ced63e..acb520f894569 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -194,6 +194,24 @@ inline void CheckTensor(const Tensor& expected_tensor, const Tensor& output_tens } } +template +std::vector GetTypedArray(std::vector inputs) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_integral_v, + "Only float, double, MLFloat16, and integral types are supported."); + if constexpr (std::is_same::value) { + return inputs; + } else if constexpr (std::is_integral_v || std::is_same::value) { + std::vector result(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + result[i] = static_cast(inputs[i]); + } + return result; + } else { + return ToFloat16(inputs); + } +} + class ParallelRandomValueGenerator { public: using RandomEngine = std::default_random_engine; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 5f94d30112f0e..8c69e2d9810b8 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -7,6 +7,8 @@ #include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/model_tester.h" +#include "test/util/include/current_test_name.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" @@ -73,6 +75,9 @@ TEST(BeamSearchTest, GptBeamSearchFp32) { const char* const output_names[] = {"sequences"}; Ort::SessionOptions session_options; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; cuda_options.use_tf32 = false; @@ -166,6 +171,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16) { bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); if (enable_cuda || enable_rocm) { Ort::SessionOptions session_options; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif #ifdef USE_CUDA OrtCUDAProviderOptionsV2 cuda_options; cuda_options.use_tf32 = false; @@ -388,5 +396,47 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } } +TEST(BeamSearchTest, DummyT5) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc index 027d4b3fff1b0..297629b015796 100644 --- a/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_dropout_op_test.cc @@ -181,6 +181,9 @@ void RunBiasDropoutTest(const bool use_mask, const std::vector& input_s t.SetCustomOutputVerifier(output_verifier); std::vector> t_eps; #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } t_eps.emplace_back(DefaultCudaExecutionProvider()); #elif USE_ROCM t_eps.emplace_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc b/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc index 7ca4e1004066c..26b0e3a4dd7a9 100644 --- a/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc +++ b/onnxruntime/test/contrib_ops/bitmask_dropout_op_test.cc @@ -61,7 +61,9 @@ void RunTestForInference(const std::vector& input_dims, bool has_ratio std::vector> test_eps; #ifdef USE_CUDA - test_eps.emplace_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + test_eps.emplace_back(DefaultCudaExecutionProvider()); + } #elif USE_ROCM test_eps.emplace_back(DefaultRocmExecutionProvider()); #endif @@ -122,6 +124,9 @@ void RunTestForTraining(const std::vector& input_dims) { std::vector> dropout_eps; #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } dropout_eps.emplace_back(DefaultCudaExecutionProvider()); #elif USE_ROCM dropout_eps.emplace_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 17c9e8592f64e..208545eacf224 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -15,23 +15,20 @@ namespace onnxruntime { namespace test { -// This op is currently only supported on CUDA- so test it only for CUDA -#ifdef USE_CUDA - template static std::vector CreateOnes(int size) { std::vector f; f.reserve(size); for (int i = 0; i < size; ++i) { - f.push_back(T(1)); + f.push_back(T(1.0f)); } return f; } template -static std::vector CreateValues(int size, int val) { +static std::vector CreateValues(int size, float val) { std::vector f; f.reserve(size); @@ -72,39 +69,25 @@ static std::vector CreateRandom(int size) { return f; } -// QKV template -static std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size); +float ToFloat(T val); template <> -std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size) { - std::vector qkv; - qkv.resize(batch_size * sequence_length * 3 * hidden_size, 0); - - for (int b = 0; b < batch_size; ++b) { - for (int i = 0; i < sequence_length; ++i) { - for (int j = 0; j < 3 * hidden_size; ++j) { - float sum = 0; - - for (int k = 0; k < hidden_size; ++k) { - sum += input[b * sequence_length * hidden_size + i * hidden_size + k] * weights[k * 3 * hidden_size + j]; - } - - qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = sum + bias[j]; - } - } - } - - return qkv; +constexpr float ToFloat(float val) { + return val; } template <> -std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size) { - std::vector qkv; - qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast(0.f)); +float ToFloat(MLFloat16 val) { + return val.ToFloat(); +} + +// QKV +template +static std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, + int batch_size, int sequence_length, int hidden_size) { + std::vector qkv; + qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int i = 0; i < sequence_length; ++i) { @@ -112,10 +95,11 @@ std::vector QKV(std::vector& input, std::vector float sum = 0; for (int k = 0; k < hidden_size; ++k) { - sum += input[b * sequence_length * hidden_size + i * hidden_size + k].ToFloat() * weights[k * 3 * hidden_size + j].ToFloat(); + sum += ToFloat(input[b * sequence_length * hidden_size + i * hidden_size + k]) * + ToFloat(weights[k * 3 * hidden_size + j]); } - qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = static_cast(sum + bias[j].ToFloat()); + qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = static_cast(sum + ToFloat(bias[j])); } } } @@ -180,15 +164,17 @@ void CheckEquality(T* data_1, T* data_2, int batch_size, int num_heads, int num_ // Reorder 'K' from [B, N, S, H] to [B, N, H/x, S, x] where x = (sizeof(T) / 16); // Copy 'V' over as is template -static std::vector ReorderKVCache(std::vector& unordered_k_cache, +static std::vector ReorderKVCache(const std::vector& unordered_k_cache, int batch_size, int num_heads, int sequence_length, - int head_size, int max_sequence_length) { + int head_size, int max_sequence_length, bool merge_past_kv = true) { std::vector ordered(unordered_k_cache.size(), T{0.f}); // Copy V over - size_t v_start = unordered_k_cache.size() / 2; - for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { - ordered[i] = unordered_k_cache[i]; + if (merge_past_kv) { + size_t v_start = unordered_k_cache.size() / 2; + for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { + ordered[i] = unordered_k_cache[i]; + } } // Now let us re-order K and copy it over to the final buffer @@ -203,7 +189,8 @@ static std::vector ReorderKVCache(std::vector& unordered_k_cache, (h * max_sequence_length * head_size); int input_base_offset = base_offset + (s * head_size) + (c * num_inner_elements); - int output_base_offset = base_offset + (c * max_sequence_length * num_inner_elements) + (s * num_inner_elements); + int output_base_offset = base_offset + (c * max_sequence_length * num_inner_elements) + + (s * num_inner_elements); for (int e = 0; e < num_inner_elements; ++e) { ordered[output_base_offset + e] = unordered_k_cache[input_base_offset + e]; @@ -224,7 +211,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache T* k, int batch_size, int num_heads, int past_sequence_length, int max_sequence_length, - int head_size) { + int head_size, bool merge_past_kv = true) { std::vector merged = ordered_k_cache; int total_seq_length = past_sequence_length + 1; @@ -249,10 +236,11 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache input_value = ordered_k_cache[input_offset]; } else { int hidden_size = num_heads * head_size; - int input_offset = (b * 3 * hidden_size) + - (n * num_chunks * chunk_size) + - (c * chunk_size) + - h; + int input_offset = merge_past_kv ? ((b * 3 * hidden_size) + + (n * num_chunks * chunk_size) + + (c * chunk_size) + + h) + : ((b * hidden_size) + n * head_size + c * chunk_size + h); input_value = k[input_offset]; } @@ -272,7 +260,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache return merged; } -// GIven a pointer to the 'V' component of the past cache, we will merge it +// Given a pointer to the 'V' component of the past cache, we will merge it // with current 'V' in-place template static void MergeReorderedKVCacheWithV(T* v_cache, @@ -299,7 +287,8 @@ static void MergeReorderedKVCacheWithV(T* v_cache, template static std::pair, std::vector> MergePastKWithPresentKAndTranspose(T* past_k, T* present_k, int num_batch, int num_heads, - int past_sequence_length, int max_sequence_length, + int past_sequence_length, + int max_sequence_length, int head_size) { int total_seq_length = (past_sequence_length + 1); std::vector merged_k(num_batch * num_heads * total_seq_length * head_size, T{0.f}); @@ -312,16 +301,18 @@ static std::pair, std::vector> MergePastKWithPresentKAndTransp T input_value{0.f}; if (s < past_sequence_length) { - int input_offset = b * num_heads * max_sequence_length * head_size + (n * max_sequence_length * head_size) + (s * head_size) + h; + int input_offset = b * num_heads * max_sequence_length * head_size + + (n * max_sequence_length * head_size) + (s * head_size) + h; input_value = past_k[input_offset]; } else { int hidden_size = num_heads * head_size; - // Offset by 3* hidden_size because QKV data contains Q, K, and V per batch + // Offset by 3 * hidden_size because QKV data contains Q, K, and V per batch int input_offset = (b * 3 * hidden_size) + (n * head_size) + h; input_value = present_k[input_offset]; } - int output_offset = b * num_heads * total_seq_length * head_size + (n * total_seq_length * head_size) + (s * head_size) + h; + int output_offset = b * num_heads * total_seq_length * head_size + + (n * total_seq_length * head_size) + (s * head_size) + h; merged_k[output_offset] = input_value; } @@ -383,15 +374,11 @@ void ValidateReorderedMergedKWithK(T* k, T* k_cache, int batch_size, int num_hea // QK_Transpose template std::vector QK_Transpose(T* q_matrix, T* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size); - -template <> -std::vector QK_Transpose(float* q_matrix, float* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size) { + int batch_size, int num_heads, int total_sequence_length, int head_size) { int hidden_size = num_heads * head_size; - std::vector qk_transpose; - qk_transpose.resize(batch_size * num_heads * total_sequence_length, 0); + std::vector qk_transpose; + qk_transpose.resize(batch_size * num_heads * total_sequence_length, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { @@ -409,50 +396,12 @@ std::vector QK_Transpose(float* q_matrix, float* k_transpose_matrix, for (int j = 0; j < total_sequence_length; ++j) { float sum = 0; for (int k = 0; k < head_size; ++k) { - sum += (q_matrix[input_1_base_offset + i * head_size + k] * - k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j]); + sum += (ToFloat(q_matrix[input_1_base_offset + i * head_size + k]) * + ToFloat(k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j])); } float scale = 1 / sqrt(static_cast(head_size)); - qk_transpose[output_base_offset + i * total_sequence_length + j] = scale * sum; - } - } - } - } - - return qk_transpose; -} - -template <> -std::vector QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size) { - int hidden_size = num_heads * head_size; - - std::vector qk_transpose; - qk_transpose.resize(batch_size * num_heads * total_sequence_length, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int input_1_base_offset = (b * 3 * hidden_size) + - (n * head_size); - - int input_2_base_offset = (b * num_heads * total_sequence_length * head_size) + - (n * total_sequence_length * head_size); - - int output_base_offset = (b * num_heads * total_sequence_length) + - (n * total_sequence_length); - - // sequence_length == 1 - for (int i = 0; i < 1; ++i) { - for (int j = 0; j < total_sequence_length; ++j) { - float sum = 0; - for (int k = 0; k < head_size; ++k) { - sum += (q_matrix[input_1_base_offset + i * head_size + k].ToFloat() * - k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j].ToFloat()); - } - - float scale = 1 / sqrt(static_cast(head_size)); - qk_transpose[output_base_offset + i * total_sequence_length + j] = MLFloat16(scale * sum); + qk_transpose[output_base_offset + i * total_sequence_length + j] = static_cast(scale * sum); } } } @@ -464,26 +413,23 @@ std::vector QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_ // Softmax_QK_Transpose template std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int head_size); - -template <> -std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int /*head_size*/) { + int sequence_length, int total_sequence_length) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } - std::vector softmax_qk_transpose; - softmax_qk_transpose.resize(batch_size * num_heads * sequence_length * total_sequence_length, 0); + std::vector softmax_qk_transpose; + softmax_qk_transpose.resize(static_cast(batch_size) * num_heads * sequence_length * total_sequence_length, + static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { int base_offset = (b * num_heads * sequence_length * total_sequence_length) + (n * sequence_length * total_sequence_length); - float max = std::numeric_limits::min(); + float max = std::numeric_limits::lowest(); for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); if (val > max) { max = val; } @@ -491,52 +437,13 @@ std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_si float denom = 0; for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); denom += std::exp(val - max); } for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; - softmax_qk_transpose[base_offset + s] = std::exp(val - max) / (denom + (float)0.000001); - } - } - } - - return softmax_qk_transpose; -} - -template <> -std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int /*head_size*/) { - if (sequence_length != 1) { - throw std::runtime_error("Not supported"); - } - - std::vector softmax_qk_transpose; - softmax_qk_transpose.resize(batch_size * num_heads * sequence_length * total_sequence_length, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int base_offset = (b * num_heads * sequence_length * total_sequence_length) + - (n * sequence_length * total_sequence_length); - - float max = std::numeric_limits::min(); - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - if (val > max) { - max = val; - } - } - - float denom = 0; - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - denom += std::exp(val - max); - } - - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - softmax_qk_transpose[base_offset + s] = MLFloat16(std::exp(val - max) / (denom + (float)0.000001)); + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); + softmax_qk_transpose[base_offset + s] = static_cast(std::exp(val - max) / (denom + (float)0.000001)); } } } @@ -550,19 +457,13 @@ std::vector Softmax_QK_Transpose_V(T* softmax_qk_transpose_matrix, T* v_matrix, int batch_size, int num_heads, int sequence_length, int total_sequence_length, int max_sequence_length, - int head_size); -template <> -std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, - float* v_matrix, - int batch_size, int num_heads, int sequence_length, - int total_sequence_length, int max_sequence_length, - int head_size) { + int head_size) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } - std::vector output; - output.resize(batch_size * sequence_length * num_heads * head_size, 0); + std::vector output; + output.resize(batch_size * sequence_length * num_heads * head_size, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { @@ -580,11 +481,11 @@ std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, float sum = 0; for (int k = 0; k < total_sequence_length; ++k) { - sum += (softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k] * - v_matrix[input_2_base_offset + k * head_size + j]); + sum += (ToFloat(softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k]) * + ToFloat(v_matrix[input_2_base_offset + k * head_size + j])); } - output[output_base_offset + i * head_size + j] = sum; + output[output_base_offset + i * head_size + j] = static_cast(sum); } } } @@ -593,48 +494,11 @@ std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, return output; } -template <> -std::vector Softmax_QK_Transpose_V(MLFloat16* softmax_qk_transpose_matrix, - MLFloat16* v_matrix, - int batch_size, int num_heads, int sequence_length, - int total_sequence_length, int max_sequence_length, - int head_size) { - if (sequence_length != 1) { - throw std::runtime_error("Not supported"); - } - - std::vector output; - output.resize(batch_size * sequence_length * num_heads * head_size, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int input_1_base_offset = (b * num_heads * sequence_length * total_sequence_length) + - (n * sequence_length * total_sequence_length); - - int input_2_base_offset = (b * num_heads * max_sequence_length * head_size) + - (n * max_sequence_length * head_size); - - int output_base_offset = (b * num_heads * sequence_length * head_size) + - (n * sequence_length * head_size); - - for (int i = 0; i < sequence_length; ++i) { - for (int j = 0; j < head_size; ++j) { - float sum = 0; - - for (int k = 0; k < total_sequence_length; ++k) { - sum += (softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k].ToFloat() * - v_matrix[input_2_base_offset + k * head_size + j].ToFloat()); - } - - output[output_base_offset + i * head_size + j] = MLFloat16(sum); - } - } - } - } +// Currently we only support CUDA for DecoderMaskedSelfAttention +#ifdef USE_CUDA - return output; -} -TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { +template +static void TestDecoderMaskedSelfAttention() { // The kernel is only supported on CC 5.3 or higher GPUs if (NeedSkipIfCudaArchLowerThan(530)) { return; @@ -661,19 +525,19 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { }; constexpr int sequence_length = 1; - constexpr int number_of_heads = 12; + constexpr int num_heads = 12; for (MyTestCase test_case : test_cases) { int batch_size = test_case.batch_size; int past_sequence_length = test_case.past_sequence_length; int hidden_size = test_case.hidden_size; - int head_size = (hidden_size / number_of_heads); + int head_size = (hidden_size / num_heads); int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("past_present_share_buffer", static_cast(1)); std::vector input_dims = {batch_size, sequence_length, hidden_size}; @@ -681,38 +545,38 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { std::vector bias_dims = {3 * hidden_size}; std::vector output_dims = {batch_size, sequence_length, hidden_size}; - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); // Mask tester.AddOptionalInputEdge(); // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + std::vector past_dims = {2, batch_size, num_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * num_heads * max_sequence_length * head_size; - auto kv_cache = CreateRandom(past_present_size); + auto kv_cache = CreateRandom(past_present_size); - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + num_heads, past_sequence_length, head_size, max_sequence_length); // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); + int chunk_size = 16 / sizeof(T); int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + auto transposed = Transpose(kv_cache.data(), batch_size, num_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, num_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + tester.AddInput("past", past_dims, reordered_kv_cache); // Rel - tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); // Past sequence length std::vector arr_past_sequence_len(1, past_sequence_length); @@ -722,41 +586,44 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); auto* qkv_matrix = qkv.data(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); auto k_merged = pair.first; auto k_transpose = pair.second; - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, num_heads, + total_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, num_heads, + sequence_length, total_sequence_length); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + num_heads, past_sequence_length, max_sequence_length, head_size); // Validate our test logic // We want to validate if our merged "unordered" K is the same as // the merged "ordered" K so that the QKT we do in our test code // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, num_heads, total_sequence_length, + max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + num_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, num_heads, sequence_length, total_sequence_length, + max_sequence_length, head_size); // Output(s) - tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("output", input_dims, output); + tester.AddOutput("present", past_dims, present); - tester.SetOutputTolerance(0.001f, 0.001f); + if (std::is_same::value) { + tester.SetOutputTolerance(0.005f); + } else { + tester.SetOutputTolerance(0.001f, 0.001f); + } // Run - Regular kernel execution path { @@ -778,150 +645,292 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { } } -TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { - // The kernel is only supported on CC 5.3 or higher GPUs - if (NeedSkipIfCudaArchLowerThan(530)) { - return; - } - - // Buckets for test data: - // batch_size: 1, >=2 - // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) - // head_size: 32, 64, 128 - struct MyTestCase { - int batch_size; - int past_sequence_length; - int hidden_size; - } test_cases[] = { - {1, 0, 768}, - {1, 1, 768}, - {3, 30, 384}, - {8, 31, 1536}, - {4, 256, 384}, - {3, 1024, 768}, - {2, 2046, 1536}, - {1, 2047, 384}, - {2, 3000, 768}, - }; - - constexpr int sequence_length = 1; - constexpr int number_of_heads = 12; - - for (MyTestCase test_case : test_cases) { - int batch_size = test_case.batch_size; - int past_sequence_length = test_case.past_sequence_length; - int hidden_size = test_case.hidden_size; +#endif // USE_CUDA - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); +template +static std::vector CalculateOutputQK(const std::vector& q, const std::vector& k, + const std::vector& mask_index, const std::vector& attention_bias, + int batch_size, int num_heads, + int sequence_length, int max_sequence_length, int head_size) { + // q (B, 1, NH), k (B, N, L(M), H) -> qk (B, N, 1, L) + // mask_index (B, L), (optional) attention_bias (1, 1, 1, L) + float scale = 1 / sqrt(static_cast(head_size)); + std::vector output_qk; + output_qk.resize(static_cast(batch_size) * num_heads * sequence_length, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < sequence_length; ++s) { + float mask_value = (mask_index[b * sequence_length + s] == 0) ? -10000.f : 0.f; + float bias_value = (attention_bias.empty()) ? 0.f : ToFloat(attention_bias[s]); + float sum = 0; + for (int h = 0; h < head_size; ++h) { + sum += ToFloat(q[b * num_heads * head_size + n * head_size + h]) * + ToFloat(k[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; + output_qk[b * num_heads * sequence_length + n * sequence_length + s] = + static_cast(scale * sum + mask_value + bias_value); + } + } + } - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + return output_qk; +} - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); +template +static std::vector CalculateOutput(const std::vector& softmax, const std::vector& v, int batch_size, + int num_heads, int sequence_length, int max_sequence_length, int head_size) { + // softmax (B, N, 1, L) v (B, N, L(M), H) -> output (B, N, 1, H) + std::vector output; + output.resize(static_cast(batch_size) * num_heads * head_size, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + float sum = 0; + for (int s = 0; s < sequence_length; ++s) { + sum += ToFloat(softmax[b * num_heads * sequence_length + n * sequence_length + s]) * + ToFloat(v[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + output[b * num_heads * head_size + n * head_size + h] = static_cast(sum); + } + } + } - // Mask - tester.AddOptionalInputEdge(); + return output; +} - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; +template +static std::vector MergePast(const std::vector& past, const std::vector& current, int batch_size, + int num_heads, int past_seq_len, int max_seq_len, int head_size) { + // past (B, N, S(M), H), current (B, 1, NH) -> merged (B, N, S+1(M), H) + std::vector merged = past; + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + merged[b * num_heads * max_seq_len * head_size + n * max_seq_len * head_size + past_seq_len * head_size + h] = + current[b * num_heads * head_size + n * head_size + h]; + } + } + } - auto kv_cache = CreateRandom(past_present_size); + return merged; +} - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); +template +static std::vector ReorderKVByCacheIndirection(const std::vector& key_or_value, + const int32_t* cache_indirection, + int batch_size, int beam_width, int max_sequence_length, + int num_heads, int head_size, int past_sequence_length) { + std::vector reordered = key_or_value; - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + for (int b = 0; b < batch_size; ++b) { + int beam_batch_index = b / beam_width; + const int* beam_indices = cache_indirection + b * max_sequence_length; + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < past_sequence_length; ++s) { + int beam_offset = beam_indices[s] * num_heads * max_sequence_length * head_size; + int beam_batch_offset = (beam_batch_index * beam_width * num_heads + n) * max_sequence_length * head_size; + for (int h = 0; h < head_size; ++h) { + reordered[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h] = + key_or_value[beam_offset + beam_batch_offset + s * head_size + h]; + } + } + } + } - tester.AddInput("past", past_dims, reordered_kv_cache); + return reordered; +} - // Rel - tester.AddOptionalInputEdge(); +template +static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool use_cuda = true) { + int batch_size = 8; + int past_sequence_length = 2; + int kv_sequence_length = 16; + int head_size = 32; + int num_heads = 12; + int beam_width = 4; + int hidden_size = head_size * num_heads; + + OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain); + FixedPatternValueGenerator generator{}; + RandomValueGenerator random{123}; + + // Attributes + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(!is_cross_attn)); + // Output scaled Q * K^T by default for cross-attention + tester.AddAttribute("output_qk", static_cast(is_cross_attn)); + + // Inputs and outputs + auto query = CreateRandom(batch_size * 1 * hidden_size); + tester.AddInput("query", {batch_size, 1, hidden_size}, query); + + if (is_cross_attn) { + auto key = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + std::vector reordered_key; + if (use_cuda) { + reordered_key = ReorderKVCache(key, batch_size, num_heads, + kv_sequence_length, head_size, kv_sequence_length, false); + } + auto value = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + tester.AddInput("key", {batch_size, num_heads, kv_sequence_length, head_size}, (use_cuda ? reordered_key : key)); + tester.AddInput("value", {batch_size, num_heads, kv_sequence_length, head_size}, + CreateRandom(batch_size * num_heads * kv_sequence_length * head_size)); + + const std::vector mask_index_dims = {batch_size, kv_sequence_length}; + auto mask_index = generator.Discrete(mask_index_dims, AsSpan({0, 1})); + tester.AddInput("mask_index", {batch_size, kv_sequence_length}, mask_index); + + // Calculate Softmax(Q * K^T + (Optional) mask) * V + std::vector empty_attention_bias; + auto output_qk = CalculateOutputQK(query, key, mask_index, empty_attention_bias, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + std::vector output_qk_float(output_qk.size()); + for (size_t i = 0; i < output_qk.size(); ++i) { + output_qk_float[i] = static_cast(output_qk[i]); + } + auto softmax = Softmax_QK_Transpose(output_qk.data(), batch_size, num_heads, 1, kv_sequence_length); + auto output = CalculateOutput(softmax, value, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + tester.AddOptionalOutputEdge(); // optional present_key + tester.AddOptionalOutputEdge(); // optional present_value + tester.AddOutput("qk", {batch_size, num_heads, 1, kv_sequence_length}, output_qk_float); + } else { + int max_sequence_length = past_sequence_length + 10; + int total_sequence_length = past_sequence_length + 1; + + auto key = CreateRandom(batch_size * hidden_size); + auto value = CreateRandom(batch_size * hidden_size); + tester.AddInput("key", {batch_size, 1, hidden_size}, key); + tester.AddInput("value", {batch_size, 1, hidden_size}, value); + + const std::vector mask_index_dims = {batch_size, total_sequence_length}; + auto mask_index = generator.Discrete(mask_index_dims, AsSpan({0, 1})); + tester.AddInput("mask_index", {batch_size, total_sequence_length}, mask_index); + std::vector attention_bias_dims = {1, 1, 1, total_sequence_length}; + auto attention_bias_float = random.Gaussian(attention_bias_dims, 0.0f, 0.3f); + std::vector attention_bias(attention_bias_float.size()); + for (size_t i = 0; i < attention_bias.size(); ++i) { + attention_bias[i] = static_cast(attention_bias_float[i]); + } + tester.AddInput("attention_bias", {1, 1, 1, total_sequence_length}, attention_bias); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + auto past_key = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); + auto past_value = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + std::vector reordered_past_key; // For CUDA, we need to reorder past key + if (use_cuda) { + reordered_past_key = ReorderKVCache(past_key, batch_size, num_heads, + past_sequence_length, head_size, max_sequence_length, false); + } - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + tester.AddInput("past_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? reordered_past_key : past_key)); + tester.AddInput("past_value", {batch_size, num_heads, max_sequence_length, head_size}, past_value); + + // merge past key and value with current key and value + auto merged_key = MergePast(past_key, key, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + std::vector merged_reordered_key; + if (use_cuda) { + merged_reordered_key = MergeReorderedKVCacheWithK(reordered_past_key, key.data(), batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size, false); + } + auto merged_value = MergePast(past_value, value, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + + tester.AddInput("past_sequence_length", {1}, {past_sequence_length}); + + std::vector mod_merged_key, mod_merged_value; + if (beam_width > 1) { + tester.AddInput("beam_width", {1}, {beam_width}); + + const std::vector cache_indir_dims = {batch_size, beam_width, max_sequence_length}; + auto value_candidates = ValueRange(beam_width); + auto cache_indir = generator.Discrete(cache_indir_dims, value_candidates); + tester.AddInput("cache_indirection", cache_indir_dims, cache_indir); + + // Modify merged_key and merged_value according to cache_indirection + mod_merged_key = ReorderKVByCacheIndirection(merged_key, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + mod_merged_value = ReorderKVByCacheIndirection(merged_value, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + } - auto k_merged = pair.first; - auto k_transpose = pair.second; + // Calculate Softmax(Q * K^T + (Optional) mask) * V + auto output_qk = CalculateOutputQK(query, (beam_width > 1 ? mod_merged_key : merged_key), + mask_index, attention_bias, + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + auto softmax = Softmax_QK_Transpose(output_qk.data(), batch_size, num_heads, 1, total_sequence_length); + auto output = CalculateOutput(softmax, (beam_width > 1 ? mod_merged_value : merged_value), + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + tester.AddOutput("present_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? merged_reordered_key : merged_key)); + tester.AddOutput("present_value", {batch_size, num_heads, max_sequence_length, head_size}, merged_value); + } - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + if (std::is_same::value) { + tester.SetOutputTolerance(0.02f); + } else { + tester.SetOutputTolerance(0.0001f, 0.0001f); + } - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + { + std::vector> execution_providers; + if (use_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } else { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); +#ifdef USE_CUDA - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); +TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { + TestDecoderMaskedSelfAttention(); +} - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); +TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { + TestDecoderMaskedSelfAttention(); +} - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(); +} - // Output(s) - tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(); +} - tester.SetOutputTolerance(0.005f); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; +#endif - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ true, /* use_cuda = */ false); } -#endif +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false, /* use_cuda = */ false); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 5cf749dc4c97c..a7d751f4472fc 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -41,7 +41,7 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& bias_dims, const std::vector& output_dims, @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector& input_data, const std:: execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -142,7 +144,7 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); #endif RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); @@ -245,8 +247,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); } -// CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) +// CUDA, ROCm and WebGPU only for Float16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { int batch_size = 1; int sequence_length = 2; @@ -381,7 +383,10 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); } +#endif +// CUDA and ROCm only for BFloat16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 9ecaa16a2ab24..52e67bf0616d1 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -120,7 +120,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { @@ -134,7 +134,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { @@ -162,7 +162,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput_Initializers) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias) { @@ -192,7 +192,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, - kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider}); + kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { @@ -207,24 +207,35 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { - OpTester test("LayerNormalization"); - test.AddAttribute("epsilon", 1e-05f); - - std::vector dims{1, 3, 2}; - test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); - test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f})); - test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f})); - test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); - // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes - test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + auto run_test = [](bool is_initializer) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), is_initializer); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}), is_initializer); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL, OpenVINO and NNAPI don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); + }; + run_test(false); + run_test(true); } +template +class LayerNormTest : public ::testing::Test { +}; + +using LayerNormTestTypes = ::testing::Types; +TYPED_TEST_SUITE(LayerNormTest, LayerNormTestTypes); + TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializers) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -237,19 +248,41 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializer // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider}); } // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. -TEST(LayerNormTest, LayerNorm17_float) { - OpTester test("LayerNormalization", 17); - test.AddAttribute("epsilon", 1e-05f); +TYPED_TEST(LayerNormTest, LayerNorm17_opset) { + auto run_test = [](bool is_initializer) { + OpTester test("LayerNormalization", 17); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, GetTypedArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, GetTypedArray({1.0f, 1.0f, 1.0f}), is_initializer); + test.AddOutput("output", dims, GetTypedArray({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); + if (std::is_same::value) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider}, + nullptr, &execution_providers); + } else { + test.Run(); + } + }; + // Execution provider entry invalid. + // when other EPs support layer-norm fp16, this test should be updated to include them. + if (std::is_same::value) { +#if !defined(COREML_ENABLE_MLPROGRAM) + return; +#endif + } - std::vector dims{1, 2, 3}; - test.AddInput("x", dims, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - test.AddInput("gamma", {3}, {1.0f, 1.0f, 1.0f}); - test.AddOutput("output", dims, {-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}); - test.Run(); + run_test(false); + run_test(true); } TEST(LayerNormTest, LayerNorm17_double) { diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 438a1100ca95c..b414a98c4e756 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -2,11 +2,12 @@ // Licensed under the MIT License. #include "test/providers/compare_provider_test_utils.h" +#include "test/util/include/default_providers.h" namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) constexpr auto k_epsilon_default = 1e-5f; constexpr auto k_random_data_min = -10.0f; constexpr auto k_random_data_max = 10.0f; @@ -65,8 +66,8 @@ static void TestLayerNorm(const std::vector& x_dims, std::vector Y_data = FillZeros(n_x_m_dims); test.AddOutput("output", n_x_m_dims, Y_data); -#ifndef USE_DML - // DML doesn't support more than one output for these ops yet +#if !defined(USE_DML) && !defined(USE_WEBGPU) + // DML and WebGPU don't support more than one output for these ops yet const std::vector& stats_dims = keep_dims ? n_and_ones_dims : n_dims; std::vector mean_data = FillZeros(stats_dims); std::vector var_data = FillZeros(stats_dims); @@ -79,11 +80,19 @@ static void TestLayerNorm(const std::vector& x_dims, #endif #ifdef USE_CUDA - test.CompareWithCPU(kCudaExecutionProvider); + if (DefaultCudaExecutionProvider() != nullptr) { + test.CompareWithCPU(kCudaExecutionProvider); + } #elif USE_ROCM test.CompareWithCPU(kRocmExecutionProvider); -#elif USE_DML - test.CompareWithCPU(kDmlExecutionProvider); +#elif USE_WEBGPU + test.CompareWithCPU(kWebGpuExecutionProvider); +#endif + +#ifdef USE_DML + if (DefaultDmlExecutionProvider() != nullptr) { + test.CompareWithCPU(kDmlExecutionProvider); + } #endif } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 8138829b057f2..6dedce24e7e07 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -82,6 +82,7 @@ struct TestOptions { bool has_bias{false}; std::optional output_abs_error{}; + std::optional output_rel_error{}; }; std::ostream& operator<<(std::ostream& os, const TestOptions& opts) { @@ -253,6 +254,10 @@ void RunTest(const TestOptions& opts, test.SetOutputAbsErr("Y", *opts.output_abs_error); } + if (opts.output_rel_error.has_value()) { + test.SetOutputRelErr("Y", *opts.output_rel_error); + } + if (!explicit_eps.empty()) { test.ConfigEps(std::move(explicit_eps)); } @@ -271,10 +276,10 @@ void TestMatMulNBitsTyped() { if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; - } else { - if constexpr (std::is_same::value) { - base_opts.output_abs_error = 0.01f; - } + base_opts.output_rel_error = 0.02f; + } else if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.055f; + base_opts.output_rel_error = 0.02f; } { @@ -288,7 +293,7 @@ void TestMatMulNBitsTyped() { RunTest(opts); } -#if !defined(USE_DML) +#if !defined(USE_DML) && !defined(USE_WEBGPU) { TestOptions opts = base_opts; opts.has_g_idx = true; @@ -319,7 +324,7 @@ void TestMatMulNBitsTyped() { opts.has_zero_point = true, opts.zp_is_4bit = false; RunTest(opts); } -#endif // !defined(USE_DML) +#endif // !defined(USE_DML) && !defined(USE_WEBGPU) } TEST(MatMulNBits, Float32_Accuracy0) { @@ -387,48 +392,48 @@ TEST(MatMulNBits, Float32_Accuracy4) { TestMatMulNBitsTyped(); } -#ifdef MLAS_TARGET_AMD64_IX86 +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64) #if !defined(USE_DML) // Actual and expected difference is over 0.01 with DmlExecutionProvider. // Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16_Accuracy2) { + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); +} + TEST(MatMulNBits, Float16_Accuracy0) { TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float16_Accuracy1) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); -} - TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -458,7 +463,7 @@ TEST(MatMulNBits, Float16_Accuracy4) { #endif #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) namespace { // Legacy test function. @@ -485,13 +490,20 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura std::vector> execution_providers; if (use_float16) { #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); #endif #ifdef USE_DML - execution_providers.push_back(DefaultDmlExecutionProvider()); + if (DefaultDmlExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } +#endif +#ifdef USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif RunTest(opts, std::move(execution_providers)); @@ -506,8 +518,11 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura } // namespace TEST(MatMulNBits, Float16Cuda) { -#if defined(USE_CUDA) || defined(USE_ROCM) - auto has_gidx_options = {true, false}; +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) + std::vector has_gidx_options = {true, false}; + if (DefaultDmlExecutionProvider() != nullptr) { + has_gidx_options.assign(1, false); + } #else auto has_gidx_options = {false}; #endif @@ -518,7 +533,9 @@ TEST(MatMulNBits, Float16Cuda) { for (auto block_size : {16, 32, 64, 128}) { for (auto has_gidx : has_gidx_options) { #ifdef USE_DML - RunTest(M, N, K, block_size, 0, false, true, has_gidx, true, 0.04f); + if (DefaultDmlExecutionProvider() != nullptr) { + RunTest(M, N, K, block_size, 0, false, true, has_gidx, true, 0.04f); + } #else RunTest(M, N, K, block_size, 0, false, true, has_gidx); RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); @@ -531,12 +548,19 @@ TEST(MatMulNBits, Float16Cuda) { } TEST(MatMulNBits, Float16Large) { -#ifdef USE_DML +#if defined(USE_CUDA) || defined(USE_DML) // For some reason, the A10 machine that runs these tests during CI has a much bigger error than all retail // machines we tested on. All consumer-grade machines from Nvidia/AMD/Intel seem to pass these tests with an // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. - float abs_error = 0.3f; + float abs_error = 0.05f; + if (DefaultDmlExecutionProvider() != nullptr) { + // it means the ep is dml in runtime, the abs_error is changed to 0.3f + abs_error = 0.3f; + } +#elif USE_WEBGPU + // See Intel A770 to pass these tests with an absolute error of 0.08. + float abs_error = 0.08f; #else float abs_error = 0.05f; #endif @@ -549,7 +573,6 @@ TEST(MatMulNBits, Float16Large) { } } } - #endif // defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 8d7629b5fda1c..d88c3131a4ca5 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -227,7 +227,7 @@ TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8S8) { } // DML EP supports Float16 output type and Signed A Matrix and Unsigned B Matric for Float32 output -#if defined(USE_DML) +#if defined(USE_DML) && !defined(USE_CUDA) TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8U8) { RunMatMulIntegerToFloatTest(); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 1d167b5dffdb5..6b6799d73fb56 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -49,6 +49,7 @@ static void RunMultiHeadAttentionTest( bool use_float16 = false, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, // not supported in rocm right now. bool disable_dml = false) { kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length); @@ -59,6 +60,7 @@ static void RunMultiHeadAttentionTest( bool enable_rocm = (nullptr != DefaultRocmExecutionProvider(/*test_tunable_op=*/true).get()) && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; if (enable_rocm && !use_float16) { LOGS_DEFAULT(WARNING) << "ROCm MHA only have kernel for half datatype implemented, skip float datatype tests"; @@ -70,7 +72,7 @@ static void RunMultiHeadAttentionTest( enable_rocm = false; } - if (enable_cpu || enable_cuda || enable_rocm || enable_dml) { + if (enable_cpu || enable_cuda || enable_rocm || enable_dml || enable_webgpu) { OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); @@ -266,6 +268,12 @@ static void RunMultiHeadAttentionTest( execution_providers.push_back(DefaultDmlExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + + if (enable_webgpu) { + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } @@ -295,6 +303,7 @@ static void RunMultiHeadAttentionKernel( bool is_static_kv = true, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, bool disable_dml = false) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { @@ -309,7 +318,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -325,7 +335,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -341,7 +352,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -358,7 +370,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } #endif @@ -376,7 +389,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -392,11 +406,30 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } } -static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { +enum RunMultiHeadAttentionTestToggles : uint32_t { + DISABLE_NONE = 0, + DISABLE_CPU = 1 << 0, + DISABLE_CUDA = 1 << 1, + DISABLE_WEBGPU = 1 << 2, +}; +inline RunMultiHeadAttentionTestToggles operator|(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline RunMultiHeadAttentionTestToggles operator&(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) & static_cast(b)); +} + +static void RunMultiHeadAttentionTests(AttentionTestData& data, + RunMultiHeadAttentionTestToggles toggles = DISABLE_NONE) { + bool disable_cpu = toggles & DISABLE_CPU; + bool disable_cuda = toggles & DISABLE_CUDA; + bool disable_webgpu = toggles & DISABLE_WEBGPU; + if (data.fp32_output_data.size() > 0) { constexpr bool use_float16 = false; @@ -407,7 +440,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -420,7 +453,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } #endif @@ -431,7 +464,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } if (data.fp16_output_data.size() > 0) { @@ -443,7 +476,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; @@ -453,7 +486,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -464,7 +497,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #endif @@ -475,7 +508,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_Default; @@ -484,7 +517,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } @@ -503,40 +536,40 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } // This tests qk_head_size != v_head_size @@ -561,7 +594,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, false, true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // TODO (pavignol): Fix this regression @@ -571,7 +604,7 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetCrossAttentionDataWithPast(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } #endif @@ -579,27 +612,27 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU); } TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; GetAttentionDataCutlassAttnBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { // Whisper decoder cross attention without mask and different sequence lengths for Q and K/V AttentionTestData data; GetCrossAttentionData_DiffSequenceLengths(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { @@ -609,10 +642,10 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) RunMultiHeadAttentionTests(data); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // This test is disabled since it is not used in Whisper anymore, and it fails in ROCm. diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 8675a997d29a1..0e964cf64fbbd 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -67,6 +67,7 @@ static void RunTest( : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); if (enable_cuda && !disable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -74,9 +75,12 @@ static void RunTest( if (enable_dml && !disable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } - if (tensor_type == TensorType::kFloat && !disable_cpu) { + if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } if (execution_providers.size() == 0) { // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) return; @@ -136,26 +140,7 @@ static void RunTests(const std::vector& input_data, int64_t interleaved = 0, int64_t is_packed_batching = 0, bool use_float16 = true) { - // FP32 test for CPU - RunTest(input_data, - position_ids, - cos_cache, - sin_cache, - output_data, - batch_size, - sequence_length, - head_size, - rotary_embedding_dim, - num_heads, - max_sequence_length, - interleaved, - is_packed_batching, - TensorType::kFloat, - false, /* disable_cpu */ - true, /* disable_cuda */ - true /* disable_dml */); - - // FP32 test for CUDA and DML + // FP32 test for CPU, CUDA and DML RunTest(input_data, position_ids, cos_cache, @@ -174,7 +159,7 @@ static void RunTests(const std::vector& input_data, false, /* disable_cuda */ false /* disable_dml */); - // FP16 test for CUDA and DML + // FP16 test for CPU, CUDA and DML if (use_float16) { RunTest(input_data, position_ids, @@ -190,26 +175,9 @@ static void RunTests(const std::vector& input_data, interleaved, is_packed_batching, TensorType::kFloat16, - true, /* disable_cpu */ + false, /* disable_cpu */ false, /* disable_cuda*/ false /* disable_dml */); - - // RunTest(input_data, - // position_ids, - // cos_cache, - // sin_cache, - // output_data, - // batch_size, - // sequence_length, - // head_size, - // rotary_embedding_dim, - // num_heads, - // max_sequence_length, - // interleaved, - // TensorType::kBFloat16, - // true, /* disable_cpu */ - // false, /* disable_cuda*/ - // false /* disable_dml */); } } diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index edf9064bb43c9..4e8d1b9f016f0 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -62,6 +62,8 @@ static void RunOneTest( auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); auto cpu_ep = DefaultCpuExecutionProvider(); + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + std::vector> execution_providers; if (!use_float16) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); @@ -95,10 +97,14 @@ static void RunOneTest( if (cpu_ep != nullptr) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || - rocm_ep != nullptr) { + rocm_ep != nullptr || + webgpu_ep != nullptr) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("skip", skip_dims, ToFloat16(skip_data)); @@ -132,7 +138,9 @@ static void RunOneTest( ToFloat16(sum_output_data)); } - if (dml_ep != nullptr) { + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } else if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); } else if (rocm_ep != nullptr) { execution_providers.push_back(DefaultRocmExecutionProvider()); @@ -186,6 +194,32 @@ static void RunTest( } } +TEST(SkipLayerNormTest, SkipLayerNormPrePack) { + OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 2; + std::vector input_skip_output_dims = {batch_size, sequence_length, hidden_size}; + std::vector gamma_beta_bias_dims = {hidden_size}; + test.AddInput("x", input_skip_output_dims, ToFloat16({1.f, 1.f, 1.f, 1.f})); + test.AddInput("skip", input_skip_output_dims, ToFloat16({1.f, 1.f, 1.f, 1.f})); + test.AddInput("gamma", gamma_beta_bias_dims, ToFloat16({1.f, 1.f}), true); + test.AddInput("beta", gamma_beta_bias_dims, ToFloat16({1.f, 1.f}), true); + test.AddOutput("output", input_skip_output_dims, ToFloat16({ + 1.f, + 1.f, + 1.f, + 1.f, + })); + + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider}); +} + TEST(SkipLayerNormTest, SkipLayerNormNullInput) { int batch_size = 1; int sequence_length = 0; diff --git a/onnxruntime/test/contrib_ops/tensor_op_test.cc b/onnxruntime/test/contrib_ops/tensor_op_test.cc index bc2ff5f4f724d..d5e2ddebfe67f 100644 --- a/onnxruntime/test/contrib_ops/tensor_op_test.cc +++ b/onnxruntime/test/contrib_ops/tensor_op_test.cc @@ -121,7 +121,15 @@ void MeanVarianceNormalizationAcrossChannels(bool across_channels, bool normaliz test.AddAttribute("normalize_variance", normalize_variance ? one : zero); test.AddInput("input", {N, C, H, W}, X); test.AddOutput("output", {N, C, H, W}, result); +#if defined(USE_CUDA) && defined(USE_DML) + if (DefaultCudaExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaExecutionProvider, kTensorrtExecutionProvider}); + } else if (DefaultDmlExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kDmlExecutionProvider, kTensorrtExecutionProvider}); + } +#else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kTensorrtExecutionProvider}); // OpenVINO doesn't support MVN operator below opset 9. TensorRT doesn't support opset 8 of MVN operator. +#endif } void MeanVarianceNormalizationPerChannel(bool across_channels, bool normalize_variance) { @@ -188,7 +196,15 @@ void MeanVarianceNormalizationPerChannel(bool across_channels, bool normalize_va test.AddAttribute("normalize_variance", normalize_variance ? one : zero); test.AddInput("input", {N, C, H, W}, X); test.AddOutput("output", {N, C, H, W}, result); +#if defined(USE_CUDA) && defined(USE_DML) + if (DefaultCudaExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaExecutionProvider, kTensorrtExecutionProvider}); + } else if (DefaultDmlExecutionProvider() == nullptr) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kDmlExecutionProvider, kTensorrtExecutionProvider}); + } +#else test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kTensorrtExecutionProvider}); // OpenVINO doesn't support MVN operator below opset 9. TensorRT doesn't support opset 8 of MVN operator. +#endif } TEST(MVNContribOpTest, MeanVarianceNormalizationCPUTest_Version1_TO_8) { @@ -230,7 +246,9 @@ TEST(UnfoldTensorOpTest, LastDim) { std::vector> execution_providers; #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } #endif execution_providers.push_back(DefaultCpuExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 0105e90b5a24a..adab93908cdc4 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -28,6 +28,7 @@ using json = nlohmann::json; #ifdef USE_CUDA #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_provider_factory.h" +#include "test/common/cuda_op_test_utils.h" #endif // USE_CUDA #include "core/session/onnxruntime_session_options_config_keys.h" using namespace ONNX_NAMESPACE; @@ -251,6 +252,7 @@ class PlannerTest : public ::testing::Test { void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg, std::unordered_map>& kernel_create_info_map) { + const auto& logger = DefaultLoggingManager().DefaultLogger(); const IExecutionProvider* ep = execution_providers_.Get(*p_node); ASSERT_NE(ep, nullptr); auto info = std::make_unique( @@ -260,7 +262,7 @@ class PlannerTest : public ::testing::Test { op_kernel_infos_.push_back(std::move(info)); const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{}; if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider, - kernel_type_str_resolver)) { + kernel_type_str_resolver, logger)) { ASSERT_STATUS_OK(reg->Register( KernelCreateInfo(std::make_unique(kernel_def), [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { @@ -270,7 +272,7 @@ class PlannerTest : public ::testing::Test { } const KernelCreateInfo* kci; - ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, &kci)); + ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, logger, &kci)); kernel_create_info_map.insert({p_node->Index(), gsl::not_null(kci)}); } @@ -282,7 +284,8 @@ class PlannerTest : public ::testing::Test { } } - void CreatePlan(const std::vector& outer_scope_node_args = {}, bool invoke_createPlan_explicityly = true) { + void CreatePlan(const std::vector& outer_scope_node_args = {}, + bool invoke_createPlan_explicityly = true) { state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); EXPECT_EQ(graph_.Resolve(), Status::OK()); @@ -894,6 +897,9 @@ TEST_F(PlannerTest, LocationPlanningForPassThroughExplicitAndImplicitSubgraphInp SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1036,6 +1042,9 @@ TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) { SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1143,6 +1152,9 @@ TEST_F(PlannerTest, LocationPlanningForInitializersUsedOnDifferentDevicesInMainG SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1235,6 +1247,9 @@ TEST_F(PlannerTest, LocationPlanningForImplicitInputsWithoutExplicitConsumersInM SessionOptions so; InferenceSession sess{so, GetEnvironment()}; + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); ASSERT_TRUE(status.IsOK()); @@ -1267,6 +1282,10 @@ TEST_F(PlannerTest, LocationPlanningForImplicitInputsWithoutExplicitConsumersInM // Test MultiStream scenario for the graph: // node1(CPU ep)->node2(CPU ep)->node3(CUDA ep)->node4(CPU ep) TEST_F(PlannerTest, MultiStream) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ONNX_NAMESPACE::TensorProto tensor; tensor.add_dims(1); tensor.add_float_data(1.0f); @@ -1285,6 +1304,7 @@ TEST_F(PlannerTest, MultiStream) { onnxruntime::ProviderInfo_CUDA& ep = onnxruntime::GetProviderInfo_CUDA(); auto epFactory = ep.CreateExecutionProviderFactory(epi); std::unique_ptr execution_provider = epFactory->CreateProvider(); + ORT_THROW_IF_ERROR(GetExecutionProviders().Add("CUDAExecutionProvider", std::move(execution_provider))); CreatePlan({}, false); @@ -1312,6 +1332,9 @@ TEST_F(PlannerTest, MultiStream) { // node3 // All 3 nodes are CUDA EP, node1 is in stream0, node2 is in stream1, node3 is in stream2 TEST_F(PlannerTest, MultiStream1StreamWaitFor2Streams) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unique_ptr<::onnxruntime::KernelDef> cudaKernel = KernelDefBuilder().SetName("Transpose").Provider(kCudaExecutionProvider).SinceVersion(1, 10).Build(); std::unique_ptr<::onnxruntime::KernelDef> cudaKernelAdd = KernelDefBuilder().SetName("Add").Provider(kCudaExecutionProvider).SinceVersion(1, 10).Build(); std::string Graph_input("Graph_input"), Arg1("Arg1"), Arg2("Arg2"), Arg3("Arg3"), node1("node1"), node2("node2"), node3("node3"); @@ -1353,6 +1376,9 @@ TEST_F(PlannerTest, MultiStream1StreamWaitFor2Streams) { // stream 1: node2 (CPU EP) // node1's output, which is consumed by both node2 and node3, is in CPU. TEST_F(PlannerTest, MultiStreamCudaEPNodeCPUOutput) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif MemcpyToHostInCuda_TransposeInCudaAndCpu("./testdata/multi_stream_models/memcpyToHost_same_stream_with_transpose.json"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 5) << "stream 0 has 5 steps"; @@ -1374,6 +1400,11 @@ TEST_F(PlannerTest, MultiStreamCudaEPNodeCPUOutput) { // TODO(leca): there is a bug in the corresponding graph that node2 will be visited twice when traversing node1's output nodes // (see: for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) in BuildExecutionPlan()). We can just break the loop and don't need the extra variables once it is fixed TEST_F(PlannerTest, MultiStreamMultiOutput) { +#if defined(USE_CUDA) && defined(USE_DML) + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } +#endif std::unique_ptr<::onnxruntime::KernelDef> cudaKernel = KernelDefBuilder().SetName("RNN").Provider(kCudaExecutionProvider).SinceVersion(7).Build(); std::string Graph_input1("Graph_input1"), Graph_input2("Graph_input2"), Graph_input3("Graph_input3"), Arg1("Arg1"), Arg2("Arg2"), Arg3("Arg3"), node1("node1"), node2("node2"); std::vector input1{Arg(Graph_input1), Arg(Graph_input2), Arg(Graph_input3)}, output1{Arg(Arg1), Arg(Arg2)}, input2{Arg(Arg1), Arg(Arg2)}, output2{Arg(Arg3)}; @@ -1411,6 +1442,9 @@ TEST_F(PlannerTest, MultiStreamMultiOutput) { // TODO(leca): the ideal case is there is only 1 wait step before launching node3, // as there is a specific order between node1 and node2 if they are in the same stream, thus node3 will only need to wait the latter one TEST_F(PlannerTest, MultiStream2NodesSameStreamConsumedBy1NodeInDifferentStream) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unique_ptr<::onnxruntime::KernelDef> cudaKernel = KernelDefBuilder().SetName("Transpose").Provider(kCudaExecutionProvider).SinceVersion(1, 10).Build(); std::string Graph_input1("Graph_input1"), Graph_input2("Graph_input2"), Graph_input3("Graph_input3"), Arg1("Arg1"), Arg2("Arg2"), Arg3("Arg3"), node1("node1"), node2("node2"), node3("node3"); std::vector input1{Arg(Graph_input1)}, input2{Arg(Graph_input2)}, output1{Arg(Arg1)}, output2{Arg(Arg2)}, input3{Arg(Arg1), Arg(Arg2)}, output3{Arg(Arg3)}; @@ -1448,6 +1482,9 @@ TEST_F(PlannerTest, MultiStream2NodesSameStreamConsumedBy1NodeInDifferentStream) #if !defined(__wasm__) && defined(ORT_ENABLE_STREAM) TEST_F(PlannerTest, ParaPlanCreation) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif TypeProto graph_in_type; graph_in_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); auto* graph_in_shape = graph_in_type.mutable_tensor_type()->mutable_shape(); @@ -1889,6 +1926,10 @@ TEST_F(PlannerTest, ParaPlanCreation) { } TEST_F(PlannerTest, TestMultiStreamConfig) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + const char* type = "DeviceBasedPartitioner"; constexpr size_t type_len = 22; @@ -1962,6 +2003,10 @@ TEST_F(PlannerTest, TestMultiStreamSaveConfig) { // Load with partition config where a node is missing, session load expected to fail. TEST_F(PlannerTest, TestMultiStreamMissingNodeConfig) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + const char* config_file_path = "./testdata/multi_stream_models/conv_add_relu_single_stream_missing_node.json"; SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default; @@ -1982,6 +2027,9 @@ TEST_F(PlannerTest, TestMultiStreamMissingNodeConfig) { // Load with partition config where streams and devices has mismatch TEST_F(PlannerTest, TestMultiStreamMismatchDevice) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif const char* config_file_path = "./testdata/multi_stream_models/conv_add_relu_single_stream_mismatch_device.json"; SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default; @@ -2007,6 +2055,9 @@ TEST_F(PlannerTest, TestCpuIf) { sess_opt.graph_optimization_level = TransformerLevel::Default; InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/cpu_if.onnx")); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(sess.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_STATUS_OK(sess.Load()); ASSERT_STATUS_OK(sess.Initialize()); @@ -2067,10 +2118,17 @@ TEST_F(PlannerTest, TestCpuIf) { // onnx.save(model, 'issue_19480.onnx') // TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + SessionOptions sess_opt; sess_opt.graph_optimization_level = TransformerLevel::Default; InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx")); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); status = sess.Load(); status = sess.Initialize(); diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 8961058628490..fa6c4966d6953 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -3,6 +3,7 @@ #include #include "core/framework/allocator.h" +#include "core/framework/allocator_utils.h" #include "test_utils.h" #include "gtest/gtest.h" @@ -15,12 +16,10 @@ TEST(AllocatorTest, CPUAllocatorTest) { ASSERT_STREQ(cpu_arena->Info().name, CPU); EXPECT_EQ(cpu_arena->Info().id, 0); - // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) - EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtArenaAllocator); -#else - EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtDeviceAllocator); -#endif + const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() + ? OrtAllocatorType::OrtArenaAllocator + : OrtAllocatorType::OrtDeviceAllocator; + EXPECT_EQ(cpu_arena->Info().alloc_type, expected_allocator_type); size_t size = 1024; auto bytes = cpu_arena->Alloc(size); diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index e28327941dda4..3e5ef30e7ebef 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -115,6 +115,9 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) { SessionOptions so; FenceCudaTestInferenceSession session(so, GetEnvironment()); ASSERT_STATUS_OK(LoadInferenceSessionFromModel(session, *model)); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_TRUE(session.Initialize().IsOK()); ASSERT_TRUE(1 == CountCopyNodes(graph)); @@ -164,6 +167,9 @@ TEST(CUDAFenceTests, TileWithInitializer) { SessionOptions so; FenceCudaTestInferenceSession session(so, GetEnvironment()); ASSERT_STATUS_OK(LoadInferenceSessionFromModel(session, *model)); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_STATUS_OK(session.Initialize()); @@ -224,6 +230,9 @@ TEST(CUDAFenceTests, TileWithComputedInput) { SessionOptions so; FenceCudaTestInferenceSession session(so, GetEnvironment()); ASSERT_STATUS_OK(LoadInferenceSessionFromModel(session, *model)); + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session.RegisterExecutionProvider(DefaultCudaExecutionProvider())); ASSERT_TRUE(session.Initialize().IsOK()); diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index fa3545ef27d72..180a75a64c10e 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -580,13 +580,7 @@ TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) { // myfun is not removed because it was claimed by InternalTestingEP model_proto = session_object.GetModel().ToProto(); -#ifdef USE_TVM - // TVM EP takes the whole graph and optimizes it within its own framework. - // It does not retain the original graph. - ASSERT_EQ(0, model_proto.functions_size()); -#else ASSERT_EQ(1, model_proto.functions_size()); -#endif } TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) { diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index da5fa2c3a5a24..7f4616c964e33 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -34,6 +34,7 @@ #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/gpu_data_transfer.h" +#include "test/common/cuda_op_test_utils.h" #endif #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" @@ -45,7 +46,6 @@ #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" -#include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "dummy_provider.h" @@ -65,8 +65,6 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; using namespace onnxruntime::concurrency; -extern std::unique_ptr ort_env; - namespace { struct KernelRegistryAndStatus { std::shared_ptr kernel_registry = std::make_shared(); @@ -499,57 +497,6 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } -// Test feature serialize prepack weight is only used in PC with CPU on inference, -// disable this test for training, other device and eps -#if !ENABLE_TRAINING && !defined(USE_CUDA) && !defined(__wasm__) && !defined(USE_DNNL) && !defined(USE_QNN) && !defined(__ANDROID__) && !defined(USE_COREML) -// MLAS dispatcher used in matmul_nbits kernels here is 64 bit only -#if defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64) -TEST(InferenceSessionTests, TestPrePackSerialization) { - SessionOptions so; - std::string model_name = "model_with_matmul_nbits"; - - const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; - const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; - - so.session_logid = "InferenceSessionTests.TestPrepackSerialization"; - so.enable_cpu_mem_arena = false; - so.graph_optimization_level = TransformerLevel::Default; - so.optimized_model_filepath = optimized_model; - std::string external_initializer_file_name = model_name + "_opt.onnx.data"; - - // enable serialize prepack initializer to data file - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, - "1")); - // always save external initializer to data file for test - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, - "0")); - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, - external_initializer_file_name.c_str())); - - // optimize model with serialize prepack constant initializers - InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(test_model).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); - - // Verify prepack initializers are serialized into optimized model and data file - // load optimized model and check initializer are prepacked - auto logger = DefaultLoggingManager().CreateLogger("TestPrepackSerialization"); - std::shared_ptr model; - auto load_status = Model::Load(ToWideString(optimized_model), model, nullptr, *logger); - ASSERT_EQ(Status::OK(), load_status); - Graph& graph = model->MainGraph(); - - bool found_prepack_initializer = false; - for (const auto& item : graph.GetAllInitializedTensors()) { - if (item.first.find(':') != std::string::npos) { - found_prepack_initializer = true; - } - } - ASSERT_TRUE(found_prepack_initializer); -} -#endif -#endif - #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { @@ -689,6 +636,9 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { InferenceSession session_object(so, GetEnvironment()); #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #endif #ifdef USE_ROCM @@ -743,10 +693,16 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { InferenceSession session_object(so, GetEnvironment()); #ifdef USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #endif #ifdef USE_ROCM ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); +#endif +#ifdef USE_WEBGPU + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultWebGpuExecutionProvider())); #endif ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); ASSERT_STATUS_OK(session_object.Initialize()); @@ -773,7 +729,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { ASSERT_TRUE(lines[size - 1].find("]") != string::npos); std::vector tags = {"pid", "dur", "ts", "ph", "X", "name", "args"}; - bool has_api_info = false; + [[maybe_unused]] bool has_api_info = false; for (size_t i = 1; i < size - 1; ++i) { for (auto& s : tags) { ASSERT_TRUE(lines[i].find(s) != string::npos); @@ -784,14 +740,16 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions2) { #ifdef USE_ROCM has_api_info = has_api_info || lines[i].find("Api") != string::npos && lines[i].find("hipLaunch") != string::npos; +#endif +#ifdef USE_WEBGPU + has_api_info = has_api_info || lines[i].find("Api") != string::npos; #endif } } -#if defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING) +// Note that the apple device is a paravirtual device which may not support webgpu timestamp query. So skip the check on it. +#if (defined(USE_ROCM) && defined(ENABLE_ROCM_PROFILING)) || (defined(USE_WEBGPU) && !defined(__APPLE__)) ASSERT_TRUE(has_api_info); -#else - ASSERT_TRUE(has_api_info || true); #endif } @@ -860,6 +818,47 @@ TEST(InferenceSessionTests, CheckRunProfilerStartTime) { ASSERT_TRUE(before_start_time <= profiling_start_time && profiling_start_time <= after_start_time); } +TEST(InferenceSessionTests, CheckRunProfilerWithOptionalValues) { + // Test whether the profiler can work on model with optional values + SessionOptions so; + + so.session_logid = "CheckRunProfiler"; + so.enable_profiling = true; + so.profile_file_prefix = ORT_TSTR("onnxprofile_profile_test"); + + InferenceSession session_object(so, GetEnvironment()); + ASSERT_STATUS_OK(session_object.Load(ORT_TSTR("testdata/relu_with_optional.onnx"))); + ASSERT_STATUS_OK(session_object.Initialize()); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + + // prepare inputs + std::vector dims_x = {1}; + std::vector values_x = {-4}; + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_x, values_x, &ml_value); + NameMLValMap feeds; + feeds.insert(std::make_pair("input", ml_value)); + + // prepare outputs + std::vector output_names; + output_names.push_back("output"); + std::vector fetches; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {1}; + std::vector expected_values_y = {0}; + + // Now run + common::Status st = session_object.Run(run_options, feeds, output_names, &fetches); + if (!st.IsOK()) { + std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; + } + ASSERT_TRUE(st.IsOK()); + VerifyOutputs(fetches.at(0).Get(), expected_dims_y, expected_values_y); +} + TEST(InferenceSessionTests, MultipleSessionsNoTimeout) { SessionOptions session_options; @@ -1050,6 +1049,9 @@ static void TestBindHelper(const std::string& log_str, if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider) { #ifdef USE_CUDA auto provider = DefaultCudaExecutionProvider(); + if (provider == nullptr) { + return; + } gpu_provider = provider.get(); ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); #endif @@ -1645,6 +1647,9 @@ TEST(InferenceSessionTests, Test3LayerNestedSubgraph) { #if USE_TENSORRT ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider())); #elif USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #elif USE_ROCM ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); @@ -1797,6 +1802,9 @@ TEST(InferenceSessionTests, Test2LayerNestedSubgraph) { #if USE_TENSORRT ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultTensorrtExecutionProvider())); #elif USE_CUDA + if (DefaultCudaExecutionProvider() == nullptr) { + return; + } ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); #elif USE_ROCM ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); @@ -2152,6 +2160,9 @@ TEST(InferenceSessionTests, TestStrictShapeInference) { #ifdef USE_CUDA // disable it, since we are going to enable parallel execution with cuda ep TEST(InferenceSessionTests, DISABLED_TestParallelExecutionWithCudaProvider) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif string model_uri = "testdata/transform/fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; SessionOptions so; @@ -2175,6 +2186,10 @@ TEST(InferenceSessionTests, DISABLED_TestParallelExecutionWithCudaProvider) { } TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + OrtArenaCfg arena_cfg; arena_cfg.arena_extend_strategy = 1; // kSameAsRequested diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index 6e86e5b58aead..2313f00e4d123 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -9,6 +9,9 @@ #include "default_providers.h" #include "gtest/gtest.h" #include "test_utils.h" +#ifdef USE_CUDA +#include "test/common/cuda_op_test_utils.h" +#endif #include "test/test_environment.h" #include "asserts.h" @@ -74,6 +77,9 @@ void ExpectCopy(const onnxruntime::Node& source, const std::string copy_op, #ifdef USE_CUDA TEST(TransformerTest, MemcpyTransformerTest) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; auto model = std::make_shared("test", false, ModelMetaData(), PathString(), @@ -106,7 +112,9 @@ TEST(TransformerTest, MemcpyTransformerTest) { KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -129,6 +137,9 @@ TEST(TransformerTest, MemcpyTransformerTest) { } TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 7; auto model = std::make_shared("test", false, ModelMetaData(), PathString(), @@ -161,7 +172,9 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -281,7 +294,11 @@ TEST(TransformerTest, TestInitializerDuplicationInSubgraph) { KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -323,7 +340,11 @@ TEST(TransformerTest, MemcpyTransformerTestGraphInputConsumedOnMultipleDevices) KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; @@ -425,7 +446,11 @@ TEST(TransformerTest, MemcpyTransformerTestImplicitInputConsumedOnMultipleDevice KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCudaExecutionProvider, DefaultCudaExecutionProvider())); + ASSERT_STATUS_OK(execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique(CPUExecutionProviderInfo()))); KernelRegistryManager test_registry_manager; diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index 0f76cb61ace74..d0bc088175755 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -7,7 +7,6 @@ #include "core/framework/data_types.h" #include "core/graph/model.h" #include "core/framework/tensorprotoutils.h" -#include "core/framework/session_state.h" #include "test/test_environment.h" #include "test_utils.h" #include "test/util/include/asserts.h" @@ -20,34 +19,19 @@ using namespace onnxruntime; namespace onnxruntime { namespace test { -std::vector split(const std::string& str, char delimiter) { - std::vector result; - std::stringstream ss(str); - std::string token; - - // Use getline with a delimiter to split the string - while (std::getline(ss, token, delimiter)) { - result.push_back(token); - } - - return result; -} - Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, const std::filesystem::path& input_external_init_file, const std::filesystem::path& output_onnx, const std::filesystem::path& output_external_init_file, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers_tensor_proto, - bool save_prepacked_constant_initializers = false) { + const Graph::OffsetAlignmentInfo& align_info) { auto logger = DefaultLoggingManager().CreateLogger("LoadSaveAndCompareModel"); std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(input_onnx, model, nullptr, *logger)); std::filesystem::remove(output_onnx); std::filesystem::remove(output_external_init_file); ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(*model, output_onnx, output_external_init_file, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, pre_packed_initializers_tensor_proto)); + align_info)); std::shared_ptr model_from_external; ORT_RETURN_IF_ERROR(Model::Load(output_onnx.native(), model_from_external, nullptr, *logger)); @@ -66,11 +50,10 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Compare the initializers of the two versions. std::filesystem::path model_path{}; std::filesystem::path external_data_path{}; - for (const auto& i : initializers_from_external) { + for (const auto& i : initializers) { const std::string kInitName = i.first; - const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = i.second; - // prepack initializer will have name as [original name]:[kernel name] in case initializer used by multiple kernels - const ONNX_NAMESPACE::TensorProto* tensor_proto = save_prepacked_constant_initializers ? initializers[split(kInitName, ':')[0]] : initializers[kInitName]; + const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second; + const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName]; std::vector tensor_proto_data; model_path = input_onnx; @@ -92,12 +75,8 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, ORT_RETURN_IF_NOT(from_external_tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL, "location mismatch"); } - if (!save_prepacked_constant_initializers) { - ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); - ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); - } else { - ORT_RETURN_IF_NOT(from_external_tensor_proto_size >= tensor_proto_size, "prepack initializer's size is at least same as original tensor, might be larger"); - } + ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); + ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); if (align_info.align_offset) { for (const StringStringEntryProto& entry : from_external_tensor_proto->external_data()) { @@ -110,7 +89,6 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, } } } - // Cleanup. ORT_RETURN_IF_NOT(std::filesystem::remove(output_onnx), "delete file failed"); ORT_RETURN_IF_NOT(std::filesystem::remove(external_data_path), "delete file failed"); @@ -120,15 +98,13 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Original model does not have external initializers TEST(SaveWithExternalInitializers, Mnist) { Graph::OffsetAlignmentInfo align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info, pre_packed_initializers_tensor_proto)); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info)); } // Original model has external initializers TEST(SaveWithExternalInitializers, ModelWithOriginalExternalData) { Graph::OffsetAlignmentInfo align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); } // Original model has external initializers, align offset @@ -136,22 +112,7 @@ TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffset) { Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; align_info.align_threshold = 0; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); -} - -// Original model has external initializers, align offset and serialize prepacked external initializer to model file -TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffsetAndSavePrepackTensors) { - Graph::OffsetAlignmentInfo align_info; - align_info.align_offset = true; - align_info.align_threshold = 0; - std::shared_ptr alloc = std::make_shared(); - TensorShape shape = {178}; - // prepack both initializers for test purpose - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - pre_packed_initializers_tensor_proto["MatMul.Weight"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "MatMul.Weight:MatMul_0"); - pre_packed_initializers_tensor_proto["scales"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "scales:MatMul_0"); - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/prepack/model_with_matmul_nbits.onnx"), ORT_TSTR("model_with_matmul_nbits.onnx.data"), ORT_TSTR("testdata/prepack/model_with_matmul_nbits_opt.onnx"), ORT_TSTR("model_with_matmul_nbits_opt.onnx.data"), 0, align_info, pre_packed_initializers_tensor_proto, true)); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); } } // namespace test diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 6265eccb7bd9b..3e694020f796b 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -5,6 +5,7 @@ #include #include "asserts.h" +#include "core/framework/allocator_utils.h" #include "core/framework/execution_providers.h" #include "core/framework/graph_partitioner.h" #include "core/framework/kernel_registry.h" @@ -216,10 +217,12 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { // Test that we allocate memory for an initializer from non-arena memory even if we provide an arena-based allocator // if the relevant session option config flag is set -// For this test we need to enable the arena-based allocator which is not supported on x86 builds, so -// enable this test only on x64 builds -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { + // For this test we need to enable the arena-based allocator. + if (!DoesCpuAllocatorSupportArenaUsage()) { + GTEST_SKIP() << "CPU allocator does not support arena usage."; + } + AllocatorPtr cpu_allocator = std::make_shared(); // Part 1: Feature turned ON (i.e.) allocate from non-arena memory { @@ -348,8 +351,6 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { } } -#endif - INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStateTestP, testing::ValuesIn(param_list)); #ifndef ENABLE_TRAINING_CORE @@ -372,11 +373,10 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { ORT_UNUSED_PARAMETER(tensor); ORT_UNUSED_PARAMETER(input_idx); - ORT_UNUSED_PARAMETER(save_prepacked_initializers); size_t weight_packed_len = 8; weight_packed_ = IAllocator::MakeUniquePtr(alloc, weight_packed_len, true); @@ -394,20 +394,9 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } - std::optional GetPrePackTensor(int input_idx) override { - ORT_UNUSED_PARAMETER(input_idx); - ++get_prepack_tensors_count; - - TensorShape shape = {2}; - packed_tensor = Tensor(DataTypeImpl::GetType(), shape, std::make_shared()); - return std::move(packed_tensor); - } - int prepack_calls_count = 0; int store_pre_packed_weight_calls_count = 0; - int get_prepack_tensors_count = 0; IAllocatorUniquePtr weight_packed_; - Tensor packed_tensor; }; static void CreateSimpleGraph(Graph& graph) { @@ -542,7 +531,6 @@ static void PlaceAllNodesToCPUEP(Graph& graph) { struct PrepackingTestParam { bool test_subgraph; bool test_prepacking; - bool test_save_prepack_initializer; }; class SessionStatePrepackingTest : public testing::TestWithParam {}; @@ -585,8 +573,6 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { sess_options.enable_mem_reuse = true; sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = test_param.test_prepacking ? "0" : "1"; - sess_options.config_options.configurations[kOrtSessionOptionsSavePrePackedConstantInitializers] = - test_param.test_save_prepack_initializer ? "1" : "0"; SessionState session_state(model.MainGraph(), execution_providers, @@ -612,47 +598,12 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { kernel_registry_manager.RegisterKernelRegistry(kernel_registry); PlaceAllNodesToCPUEP(model.MainGraph()); - SessionState::PrePackInitializers pre_packed_initializers; ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string(), - kernel_registry_manager, - pre_packed_initializers)); + kernel_registry_manager)); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); // check prepacking ASSERT_EQ(const_initialized_tensors.size(), size_t(test_param.test_prepacking ? 0 : 1)); - - // check get prepack tensor method called when set save_prepacked_constant_initializers - if (!test_param.test_subgraph) { - const auto* kernel = reinterpret_cast(session_state.GetKernel(0)); - ASSERT_EQ(kernel->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); - } else { - auto if_index = 1; - if (session_state.GetKernel(0)->Node().OpType() == "If") { - if_index = 0; - } - - const auto& subgraph_session_states = session_state.GetSubgraphSessionStateMap(); - const auto& if_node_session_states = subgraph_session_states.at(if_index); - const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); - const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); - - const auto* kernel_if_0 = reinterpret_cast(session_state_1_then_branch_session_state.GetKernel(0)); - const auto* kernel_if_1 = reinterpret_cast(session_state_1_else_branch_session_state.GetKernel(0)); - ASSERT_EQ(kernel_if_0->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); - ASSERT_EQ(kernel_if_1->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); - } - - // check pre_packed_initializers_to_save will be set properly when set save_prepacked_constant_initializers - if (!test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("node_0_input_1"), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["node_0_input_1"].count("node_0"), size_t(1)); - } else if (test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("if_shared"), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_1"), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_0"), size_t(1)); - } } class SessionStateTestSharedInitalizersWithPrePacking : public ::testing::Test { @@ -1050,14 +1001,10 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, - testing::Values(PrepackingTestParam{false, false, false}, - PrepackingTestParam{false, true, false}, - PrepackingTestParam{true, false, false}, - PrepackingTestParam{true, true, false}, - PrepackingTestParam{false, false, true}, - PrepackingTestParam{false, true, true}, - PrepackingTestParam{true, false, true}, - PrepackingTestParam{true, true, true})); + testing::Values(PrepackingTestParam{false, false}, + PrepackingTestParam{false, true}, + PrepackingTestParam{true, false}, + PrepackingTestParam{true, true})); #endif } // namespace test diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 7bd6b47f52b7d..db9592c293fd0 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -1457,6 +1457,9 @@ TEST(SparseTensorConversionTests, CsrConversion) { #ifdef USE_CUDA auto cuda_provider = DefaultCudaExecutionProvider(); + if (cuda_provider == nullptr) { + return; + } auto cuda_allocator = cuda_provider->CreatePreferredAllocators()[0]; { auto cuda_transfer = cuda_provider->GetDataTransfer(); @@ -1684,6 +1687,9 @@ TEST(SparseTensorConversionTests, CooConversion) { #ifdef USE_CUDA auto cuda_provider = DefaultCudaExecutionProvider(); + if (cuda_provider == nullptr) { + return; + } auto cuda_allocator = cuda_provider->CreatePreferredAllocators()[0]; { auto cuda_transfer = cuda_provider->GetDataTransfer(); diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index 9202543b75a6f..fba099f9c55b3 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensor.h" +#include "core/framework/allocator_utils.h" #include "test_utils.h" #include "gmock/gmock.h" @@ -137,12 +138,10 @@ TEST(TensorTest, EmptyTensorTest) { ASSERT_STREQ(location.name, CPU); EXPECT_EQ(location.id, 0); - // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) - EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtArenaAllocator); -#else - EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtDeviceAllocator); -#endif + const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() + ? OrtAllocatorType::OrtArenaAllocator + : OrtAllocatorType::OrtDeviceAllocator; + EXPECT_EQ(location.alloc_type, expected_allocator_type); } TEST(TensorTest, StringTensorTest) { diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc index fde603858f9a9..9d8febb453739 100644 --- a/onnxruntime/test/lora/lora_test.cc +++ b/onnxruntime/test/lora/lora_test.cc @@ -200,63 +200,41 @@ TEST(LoraAdapterTest, Load) { } #ifdef USE_CUDA -TEST(LoraAdapterTest, VerifyCudaDeviceCopy) { - auto cpu_ep = DefaultCpuExecutionProvider(); - auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; - auto cuda_allocator = DefaultCudaExecutionProvider()->CreatePreferredAllocators()[0]; - auto cuda_transfer = DefaultCudaExecutionProvider()->GetDataTransfer(); - - auto test_params = GenerateTestParameters()(); - lora::LoraAdapter adapter(std::move(cuda_allocator)); - adapter.Load(std::move(test_params)); - - auto [begin, end] = adapter.GetParamIterators(); - for (; begin != end; ++begin) { - const auto& [_, param] = *begin; - const auto& tensor_device = param.GetDeviceOrMapped().Get(); - ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::CUDA)); - - const auto& tensor_cpu = param.GetMapped().Get(); - ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); - - Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator); - ASSERT_TRUE(cuda_transfer->CanCopy(tensor_device.Location().device, - copy.Location().device)); - ASSERT_STATUS_OK(cuda_transfer->CopyTensor(tensor_device, copy)); - - auto expected_span = tensor_cpu.DataAsSpan(); - auto copy_span = copy.DataAsSpan(); - - ASSERT_EQ(expected_span, copy_span); +TEST(LoraAdapterTest, VerifyDeviceCopy) { + // These checks for CUDA/DML combined Package, Be careful when you want to remove it! + if (DefaultCudaExecutionProvider() == nullptr) { + GTEST_SKIP() << "Skip This Test Due to this EP is null"; + } +#ifdef USE_DML + if (DefaultDmlExecutionProvider() != nullptr) { + GTEST_FAIL() << "It should not run with DML EP"; } -} #endif -#ifdef USE_DML -TEST(LoraAdapterTest, VerifyDmlDeviceCopy) { auto cpu_ep = DefaultCpuExecutionProvider(); auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; + auto cuda_ep = DefaultCudaExecutionProvider(); + auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0]; - auto dml_allocator = DefaultDmlExecutionProvider()->CreatePreferredAllocators()[0]; - auto dml_transfer = DefaultDmlExecutionProvider()->GetDataTransfer(); + auto gpu_transfer = cuda_ep->GetDataTransfer(); auto test_params = GenerateTestParameters()(); - lora::LoraAdapter adapter(std::move(dml_allocator)); + lora::LoraAdapter adapter(std::move(cuda_allocator)); adapter.Load(std::move(test_params)); auto [begin, end] = adapter.GetParamIterators(); for (; begin != end; ++begin) { const auto& [_, param] = *begin; const auto& tensor_device = param.GetDeviceOrMapped().Get(); - ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::DML)); + ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::CUDA)); const auto& tensor_cpu = param.GetMapped().Get(); ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator); - ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device, + ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device, copy.Location().device)); - ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy)); + ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy)); auto expected_span = tensor_cpu.DataAsSpan(); auto copy_span = copy.DataAsSpan(); @@ -265,6 +243,5 @@ TEST(LoraAdapterTest, VerifyDmlDeviceCopy) { } } #endif - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp b/onnxruntime/test/mlas/bench/bench_cast.cpp similarity index 100% rename from onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp rename to onnxruntime/test/mlas/bench/bench_cast.cpp diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp similarity index 53% rename from onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp rename to onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp index 71db7d81075b5..64d229889214b 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "benchmark/benchmark.h" @@ -16,16 +17,16 @@ #include "core/util/thread_utils.h" #include "core/platform/env_var_utils.h" -template -void RunSQNBitGemmBenchmark(size_t BlkLen, - size_t M, size_t N, size_t K, - size_t Threads, - bool Symmetric, - bool HasBias, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - benchmark::State& state) { - if (!MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); +template +void RunQNBitGemmBenchmark(size_t BlkLen, + size_t M, size_t N, size_t K, + size_t Threads, + bool Symmetric, + bool HasBias, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + benchmark::State& state) { + if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { + state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine."); return; } @@ -43,40 +44,40 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - const auto A = RandomVectorUniform(M * K, -1.0f, 1.0f); - const auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); + const auto A = RandomVectorUniform(M * K, AType(-1.0f), AType(1.0f)); + const auto B = RandomVectorUniform(K * N, AType(-1.0f), AType(1.0f)); - const auto Bias = HasBias ? RandomVectorUniform(N, -1.0f, 1.0f) : std::vector(); + const auto Bias = HasBias ? RandomVectorUniform(N, AType(-1.0f), AType(1.0f)) : std::vector(); - std::vector C(static_cast(M * N)); + std::vector C(static_cast(M * N)); std::vector QuantBData(QuantBDataSizeInBytes); - std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); bool has_zp_input = !Symmetric; - MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), B.data(), static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), static_cast(N), tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), - tp.get()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), + tp.get()); } - MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + MLAS_QNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) @@ -92,15 +93,15 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); } } -template -void SQNBITGEMM(benchmark::State& state) { +template +void QNBITGEMM(benchmark::State& state) { using onnxruntime::narrow; const auto BlkLen = narrow(state.range(0)); @@ -110,46 +111,50 @@ void SQNBITGEMM(benchmark::State& state) { const auto Threads = narrow(state.range(4)); const auto Symmetric = narrow(state.range(5)); const bool HasBias = narrow(state.range(6)); - const auto ComputeType = static_cast(state.range(7)); + const auto ComputeType = static_cast(state.range(7)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +template +static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); b->ArgsProduct({ - {128}, // BlkLen - {1}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType + {128}, // BlkLen + {1, 4096}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{false}, int64_t{true}}, // HasBias + std::is_same_v + ? std::vector{int64_t{HQNBIT_CompFp16}} + : std::vector{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); // This test gets benchmark arguments from environment variables. -template -void SQNBITGEMM_ENV(benchmark::State& state) { +template +void QNBITGEMM_ENV(benchmark::State& state) { using onnxruntime::ParseEnvironmentVariableWithDefault; - const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_BLKLEN", 32); - const auto M = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_M", 1); - const auto N = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_N", 4096); - const auto K = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_K", 4096); - const auto Threads = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_THREADS", 1); - const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); - const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); - const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", - static_cast(CompFp32)); + const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_BLKLEN", 32); + const auto M = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_M", 1); + const auto N = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_N", 4096); + const auto K = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_K", 4096); + const auto Threads = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_THREADS", 1); + const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_SYMMETRIC", true); + const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_HAS_BIAS", false); + const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_COMPUTE_TYPE", + static_cast(SQNBIT_CompFp32)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), - state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, + static_cast(ComputeType), + state); std::ostringstream s; s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen @@ -159,4 +164,4 @@ void SQNBITGEMM_ENV(benchmark::State& state) { state.SetLabel(s.str()); } -BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime(); +BENCHMARK(QNBITGEMM_ENV)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_util.h b/onnxruntime/test/mlas/bench/bench_util.h index f96dd5c673b3d..78789ef1cbc1a 100644 --- a/onnxruntime/test/mlas/bench/bench_util.h +++ b/onnxruntime/test/mlas/bench/bench_util.h @@ -8,8 +8,12 @@ #include #include +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas.h" + template -std::vector RandomVectorUniform( +typename std::enable_if_t, std::vector> +RandomVectorUniform( size_t N, ElementType min_value = std::numeric_limits::lowest(), ElementType max_value = std::numeric_limits::max()) { @@ -26,6 +30,25 @@ std::vector RandomVectorUniform( return r; } +template +typename std::enable_if_t, std::vector> +RandomVectorUniform( + size_t N, + ElementType min_value, + ElementType max_value) { + if (min_value.ToFloat() >= max_value.ToFloat()) { + return std::vector(N, min_value); + } + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(min_value.ToFloat(), max_value.ToFloat()); + + std::vector r(N); + for (size_t i = 0; i < N; i++) { + r[i] = ElementType(distribution(generator)); + } + return r; +} + std::vector RandomVectorUniform(std::vector shape, float min_value, float max_value); std::vector BenchArgsVector(benchmark::State& state, size_t& start, size_t count); diff --git a/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp new file mode 100644 index 0000000000000..b598c20e29280 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp @@ -0,0 +1,501 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hqnbitgemm_neon.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/qnbitgemm.h" +#include "mlas_qnbit.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + MatrixGuardBuffer fp32Buffer_; + MatrixGuardBuffer fp16Buffer_; + + template + void TestFp16ToFp32() { + const auto* src = fp16Buffer_.GetFilledBuffer(count, [](unsigned short* start, size_t size) { + for (size_t i = 0; i < size; i++) { + start[i] = static_cast(i); + } + }); + auto* dest = fp32Buffer_.GetBuffer(count, true); + + MlasCastF16ToF32KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + template + void TestFp32ToFp16() { + const auto* src = fp32Buffer_.GetFilledBuffer(count, [](float* p, size_t size) { + for (size_t i = 0; i < size; i++) { + p[i] = static_cast(i) + 0.125f; + } + }); + auto* dest = fp16Buffer_.GetBuffer(count, true); + + MlasCastF32ToF16KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32<(1 << 16)>(); + TestFp16ToFp32<1>(); + TestFp16ToFp32<4>(); + TestFp16ToFp32<7>(); + TestFp32ToFp16<(1 << 16)>(); + TestFp32ToFp16<3>(); + TestFp32ToFp16<4>(); + TestFp32ToFp16<6>(); + } +}; + +class MlasNeonFp16PrepackTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void Transpose8x8(const uint8_t* src, size_t n, size_t k, uint8_t* dst) { + for (size_t c = 0; c < 8; c++) { + for (size_t r = 0; r < 8; r++) { + size_t i = (n + c) * Ldb + r + k; + size_t j = n * Ldb + (r + k) * 8 + c; + dst[j] = src[i]; + } + } + } + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + MLAS_FORCEINLINE + void PrepackSlice(const uint8_t* src, size_t j, uint8_t* dst) { + for (size_t i = 0; i < 8; i++) { + uint8_t v0 = GetInt4(src[j + (i >> 1)], i); + uint8_t v1 = GetInt4(src[j + ((8 + i) >> 1)], i + 8); + dst[j + i] = v0 | (v1 << 4); + } + } + + template + MLAS_FORCEINLINE void Prepack(const uint8_t* src, uint8_t* dst) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t k = 0; k < Ldb; k += 8) { + Transpose8x8(src, n, k, dst); + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < Ldb; k += 8) { + PrepackSlice(src, n * Ldb + k, dst); + } + } + } + + template + MLAS_FORCEINLINE void Check(const uint8_t* packed, const uint8_t* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; i += 2) { + for (size_t j = 0; j < 8; ++j) { + ASSERT_EQ(packed[n * Ldb + (i >> 1) * 8 + j], ref[n * Ldb + (i >> 1) * 8 + j]) + << " seed " << seed_ + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; i += 2) { + ASSERT_EQ(packed[n * Ldb + (i >> 1)], ref[n * Ldb + (i >> 1)]) + << " seed " << seed_ + << " n " << n << " i " << i; + } + } + } + + template + void TestPrepack() { + constexpr size_t Bits = 4; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t BufferSize = N * Ldb; + auto InitializeBuffer = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BufferSize, InitializeBuffer); + auto* packed = packed_.GetBuffer(BufferSize, true); + auto* ref = ref_.GetBuffer(BufferSize, true); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::HQNBIT_CompFp16, input, packed, + nullptr, false, nullptr, nullptr); + Prepack(input, ref); + Check(packed, ref); + } + + public: + MlasNeonFp16PrepackTest() + : seed_(19287), gen_(seed_), distrib_(0, 255) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16Prepack"; + } + + void ExecuteShort(void) override { + TestPrepack<1, 1, 16>(); + TestPrepack<1, 15, 16>(); + TestPrepack<1, 31, 16>(); + TestPrepack<8, 1, 16>(); + TestPrepack<8, 16, 16>(); + TestPrepack<9, 31, 16>(); + TestPrepack<9, 33, 32>(); + TestPrepack<15, 33, 16>(); + TestPrepack<17, 67, 16>(); + TestPrepack<17, 96, 128>(); + TestPrepack<263, 263, 16>(); + } +}; + +class MlasNeonFp16DequantBTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + std::uniform_real_distribution _distribFp; + MatrixGuardBuffer input_, zero_points_; + MatrixGuardBuffer dequant_, ref_, scales_; + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + template + void DequantB(const uint8_t* src, MLAS_FP16* dst, const MLAS_FP16* scales, const uint8_t* zero_points) { + constexpr size_t blkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ld_src = (blkNum * BlkLen + 1) / 2; + constexpr size_t ld_dst = blkNum * BlkLen; + constexpr size_t ld_zp = (blkNum + 1) / 2; + size_t n = 0; + for (; n + 8 <= N; n += 8) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + for (size_t i = 0; i < BlkLen; i += 2, i_dst += 8) { + for (size_t j = 0; j < 8; ++j, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp + ld_zp * j], blk) : 8); + float scale = scales[i_scale + blkNum * j]; + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + + for (; n < N; ++n) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp], blk) : 8); + float scale = scales[i_scale]; + for (size_t i = 0; i < BlkLen; i += 16, i_dst += 8) { + for (size_t j = 0; j < 16; j += 2, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = std::abs(v0.ToFloat()), f1 = std::abs(v1.ToFloat()); + return std::abs(f0 - f1) <= f1 * rtol + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < 8; ++j) { + size_t idx = n * Ldb + i * 8 + j; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; ++i) { + size_t idx = n * Ldb + i; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i; + } + } + } + + template + void TestDequant() { + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t BCount = BlkNum * BlkLen * N; + constexpr size_t ScaleCount = N * BlkNum; + constexpr size_t ZpSize = N * ((BlkNum + 1) / 2); + + auto InitializeBuffer_i8 = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + auto InitializeBuffer_fp16 = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(_distribFp(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BCount / 2, InitializeBuffer_i8); + const auto* zero_points = zero_points_.GetFilledBuffer(ZpSize, InitializeBuffer_i8); + auto* dequant = dequant_.GetBuffer(BCount); + auto* ref = ref_.GetBuffer(BCount); + const auto* scales = scales_.GetFilledBuffer(ScaleCount, InitializeBuffer_fp16); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant, reinterpret_cast(input), scales, + UseZeroPoints ? reinterpret_cast(zero_points) : nullptr, + N, K, BlkNum); + DequantB(input, ref, scales, zero_points); + Check(dequant, ref); + } + + public: + MlasNeonFp16DequantBTest() + : seed_(19287), gen_(seed_), distrib_(0, 255), _distribFp(0.5f, 2.0f) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16DequantB"; + } + + void ExecuteShort(void) override { + TestDequant<1, 1, 16, false>(); + TestDequant<1, 1, 16, true>(); + TestDequant<1, 15, 16, false>(); + TestDequant<1, 15, 16, true>(); + TestDequant<1, 31, 16, false>(); + TestDequant<1, 31, 16, true>(); + TestDequant<8, 1, 16, false>(); + TestDequant<8, 1, 16, true>(); + TestDequant<8, 16, 16, false>(); + TestDequant<8, 16, 16, true>(); + TestDequant<9, 31, 16, false>(); + TestDequant<9, 31, 16, true>(); + TestDequant<9, 33, 32, false>(); + TestDequant<9, 33, 32, true>(); + TestDequant<15, 33, 16, false>(); + TestDequant<15, 33, 16, true>(); + TestDequant<17, 67, 16, false>(); + TestDequant<17, 67, 16, true>(); + TestDequant<17, 96, 128, false>(); + TestDequant<17, 96, 128, true>(); + TestDequant<263, 263, 16, false>(); + TestDequant<263, 263, 16, true>(); + } +}; + +class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + MatrixGuardBuffer A_, B_, C_, ref_, bias_; + + MLAS_FORCEINLINE + void InitializeBuffer(MLAS_FP16* buffer, float min, float max, size_t count) { + std::uniform_real_distribution distrib(min, max); + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib(gen_)); + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + float GetBVal(const MLAS_FP16* B, size_t n, size_t k) { + size_t i; + if ((N & (~7)) > n) { + size_t full8 = n & (~7); + i = full8 * ldb + 8 * k + (n - full8); + } else { + i = n * ldb + k; + } + return B[i].ToFloat(); + } + + template + void MatMul(const MLAS_FP16* A, const MLAS_FP16* B, const MLAS_FP16* bias, MLAS_FP16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = UseBias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * K + k].ToFloat(); + float b = GetBVal(B, n, k); + accu = accu + a * b; + } + C[m * N + n] = MLAS_FP16(accu); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * Ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestHQ4BitGemmKernel() { + static_assert(M <= 2); + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkNum * BlkLen; + + const auto* A = A_.GetFilledBuffer(M * K, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + const auto* B = B_.GetFilledBuffer(ldb * N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + auto* bias = bias_.GetFilledBuffer(N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -5.0f, 5.0f, t); + }); + + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + A, B, UseBias ? bias : nullptr, C, M, N, K, K, ldb, N); + + MatMul(A, B, bias, ref); + Check(C, ref); + } + + public: + MlasNeonFp16HQ4BitGemmKernelTest() + : seed_(19287), gen_(seed_) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16HQ4BitGemmKernel"; + } + + template + void ExecuteShort_T(void) { + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + } + + void ExecuteShort(void) override { + ExecuteShort_T<1>(); + ExecuteShort_T<2>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + if (GetMlasPlatform().QNBitGemmDispatch) { + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmPackQuantBData) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + } + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 0710981fa17c6..e22018ae2877f 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -18,11 +18,11 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" -static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { +static constexpr const char* ComputeTypeName(MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType) { switch (ComputeType) { - case CompFp32: + case SQNBIT_CompFp32: return "Fp32"; - case CompInt8: + case SQNBIT_CompInt8: return "Int8"; default: return "unknown"; @@ -63,16 +63,16 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C, size_t ldc, void* Workspace, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { - MLAS_SQNBIT_GEMM_DATA_PARAMS params; + MLAS_QNBIT_GEMM_DATA_PARAMS params; params.A = A; params.lda = lda; params.Bias = Bias; params.C = C; params.ldc = ldc; #ifdef MLAS_TARGET_AMD64_IX86 - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { params.QuantBDataWorkspace = PackedQuantBDataWorkspace; } #endif @@ -81,7 +81,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); } void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { @@ -201,7 +201,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; @@ -265,19 +265,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } void* PackedQuantBDataWorkspace = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); bool has_zp_input = QuantBZeroPoint != nullptr; - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, - QuantBScale, has_zp_input, QuantBZeroPoint, - GetMlasThreadPool()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, + GetMlasThreadPool()); } CallGemm(M, N, K, @@ -289,9 +289,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { ComputeType, Threadpool); - if (ComputeType == CompFp32) { + if (ComputeType == SQNBIT_CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == CompInt8) { + } else if (ComputeType == SQNBIT_CompInt8) { CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else { FAIL() << "Test is not implemented for compute type " @@ -324,7 +324,7 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) : M_(M), N_(N), @@ -341,11 +341,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture= range ? FillValue - range : FillValue; } }); } diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 93a1bf9f30651..99c3e44e13013 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -454,8 +454,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (ep_context_enable) sf.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - if (disable_ep_context_embed_mode) + if (disable_ep_context_embed_mode) { sf.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + } else { + sf.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "1"); + } for (auto& it : session_config_entries) { sf.AddConfigEntry(it.first.c_str(), it.second.c_str()); @@ -631,7 +634,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { } if (enable_coreml) { #ifdef USE_COREML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sf, 0)); + sf.AppendExecutionProvider("CoreML", {}); #else fprintf(stderr, "CoreML is not supported in this build"); return -1; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 3aec0d5a67e94..2ff0b599beebf 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -831,7 +831,8 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_mapName() == "ConstantFolding") { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), TransformerLevel::Level1)); @@ -1764,6 +1765,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { } } +TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; @@ -4675,7 +4705,8 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { // Compare results double per_sample_tolerance = 1e-3; double relative_per_sample_tolerance = 0.0; - auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], per_sample_tolerance, relative_per_sample_tolerance, false); + auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], + per_sample_tolerance, relative_per_sample_tolerance, false); EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; } @@ -4684,7 +4715,8 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt std::make_unique(CPUExecutionProviderInfo()); bool has_gelu_approximation = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "GeluApproximation") { has_gelu_approximation = true; @@ -4699,7 +4731,8 @@ TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { std::unique_ptr cpu_ep = std::make_unique(CPUExecutionProviderInfo()); bool has_double_qdq_remover = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "DoubleQDQPairsRemover") { has_double_qdq_remover = true; @@ -5859,6 +5892,22 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) { std::map op_to_count = CountOpsInGraph(graph); EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); } + +TEST_F(GraphTransformationTests, MatMulIntegerToFloatLargeTensorTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float_large_tensor.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType(kDmlExecutionProvider); + } + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 0); +} #endif // USE_DML #endif diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 66b74641e41d3..caa64560426af 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -36,9 +36,11 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -61,8 +63,9 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, + disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 81c1a4ace1e33..b306f026b2dfd 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -27,6 +27,7 @@ namespace test { TEST(OptimizerTest, Basic) { Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); auto& graph = model.MainGraph(); constexpr int tensor_dim = 10; @@ -66,22 +67,21 @@ TEST(OptimizerTest, Basic) { auto cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo()); #if !defined(DISABLE_SPARSE_TENSORS) - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [&graph](const std::string& name) -> bool { - return graph.IsSparseInitializer(name); - }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [&graph](const std::string& name) -> bool { + return graph.IsSparseInitializer(name); + }, + logger); #else - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [](std::string const&) { return false; }, + logger); #endif //! defined(DISABLE_SPARSE_TENSORS) std::vector fetch_mlvalue_idxs{info.GetMLValueIndex("out")}; OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); - const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); const ConfigOptions empty_config_options; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d07977d4b97b8..043b92d7ef121 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -11,6 +11,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" +#include "core/optimizer/qdq_transformer/bias_quantization.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -3927,6 +3928,7 @@ TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) { TEST(QDQTransformerTests, QDQ_Selector_Test) { const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx"); + const auto& logger = DefaultLoggingManager().DefaultLogger(); SessionOptions so; // We want to keep the graph un-optimized to prevent QDQ transformer to kick in @@ -3961,7 +3963,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check if SelectorManager get a conv qdq group selection as expected { - const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger); ASSERT_FALSE(result.empty()); const auto& qdq_group = result.at(0); ASSERT_EQ(std::vector({0, 1, 2}), qdq_group.dq_nodes); @@ -3976,7 +3978,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger); // We should get a single QDQ Node unit in the result ASSERT_EQ(1, node_unit_holder.size()); @@ -4044,7 +4046,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check SelectorManager will get empty result { - const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer, logger); ASSERT_TRUE(result.empty()); } } @@ -4846,5 +4848,95 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) { } #endif // !defined(DISABLE_CONTRIB_OPS) +TEST(QDQTransformerTests, BiasQuantization_Conv) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = builder.MakeInput({1, 24, 128, 128}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({24, 1, 3, 3}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* bias_arg = builder.MakeInitializer({24}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* weight_dq_arg = builder.MakeIntermediate(); + NodeArg* conv_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.07f, static_cast(0), input_dq_arg, + use_contrib_qdq); + auto& weight_dq_node = builder.AddDequantizeLinearNode(weight_arg, std::vector(24, 0.05f), + std::vector(24, static_cast(0)), + weight_dq_arg, nullptr, use_contrib_qdq); + weight_dq_node.AddAttribute("axis", static_cast(0)); + auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_dq_arg, bias_arg}, {conv_dq_arg}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + conv_node.AddAttribute("strides", std::vector{1, 1}); + conv_node.AddAttribute("group", static_cast(24)); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + builder.AddQuantizeLinearNode(conv_dq_arg, 0.14f, static_cast(127), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["QLinearConv"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + +TEST(QDQTransformerTests, BiasQuantization_Gemm) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = + builder.MakeInput({1, 32}, std::numeric_limits::min(), std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({16, 32}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* bias_arg = builder.MakeInitializer({16}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* weight_dq_arg = builder.MakeIntermediate(); + NodeArg* gemm_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.001f, static_cast(0), input_dq_arg, + use_contrib_qdq); + builder.AddDequantizeLinearNode(weight_arg, 0.26f, static_cast(0), weight_dq_arg, + use_contrib_qdq); + auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg}); + gemm_node.AddAttribute("transB", static_cast(1)); + builder.AddQuantizeLinearNode(gemm_dq_arg, 0.144f, static_cast(69), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 35ba1a3369597..f6fce37322c10 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -22,6 +22,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/internal_testing/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test_utils.h" @@ -3800,6 +3801,46 @@ TEST(TransposeOptimizerTests, TestCast) { /*opset_version*/ {15, 18}); } +TEST(TransposeOptimizerTests, TestQLinearSoftmax) { + auto build_test_case_1 = [&](ModelTestBuilder& builder) { + auto* input0_arg = MakeInput(builder, std::nullopt, {1, 384, 384, 21}, 0, 255); + auto* transpose_1_out_0 = builder.MakeIntermediate(); + auto* input_x_scale = builder.MakeScalarInitializer(0.5086354613304138); + auto* input_x_zero_point = builder.MakeScalarInitializer(74); + auto* input_y_scale = builder.MakeScalarInitializer(0.003921568859368563); + auto* input_y_zero_point = builder.MakeScalarInitializer(0); + auto* qlinearsoftmax_1_out_0 = builder.MakeIntermediate(); + auto* transpose_2_out_0 = builder.MakeOutput(); + + auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); + auto& qlinearsoftmax_1 = builder.AddNode("QLinearSoftmax", + {transpose_1_out_0, input_x_scale, input_x_zero_point, input_y_scale, input_y_zero_point}, + {qlinearsoftmax_1_out_0}, kMSDomain); + qlinearsoftmax_1.AddAttribute("axis", static_cast(1)); + qlinearsoftmax_1.AddAttribute("opset", static_cast(13)); + auto& transpose_2 = builder.AddNode("Transpose", {qlinearsoftmax_1_out_0}, {transpose_2_out_0}); + transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); + }; + + auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { + int transpose_cost = EstimateTransposeCost(session.GetGraph()); + EXPECT_EQ(transpose_cost, 0); + }; + + TransformerTester(build_test_case_1, + check_optimized_graph_1, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 13, + /*per_sample_tolerance*/ 0.0, + /*relative_per_sample_tolerance*/ 0.0, + /*transformer*/ nullptr, + /*add_session_options*/ {}, + /*disabled_optimizers*/ {}, + /*ep*/ DefaultCpuExecutionProvider()); +} + TEST(TransposeOptimizerTests, TestBroadcastReusedInputs) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{-1, -1, 3, 4}}, {1, 2, 3, 4}, 0.0, 1.0); diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index e40544d950ed7..23c3812ebd025 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -24,6 +24,7 @@ #include #include "test_configuration.h" +#include "strings_helper.h" namespace onnxruntime { namespace perftest { @@ -36,7 +37,7 @@ namespace perftest { "\t\tProvide 'duration' to run the test for a fix duration, and 'times' to repeated for a certain times. \n" "\t-M: Disable memory pattern.\n" "\t-A: Disable memory arena\n" - "\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n" + "\t-I: Generate tensor input binding. Free dimensions are treated as 1 unless overridden using -f.\n" "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai|webgpu]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'. " @@ -100,6 +101,7 @@ namespace perftest { "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" @@ -129,8 +131,14 @@ namespace perftest { "\t [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n" "\t [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n" "\n" - "\t [CoreML only] [COREML_FLAG_CREATE_MLPROGRAM COREML_FLAG_USE_CPU_ONLY COREML_FLAG_USE_CPU_AND_GPU]: Create an ML Program model instead of Neural Network.\n" - "\t [Example] [For CoreML EP] -e coreml -i \"COREML_FLAG_CREATE_MLPROGRAM\"\n" + "\t [CoreML only] [ModelFormat]:[MLProgram, NeuralNetwork] Create an ML Program model or Neural Network. Default is NeuralNetwork.\n" + "\t [CoreML only] [MLComputeUnits]:[CPUAndNeuralEngine CPUAndGPU ALL CPUOnly] Specify to limit the backend device used to run the model.\n" + "\t [CoreML only] [AllowStaticInputShapes]:[0 1].\n" + "\t [CoreML only] [EnableOnSubgraphs]:[0 1].\n" + "\t [CoreML only] [SpecializationStrategy]:[Default FastPrediction].\n" + "\t [CoreML only] [ProfileComputePlan]:[0 1].\n" + "\t [CoreML only] [AllowLowPrecisionAccumulationOnGPU]:[0 1].\n" + "\t [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram MLComputeUnits|CPUAndGPU\"\n" "\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" "\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n" @@ -175,39 +183,6 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, return true; } -static bool ParseSessionConfigs(const std::string& configs_string, - std::unordered_map& session_configs) { - std::istringstream ss(configs_string); - std::string token; - - while (ss >> token) { - if (token == "") { - continue; - } - - std::string_view token_sv(token); - - auto pos = token_sv.find("|"); - if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { - // Error: must use a '|' to separate the key and value for session configuration entries. - return false; - } - - std::string key(token_sv.substr(0, pos)); - std::string value(token_sv.substr(pos + 1)); - - auto it = session_configs.find(key); - if (it != session_configs.end()) { - // Error: specified duplicate session configuration entry: {key} - return false; - } - - session_configs.insert(std::make_pair(std::move(key), std::move(value))); - } - - return true; -} - /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) { @@ -382,7 +357,13 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.run_config.intra_op_thread_affinities = ToUTF8String(optarg); break; case 'C': { - if (!ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries)) { + ORT_TRY { + ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "Error parsing session configuration entries: %s\n", ex.what()); + }); return false; } break; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index e69c87b2540e5..a96028ed3903e 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -17,6 +17,7 @@ #include #include "providers.h" #include "TestCase.h" +#include "strings_helper.h" #ifdef USE_OPENVINO #include "nlohmann/json.hpp" @@ -58,6 +59,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device Ort::SessionOptions session_options; provider_name_ = performance_test_config.machine_config.provider_type_name; + std::unordered_map provider_options; if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { #ifdef USE_DNNL // Generate provider options @@ -72,24 +74,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif // defined(_MSC_VER) int num_threads = 0; - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW( - "[ERROR] [OneDNN] Use a '|' to separate the key and value for the " - "run-time option you are trying to use.\n"); - } - - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - - if (key == "num_of_threads") { - std::stringstream sstream(value); + ParseSessionConfigs(ov_string, provider_options, {"num_of_threads"}); + for (const auto& provider_option : provider_options) { + if (provider_option.first == "num_of_threads") { + std::stringstream sstream(provider_option.second); sstream >> num_threads; if (num_threads < 0) { ORT_THROW( @@ -97,10 +85,6 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device " set number of threads or use '0' for default\n"); // If the user doesnt define num_threads, auto detect threads later } - } else { - ORT_THROW( - "[ERROR] [OneDNN] wrong key type entered. " - "Choose from the following runtime key options that are available for OneDNN. ['num_of_threads']\n"); } } dnnl_options.threadpool_args = static_cast(&num_threads); @@ -144,22 +128,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW( - "[ERROR] [CUDA] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - buffer.emplace_back(token.substr(0, pos)); - option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(token.substr(pos + 1)); - option_values.push_back(buffer.back().c_str()); + ParseSessionConfigs(ov_string, provider_options); + for (const auto& provider_option : provider_options) { + option_keys.push_back(provider_option.first.c_str()); + option_values.push_back(provider_option.second.c_str()); } Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options, @@ -192,24 +164,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW( - "[ERROR] [TensorRT] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - buffer.emplace_back(token.substr(0, pos)); - option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(token.substr(pos + 1)); - option_values.push_back(buffer.back().c_str()); + ParseSessionConfigs(ov_string, provider_options); + for (const auto& provider_option : provider_options) { + option_keys.push_back(provider_option.first.c_str()); + option_values.push_back(provider_option.second.c_str()); } - Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options, option_keys.data(), option_values.data(), option_keys.size())); if (!status.IsOK()) { @@ -239,22 +198,14 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string option_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(option_string); - std::string token; - std::unordered_map qnn_options; - - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use."); - } - - std::string key(token.substr(0, pos)); - std::string value(token.substr(pos + 1)); - + ParseSessionConfigs(option_string, provider_options, + {"backend_path", "profiling_file_path", "profiling_level", "rpc_control_latency", + "vtcm_mb", "soc_model", "device_id", "htp_performance_mode", "qnn_saver_path", + "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", + "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer"}); + for (const auto& provider_option : provider_options) { + const std::string& key = provider_option.first; + const std::string& value = provider_option.second; if (key == "backend_path" || key == "profiling_file_path") { if (value.empty()) { ORT_THROW("Please provide the valid file path."); @@ -302,7 +253,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -311,16 +262,9 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for ", key, ". select from: ", str); } - } else { - ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', -'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', -'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])"); } - - qnn_options[key] = value; } - session_options.AppendExecutionProvider("QNN", qnn_options); + session_options.AppendExecutionProvider("QNN", provider_options); #else ORT_THROW("QNN is not supported in this build\n"); #endif @@ -331,22 +275,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #else std::string option_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(option_string); - std::string token; - std::unordered_map snpe_options; - - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - std::string key(token.substr(0, pos)); - std::string value(token.substr(pos + 1)); - + ParseSessionConfigs(option_string, provider_options, {"runtime", "priority", "buffer_type", "enable_init_cache"}); + for (const auto& provider_option : provider_options) { if (key == "runtime") { std::set supported_runtime = {"CPU", "GPU_FP32", "GPU", "GPU_FLOAT16", "DSP", "AIP_FIXED_TF"}; if (supported_runtime.find(value) == supported_runtime.end()) { @@ -365,14 +295,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); if (value != "1") { ORT_THROW("Set to 1 to enable_init_cache."); } - } else { - ORT_THROW("Wrong key type entered. Choose from options: ['runtime', 'priority', 'buffer_type', 'enable_init_cache'] \n"); } - - snpe_options[key] = value; } - session_options.AppendExecutionProvider("SNPE", snpe_options); + session_options.AppendExecutionProvider("SNPE", provider_options); #else ORT_THROW("SNPE is not supported in this build\n"); #endif @@ -416,30 +342,43 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { #ifdef __APPLE__ #ifdef USE_COREML - uint32_t coreml_flags = 0; std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; - std::istringstream ss(ov_string); - - std::string key; - while (ss >> key) { - if (key == "COREML_FLAG_CREATE_MLPROGRAM") { - coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; - std::cout << "Enabling ML Program.\n"; - } else if (key == "COREML_FLAG_USE_CPU_ONLY") { - coreml_flags |= COREML_FLAG_USE_CPU_ONLY; - std::cout << "CoreML enabled COREML_FLAG_USE_CPU_ONLY.\n"; - } else if (key == "COREML_FLAG_USE_CPU_AND_GPU") { - coreml_flags |= COREML_FLAG_USE_CPU_AND_GPU; - std::cout << "CoreML enabled COREML_FLAG_USE_CPU_AND_GPU.\n"; - } else if (key.empty()) { + static const std::unordered_set available_keys = {kCoremlProviderOption_MLComputeUnits, + kCoremlProviderOption_ModelFormat, + kCoremlProviderOption_RequireStaticInputShapes, + kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU}; + ParseSessionConfigs(ov_string, provider_options, available_keys); + + std::unordered_map available_options = { + {"CPUAndNeuralEngine", "1"}, + {"CPUAndGPU", "1"}, + {"CPUOnly", "1"}, + {"ALL", "1"}, + }; + for (const auto& provider_option : provider_options) { + if (provider_option.first == kCoremlProviderOption_MLComputeUnits && + available_options.find(provider_option.second) != available_options.end()) { + } else if (provider_option.first == kCoremlProviderOption_ModelFormat && + (provider_option.second == "MLProgram" || provider_option.second == "NeuralNetwork")) { + } else if (provider_option.first == kCoremlProviderOption_RequireStaticInputShapes && + (provider_option.second == "1" || provider_option.second == "0")) { + } else if (provider_option.first == kCoremlProviderOption_EnableOnSubgraphs && + (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_SpecializationStrategy && + (provider_option.second == "Default" || provider_option.second == "FastPrediction")) { + } else if (provider_option.first == kCoremlProviderOption_ProfileComputePlan && + (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU && + (provider_option.second == "0" || provider_option.second == "1")) { } else { - ORT_THROW( - "[ERROR] [CoreML] wrong key type entered. Choose from the following runtime key options " - "that are available for CoreML. ['COREML_FLAG_CREATE_MLPROGRAM'] \n"); + ORT_THROW("Invalid value for option ", provider_option.first, ": ", provider_option.second); } } // COREML_FLAG_CREATE_MLPROGRAM - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags)); + session_options.AppendExecutionProvider("CoreML", provider_options); #else ORT_THROW("CoreML is not supported in this build\n"); #endif @@ -448,34 +387,20 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML - std::unordered_map dml_options; - dml_options["performance_preference"] = "high_performance"; - dml_options["device_filter"] = "gpu"; - dml_options["disable_metacommands"] = "false"; - dml_options["enable_graph_capture"] = "false"; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(ov_string); - std::string token; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [DML] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - + ParseSessionConfigs(ov_string, provider_options, + {"device_filter", "performance_preference", "disable_metacommands", + "enable_graph_capture", "enable_graph_serialization"}); + for (const auto& provider_option : provider_options) { + const std::string& key = provider_option.first; + const std::string& value = provider_option.second; if (key == "device_filter") { std::set ov_supported_device_types = {"gpu", "npu"}; if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong configuration value for the key 'device_filter'. " @@ -484,7 +409,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (key == "performance_preference") { std::set ov_supported_values = {"default", "high_performance", "minimal_power"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. " @@ -493,7 +417,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (key == "disable_metacommands") { std::set ov_supported_values = {"true", "True", "false", "False"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong value for the key 'disable_metacommands'. " @@ -502,7 +425,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else if (key == "enable_graph_capture") { std::set ov_supported_values = {"true", "True", "false", "False"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { - dml_options[key] = value; } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong value for the key 'enable_graph_capture'. " @@ -519,7 +441,19 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } } } - session_options.AppendExecutionProvider("DML", dml_options); + if (provider_options.find("performance_preference") == provider_options.end()) { + provider_options["performance_preference"] = "high_performance"; + } + if (provider_options.find("device_filter") == provider_options.end()) { + provider_options["device_filter"] = "gpu"; + } + if (provider_options.find("disable_metacommands") == provider_options.end()) { + provider_options["disable_metacommands"] = "false"; + } + if (provider_options.find("enable_graph_capture") == provider_options.end()) { + provider_options["enable_graph_capture"] = "false"; + } + session_options.AppendExecutionProvider("DML", provider_options); #else ORT_THROW("DML is not supported in this build\n"); #endif @@ -530,21 +464,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif // defined(_MSC_VER) - std::istringstream ss(ov_string); - std::string token; bool enable_fast_math = false; - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [ACL] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } - - auto key = token.substr(0, pos); - auto value = token.substr(pos + 1); - + ParseSessionConfigs(ov_string, provider_options, {"enable_fast_math"}); + for (const auto& provider_option : provider_options) { if (key == "enable_fast_math") { std::set ov_supported_values = {"true", "True", "false", "False"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { @@ -554,9 +476,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [ACL] You have selcted an invalid value for the key 'enable_fast_math'. " "Select from 'true' or 'false' \n"); } - } else { - ORT_THROW( - "[ERROR] [ACL] Unrecognized option: ", key); } } Ort::ThrowOnError( @@ -601,8 +520,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kWebGpuExecutionProvider) { #ifdef USE_WEBGPU - session_options.AppendExecutionProvider( - "WebGPU", {{"intra_op_num_threads", std::to_string(performance_test_config.run_config.intra_op_num_threads)}}); + session_options.AppendExecutionProvider("WebGPU", {}); #else ORT_THROW("WebGPU is not supported in this build\n"); #endif @@ -613,24 +531,9 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #else std::string option_string = performance_test_config.run_config.ep_runtime_config_string; #endif - std::istringstream ss(option_string); - std::string token; - std::unordered_map vitisai_session_options; - - while (ss >> token) { - if (token == "") { - continue; - } - auto pos = token.find("|"); - if (pos == std::string::npos || pos == 0 || pos == token.length()) { - ORT_THROW("[ERROR] [VitisAI] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); - } + ParseSessionConfigs(option_string, provider_options); - std::string key(token.substr(0, pos)); - std::string value(token.substr(pos + 1)); - vitisai_session_options[key] = value; - } - session_options.AppendExecutionProvider_VitisAI(vitisai_session_options); + session_options.AppendExecutionProvider_VitisAI(provider_options); #else ORT_THROW("VitisAI is not supported in this build\n"); #endif diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc new file mode 100644 index 0000000000000..e09c8fac70887 --- /dev/null +++ b/onnxruntime/test/perftest/strings_helper.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. + +#include +#include + +#include "strings_helper.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace perftest { + +void ParseSessionConfigs(const std::string& configs_string, + std::unordered_map& session_configs, + const std::unordered_set& available_keys) { + std::istringstream ss(configs_string); + std::string token; + + while (ss >> token) { + if (token == "") { + continue; + } + + std::string_view token_sv(token); + + auto pos = token_sv.find("|"); + if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { + ORT_THROW("Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + std::string key(token_sv.substr(0, pos)); + std::string value(token_sv.substr(pos + 1)); + + if (available_keys.empty() == false && available_keys.count(key) == 0) { + // Error: unknown option: {key} + std::string available_keys_str; + for (const auto& av_key : available_keys) { + available_keys_str += av_key; + available_keys_str += ", "; + } + ORT_THROW("[ERROR] wrong key type entered : `", key, + "`. The following runtime key options are avaible: [", available_keys_str, "]"); + } + + auto it = session_configs.find(key); + if (it != session_configs.end()) { + // Error: specified duplicate session configuration entry: {key} + ORT_THROW("Specified duplicate session configuration entry: ", key); + } + + session_configs.insert(std::make_pair(std::move(key), std::move(value))); + } +} +} // namespace perftest +} // namespace onnxruntime diff --git a/onnxruntime/test/perftest/strings_helper.h b/onnxruntime/test/perftest/strings_helper.h new file mode 100644 index 0000000000000..0d6c56709fde6 --- /dev/null +++ b/onnxruntime/test/perftest/strings_helper.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. +#include +#include +#include + +namespace onnxruntime { +namespace perftest { + +void ParseSessionConfigs(const std::string& configs_string, + std::unordered_map& session_configs, + const std::unordered_set& available_keys = {}); +} // namespace perftest +} // namespace onnxruntime diff --git a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm index 32b4b32e299d6..fa95c1fc52b94 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm +++ b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm @@ -35,8 +35,9 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU = #if COREML_EP_AVAILABLE if (useCoreML) { - const uint32_t flags = COREML_FLAG_USE_CPU_ONLY; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags)); + std::unordered_map provider_options = { + {kCoremlProviderOption_MLComputeUnits, "CPUOnly"}}; + session_options.AppendExecutionProvider("CoreML", provider_options); } #else (void)useCoreML; diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm index 86001b6cb50a5..b53a4a2df09b4 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -35,8 +35,9 @@ void testSigmoid(const char* modelPath, bool useCoreML = false, bool useWebGPU = #if COREML_EP_AVAILABLE if (useCoreML) { - const uint32_t flags = COREML_FLAG_USE_CPU_ONLY; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags)); + std::unordered_map provider_options = { + {kCoremlProviderOption_MLComputeUnits, "CPUOnly"}}; + session_options.AppendExecutionProvider("CoreML", provider_options); } #else (void)useCoreML; diff --git a/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py b/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py new file mode 100644 index 0000000000000..4e5329dd5b09a --- /dev/null +++ b/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py @@ -0,0 +1,54 @@ +import argparse + +plist_file_content = """ + + + + + method + development + teamID + {team_id} + provisioningProfiles + + ai.onnxruntime.tests.ios-package-test + {provisioning_profile_uuid} + + signingStyle + manual + + +""" +if __name__ == "__main__": + # handle cli args + parser = argparse.ArgumentParser( + "Generates a PList file to the relevant destination. This PList file contains the properties to allow a user to generate an IPA file for the ios-package-test. " + ) + + parser.add_argument("--dest_file", type=str, help="Path to output the PList file to.", required=True) + parser.add_argument( + "--apple_team_id", + type=str, + help="The Team ID associated with the provisioning profile. You should be able to find this from the Apple developer portal under Membership.", + required=True, + ) + parser.add_argument( + "--provisioning_profile_uuid", + type=str, + help="The Provisioning Profile UUID, which can be found in the .mobileprovision file. ", + required=True, + ) + + args = parser.parse_args() + + formatted_plist = plist_file_content.format( + team_id=args.apple_team_id, provisioning_profile_uuid=args.provisioning_profile_uuid + ) + + with open(args.dest_file, "w") as file: + file.write(formatted_plist) + + print("Wrote plist file to ", args.dest_file) + print() + print("Contents of file:") + print(formatted_plist) diff --git a/onnxruntime/test/platform/windows/stacktrace_test.cc b/onnxruntime/test/platform/windows/stacktrace_test.cc index de09dbcf270a9..9b1840f4b5d65 100644 --- a/onnxruntime/test/platform/windows/stacktrace_test.cc +++ b/onnxruntime/test/platform/windows/stacktrace_test.cc @@ -14,7 +14,6 @@ namespace onnxruntime { namespace test { using namespace ::testing; -// TVM is not working with StackTrace now. #if !defined(ORT_NO_EXCEPTIONS) TEST(StacktraceTests, BasicTests) { auto result = ::onnxruntime::GetStackTrace(); diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index dea39bc99d3e9..b0958e05dc373 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -420,6 +420,7 @@ bool SetEpsForAllNodes(Graph& graph, continue; bool found = false; + const auto& logger = DefaultLoggingManager().DefaultLogger(); for (const auto& ep : execution_providers) { auto provider_type = ep->Type(); @@ -438,7 +439,8 @@ bool SetEpsForAllNodes(Graph& graph, } // Check the EP has an impl for the node from builtin registry. - if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver)) { + if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver, + logger)) { found = true; break; } @@ -451,6 +453,7 @@ bool SetEpsForAllNodes(Graph& graph, std::string_view(kMSInternalNHWCDomain), node.SinceVersion(), type_constraint_map, + logger, &kci); if (status.IsOK() && kci != nullptr) { found = true; @@ -463,7 +466,7 @@ bool SetEpsForAllNodes(Graph& graph, std::any_of(custom_registries->cbegin(), custom_registries->cend(), [&](auto reg) { return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(), - kernel_type_str_resolver); + kernel_type_str_resolver, logger); })) { found = true; break; @@ -529,6 +532,17 @@ void BaseTester::Run(ExpectResult expect_result, const std::string& expected_fai so.use_deterministic_compute = use_determinism_; so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off + // remove nullptr in execution_providers. + // it's a little ugly but we need to do this because DefaultXXXExecutionProvider() can return nullptr in Runtime. + // And there're many places adding DefaultXXXExecutionProvider() to execution_providers directly. + if (execution_providers != nullptr) { + execution_providers->erase(std::remove(execution_providers->begin(), execution_providers->end(), nullptr), execution_providers->end()); + if (execution_providers->size() == 0) { + // In fact, no ep is needed to run + return; + } + } + Run(so, expect_result, expected_failure_string, excluded_provider_types, run_options, execution_providers, options); } diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 386a5656d8a01..9acb37c24ddd0 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -53,6 +53,11 @@ void CompareOpTester::CompareWithCPU(const std::string& target_provider_type, SetTestFunctionCalled(); std::unique_ptr target_execution_provider = GetExecutionProvider(target_provider_type); +#if defined(USE_CUDA) && defined(USE_DML) + if (target_execution_provider == nullptr) { + return; + } +#endif ASSERT_TRUE(target_execution_provider != nullptr) << "provider_type " << target_provider_type << " is not supported."; diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index daa24db134114..a8480e7416de5 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -4,7 +4,7 @@ #include "core/common/logging/logging.h" #include "core/graph/graph.h" #include "core/graph/graph_viewer.h" -#include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory_creator.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" @@ -30,11 +30,11 @@ using namespace ::onnxruntime::logging; namespace onnxruntime { namespace test { -// We want to run UT on CPU only to get output value without losing precision to pass the verification -static constexpr uint32_t s_coreml_flags = COREML_FLAG_USE_CPU_ONLY; - -static std::unique_ptr MakeCoreMLExecutionProvider(uint32_t flags = s_coreml_flags) { - return std::make_unique(flags); +static std::unique_ptr MakeCoreMLExecutionProvider( + std::string ModelFormat = "NeuralNetwork", std::string ComputeUnits = "CPUOnly") { + std::unordered_map provider_options = {{kCoremlProviderOption_MLComputeUnits, ComputeUnits}, + {kCoremlProviderOption_ModelFormat, ModelFormat}}; + return CoreMLProviderFactoryCreator::Create(provider_options)->CreateProvider(); } #if !defined(ORT_MINIMAL_BUILD) @@ -127,6 +127,10 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { MakeCoreMLExecutionProvider(), feeds, verification_params); + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); #endif @@ -164,6 +168,11 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { MakeCoreMLExecutionProvider(), feeds, verification_params); + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some); #endif diff --git a/onnxruntime/test/providers/coreml/dynamic_input_test.cc b/onnxruntime/test/providers/coreml/dynamic_input_test.cc index c91ef23650040..8294f65745256 100644 --- a/onnxruntime/test/providers/coreml/dynamic_input_test.cc +++ b/onnxruntime/test/providers/coreml/dynamic_input_test.cc @@ -7,6 +7,7 @@ #include #include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory_creator.h" #include "core/providers/coreml/coreml_provider_factory.h" // for COREMLFlags #include "test/common/random_generator.h" #include "test/providers/model_tester.h" @@ -20,8 +21,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MatMul) { auto test = [&](const size_t M) { SCOPED_TRACE(MakeString("M=", M)); - - auto coreml_ep = std::make_unique(0); + std::unordered_map options; + auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); const auto ep_verification_params = EPVerificationParams{ ExpectedEPNodeAssignment::All, @@ -54,8 +55,8 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, MobileNetExcerpt) { auto test = [&](const size_t batch_size) { SCOPED_TRACE(MakeString("batch_size=", batch_size)); - - auto coreml_ep = std::make_unique(0); + std::unordered_map options; + auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); const auto ep_verification_params = EPVerificationParams{ ExpectedEPNodeAssignment::All, @@ -87,6 +88,7 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) { constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx"); ModelTester tester(CurrentTestName(), model_path); + std::unordered_map options; tester.AddInput("A", {0, 2}, {}); tester.AddOutput("Y", {0, 4}, {}); @@ -94,14 +96,15 @@ TEST(CoreMLExecutionProviderDynamicInputShapeTest, EmptyInputFails) { tester .Config(ModelTester::ExpectResult::kExpectFailure, "the runtime shape ({0,2}) has zero elements. This is not supported by the CoreML EP.") - .ConfigEp(std::make_unique(0)) + .ConfigEp(CoreMLProviderFactoryCreator::Create(options)->CreateProvider()) .RunWithConfig(); } TEST(CoreMLExecutionProviderDynamicInputShapeTest, OnlyAllowStaticInputShapes) { constexpr auto model_path = ORT_TSTR("testdata/matmul_with_dynamic_input_shape.onnx"); - - auto coreml_ep = std::make_unique(COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES); + std::unordered_map options = {{kCoremlProviderOption_RequireStaticInputShapes, "1"}}; + auto coreml_ep = CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); + ; TestModelLoad(model_path, std::move(coreml_ep), // expect no supported nodes because we disable dynamic input shape support diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 8ca0f6d845a09..59813f433dc41 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -105,7 +105,12 @@ class ActivationOpTest : public ::testing::Test { std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution dist(low, high); +#ifdef COREML_ENABLE_MLPROGRAM + // please check onnxruntime/onnxruntime/core/providers/coreml/builders/helper.cc:81 + std::vector batch_size_list = {1, 2, 4, 9, 100}; +#else std::vector batch_size_list = {1, 2, 4, 9, 100000}; +#endif for (auto batch_size : batch_size_list) { std::vector vec(batch_size); for (size_t i = 0; i != batch_size; ++i) { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index b2e9034653746..a74517840097c 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -414,6 +414,28 @@ TEST(MathOpTest, Add_Broadcast_3x2_3x1) { #endif } +TEST(MathOpTest, Add_Broadcast_2x2x2_1x2x2) { + OpTester test("Add"); + + test.AddInput("A", {2, 2, 2}, + {101.0f, 102.0f, + 103.0f, 104.0f, + + 201.0f, 202.0f, + 203.0f, 204.0f}); + test.AddInput("B", {1, 2, 2}, + {010.0f, 020.0f, + 030.0f, 040.0f}); + test.AddOutput("C", {2, 2, 2}, + {111.0f, 122.0f, + 133.0f, 144.0f, + + 211.0f, 222.0f, + 233.0f, 244.0f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(MathOpTest, Add_Broadcast_2x1x4_1x3x1) { OpTester test("Add"); @@ -2249,6 +2271,21 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Max_12_MLFloat16_Scalar2) { + OpTester test("Max", 12); + test.AddInput("data_0", {1}, + MakeMLFloat16({-1.f})); + test.AddInput("data_1", {}, + MakeMLFloat16({2.f})); + test.AddInput("data_2", {1, 3}, + MakeMLFloat16({-2.f, -3.f, -4.f})); + test.AddInput("data_3", {1, 1, 3}, + MakeMLFloat16({-2.f, -3.f, -4.f})); + test.AddOutput("max", {1, 1, 3}, + MakeMLFloat16({2.f, 2.f, 2.f})); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent +} + TEST(MathOpTest, Max_13_Float16_MatrixVector) { TestFloat16MinMax("Max", {4, 3}, @@ -3181,7 +3218,14 @@ TEST(MathOpTest, Tan) { TEST(MathOpTest, Asin) { OpTester test("Asin"); - float abs_error = DefaultDmlExecutionProvider().get() != nullptr ? 0.0001f : -1.0f; + float abs_error = +#ifdef _WIN32 + // Set abs_error to 0.0001f for built-in function asin() in HLSL based EPs (DML and WebGPU) + DefaultDmlExecutionProvider().get() != nullptr || DefaultWebGpuExecutionProvider().get() != nullptr + ? 0.0001f + : +#endif + -1.0f; TrigFloatTest<::asinf>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}, abs_error); } diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc new file mode 100644 index 0000000000000..49bb0ae65d1c9 --- /dev/null +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + array_as_tensor.add_dims(array.size()); + for (auto v : array) { + array_as_tensor.add_double_data(v); + } + + return array_as_tensor; +} + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + array_as_tensor.add_dims(array.size()); + for (auto v : array) { + array_as_tensor.add_float_data(v); + } + + return array_as_tensor; +} + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + array_as_tensor.add_dims(array.size()); + for (const auto v : array) { + array_as_tensor.add_int32_data(v); + } + + return array_as_tensor; +} + +template +void _multiply_update_array(std::vector& data, int n, T inc = 0) { + std::vector copy = data; + data.resize(copy.size() * n); + T cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + data[j + i * copy.size()] = copy[j] + cst; + } + cst += inc; + } +} + +template +void _multiply_update_childnode(std::vector& childnodes, std::vector& childleafs, std::vector& otherchildleafs, int n) { + int64_t leafs_cnt = 0; + int64_t nodes_cnt = childnodes.size(); + for (auto& childleaf : childleafs) { + if (childleaf) { + leafs_cnt++; + } + } + for (auto& childleaf : otherchildleafs) { + if (childleaf) { + leafs_cnt++; + } + } + + std::vector copy = childnodes; + childnodes.resize(copy.size() * n); + T leafs_cst = 0; + T nodes_cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + T curr_inc = childleafs[j] ? leafs_cst : nodes_cst; + childnodes[j + i * copy.size()] = copy[j] + curr_inc; + } + + leafs_cst += leafs_cnt; + nodes_cst += nodes_cnt; + } +} + +template +void _multiply_arrays_values(std::vector& data, int64_t val) { + for (auto& curr : data) { + curr *= val; + } +} + +template +void GenTreeAndRunTest(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 2; + + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_featureids = {0, 0, 0}; + std::vector nodes_modes = {0, 0, 0}; + std::vector nodes_splits = {3.14f, 1.2f, 4.2f}; + std::vector nodes_truenodeids = {1, 0, 1}; + std::vector nodes_trueleafs = {0, 1, 1}; + std::vector nodes_falsenodeids = {2, 2, 3}; + std::vector nodes_falseleafs = {0, 1, 1}; + + std::vector leaf_targetids = {0, 1, 0, 1}; + std::vector leaf_weights = {5.23f, 12.12f, -12.23f, 7.21f}; + + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size()); + _multiply_update_array(nodes_featureids, n_trees); + _multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees); + _multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees); + _multiply_update_array(nodes_trueleafs, n_trees); + _multiply_update_array(nodes_falseleafs, n_trees); + _multiply_update_array(leaf_targetids, n_trees); + _multiply_update_array(nodes_modes, n_trees); + _multiply_update_array(nodes_splits, n_trees); + _multiply_update_array(leaf_weights, n_trees); + } + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 2}, Y); + test.Run(); +} + +template +void GenTreeAndRunTestWithSetMembership(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 4; + + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_featureids = {0, 0, 0}; + std::vector nodes_truenodeids = {1, 0, 1}; + std::vector nodes_trueleafs = {0, 1, 1}; + std::vector nodes_falsenodeids = {2, 2, 3}; + std::vector nodes_falseleafs = {1, 0, 1}; + std::vector leaf_targetids = {0, 1, 2, 3}; + + std::vector nodes_modes = {0, 6, 6}; + std::vector nodes_splits = {11.f, 232344.f, NAN}; + std::vector membership_values = {1.2f, 3.7f, 8.f, 9.f, NAN, 12.f, 7.f, NAN}; + std::vector leaf_weights = {1.f, 10.f, 1000.f, 100.f}; + + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size()); + _multiply_update_array(nodes_featureids, n_trees); + _multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees); + _multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees); + _multiply_update_array(nodes_trueleafs, n_trees); + _multiply_update_array(nodes_falseleafs, n_trees); + _multiply_update_array(leaf_targetids, n_trees); + _multiply_update_array(nodes_modes, n_trees); + _multiply_update_array(nodes_splits, n_trees); + _multiply_update_array(membership_values, n_trees); + _multiply_update_array(leaf_weights, n_trees); + } + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto membership_values_as_tensor = make_tensor(membership_values, "membership_values"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("membership_values", membership_values_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + test.AddInput("X", {6, 1}, X); + test.AddOutput("Y", {6, 4}, Y); + test.Run(); +} + +TEST(MLOpTest, TreeEnsembleFloat) { + std::vector X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f}; + std::vector Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f}; + GenTreeAndRunTest(X, Y, 1, 1); + + Y = {15.69f, 0.f, 15.69f, 0.f, 0.f, 36.36f}; + GenTreeAndRunTest(X, Y, 1, 3); +} + +TEST(MLOpTest, TreeEnsembleDouble) { + std::vector X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f}; + std::vector Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f}; + GenTreeAndRunTest(X, Y, 1, 1); + + _multiply_arrays_values(Y, 3); + GenTreeAndRunTest(X, Y, 1, 3); +} + +TEST(MLOpTest, TreeEnsembleSetMembership) { + std::vector X = {1.2f, 3.4f, -0.12f, NAN, 12.0f, 7.0f}; + std::vector Y = { + 1.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 100.f, + 0.f, 0.f, 0.f, 100.f, + 0.f, 0.f, 1000.f, 0.f, + 0.f, 0.f, 1000.f, 0.f, + 0.f, 10.f, 0.f, 0.f}; + GenTreeAndRunTestWithSetMembership(X, Y, 1, 1); + + _multiply_arrays_values(Y, 5); + GenTreeAndRunTestWithSetMembership(X, Y, 1, 5); +} + +TEST(MLOpTest, TreeEnsembleLeafOnly) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 1; + + int64_t aggregate_function = 1; + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_modes = {0}; + std::vector nodes_featureids = {0}; + std::vector nodes_splits = {0.f}; + std::vector nodes_truenodeids = {0}; + std::vector nodes_trueleafs = {1}; + std::vector nodes_falsenodeids = {0}; + std::vector nodes_falseleafs = {1}; + + std::vector leaf_targetids = {0}; + std::vector leaf_weights = {6.23f}; + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + std::vector X = {1.f, 4.f}; + std::vector Y = {6.23f, 6.23f}; + + test.AddInput("X", {2, 1}, X); + test.AddOutput("Y", {2, 1}, Y); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 33c23b53fb5aa..eaf8fea03eaa0 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) { GenTreeAndRunTest1_as_tensor_precision(3); } +TEST(MLOpTest, TreeRegressorCategoricals) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 0, 0, 1, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 4, 0, 5.5, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0}; + std::vector nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, 0}; + std::vector target_nodeids = {3, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f}; + std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 1}, Y); + test.Run(); +} + +TEST(MLOpTest, TreeRegressorCategoricalsFolding) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 1, 1, 0, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 2, 3, 0, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0}; + std::vector nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, 0}; + std::vector target_nodeids = {4, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f}; + std::vector Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453}; + test.AddInput("X", {4, 2}, X); + test.AddOutput("Y", {4, 1}, Y); + test.Run(); +} + TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) { OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index e3c86a137484f..b46c253fb8ed9 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -491,6 +491,18 @@ ::std::vector<::std::basic_string> GetParameterStrings() { // the number of times these are run to reduce the CI time. provider_names.erase(provider_name_cpu); #endif + +#if defined(USE_CUDA) && defined(USE_DML) + const std::string no_cuda_ep_test = Env::Default().GetEnvironmentVar("NO_CUDA_TEST"); + if (no_cuda_ep_test == "1") { + provider_names.erase(provider_name_cuda); + } + const std::string no_dml_ep_test = Env::Default().GetEnvironmentVar("NO_DML_TEST"); + if (no_dml_ep_test == "1") { + provider_names.erase(provider_name_dml); + } +#endif + std::vector> v; // Permanently exclude following tests because ORT support only opset starting from 7, // Please make no more changes to the list diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index b0d97410ac9b3..08c4e608aada3 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -704,7 +704,7 @@ TEST(BatchNormTest, NonSpatial_Complicated) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(BatchNormTest, BatchNorm2d_fp16) { vector X{-0.91221f, -0.283559f, 0.937637f, 2.09818f, -0.100199f, -0.608113f, 0.444562f, -1.07505f, 0.940591f, -0.922262f, 0.0931303f, 0.69611f, 1.55187f, 0.159808f, 0.914874f, -1.24856f, -1.98928f, -0.331621f, @@ -765,9 +765,6 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { -0.0989828f, -0.160014f, 0.362077f, 0.0649763f, -0.371465f, 0.727401f, 0.0320011f}; float epsilon = 1e-05f; - OpTester test("BatchNormalization"); - test.AddAttribute("epsilon", epsilon); - vector input_shape{2, 3, 6, 6}; int input_size = 2 * 3 * 6 * 6; @@ -785,13 +782,20 @@ TEST(BatchNormTest, BatchNorm2d_fp16) { ConvertFloatToMLFloat16(var.data(), f_var.data(), 3); ConvertFloatToMLFloat16(expected_output.data(), f_output.data(), input_size); - test.AddInput("X", input_shape, f_X); - test.AddInput("scale", {3}, f_scale); - test.AddInput("B", {3}, f_B); - test.AddInput("mean", {3}, f_mean); - test.AddInput("var", {3}, f_var); - test.AddOutput("output", input_shape, f_output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + auto run_test = [&](bool is_initializer) { + OpTester test("BatchNormalization"); + test.AddAttribute("epsilon", epsilon); + test.AddInput("X", input_shape, f_X); + test.AddInput("scale", {3}, f_scale, is_initializer); + test.AddInput("B", {3}, f_B, is_initializer); + test.AddInput("mean", {3}, f_mean, is_initializer); + test.AddInput("var", {3}, f_var, is_initializer); + test.AddOutput("output", input_shape, f_output, is_initializer); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + // coreml EP requires initializer + run_test(true); } #endif diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 25caa732efa25..a3a3dd939cbf0 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - +#include "core/graph/constants.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" + using namespace std; namespace onnxruntime { namespace test { @@ -28,7 +29,8 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, optional epsilon = optional(), OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - int opset = 7) { + int opset = 7, + bool exclude_cuda_nhwc = false) { OpTester test("Conv", opset); test.AddAttribute("group", attributes.group); test.AddAttribute("kernel_shape", attributes.kernel_shape); @@ -65,6 +67,12 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); + if (exclude_cuda_nhwc) { +#ifdef ENABLE_CUDA_NHWC_OPS + excluded_providers.insert(kCudaNHWCExecutionProvider); +#endif + } + // QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs. excluded_providers.insert(kQnnExecutionProvider); @@ -197,10 +205,15 @@ TEST(ConvTest, Conv1D_Bias) { // as TF32 has a 10 bit mantissa. float epsilon = 1.1e-5f; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon); + // This case is not supported by cuDNN frontend, and the fallback (legacy code) requires weight to 4D tensor for NHWC. + constexpr bool exclude_cuda_nhwc = true; + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, epsilon, + OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc); // CoreML EP requires weight to be an initializer - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon); + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, epsilon, + OpTester::ExpectResult::kExpectSuccess, "", 10, exclude_cuda_nhwc); } // Conv47 diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 0ce87fb65898b..83b27f10fe04f 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -4,6 +4,7 @@ #include "core/providers/xnnpack/xnnpack_init.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" #include "default_providers.h" using namespace std; @@ -130,17 +131,6 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape diff --git a/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc new file mode 100644 index 0000000000000..ac517193a2c77 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/group_norm_op_test.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/util/include/default_providers.h" + +#ifdef COREML_ENABLE_MLPROGRAM +using namespace std; +namespace onnxruntime { +namespace test { + +template +class GroupNormalizationOpTest : public ::testing::Test { +}; +using GroupNormalizationOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GroupNormalizationOpTest, GroupNormalizationOpTestTypes); + +// GroupSize = channel_dims to simulate InstanceNorm +// Disable TensorRT on some of the tests because its parser doesn't support weight as input +TYPED_TEST(GroupNormalizationOpTest, Equivalent_InstanceNorm_G_C) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 0.3F); + test.AddAttribute("num_groups", int64_t(3)); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, + 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, + 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; + vector input_dims = {2, 3, 4}; + test.AddInput("X", input_dims, GetTypedArray(input)); + + vector scale = {1.F, 1.F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), true); + + vector B = {0.F, 0.F, 0.F}; + vector B_dims = {3}; + test.AddInput("bias", B_dims, GetTypedArray(B), true); + + // expected output is calculated using torch.nn.GroupNorm(3, 3, eps=0.3) + vector expected_output = {-0.56495477f, 1.48930046f, -1.13334329f, 0.20899761f, + 1.46688162f, -0.98600774f, -0.79911913f, 0.31824524f, + 0.57370438f, 0.42193634f, 0.6525492f, -1.64818992f, + + -0.92380346f, -0.60808484f, 0.04711878f, 1.48476953f, + -0.14644464f, -0.82262872f, -0.66852817f, 1.63760153f, + -1.65898662f, 0.27618144f, 0.64840618f, 0.734399f}; + + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GroupSize = 1 to simulate LayerNorm, (LayerNorm) +// expected output is calculated using torch.nn.GroupNorm(1, 3, eps=1e-5f) +TYPED_TEST(GroupNormalizationOpTest, Equivalent_LayerNorm_G_1) { + auto run_test = [](bool is_initializer) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 1e-5f); + test.AddAttribute("num_groups", int64_t(1)); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, GetTypedArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("scale", {2}, GetTypedArray({1.0f, 1.0f}), is_initializer); + test.AddInput("bias", {2}, GetTypedArray({2.0f, 1.0f}), is_initializer); + test.AddOutput("output", dims, GetTypedArray({0.5361f, 1.1216f, 1.7072f, 1.2928f, 1.8783f, 2.4638f})); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + }; + + run_test(true); +} + +// expected output is calculated using torch.nn.GroupNorm(2, 6, eps=0.3) +TYPED_TEST(GroupNormalizationOpTest, GroupSize_N) { + OpTester test("GroupNormalization", 18); + test.AddAttribute("epsilon", 0.3F); + test.AddAttribute("num_groups", int64_t(2)); + + vector input = {-1.1258f, -1.1524f, -0.2506f, -0.4339f, + 0.8487f, 0.6920f, -0.3160f, -2.1152f, + 0.3223f, -1.2633f, 0.3500f, 0.3081f, + 0.1198f, 1.2377f, 1.1168f, -0.2473f, + -1.3527f, -1.6959f, 0.5667f, 0.7935f, + 0.5988f, -1.5551f, -0.3414f, 1.8530f, + + 0.7502f, -0.5855f, -0.1734f, 0.1835f, + 1.3894f, 1.5863f, 0.9463f, -0.8437f, + -0.6136f, 0.0316f, -0.4927f, 0.2484f, + 0.4397f, 0.1124f, 0.6408f, 0.4412f, + -0.1023f, 0.7924f, -0.2897f, 0.0525f, + 0.5229f, 2.3022f, -1.4689f, -1.5867f}; + vector input_dims = {2, 6, 4}; + test.AddInput("X", input_dims, GetTypedArray(input)); + + vector scale = {1.F, 1.F, 1.F, 1.F, 1.F, 1.F}; + vector scale_dims = {6}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), true); + + vector B = {.0F, .0F, .0F, .0F, .0F, .0F}; + vector B_dims = {6}; + test.AddInput("bias", B_dims, GetTypedArray(B), true); + + vector expected_output = { + -0.7590f, -0.7848f, 0.0914f, -0.0867f, + 1.1595f, 1.0073f, 0.0278f, -1.7203f, + 0.6480f, -0.8926f, 0.6749f, 0.6343f, + 0.0232f, 0.9274f, 0.8296f, -0.2738f, + -1.1679f, -1.4456f, 0.3846f, 0.5681f, + 0.4107f, -1.3317f, -0.3499f, 1.4252f, + + 0.5772f, -0.8298f, -0.3957f, -0.0198f, + 1.2505f, 1.4580f, 0.7838f, -1.1017f, + -0.8594f, -0.1798f, -0.7320f, 0.0486f, + 0.2541f, -0.0377f, 0.4334f, 0.2554f, + -0.2291f, 0.5686f, -0.3962f, -0.0911f, + 0.3282f, 1.9145f, -1.4475f, -1.5525f}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + // coreml EP requires weight and bias to be initializers + if constexpr (std::is_same::value) { + test.SetOutputTolerance(1e-4f); + } else { + test.SetOutputTolerance(0.005f); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime +#endif diff --git a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc index 31f119ec6b0e9..341bb8a4fc957 100644 --- a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc @@ -3,71 +3,87 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" + using namespace std; namespace onnxruntime { namespace test { -// Disable TensorRT on some of the tests because its parser doesn't support weight as input +template +class InstanceNormalizationOpTest : public ::testing::Test { +}; +using InstanceNormalizationOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(InstanceNormalizationOpTest, InstanceNormalizationOpTestTypes); -TEST(InstanceNormalizationOpTest, InstanceNorm) { - OpTester test("InstanceNormalization"); - test.AddAttribute("epsilon", 0.3F); - - vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, - 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, - 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, - - 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, - 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, - 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; - vector input_dims = {2, 3, 4}; - test.AddInput("input", input_dims, input); - - // vector scale = {2.1F, 0.1F, 1.F}; - vector scale = {1.0F, 1.0F, 1.F}; - vector scale_dims = {3}; - test.AddInput("scale", scale_dims, scale); - - // vector B = {2.3F, 1.5F, 0.F}; - vector B = {0.0F, 0.0F, 0.F}; - vector B_dims = {3}; - test.AddInput("B", B_dims, B); - - vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, - 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, - 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, +// Disable TensorRT on some of the tests because its parser doesn't support weight as input - -0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F, - -0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F, - -1.65898662F, 0.27618144F, 0.64840618F, 0.734399F}; - test.AddOutput("Y", input_dims, expected_output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +TYPED_TEST(InstanceNormalizationOpTest, InstanceNorm) { + auto run_test = [](bool is_initializer) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F, + + 2.3667817F, 2.8248506F, 3.7754705F, 5.861325F, + 5.058735F, 3.2787242F, 3.6843839F, 9.755121F, + 2.7902672F, 7.3974323F, 8.283609F, 8.488337F}; + vector input_dims = {2, 3, 4}; + test.AddInput("input", input_dims, GetTypedArray(input)); + + // vector scale = {2.1F, 0.1F, 1.F}; + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), is_initializer); + + // vector B = {2.3F, 1.5F, 0.F}; + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + test.AddInput("B", B_dims, GetTypedArray(B), is_initializer); + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F, + + -0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F, + -0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F, + -1.65898662F, 0.27618144F, 0.64840618F, 0.734399F}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + run_test(true); } -TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { - OpTester test("InstanceNormalization"); - test.AddAttribute("epsilon", 0.3F); - - vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, - 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, - 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; - vector input_dims = {1, 3, 4}; - test.AddInput("input", input_dims, input); - - vector scale = {1.0F, 1.0F, 1.F}; - vector scale_dims = {3}; - test.AddInput("scale", scale_dims, scale); - - vector B = {0.0F, 0.0F, 0.F}; - vector B_dims = {3}; - test.AddInput("B", B_dims, B); - - vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, - 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, - 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; - test.AddOutput("Y", input_dims, expected_output); - - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +TYPED_TEST(InstanceNormalizationOpTest, InstanceNormBatch1) { + auto run_test = [](bool is_initializer) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.3F); + + vector input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F, + 8.519701F, 1.2382338F, 1.7930176F, 5.1099434F, + 7.9195533F, 7.638727F, 8.065445F, 3.8082376F}; + vector input_dims = {1, 3, 4}; + test.AddInput("input", input_dims, GetTypedArray(input)); + + vector scale = {1.0F, 1.0F, 1.F}; + vector scale_dims = {3}; + test.AddInput("scale", scale_dims, GetTypedArray(scale), is_initializer); + + vector B = {0.0F, 0.0F, 0.F}; + vector B_dims = {3}; + test.AddInput("B", B_dims, GetTypedArray(B), is_initializer); + + vector expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F, + 1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F, + 0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F}; + test.AddOutput("Y", input_dims, GetTypedArray(expected_output)); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + run_test(false); + run_test(true); } TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { @@ -105,7 +121,7 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch2) { } // Only CUDA and ROCm kernels have float 16 support -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(InstanceNormalizationOpTest, InstanceNormBatch1_fp16) { OpTester test("InstanceNormalization"); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 0968bc32e0de4..f68a245d103e1 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include "gtest/gtest.h" #include "test/common/dnnl_op_test_utils.h" @@ -1374,7 +1375,7 @@ TEST(ReductionOpTest, ReduceMax_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(ReductionOpTest, ReduceMax_half) { OpTester test("ReduceMax"); test.AddAttribute("axes", std::vector{1, 2}); @@ -2157,7 +2158,7 @@ TEST(ReductionOpTest, ReduceMin_double) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(ReductionOpTest, ReduceMin_half) { OpTester test("ReduceMin"); test.AddAttribute("axes", std::vector{0, 2}); @@ -2355,7 +2356,7 @@ TEST(ReductionOpTest, ReduceSum_int32) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(ReductionOpTest, ReduceSumHalfHalf) { OpTester test("ReduceSum"); test.AddAttribute("keepdims", (int64_t)0); @@ -3175,19 +3176,26 @@ TEST(ReductionOpTest, ReduceProd0DTensor) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -TEST(ReductionOpTest, ArgMax) { +template +class ReductionOpTest : public ::testing::Test { +}; + +using ReductionOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ReductionOpTest, ReductionOpTestTypes); + +TYPED_TEST(ReductionOpTest, ArgMax) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)1); test.AddAttribute("keepdims", (int64_t)1); - test.AddInput("data", {3, 2, 2}, - {1.0f, 2.0f, - 3.0f, 4.0f, + test.AddInput("data", {3, 2, 2}, + GetTypedArray({1.0f, 2.0f, + 3.0f, 4.0f, - 5.0f, 6.0f, - 7.0f, 8.0f, + 5.0f, 6.0f, + 7.0f, 8.0f, - 9.0f, 10.0f, - 11.0f, 12.0f}); + 9.0f, 10.0f, + 11.0f, 12.0f})); test.AddOutput("reduced", {3, 1, 2}, {1, 1, 1, 1, @@ -3330,6 +3338,41 @@ TEST(ReductionOpTest, ArgMax_int32_last_index_dups) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(ReductionOpTest, ArgMax_float_first_index_random) { + OpTester test("ArgMax", 12); + test.AddAttribute("axis", static_cast(0)); + test.AddAttribute("keepdims", static_cast(1)); + + // Since select_last_index is 0 by default, this test should run on both CPU and CUDA + test.AddAttribute("select_last_index", static_cast(0)); + + constexpr size_t vector_size = 64 * 1024; + constexpr float max_value = std::numeric_limits::infinity(); + + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1); + + std::vector data_vec(vector_size, 0.0f); + + int min_index = -1; + + // Try replace 8 elements with max_value. It is fine that some elements hit same index. + for (int i = 0; i < 8; ++i) { + int index = distribution(generator); + data_vec[index] = max_value; + if (i == 0 || index < min_index) { + min_index = index; + } + } + + test.AddInput("data", {vector_size}, data_vec); + test.AddOutput("reduced", {1}, {min_index}); + + // Exclude OpenVINO since it failed to handle this case. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(ReductionOpTest, ArgMax_int32_neg_axis) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)(-2)); @@ -3648,6 +3691,41 @@ TEST(ReductionOpTest, ArgMin_int32_neg_axis) { test.Run(); } +TEST(ReductionOpTest, ArgMin_float_first_index_random) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", static_cast(0)); + test.AddAttribute("keepdims", static_cast(1)); + + // Since select_last_index is 0 by default, this test should run on both CPU and CUDA + test.AddAttribute("select_last_index", static_cast(0)); + + constexpr size_t vector_size = 64 * 1024; + constexpr float min_value = -std::numeric_limits::infinity(); + + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1); + + std::vector data_vec(vector_size, 0.0f); + + int min_index = -1; + + // Try replace 8 elements with min_value. It is fine that some elements hit same index. + for (int i = 0; i < 8; ++i) { + int index = distribution(generator); + data_vec[index] = min_value; + if (i == 0 || index < min_index) { + min_index = index; + } + } + + test.AddInput("data", {vector_size}, data_vec); + test.AddOutput("reduced", {1}, {min_index}); + + // Exclude OpenVINO since it failed to handle this case. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1) { FastReduceKind fast_kind; TensorShapeVector fast_shape, fast_output_shape, fast_axes; @@ -5603,7 +5681,7 @@ TEST(ReductionOpTest, ReduceSum_RK_parallel) { test.AddOutput("reduced", {32}, expected); // CoreML does not provide 1e-5 precision here (it's off by 1e-4) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCoreMLExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess); } TEST(ReductionOpTest, ReduceSum_RK_keepdims) { diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc index 30960e71c577f..de2aceda17f83 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc @@ -13,7 +13,7 @@ using namespace std; namespace onnxruntime { namespace test { -static const std::vector default_activations = {"sigmoid", "tanh"}; +static const std::vector default_activations = {"Sigmoid", "Tanh"}; static void RunGruTest(const std::vector& X_data, const std::vector& W_data, @@ -150,11 +150,6 @@ void DefaultActivationsSimpleWeightsNoBias(std::string direction, } TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsNoBiasTwoRows) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.4750208f, 0.450166f, 0.4255575f, 0.45016602f, 0.40131235f, 0.35434368f, @@ -173,11 +168,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsNoBiasTwoRows) { } TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.6082785f, 0.50623393f, 0.4426924f, 0.5803454f, 0.4527356f, 0.36886263f, @@ -193,11 +183,6 @@ TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows) { } TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBias) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ // forward output for input sequence 0 0.4750208f, 0.450166f, 0.4255575f, @@ -228,11 +213,6 @@ TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBias) { } TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBiasLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ // forward output for input sequence 0 0.4750208f, 0.450166f, 0.4255575f, @@ -317,11 +297,6 @@ void DefaultActivationsSimpleWeightsWithBias(std::string direction, } TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallel) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.16783132f, -0.11754231f, 0.11977843f, 0.2046872f, -0.10372487f, 0.15365849f, @@ -333,11 +308,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallel) { } TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.15024948f, -0.11097029f, -0.02121867f, 0.18887489f, -0.09747667f, 0.02093463f, @@ -350,11 +320,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearB } TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.20910699f, -0.18880953f, -0.04005555f, 0.29700265f, -0.15308119f, 0.04537245f, @@ -368,11 +333,6 @@ TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearB // test forward !batch_parallel_ path with linear_before_reset TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.15024948f, -0.11097029f, -0.02121867f, 0.19538902f, -0.19016478f, -0.05644283f}; @@ -384,11 +344,6 @@ TEST(GRUTest, ForwardDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset) { // test reverse !batch_parallel_ path with linear_before_reset TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - std::vector Y_data{ 0.20910699f, -0.18880953f, -0.04005555f, 0.12252139f, -0.12032216f, -0.05064924f}; @@ -588,13 +543,8 @@ void DeepCpuGruOpTestContext::RunTest(const std::vector& X, } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBasic) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -611,13 +561,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBasic) { } TEST(GRUTest, ONNXRuntime_TestGRUOpBackwardBasic) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "reverse"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -635,13 +580,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpBackwardBasic) { } TEST(GRUTest, ONNXRuntime_TestGRUOpBidirectionalBasic) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -663,13 +603,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpBidirectionalBasic) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardActivation) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"tanh", "sigmoid"}; + const std::vector activations = {"Tanh", "Sigmoid"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -687,13 +622,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardActivation) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardInitialHiddenState) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -711,13 +641,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardInitialHiddenState) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatch) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -743,13 +668,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatch) { } TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatchLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -775,13 +695,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpForwardBatchLinearBeforeReset) { } TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLength) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -820,13 +735,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLength) { } TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLengthLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -865,13 +775,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpGrowBatchSequenceLengthLinearBeforeReset) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeResetB1) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -891,13 +796,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeResetB2) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -916,13 +816,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeReset) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -949,13 +844,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithBidirectionalLinearBeforeRe } TEST(GRUTest, ONNXRuntime_TestGRUOpShorterSeqInMiddle) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -987,13 +877,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpShorterSeqInMiddle) { } TEST(GRUTest, ONNXRuntime_TestGRUOpZeroSeqInMiddle) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1025,13 +910,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpZeroSeqInMiddle) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1058,13 +938,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthWithPartialZero) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthShorterThanInputSequenceLength) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "bidirectional"; - const std::vector activations = {"sigmoid", "tanh", "sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1092,13 +967,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthShorterThanInputSequenceLength) } TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthAllZeros) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations); @@ -1125,13 +995,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSequenceLengthAllZeros) { } TEST(GRUTest, ONNXRuntime_TestGRUOpSingleBatchMultipleHiddenThreads) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations, true, {}, {}, /*large_hidden*/ true); @@ -1160,13 +1025,8 @@ TEST(GRUTest, ONNXRuntime_TestGRUOpSingleBatchMultipleHiddenThreads) { } TEST(GRUTest, ONNXRuntime_TestGRUPositiveActivationClipping) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1817): The parameter is incorrect."; - } - const std::string direction = "forward"; - const std::vector activations = {"sigmoid", "tanh"}; + const std::vector activations = {"Sigmoid", "Tanh"}; DeepCpuGruOpTestContext ctx(direction, activations, true, {}, {}, /*large_hidden*/ true); diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 421561a5a859b..384adb5916cc1 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -149,11 +149,6 @@ using CastNonStringTypes = MLFloat16, BFloat16>; TEST(CastOpTest, NonStringTypes) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: true and true"; - } - boost::mp11::mp_for_each>( CastNonStringTester{}); } diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 4a1888a5ca7d6..9e0fb81cbb0fc 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -75,17 +76,6 @@ TEST(ConcatOpTest, Concat1D_2) { kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TYPED_TEST(ConcatOpTest, Concat2D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); diff --git a/onnxruntime/test/providers/cpu/tensor/expand_test.cc b/onnxruntime/test/providers/cpu/tensor/expand_test.cc index 4b0f4e84ca37d..38e3bc3af6d6b 100644 --- a/onnxruntime/test/providers/cpu/tensor/expand_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/expand_test.cc @@ -167,6 +167,20 @@ TEST(ExpandOpTest, Expand_2x2x1x2x1_float) { test.Run(); } +TEST(ExpandOpTest, Expand_3x1x8_float) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 2, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test.AddInput("data_1", {3}, {3, 1, 8}); + test.AddOutput("result", {3, 2, 8}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, + 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, + 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f}); + test.Run(); +} + #ifndef USE_TENSORRT TEST(ExpandOpTest, Expand_scalar_float) { OpTester test("Expand", 8); diff --git a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc index be79a6d29d539..0f23e4c39d7e2 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc @@ -3,6 +3,9 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "gtest/gtest.h" +#if USE_CUDA +#include "test/common/cuda_op_test_utils.h" +#endif #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -122,6 +125,9 @@ TEST(GatherOpTest, Gather_invalid_index_gpu) { 4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.0f, 0.0f, 0.0f}); +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif // On GPU, just set the value to 0 instead of report error. exclude all other providers test #if defined(USE_CUDA) diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 05cfb5c13d689..7e1a2384d7fc6 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -15,11 +15,13 @@ std::vector> GetExecutionProviders(int opset execution_providers.emplace_back(DefaultCpuExecutionProvider()); #ifdef USE_CUDA - if (opset_version < 20) { - execution_providers.emplace_back(DefaultCudaExecutionProvider()); + if (DefaultCudaExecutionProvider() != nullptr) { + if (opset_version < 20) { + execution_providers.emplace_back(DefaultCudaExecutionProvider()); #ifdef ENABLE_CUDA_NHWC_OPS - execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); + execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); #endif + } } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index a32d43f296250..2169436255727 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -5,6 +5,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -263,22 +264,6 @@ TEST(SliceTest, Slice3D) { 332.0f, 333.0f}); } -template -static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - std::vector inputs_T(inputs.size()); - if constexpr (std::is_same::value) { - return inputs; - } else if constexpr (std::is_integral_v) { - for (size_t i = 0; i < inputs.size(); i++) { - inputs_T[i] = static_cast(inputs[i]); - } - return inputs_T; - } else { - ConvertFloatToMLFloat16(inputs.data(), inputs_T.data(), inputs.size()); - return inputs_T; - } -} - template static void TestSlice1DIntData() { // static_assert(std::is_integral_v); diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 48872404f08bd..1c2a86bb808b5 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "core/framework/to_tensor_proto_element_type.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" namespace onnxruntime { namespace test { @@ -178,17 +179,6 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } -template -std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - TEST(SplitOperatorTest, Axis0UnequalSplitString) { constexpr int64_t axis = 0; std::vector outputs; diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index b517b1a2837f0..5902fbe3ddd6f 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -142,7 +142,7 @@ void RunTestWrapper() { RunTest({2, 1, 3}, {2, 2, 1}); RunTest({2, 1, 3}, {2, 2, 1}, true); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) // _TileMemcpyKernelFromInput, vectorized 4 RunTest({256, 512}, {3, 1}); @@ -253,7 +253,7 @@ TEST(TensorOpTest, TileStringType) { RunTestWrapper(); } TEST(TensorOpTest, TileBoolType) { RunTestWrapperForBool(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(TensorOpTest, TileMLFloat16Type) { RunTestWrapper(); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 3b46dc3f5d6a2..73a5bce768a2a 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -69,17 +69,6 @@ void TransposeTest(const std::vector& input_shape, } } -template -std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { - if constexpr (std::is_same::value) { - return inputs; - } else { - std::vector inputs_fp16(inputs.size()); - ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); - return inputs_fp16; - } -} - // Test 2 dimensional transpose, with no permutation attribute specified TYPED_TEST(TransposeOpTest, TwoDimNoAttr) { std::vector input_shape({2, 3}); diff --git a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc index d2aa5dd428fec..d1910c89f76b7 100644 --- a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc @@ -11,7 +11,7 @@ namespace test { // Disable TensorRT on the tests because of SegFault errors in the parser -TEST(TensorOpTest, Unsqueeze_1) { +TEST(UnsqueezeOpTest, Unsqueeze_1) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{1}); @@ -20,7 +20,7 @@ TEST(TensorOpTest, Unsqueeze_1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -TEST(TensorOpTest, Unsqueeze_1_int32) { +TEST(UnsqueezeOpTest, Unsqueeze_1_int32) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{1}); @@ -29,7 +29,7 @@ TEST(TensorOpTest, Unsqueeze_1_int32) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -TEST(TensorOpTest, Unsqueeze_2) { +TEST(UnsqueezeOpTest, Unsqueeze_2) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{0, 4}); @@ -38,7 +38,7 @@ TEST(TensorOpTest, Unsqueeze_2) { test.Run(); } -TEST(TensorOpTest, Unsqueeze_3) { +TEST(UnsqueezeOpTest, Unsqueeze_3) { OpTester test("Unsqueeze"); test.AddAttribute("axes", std::vector{2, 1, 0}); @@ -47,7 +47,7 @@ TEST(TensorOpTest, Unsqueeze_3) { test.Run(); } -TEST(TensorOpTest, Unsqueeze_scalar) { +TEST(UnsqueezeOpTest, Unsqueeze_scalar) { { OpTester test("Unsqueeze"); @@ -85,7 +85,7 @@ TEST(TensorOpTest, Unsqueeze_scalar) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_scalar_2) { +TEST(UnsqueezeOpTest, Unsqueeze_scalar_2) { { OpTester test("Unsqueeze"); @@ -105,7 +105,7 @@ TEST(TensorOpTest, Unsqueeze_scalar_2) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_Duplicate) { +TEST(UnsqueezeOpTest, Unsqueeze_Duplicate) { { OpTester test("Unsqueeze", 12); // opset 1-12 has axes attribute @@ -128,7 +128,7 @@ TEST(TensorOpTest, Unsqueeze_Duplicate) { } } -TEST(TensorOpTest, Unsqueeze_OutOfRange) { +TEST(UnsqueezeOpTest, Unsqueeze_OutOfRange) { { OpTester test("Unsqueeze", 12); // opset 1-12 has axes attribute test.AddAttribute("axes", std::vector{4}); @@ -149,7 +149,7 @@ TEST(TensorOpTest, Unsqueeze_OutOfRange) { } } -TEST(TensorOpTest, UnsqueezeNegAxis_3) { +TEST(UnsqueezeOpTest, UnsqueezeNegAxis_3) { { OpTester test("Unsqueeze", 12); // opset 1-12 has axes attribute test.AddAttribute("axes", std::vector{-4, 1, -6}); @@ -171,7 +171,7 @@ TEST(TensorOpTest, UnsqueezeNegAxis_3) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_1_int32_axes_input) { +TEST(UnsqueezeOpTest, Unsqueeze_1_int32_axes_input) { auto run_test = [](bool axes_is_initializer) { OpTester test("Unsqueeze", 13); @@ -185,7 +185,7 @@ TEST(TensorOpTest, Unsqueeze_1_int32_axes_input) { run_test(true); } -TEST(TensorOpTest, Unsqueeze_3_axes_input) { +TEST(UnsqueezeOpTest, Unsqueeze_3_axes_input) { auto run_test = [](bool axes_is_initializer) { OpTester test("Unsqueeze", 13); @@ -200,7 +200,7 @@ TEST(TensorOpTest, Unsqueeze_3_axes_input) { } #if defined(USE_DNNL) -TEST(TensorOpTest, Unsqueeze_3_axes_input_bfloat16) { +TEST(UnsqueezeOpTest, Unsqueeze_3_axes_input_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; @@ -218,7 +218,7 @@ TEST(TensorOpTest, Unsqueeze_3_axes_input_bfloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -TEST(TensorOpTest, UnsqueezeNegAxis_3_bfloat16) { +TEST(UnsqueezeOpTest, UnsqueezeNegAxis_3_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; diff --git a/onnxruntime/test/providers/cuda/cuda_provider_test.cc b/onnxruntime/test/providers/cuda/cuda_provider_test.cc index e57cdd2350fab..e745e1bcb8171 100644 --- a/onnxruntime/test/providers/cuda/cuda_provider_test.cc +++ b/onnxruntime/test/providers/cuda/cuda_provider_test.cc @@ -11,7 +11,7 @@ ProviderInfo_CUDA& GetProviderInfo_CUDA_Test(); namespace test { namespace cuda { -TEST(CUDA_EP_Unittest, All) { +TEST(CudaEpUnittest, All) { onnxruntime::ProviderInfo_CUDA& ep = onnxruntime::GetProviderInfo_CUDA_Test(); ep.TestAll(); } diff --git a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc index b413d04fe81e8..ec7c6ec4e1605 100644 --- a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace test { -TEST(AllocatorTest, CUDAAllocatorTest) { +TEST(CudaEpAllocatorTest, CUDAAllocatorTest) { OrtDevice::DeviceId cuda_device_id = 0; // ensure CUDA device is available. @@ -77,7 +77,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { } // test that we fallback to smaller allocations if the growth of the arena exceeds the available memory -TEST(AllocatorTest, CUDAAllocatorFallbackTest) { +TEST(CudaEpAllocatorTest, CUDAAllocatorFallbackTest) { OrtDevice::DeviceId cuda_device_id = 0; size_t free = 0; diff --git a/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc index b2e986f680763..ccdc56de5937d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc @@ -17,7 +17,7 @@ using onnxruntime::contrib::attention::AttentionBackend; namespace onnxruntime { namespace test { -TEST(AttentionKernelOptionsTest, NonZeroValue) { +TEST(CudaEpAttentionKernelOptionsTest, NonZeroValue) { { AttentionKernelOptions options; int value = static_cast(AttentionBackend::FLASH_ATTENTION) | static_cast(AttentionBackend::EFFICIENT_ATTENTION); @@ -156,7 +156,7 @@ TEST(AttentionKernelOptionsTest, NonZeroValue) { } // Test all environment variables take effect when option value is 0. -TEST(AttentionKernelOptionsTest, DefaultOptionWithEnvVar) { +TEST(CudaEpAttentionKernelOptionsTest, DefaultOptionWithEnvVar) { constexpr int value = 0; ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ @@ -186,7 +186,7 @@ TEST(AttentionKernelOptionsTest, DefaultOptionWithEnvVar) { } // Test default min sequence lengths when environment variables are not set. -TEST(AttentionKernelOptionsTest, DefaultMinSeqLens) { +TEST(CudaEpAttentionKernelOptionsTest, DefaultMinSeqLens) { constexpr int value = 0; ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ diff --git a/onnxruntime/test/providers/cuda/test_cases/beam_search_topk.cc b/onnxruntime/test/providers/cuda/test_cases/beam_search_topk.cc index a0d115c41c14b..97d50398a5550 100644 --- a/onnxruntime/test/providers/cuda/test_cases/beam_search_topk.cc +++ b/onnxruntime/test/providers/cuda/test_cases/beam_search_topk.cc @@ -68,7 +68,7 @@ void ComputeTopKReference(const std::vector& values, } } -TEST(TestBeamSearch, TopK) { +TEST(CudaEpTestBeamSearch, TopK) { int32_t batch_size = 4; int32_t beam_size = 4; int32_t vocab_size = 50257; diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 3fcb9045ee7e6..d8fb3c8256012 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -230,7 +230,7 @@ void testPrepack(int rows, int columns) { } // TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 -TEST(BlkQ4_GEMM, PrepackSm80Test) { +TEST(CudaEpBlkQ4_GEMM, PrepackSm80Test) { Status status = onnxruntime::cuda::test::sm80_supported(); if (!status.IsOK()) { // skip the test if sm80 is not supported @@ -263,7 +263,7 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { testPrepack(256, 256); } -TEST(BlkQ4_GEMM, Sm80RowBlockingTest) { +TEST(CudaEpBlkQ4_GEMM, Sm80RowBlockingTest) { Status status = onnxruntime::cuda::test::sm80_supported(); if (!status.IsOK()) { // skip the test if sm80 is not supported @@ -292,7 +292,7 @@ TEST(BlkQ4_GEMM, Sm80RowBlockingTest) { onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576); } -TEST(BlkQ4_GEMM, Sm80ColBlockingTest) { +TEST(CudaEpBlkQ4_GEMM, Sm80ColBlockingTest) { Status status = onnxruntime::cuda::test::sm80_supported(); if (!status.IsOK()) { // skip the test if sm80 is not supported @@ -305,7 +305,7 @@ TEST(BlkQ4_GEMM, Sm80ColBlockingTest) { onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576); } -TEST(BlkQ4_GEMM, Sm80SmallMTest) { +TEST(CudaEpBlkQ4_GEMM, Sm80SmallMTest) { Status status = onnxruntime::cuda::test::sm80_supported(); if (!status.IsOK()) { // skip the test if sm80 is not supported @@ -326,7 +326,7 @@ TEST(BlkQ4_GEMM, Sm80SmallMTest) { onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, true>(16, 1024, 576); } -TEST(BlkQ4_GEMM, Sm80SmallTileKernelTest) { +TEST(CudaEpBlkQ4_GEMM, Sm80SmallTileKernelTest) { Status status = onnxruntime::cuda::test::sm80_supported(); if (!status.IsOK()) { // skip the test if sm80 is not supported diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 72357ec7e02d2..f3222c6f683b5 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -19,7 +19,7 @@ namespace cuda { namespace test { // TODO: Since the "DeferredRelease" has been migrated to CudaStream class, // we should migrate this test from CudaEP unit test to CudaStream unit test. -TEST(TestDeferredRelease, WithArena) { +TEST(CudaEpTestDeferredRelease, WithArena) { // Create CUDA EP. CUDAExecutionProviderInfo info; CUDAExecutionProvider ep(info); @@ -52,7 +52,7 @@ TEST(TestDeferredRelease, WithArena) { ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } -TEST(TestDeferredRelease, WithoutArena) { +TEST(CudaEpTestDeferredRelease, WithoutArena) { // Create CUDA EP. CUDAExecutionProviderInfo info; CUDAExecutionProvider ep(info); diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_utils_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_utils_test.cc index 7468a5718425e..3538c7add94d0 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_utils_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_utils_test.cc @@ -40,7 +40,7 @@ void TestFillCorrectness(size_t num_elements, TElement value) { } } // namespace -TEST(CudaUtilsTest, FillCorrectness) { +TEST(CudaEpUnittest, FillCorrectness) { TestFillCorrectness(1 << 20, 1); TestFillCorrectness(1 << 20, 2); TestFillCorrectness(1 << 20, 3); diff --git a/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc b/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc index 6636e15040393..518fde5804b23 100644 --- a/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc @@ -10,7 +10,7 @@ namespace onnxruntime { namespace cuda { namespace test { -TEST(CudaGemmOptions, TestDefaultOptions) { +TEST(CudaEpGemmOptions, TestDefaultOptions) { HalfGemmOptions gemm_options; ASSERT_FALSE(gemm_options.IsCompute16F()); #if defined(USE_CUDA) @@ -22,7 +22,7 @@ TEST(CudaGemmOptions, TestDefaultOptions) { #endif } -TEST(CudaGemmOptions, TestCompute16F) { +TEST(CudaEpGemmOptions, TestCompute16F) { HalfGemmOptions gemm_options; gemm_options.Initialize(1); ASSERT_TRUE(gemm_options.IsCompute16F()); @@ -35,7 +35,7 @@ TEST(CudaGemmOptions, TestCompute16F) { #endif } -TEST(CudaGemmOptions, NoReducedPrecision) { +TEST(CudaEpGemmOptions, NoReducedPrecision) { HalfGemmOptions gemm_options; gemm_options.Initialize(2); ASSERT_FALSE(gemm_options.IsCompute16F()); @@ -48,7 +48,7 @@ TEST(CudaGemmOptions, NoReducedPrecision) { #endif } -TEST(CudaGemmOptions, Pedantic) { +TEST(CudaEpGemmOptions, Pedantic) { HalfGemmOptions gemm_options; gemm_options.Initialize(4); ASSERT_FALSE(gemm_options.IsCompute16F()); @@ -61,7 +61,7 @@ TEST(CudaGemmOptions, Pedantic) { #endif } -TEST(CudaGemmOptions, Compute16F_Pedantic) { +TEST(CudaEpGemmOptions, Compute16F_Pedantic) { HalfGemmOptions gemm_options; gemm_options.Initialize(5); ASSERT_TRUE(gemm_options.IsCompute16F()); @@ -74,7 +74,7 @@ TEST(CudaGemmOptions, Compute16F_Pedantic) { #endif } -TEST(CudaGemmOptions, Compute16F_NoReducedPrecision) { +TEST(CudaEpGemmOptions, Compute16F_NoReducedPrecision) { HalfGemmOptions gemm_options; gemm_options.Initialize(3); ASSERT_TRUE(gemm_options.IsCompute16F()); diff --git a/onnxruntime/test/providers/cuda/test_cases/greedy_search_top_one.cc b/onnxruntime/test/providers/cuda/test_cases/greedy_search_top_one.cc index 6b8cd68de0fca..ba24cf858e80f 100644 --- a/onnxruntime/test/providers/cuda/test_cases/greedy_search_top_one.cc +++ b/onnxruntime/test/providers/cuda/test_cases/greedy_search_top_one.cc @@ -41,7 +41,7 @@ void ComputeTop1Reference(const std::vector& values, } } -TEST(TestGreedySearch, TopOne) { +TEST(CudaEpTestGreedySearch, TopOne) { int32_t batch_size = 4; int32_t vocab_size = 50257; int32_t batch_x_vocab = batch_size * vocab_size; diff --git a/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc b/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc index ec7e98528504e..09c9c1e5f8f6a 100644 --- a/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc @@ -179,7 +179,7 @@ void TestReduceColumnsToColumn(int m, int n, float relative_error_tolerance = 1e } } // namespace -TEST(ReductionFunctionsTest, ReduceRowToScalar) { +TEST(CudaEpReductionFunctionsTest, ReduceRowToScalar) { TestReduceRowToScalarApis(3); TestReduceRowToScalarApis(19); TestReduceRowToScalarApis(123); @@ -188,7 +188,7 @@ TEST(ReductionFunctionsTest, ReduceRowToScalar) { TestReduceRowToScalarApis(941736, 2e-4f); } -TEST(ReductionFunctionsTest, ReduceRowsToRow) { +TEST(CudaEpReductionFunctionsTest, ReduceRowsToRow) { for (int m : {3, 193, 2945}) { for (int n : {3, 193, 2945}) { TestReduceRowsToRow(m, n, true); @@ -197,7 +197,7 @@ TEST(ReductionFunctionsTest, ReduceRowsToRow) { } } -TEST(ReductionFunctionsTest, ReduceColumnsToColumn) { +TEST(CudaEpReductionFunctionsTest, ReduceColumnsToColumn) { for (int m : {3, 193, 2945}) { for (int n : {3, 193, 2945}) { TestReduceColumnsToColumn(m, n); @@ -205,7 +205,7 @@ TEST(ReductionFunctionsTest, ReduceColumnsToColumn) { } } -TEST(ReductionFunctionsTest, BufferOffsets) { +TEST(CudaEpReductionFunctionsTest, BufferOffsets) { const int m = 2048; const int n = 1024; const TensorShape shape{m, n}; @@ -240,7 +240,7 @@ TEST(ReductionFunctionsTest, BufferOffsets) { } } -TEST(ReductionFunctionsTest, InvalidBufferSize) { +TEST(CudaEpReductionFunctionsTest, InvalidBufferSize) { const int m = 2048; const int n = 1024; const TensorShape shape{m, n}; @@ -262,7 +262,7 @@ TEST(ReductionFunctionsTest, InvalidBufferSize) { ASSERT_FALSE(status.IsOK()); } -TEST(ReductionFunctionsTest, GetApplicableMatrixReduction) { +TEST(CudaEpReductionFunctionsTest, GetApplicableMatrixReduction) { auto test_get_applicable_matrix_reduction = [](cudnnReduceTensorOp_t cudnn_op, const std::vector& dims, const std::vector& axes, diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 23ec48fa649dd..93e688570631e 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -42,8 +42,9 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } #endif + const auto& logger = DefaultLoggingManager().DefaultLogger(); Model model("test", false, ModelMetaData(), ORT_TSTR(""), IOnnxRuntimeOpSchemaRegistryList(), - {{domain_, opset_version_}}, {}, DefaultLoggingManager().DefaultLogger()); + {{domain_, opset_version_}}, {}, logger); std::vector input_args; std::unordered_map initializer_map; @@ -89,8 +90,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { ASSERT_STATUS_OK(graph.Resolve()); node.SetExecutionProviderType(ep_type); - OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), [](std::string const&) { return false; }, logger); const KernelCreateInfo* kernel_create_info = nullptr; ASSERT_STATUS_OK(info.TryFindKernel(&node, &kernel_create_info)); ASSERT_TRUE(kernel_create_info); @@ -139,7 +139,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { #pragma warning(disable : 6387) #endif OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs, outputs); - OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, DefaultLoggingManager().DefaultLogger()); + OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, logger); #ifdef _WIN32 #pragma warning(pop) #endif diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 5db69489afaef..f1fbb1cea7ea2 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -51,7 +51,7 @@ TEST(PartitioningUtilsTest, TestQDQHandling) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, @@ -82,7 +82,7 @@ static void CheckAllNodesProcessed(const std::function& std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); const auto is_node_supported = [&](const Node& /*node*/) -> bool { return true; diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index e3f09e92593df..55177cc7ed131 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -131,11 +131,16 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { ExpectedEPNodeAssignment::All); } +// disabled for QNN 2.28.0.241029 failed for accuracy validation +// Also fails on QNN 2.28.2. +// qdq@QNN_EP val: 3.6094117164611816 (err: 1.3094117641448975, err/output_range: 22.19342041015625%) +// qdq@CPU_EP val: 2.2905881404876709 (err: 0.0094118118286132812, err/output_range: 0.15952222049236298%) +// abs(qdq@QNN_EP - qdq@CPU_EP) / output_range = 22.033897399902344% // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all // nodes are supported by the QNN EP, and that the inference results are as accurate as CPU EP. // // Static int32 indices with axis = 1 -TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis1) { +TEST_F(QnnHTPBackendTests, DISABLED_GatherOp_IndicesStaticInt32_Axis1) { RunQDQGatherOpTest(TestInputDef({3, 3}, false, {1.0f, 1.2f, 1.9f, 2.3f, 3.4f, 3.9f, 4.5f, 5.7f, 5.9f}), TestInputDef({1, 2}, true, {0, 2}), {utils::MakeAttribute("axis", static_cast(1))}, diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 2773568dde717..947ac19be40a8 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -188,15 +188,11 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_StaticBias_AU8_WU8_B ExpectedEPNodeAssignment::All); } -// QNN 2.27 accuracy issue -// Inaccuracy detected for output 'output_0', element 0 -// output_range=1.2245157957077026, tolerance=0.40000000596046448%. -// Expected val (f32@CPU_EP): -0 -// qdq@QNN_EP val: 0.19133351743221283 (err: 0.19133351743221283, err/output_range: 15.625238418579102%) -// qdq@CPU_EP val: 0 (err: 0, err/output_range: 0%) -TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { - // QNN 2.24 LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide an - // explicit bias of all zeros to get around this bug. +TEST_F(QnnHTPBackendTests, LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { + // QNN 2.24 to 2.27: LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide + // an explicit bias of all zeros to get around this bug. + // QNN 2.28.0: Validation bug is fixed, but get accuracy errors. + // QNN 2.28.2: All fixed. for (size_t i = 0; i < 15; i++) { // Run it multiple times since this is an intermittent bug. RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 1.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), @@ -207,14 +203,9 @@ TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_QNN2_24_ImplicitBias_ValidationB } } -// Test accuracy of 16-bit QDQ LayerNorm with a static scale input. -// QNN 2.27 accuracy issue -// Inaccuracy detected for output 'output_0', element 0 -// output_range=1.224743127822876, tolerance=0.40000000596046448%. -// Expected val (f32@CPU_EP): -0 -// qdq@QNN_EP val: 0.19136904180049896 (err: 0.19136904180049896, err/output_range: 15.625238418579102%) -// qdq@CPU_EP val: 0 (err: 0, err/output_range: 0%) -TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { +TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { + // QNN 2.28.0: Get accuracy errors. + // QNN 2.28.2: All fixed. RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), // Static TestInputDef(), @@ -225,7 +216,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { // Test accuracy of 8-bit QDQ LayerNorm with a dynamic scale input. // -// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. +// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. Still fails on QNN SDK 2.28.2. // Verbose logs: // Starting stage: Graph Transformations and Optimizations // C:\...\QNN\HTP\HTP\src\hexagon\prepare\graph_prepare.cc:203:ERROR:could not create op: q::flat_to_vtcm diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 800457d906940..5c6967761b1db 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -273,7 +273,7 @@ TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { } // Test QDQ per-channel MatMul with int8 act, int4 weights (static) -// QNN 2.27 regression +// QNN 2.27 regression. Also fails on QNN 2.28.2. // Failed to finalize QNN graph. Error code: 1002 TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_PerChannel_AS8_WeightInt4) { std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 79e7d39e85518..4feeb5f830508 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -388,6 +388,7 @@ bool ReduceOpHasAxesInput(const std::string& op_type, int opset_version) { {"ReduceMean", 18}, {"ReduceProd", 18}, {"ReduceSum", 13}, + {"ReduceL2", 18}, }; const auto it = opset_with_axes_as_input.find(op_type); diff --git a/onnxruntime/test/providers/qnn/reduce_op_test.cc b/onnxruntime/test/providers/qnn/reduce_op_test.cc index 13173d9a87f55..e4abe85908373 100644 --- a/onnxruntime/test/providers/qnn/reduce_op_test.cc +++ b/onnxruntime/test/providers/qnn/reduce_op_test.cc @@ -309,6 +309,27 @@ TEST_F(QnnCPUBackendTests, ReduceMeanOpset13) { ExpectedEPNodeAssignment::All); } +// +// ReduceL2 +// +TEST_F(QnnCPUBackendTests, ReduceL2Opset18) { + RunReduceTest("ReduceL2", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 18, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, ReduceL2Opset13) { + RunReduceTest("ReduceL2", + TestInputDef({2, 2}, false, -10.0f, 10.0f), + std::vector{0, 1}, + true, // keepdims + 13, + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // Test creates a graph with a ReduceSum node, and checks that all nodes are supported by the QNN EP diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 018720fd8b71f..7541d94bac0c6 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -229,8 +229,16 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Tanh) { ExpectedEPNodeAssignment::All); } +// disabled for QNN 2.28.0.241029 backendValidateOpConfig failed +// still fails on QNN 2.28.2. +// QnnDsp [4294967295] has incorrect Value -32768, expected equal to 0. +// QnnDsp validateNativeOps node_token_6:qti.aisw:Tanh htp op validator failed 3110 +// QnnDsp registered validator failed => 3110 +// QnnDsp QnnBackend_validateOpConfig failed 3110 +// QnnDsp Wake up free backend (id: 1)'s thread(s) +// QnnDsp Failed to validate op node_token_6 with error 0xc26 // Tests accuracy of 16-bit QDQ Tanh. -TEST_F(QnnHTPBackendTests, UnaryOp_Tanh_U16) { +TEST_F(QnnHTPBackendTests, DISABLED_UnaryOp_Tanh_U16) { RunQDQOpTest("Tanh", {TestInputDef({1, 2, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 6))}, {}, diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 63327a028c6f4..0022d7fc0e184 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -342,8 +342,12 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { Graph& graph = model->MainGraph(); GraphViewer viewer(graph); + std::string trt_version = std::to_string(NV_TENSORRT_MAJOR) + "." + std::to_string(NV_TENSORRT_MINOR); + std::string cuda_version = std::to_string(CUDA_VERSION); + std::string ort_version = ORT_VERSION; + // get the hash for the model when loaded from file - HashValue model_hash = TRTGenerateId(viewer); + HashValue model_hash = TRTGenerateId(viewer, trt_version, cuda_version); ASSERT_NE(model_hash, 0); // now load the model from bytes and check the hash differs @@ -358,7 +362,7 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { // Test loading same model from file and byte steam. Hash values should be different Graph& graph2 = model2->MainGraph(); GraphViewer viewer2(graph2); - HashValue model_hash2 = TRTGenerateId(viewer2); + HashValue model_hash2 = TRTGenerateId(viewer2, trt_version, cuda_version); ASSERT_NE(model_hash, model_hash2); // Test loading same model from different path, see if hash values are same as well @@ -367,7 +371,7 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { ASSERT_TRUE(Model::Load(model_path, model3, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); Graph& graph3 = model3->MainGraph(); GraphViewer viewer3(graph3); - HashValue model_hash3 = TRTGenerateId(viewer3); + HashValue model_hash3 = TRTGenerateId(viewer3, trt_version, cuda_version); ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded"; } diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 9b1e87f6ec02e..a274b90dc042f 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -105,7 +105,7 @@ def load_jsonc(basename: str): return json.loads("\n".join(lines)) -def create_backend_test(test_name=None): +def create_backend_test(devices: list[str], test_name=None): """Creates an OrtBackendTest and adds its TestCase's to global scope so unittest will find them.""" overrides = load_jsonc("onnx_backend_test_series_overrides.jsonc") @@ -126,37 +126,47 @@ def create_backend_test(test_name=None): else: filters = load_jsonc("onnx_backend_test_series_filters.jsonc") current_failing_tests = apply_filters(filters, "current_failing_tests") - if platform.architecture()[0] == "32bit": current_failing_tests += apply_filters(filters, "current_failing_tests_x86") - if backend.supports_device("DNNL"): + if backend.supports_device("DNNL") or "DNNL" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_DNNL") - if backend.supports_device("NNAPI"): + if backend.supports_device("NNAPI") or "NNAPI" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_NNAPI") - if backend.supports_device("OPENVINO_GPU"): + if backend.supports_device("OPENVINO_GPU") or "OPENVINO_GPU" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_GPU") - if backend.supports_device("OPENVINO_CPU"): + if backend.supports_device("OPENVINO_CPU") or "OPENVINO_CPU" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP32") current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16") - if backend.supports_device("OPENVINO_NPU"): + if backend.supports_device("OPENVINO_NPU") or "OPENVINO_NPU" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU") - if backend.supports_device("OPENVINO"): + if backend.supports_device("OPENVINO") or "OPENVINO" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18") - if backend.supports_device("MIGRAPHX"): + if backend.supports_device("MIGRAPHX") or "MIGRAPHX" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_MIGRAPHX") + if backend.supports_device("WEBGPU"): + current_failing_tests += apply_filters(filters, "current_failing_tests_WEBGPU") + # Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA # EPs are available (Windows GPU CI pipeline has this config) - these test will pass because CUDA has higher precedence than DML # and the nodes are assigned to only the CUDA EP (which supports these tests) - if backend.supports_device("DML") and not backend.supports_device("GPU"): + if (backend.supports_device("DML") and not backend.supports_device("GPU")) or "DML" in devices: current_failing_tests += apply_filters(filters, "current_failing_tests_pure_DML") + # exclude CUDA EP when DML test is running. + os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider,CUDAExecutionProvider" + elif backend.supports_device("DML") and "DML" not in devices: + # exclude DML EP when CUDA test is running. + os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider,DmlExecutionProvider" + else: + # exclude TRT EP temporarily and only test CUDA EP to retain previous behavior + os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider" filters = ( current_failing_tests @@ -169,9 +179,6 @@ def create_backend_test(test_name=None): backend_test.exclude("(" + "|".join(filters) + ")") print("excluded tests:", filters) - # exclude TRT EP temporarily and only test CUDA EP to retain previous behavior - os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider" - # import all test cases at global scope to make # them visible to python.unittest. globals().update(backend_test.enable_report().test_cases) @@ -196,6 +203,15 @@ def parse_args(): help="Only run tests that match this value. Matching is regex based, and '.*' is automatically appended", ) + parser.add_argument( + "--devices", + type=str, + choices=["CPU", "CUDA", "MIGRAPHX", "DNNL", "DML", "OPENVINO_GPU", "OPENVINO_CPU", "OPENVINO_NPU", "OPENVINO"], + nargs="+", # allows multiple values + default=["CPU"], # default to ["CPU"] if no input is given + help="Select one or more devices CPU, CUDA, MIGRAPHX, DNNL, DML, OPENVINO_GPU, OPENVINO_CPU, OPENVINO_NPU, OPENVINO", + ) + # parse just our args. python unittest has its own args and arg parsing, and that runs inside unittest.main() parsed, unknown = parser.parse_known_args() sys.argv = sys.argv[:1] + unknown @@ -206,5 +222,5 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - create_backend_test(args.test_name) + create_backend_test(args.devices, args.test_name) unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 29680c98fb4de..2f8fb84c4c651 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -28,7 +28,6 @@ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference import unittest -from pathlib import Path def unique_element(lst): @@ -41,6 +40,8 @@ def unique_element(lst): class TestSymbolicShapeInference(unittest.TestCase): def test_symbolic_shape_infer(self): + from pathlib import Path + cwd = os.getcwd() test_model_dir = os.path.join(cwd, "..", "models") for filename in Path(test_model_dir).rglob("*.onnx"): diff --git a/onnxruntime/test/python/onnxruntime_test_python_tvm.py b/onnxruntime/test/python/onnxruntime_test_python_tvm.py deleted file mode 100644 index 0080bf53520f2..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_python_tvm.py +++ /dev/null @@ -1,242 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -""" -Module for unit testing of TVM EP -""" - -import os -import sys -import tempfile -import unittest -from typing import Any, AnyStr, Dict, List, Tuple - -import numpy -import tvm -from numpy.testing import assert_almost_equal -from onnx import ModelProto, TensorProto, mapping -from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info - -import onnxruntime - -numpy.random.seed(32) - - -def is_windows(): - """ - Function to determine the Windows system - """ - return sys.platform.startswith("win") - - -def get_model_with_dynamic_shapes() -> ModelProto: - """ - Create model with Dynamic Shapes - """ - x = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) # pylint: disable=invalid-name, no-member - node1 = make_node("MatMul", ["X", "A"], ["XA"]) - node2 = make_node("Add", ["XA", "B"], ["Y"]) - graph = make_graph([node1, node2], "lr", [x, a, b], [y]) - onnx_model = make_model(graph) - return onnx_model - - -def get_model_with_fixed_shapes() -> ModelProto: - """ - Create model with Static Shapes - """ - - def change_input_shape(model: ModelProto, ind: int, shape: Tuple) -> None: - """ - Function to change the input form - """ - dims = model.graph.input[ind].type.tensor_type.shape.dim - assert len(dims) == len(shape), "Input rank and new shape rank do not match." - for i, new_dim in enumerate(shape): - model.graph.input[ind].type.tensor_type.shape.dim[i].dim_value = new_dim - - dynamic_model = get_model_with_dynamic_shapes() - change_input_shape(dynamic_model, 0, (1, 2)) # X - change_input_shape(dynamic_model, 1, (2, 2)) # A - change_input_shape(dynamic_model, 2, (1, 2)) # B - return dynamic_model - - -def get_input_data_for_model_with_dynamic_shapes() -> Dict[AnyStr, numpy.ndarray]: - """ - Create input data for model with dynamic shapes - """ - a = numpy.random.randn(2, 2).astype(numpy.float32) # pylint: disable=invalid-name - b = numpy.random.randn(1, 2).astype(numpy.float32) # pylint: disable=invalid-name - x = numpy.random.randn(1, 2).astype(numpy.float32) # pylint: disable=invalid-name - data = {"A": a, "B": b, "X": x} - return data - - -def get_input_data_for_model_with_fixed_shapes(onnx_model: ModelProto) -> Dict[AnyStr, numpy.ndarray]: - """ - Create input data for model with static shapes - """ - - def get_onnx_input_names(model: ModelProto) -> List[AnyStr]: - inputs = [node.name for node in model.graph.input] - initializer = [node.name for node in model.graph.initializer] - inputs = list(set(inputs) - set(initializer)) - return sorted(inputs) - - def get_onnx_input_types(model: ModelProto) -> List[numpy.dtype]: - input_names = get_onnx_input_names(model) - return [ - mapping.TENSOR_TYPE_TO_NP_TYPE[node.type.tensor_type.elem_type] - for node in sorted(model.graph.input, key=lambda node: node.name) - if node.name in input_names - ] - - def get_onnx_input_shapes(model: ModelProto) -> List[List[int]]: - input_names = get_onnx_input_names(model) - return [ - [dv.dim_value for dv in node.type.tensor_type.shape.dim] - for node in sorted(model.graph.input, key=lambda node: node.name) - if node.name in input_names - ] - - input_names = get_onnx_input_names(onnx_model) - input_shapes = get_onnx_input_shapes(onnx_model) - input_types = get_onnx_input_types(onnx_model) - assert len(input_names) == len(input_types) == len(input_shapes) - random_inputs = [numpy.random.uniform(size=shape).astype(dtype) for shape, dtype in zip(input_shapes, input_types)] - return dict(zip(input_names, random_inputs)) - - -def get_input_names_and_shapes(data: Dict[AnyStr, numpy.ndarray]) -> Tuple[List[AnyStr], List[AnyStr]]: - """ - Create text representations for model input names and shapes - """ - keys = list(data.keys()) - values = [data[key] for key in keys] - return ( - list(data.keys()), - [str(value.shape).replace(",", "").replace("(", "[").replace(")", "]") for value in values], - ) - - -def get_cpu_output(onnx_model: ModelProto, data: Dict[AnyStr, numpy.ndarray]) -> List[numpy.ndarray]: - """ - Run inference with CPUExecutionProvider - """ - # pylint: disable=no-member - sess = onnxruntime.InferenceSession( - onnx_model.SerializeToString(), - providers=["CPUExecutionProvider"], - ) - output = sess.run(None, data) - return output - - -def get_tvm_output( - onnx_model: ModelProto, data: Dict[AnyStr, numpy.ndarray], provider_options: Dict[AnyStr, Any] -) -> List[numpy.ndarray]: - """ - Run inference with TVMExecutionProvider - """ - session_options = onnxruntime.SessionOptions() # pylint: disable=no-member - session_options.log_severity_level = 0 - session_options.log_verbosity_level = 0 - # pylint: disable=no-member - session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - - sess = onnxruntime.InferenceSession( - onnx_model.SerializeToString(), - session_options, - providers=["TvmExecutionProvider"], - provider_options=[provider_options], - ) - - output = sess.run(None, data) - return output - - -# pylint: disable=no-member -def compile_virtual_machine(model: ModelProto, target_str: AnyStr) -> tvm.runtime.vm.Executable: - """ - Compile ONNX model using VirtualMachine - """ - ir_mod, _ = tvm.relay.frontend.from_onnx( - model, - opset=model.opset_import[0].version, - freeze_params=True, - ) - target = tvm.target.Target(target=target_str, host=target_str) - return tvm.relay.backend.vm.compile(ir_mod, target) - - -def serialize_virtual_machine(vm_exec: tvm.runtime.vm.Executable) -> AnyStr: - """ - Serialize VirtualMachine - """ - temp_directory = tempfile.mkdtemp() - path_consts = os.path.join(temp_directory, "consts") - vm_exec.move_late_bound_consts(path_consts, byte_limit=256) - lib_path = os.path.join(temp_directory, f"model.{'dll' if is_windows() else 'so'}") - code_path = os.path.join(temp_directory, "model.ro") - code, lib = vm_exec.save() - lib.export_library(lib_path) - with open(code_path, "wb") as code_file: - code_file.write(code) - return temp_directory - - -class TestTVM(unittest.TestCase): - """ - Unit tests for TVM EP - """ - - @staticmethod - def test_accuracy_for_model_with_dynamic_shapes(): - """ - Accuracy test for model with dynamic shapes - """ - onnx_model = get_model_with_dynamic_shapes() - data = get_input_data_for_model_with_dynamic_shapes() - - cpu_output = get_cpu_output(onnx_model, data) - names, shapes = get_input_names_and_shapes(data) - provider_options = dict( - target="llvm", - input_names=" ".join(names), - input_shapes=" ".join(shapes), - ) - tvm_output = get_tvm_output(onnx_model, data, provider_options) - - assert_almost_equal(cpu_output, tvm_output, decimal=5) - - @staticmethod - def test_accuracy_for_tvm_so(): - """ - Accuracy test for TVMso Ep - """ - onnx_model = get_model_with_fixed_shapes() - data = get_input_data_for_model_with_fixed_shapes(onnx_model) - - cpu_output = get_cpu_output(onnx_model, data) - - compiled_vm_exec = compile_virtual_machine(onnx_model, target_str="llvm") - so_folder = serialize_virtual_machine(compiled_vm_exec) - provider_options = dict( - target="llvm", - so_folder=so_folder, - ) - tvm_output = get_tvm_output(onnx_model, data, provider_options) - - assert_almost_equal(cpu_output, tvm_output, decimal=5) - - -if __name__ == "__main__": - if "TvmExecutionProvider" not in onnxruntime.get_available_providers(): - raise AssertionError(f"Unable to find 'TvmExecutionProvider' in {onnxruntime.get_available_providers()}") - unittest.main() diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index cf7fc292ea86b..82193d08684c6 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -1,3 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + import uuid from pathlib import Path @@ -661,3 +668,29 @@ def generate_random_initializer(initializer_name, tensor_shape, tensor_dtype, me tensor = np.random.normal(mean, dev, tensor_shape).astype(tensor_dtype) init = onnx.numpy_helper.from_array(tensor, initializer_name) return init + + +def get_tensor_consumers_and_producers( + model: onnx.ModelProto, +) -> tuple[dict[str, list[onnx.NodeProto]], dict[str, onnx.NodeProto]]: + """ + Returns a tuple containing the following python dictionaries: + - consumers: maps a tensor name to the list of nodes that have that tensor as an input. + - producers: maps a tensor name to the node that generates this tensor as an output. + """ + consumers: dict[str, list[onnx.NodeProto]] = {} + producers: dict[str, onnx.NodeProto] = {} + for node in model.graph.node: + # Iterate through node's inputs to build the consumers dictionary. + for input_name in node.input: + if input_name: + if input_name not in consumers: + consumers[input_name] = [] + + consumers[input_name].append(node) + + # Iterate through node's outputs to build the producers dictionary. + for output_name in node.output: + producers[output_name] = node + + return (consumers, producers) diff --git a/onnxruntime/test/python/quantization/test_get_qdq_config.py b/onnxruntime/test/python/quantization/test_get_qdq_config.py new file mode 100644 index 0000000000000..58d00272475cd --- /dev/null +++ b/onnxruntime/test/python/quantization/test_get_qdq_config.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantType, get_qdq_config, quantize + + +class TestGetQDQConfig(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.int_qdq_config_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_add_model( + self, + shape: list[int], + tensor_type: onnx.TensorProto.DataType, + weight: onnx.TensorProto | None = None, + opset: int = 21, + ) -> onnx.ModelProto: + """ + Returns an onnx.ModelProto with a single Add operator. The second input can be optionally made + a static weight. + """ + graph_inputs = [onnx.helper.make_tensor_value_info("input_0", tensor_type, shape)] + graph_outputs = [onnx.helper.make_tensor_value_info("output_0", tensor_type, shape)] + initializers = [] + add_input_names = ["input_0"] + + if weight is not None: + initializers.append(weight) + add_input_names.append(weight.name) + else: + graph_inputs.append(onnx.helper.make_tensor_value_info("input_1", tensor_type, shape)) + add_input_names.append("input_1") + + add_node = onnx.helper.make_node("Add", add_input_names, ["output_0"], name="Add0") + + graph = onnx.helper.make_graph( + [add_node], + "AddGraph", + graph_inputs, + graph_outputs, + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", opset)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_basic_args(self): + """ + Test that get_qdq_config() returns a config that sets the basic args. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=21) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + calibrate_method=CalibrationMethod.Percentile, + calibrate_args={"percentile": 99.98}, # Converted to extra_options + activation_type=QuantType.QUInt16, + weight_type=QuantType.QInt16, + per_channel=True, + reduce_range=True, + nodes_to_exclude=["Mul"], + # Other options converted to extra_options: + min_real_range=0.0001, + keep_removable_activations=True, + activation_symmetric=True, + weight_symmetric=True, + ) + self.assertEqual(qdq_config.calibrate_method, CalibrationMethod.Percentile) + self.assertEqual(qdq_config.activation_type, QuantType.QUInt16) + self.assertEqual(qdq_config.weight_type, QuantType.QInt16) + self.assertTrue(qdq_config.per_channel) + self.assertTrue(qdq_config.reduce_range) + self.assertEqual(set(qdq_config.nodes_to_exclude), {"Mul"}) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + + # Check that calibration args are translated to extra_options. + self.assertEqual(qdq_config.extra_options["CalibPercentile"], 99.98) + + # Check that other args are also translated to extra_options. + self.assertEqual(qdq_config.extra_options["MinimumRealRange"], 0.0001) + self.assertTrue(qdq_config.extra_options["QDQKeepRemovableActivations"]) + self.assertTrue(qdq_config.extra_options["ActivationSymmetric"]) + self.assertTrue(qdq_config.extra_options["WeightSymmetric"]) + + # The following options should always be set to specific values. + self.assertTrue(qdq_config.extra_options["ForceQuantizeNoInputCheck"]) + self.assertEqual(qdq_config.quant_format, QuantFormat.QDQ) + + # Should use onnx domain Q/DQ ops because onnx opset >= 21. + self.assertFalse(qdq_config.extra_options.get("UseQDQContribOps", False)) + + def test_exclude_nodes_callable(self): + """ + Test passing a function/callable to exclude nodes from quantization. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=21) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Local function that excludes all "Add" nodes. + def should_exclude_node_(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + return node.op_type == "Add" + + qdq_config = get_qdq_config( + float_model, + data_reader, + nodes_to_exclude=should_exclude_node_, + ) + + expected_excluded_nodes = set([node.name for node in float_model.graph.node if node.op_type == "Add"]) + self.assertTrue(bool(expected_excluded_nodes)) + self.assertEqual(set(qdq_config.nodes_to_exclude), expected_excluded_nodes) + + def test_external_data(self): + """ + Test that get_qdq_config() returns a config that enables external data + if the input model has external data. + """ + + # Create model with a weight large enough (> 1024 bytes) to be stored externally. + shape = [1, 32, 32] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + large_weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, large_weight) + float_model_path = os.path.join(self._tmp_dir_path, "add_ext_data_int_qdq_config.onnx") + + onnx.save_model( + float_model, + float_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="add_ext_data_int_qdq_config.bin", + ) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(0, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Create a quantization config and check that it sets boolean to use external data + qdq_config = get_qdq_config( + float_model_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8 + ) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + self.assertTrue(qdq_config.use_external_data_format) + + # Quantize the model and check computational correctness against float model. + qdq_model_path = os.path.join(self._tmp_dir_path, "add_ext_data_int_qdq_config.qdq.onnx") + quantize(float_model_path, qdq_model_path, qdq_config) + + expected_op_counts = {"DequantizeLinear": 3, "QuantizeLinear": 2, "Add": 1} + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + # The quantized weight should still be stored in an external file. + qdq_model = onnx.load_model(qdq_model_path, load_external_data=False) + weight_quantized = next( + ( + initializer + for initializer in qdq_model.graph.initializer + if initializer.name == f"{large_weight.name}_quantized" + ), + None, + ) + self.assertIsNotNone(weight_quantized) + self.assertEqual(weight_quantized.data_location, onnx.TensorProto.EXTERNAL) + + def test_use_qdq_contrib_ops_for_int16_opset19(self): + """ + Test that get_qdq_config() returns a config that forces 'com.microsoft' Q/DQ ops for + use of int16 in opset < 21. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=19) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QInt8, + ) + + self.assertEqual(qdq_config.activation_type, QuantType.QUInt16) + self.assertTrue(qdq_config.extra_options["UseQDQContribOps"]) + + def test_use_qdq_contrib_ops_for_int4_opset19(self): + """ + Test that get_qdq_config() returns a config that forces 'com.microsoft' Q/DQ ops for + use of int4 in opset < 21. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=19) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Use int4 in tensor quantization overrides. This should still force use of 'com.microsoft' Q/DQ ops. + qdq_config = get_qdq_config( + float_model, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + tensor_quant_overrides={"weight": [{"quant_type": QuantType.QInt4}]}, + ) + + self.assertEqual(qdq_config.extra_options["TensorQuantOverrides"]["weight"][0]["quant_type"], QuantType.QInt4) + self.assertTrue(qdq_config.extra_options["UseQDQContribOps"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 291bf42405d58..755c7fae5e3e8 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -4,14 +4,23 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations import itertools +import os +import tempfile import unittest import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_type_count, + check_qtype_by_node_type, + get_tensor_consumers_and_producers, +) from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static @@ -519,5 +528,160 @@ def test_pad_with_empty_string_input_name(self): self.assertNotEqual(name, "_quantized") +class TestQDQPad(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.pad_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_pad_model( + self, + mode: str, + constant_value: float | None = None, + opset: int = 21, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + ) -> onnx.ModelProto: + num_pads_start = 1 + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, (3, 2)) + output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, (3, 2 + num_pads_start)) + + initializers = [] + pad_input_names = ["input_0"] + attrs = {"mode": mode} + + pads_data = np.array([0, num_pads_start, 0, 0], dtype=np.int64) # Pad one val at beginning of axis 1. + if opset >= 11: + initializers.append(onnx.numpy_helper.from_array(pads_data, "pads")) + pad_input_names.append("pads") + else: + attrs["pads"] = pads_data.tolist() + + if mode == "constant" and constant_value is not None: + if opset >= 11: + initializers.append(onnx.helper.make_tensor("constant_value", float_type, [], [constant_value])) + pad_input_names.append("constant_value") + else: + attrs["value"] = float(constant_value) + + pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", **attrs) + + graph = onnx.helper.make_graph( + [pad_node], + "PadFloat", + [input_0], + [output_0], + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", opset)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_qdq_pad_qparams(self): + """ + Test that QDQ Pad has equal scale/zero-point for its input and output for certain configurations. + """ + test_configs = [ + # Opset 21 + ("constant", None, 21, onnx.TensorProto.FLOAT), + ("constant", None, 21, onnx.TensorProto.FLOAT16), + ("constant", 0, 21, onnx.TensorProto.FLOAT), + ("constant", 0, 21, onnx.TensorProto.FLOAT16), + ("constant", 10.0, 21, onnx.TensorProto.FLOAT), + ("constant", 10.0, 21, onnx.TensorProto.FLOAT16), + ("reflect", None, 21, onnx.TensorProto.FLOAT), + ("reflect", None, 21, onnx.TensorProto.FLOAT16), + ("edge", None, 21, onnx.TensorProto.FLOAT), + ("edge", None, 21, onnx.TensorProto.FLOAT16), + ("wrap", None, 21, onnx.TensorProto.FLOAT), + ("wrap", None, 21, onnx.TensorProto.FLOAT16), + # Model with opset 10 will use pad of opset 2, which uses attributes instead of inputs. + # Opset 10 Q/DQ ops don't support float16. + ("constant", None, 10, onnx.TensorProto.FLOAT), + ("constant", 0, 10, onnx.TensorProto.FLOAT), + ("constant", 10.0, 10, onnx.TensorProto.FLOAT), + ("reflect", None, 10, onnx.TensorProto.FLOAT), + ("edge", None, 10, onnx.TensorProto.FLOAT), + ] + + for pad_mode, constant_value, opset, float_type in test_configs: + with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset, float_type=float_type): + label = f"_{pad_mode}_{constant_value}_opset{opset}_{onnx.TensorProto.DataType.Name(float_type)}" + float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx") + + float_model = self.build_pad_model(pad_mode, constant_value, opset=opset, float_type=float_type) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type) + input_data_list = [ + {"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np_dtype)}, + {"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + ) + + expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1} + if constant_value is not None and opset >= 11: + expected_op_counts["DequantizeLinear"] += 1 # The constant padding value is quantized. + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + if pad_mode != "reflect": + # Do not check model correctness for 'reflect' mode because ONNX Runtime implementation does + # not match the ONNX reference implementation. See the following issue: + # https://github.com/microsoft/onnxruntime/issues/20801 + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + qdq_model = onnx.load_model(qdq_model_path) + quant_output_same_as_input = False + + if pad_mode in ("reflect", "edge", "wrap"): + quant_output_same_as_input = True + + if pad_mode == "constant" and constant_value in (None, 0): + quant_output_same_as_input = True + + pad_node = next((node for node in qdq_model.graph.node if node.op_type == "Pad"), None) + self.assertNotEqual(pad_node, None) + self.assertEqual(pad_node.op_type, "Pad") + + # Get the parent and child nodes of the Pad and check that they are DQ/Q. + consumers, producers = get_tensor_consumers_and_producers(qdq_model) + input_dq_node = producers.get(pad_node.input[0], None) + self.assertNotEqual(input_dq_node, None) + self.assertEqual(input_dq_node.op_type, "DequantizeLinear") + + output_q_node = consumers.get(pad_node.output[0], [None])[0] + self.assertNotEqual(output_q_node, None) + self.assertEqual(output_q_node.op_type, "QuantizeLinear") + + # Check that the Pad's input DQ uses the same scale/zp as the Pad's output Q. + if quant_output_same_as_input: + self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) # Same scale + self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) # Same zero-point + else: + self.assertNotEqual(input_dq_node.input[1], output_q_node.input[1]) + self.assertNotEqual(input_dq_node.input[2], output_q_node.input[2]) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_slice.py b/onnxruntime/test/python/quantization/test_op_slice.py new file mode 100644 index 0000000000000..bfb9fc6b46bbd --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_slice.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_type_count, + get_tensor_consumers_and_producers, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestQDQSlice(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.slice_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_slice_model( + self, + input_shape: list[int], + input_tensor_type: onnx.TensorProto.DataType, + starts: list[int], + ends: list[int], + axes: list[int] | None = None, + steps: list[int] | None = None, + ) -> onnx.ModelProto: + """ + Returns an onnx.ModelProto with a single Slice operator. + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", input_tensor_type, input_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", input_tensor_type, None) + + initializers = [ + onnx.numpy_helper.from_array(np.array(starts, dtype=np.int64), "starts"), + onnx.numpy_helper.from_array(np.array(ends, dtype=np.int64), "ends"), + ] + slice_input_names = ["input_0", "starts", "ends"] + + if axes: + initializers.append(onnx.numpy_helper.from_array(np.array(axes, dtype=np.int64), "axes")) + slice_input_names.append("axes") + + if steps: + if not axes: + slice_input_names.append("") # Empty axes input. + initializers.append(onnx.numpy_helper.from_array(np.array(steps, dtype=np.int64), "steps")) + slice_input_names.append("steps") + + slice_node = onnx.helper.make_node("Slice", slice_input_names, ["output_0"], name="Slice0") + + graph = onnx.helper.make_graph( + [slice_node], + "SliceGraph", + [input_0], + [output_0], + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_qdq_slice_qparams(self): + """ + Test that QDQ Slice has equal scale/zero-point for its input and output. + """ + test_configs = [onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16] + + for onnx_tensor_type in test_configs: + with self.subTest(onnx_tensor_type=onnx_tensor_type): + label = f"{onnx.TensorProto.DataType.Name(onnx_tensor_type)}" + float_model_path = os.path.join(self._tmp_dir_path, f"slice.{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"slice.{label}.qdq.onnx") + + input_shape = [2, 4] + float_model = self.build_slice_model( + input_shape=input_shape, + input_tensor_type=onnx_tensor_type, + starts=[1, 0], + ends=[2, 3], + axes=None, + steps=[1, 2], + ) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type) + input_data_list = [ + {"input_0": np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=np_dtype)}, + {"input_0": np.array([[-1.0, -2.0, -3.0, -4.0], [-5.0, -6.0, -7.0, -8.0]], dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + extra_options={"ForceQuantizeNoInputCheck": True}, + ) + expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Slice": 1} + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + qdq_model = onnx.load_model(qdq_model_path) + + slice_node = next((node for node in qdq_model.graph.node if node.op_type == "Slice"), None) + self.assertNotEqual(slice_node, None) + self.assertEqual(slice_node.op_type, "Slice") + + # Get the parent and child nodes of the Slice and check that they are DQ/Q. + consumers, producers = get_tensor_consumers_and_producers(qdq_model) + input_dq_node = producers.get(slice_node.input[0], None) + self.assertNotEqual(input_dq_node, None) + self.assertEqual(input_dq_node.op_type, "DequantizeLinear") + + output_q_node = consumers.get(slice_node.output[0], [None])[0] + self.assertNotEqual(output_q_node, None) + self.assertEqual(output_q_node.op_type, "QuantizeLinear") + + # Check that the Slice's input DQ uses the same scale/zp as the Slice's output Q. + self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) + self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_softmax.py b/onnxruntime/test/python/quantization/test_op_softmax.py index 3416198450137..e5bc6288c91e2 100644 --- a/onnxruntime/test/python/quantization/test_op_softmax.py +++ b/onnxruntime/test/python/quantization/test_op_softmax.py @@ -213,6 +213,40 @@ def test_quantize_softmax(self): self.quantize_softmax_test_qop(QuantType.QUInt8, QuantType.QUInt8) self.quantize_softmax_test_qdq(QuantType.QUInt8, QuantType.QUInt8) + def test_bug_fix_exclude_softmax(self): + """ + Test fix to bug that happens when softmax is excluded from quantization, but + the quantization tool still tries to assign it a tensor range of [0.0, 1.0]. + """ + np.random.seed(1) + model_fp32_path = "softmax_fp32.onnx" + model_qdq_path = "softmax_bug_exclude_softmax.qdq.onnx" + self.construct_model_conv_softmax( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + {"axis": -2}, + [1, 3, 24, 40], + add_ms_domain_opset=False, + ) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) + data_reader.rewind() + + # Bug would cause an exception during quantization. + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + nodes_to_exclude=["Softmax"], + ) + + qdq_model = onnx.load(Path(model_qdq_path)) + self.assertIn("Softmax", {node.op_type for node in qdq_model.graph.node}) + def test_quantize_softmax_s8s8(self): self.quantize_softmax_test_qop( QuantType.QInt8, diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index b99c11abf6d2c..23b397ffd80e1 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -20,10 +20,12 @@ check_op_type_count, check_op_type_order, create_clip_node, + get_tensor_consumers_and_producers, ) from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static, write_calibration_table from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData +from onnxruntime.quantization.quant_utils import quantize_nparray class TestQDQFormat(unittest.TestCase): @@ -1726,5 +1728,479 @@ def test_json_serialization(self): write_calibration_table(new_calibrate_tensors_range) +class TestAdjustWeightScaleForInt32Bias(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.adj_int32_bias_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_test_model( + self, + input0_shape: list[int], + weight_shape: list[int], + onnx_float_type: onnx.TensorProto.DataType, + ): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input0_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + + tiny_value = 1e-7 if np_float_type == np.float32 else 0.007782 + # weight_scale = 2*tiny_value / 255.0 = 7.84313725490196e-10 + + weight_data = np.full(weight_shape, tiny_value, dtype=np_float_type) + with np.nditer(weight_data, op_flags=["readwrite"]) as it: + for i, x in enumerate(it): + if i % 2 == 0: + x[...] = -x + + weight = onnx.numpy_helper.from_array(weight_data, "weight") + + # if we set input_scale to 0.05, then normally bias_scale would be + # (input_scale * weight_scale) => (0.05 * 7.84314e-10) => 3.9215686274509805e-11 + # + # If we quantize the f32 bias with this bias_scale, we get + # [5.0/bias_scale, 4.0/bias_scale] = [127500000000, 102000000000]. These quantized bias values exceed the + # range of int32. + # + # The ORT quantization tool will clamp these out-of-bounds values to int32::max(), + # which can be very inaccurate. + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + with np.nditer(bias_data, op_flags=["readwrite"]) as it: + for i, x in enumerate(it): + if i % 2 == 0: + x[...] = 5.0 if np_float_type == np.float32 else 1400 + else: + x[...] = -4.5 if np_float_type == np.float32 else -1200 + + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convfloat", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_adjust_weight_scale_for_int32_bias(self): + """ + Test adjustment of weight input's scale to ensure int32 bias's scale is not too small. + """ + test_configs = [ + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT, False), + (onnx.TensorProto.FLOAT16, True), + (onnx.TensorProto.FLOAT16, False), + ] + + for float_type, per_channel in test_configs: + with self.subTest(float_type=float_type, per_channel=per_channel): + label = f"_f{float_type}_perchannel{per_channel}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.qdq.onnx") + + # Create float model with a Conv that has tiny weight values. + # This tiny weight scale would normally create a very small bias scale that will saturate + # bias's int32 range. But, the qdq_quantizer adjusts the weight's scale to ensure this doesn't happen. + input0_shape = [1, 2, 4, 4] + weight_shape = [2, 2, 2, 2] + float_model = self.build_conv_test_model(input0_shape, weight_shape, float_type) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(float_type) + input0_rmin = 0.0 + input0_scale = 0.05 if float_type == onnx.TensorProto.FLOAT else 0.01 + input0_rmax = (input0_scale * 255.0) + input0_rmin + input_data_list = [ + {"input_0": np.full(input0_shape, input0_rmin, dtype=np_float_type)}, + {"input_0": np.full(input0_shape, (input0_rmax - input0_rmin) / 2.0, dtype=np_float_type)}, + {"input_0": np.full(input0_shape, input0_rmax, dtype=np_float_type)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + per_channel=per_channel, + ) + + # Check correctness + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + def build_model_convs_share_bias( + self, + input0_shape: list[int], + weight_shape: list[int], + onnx_float_type: onnx.TensorProto.DataType, + ): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input0_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx_float_type, None) + + weight_0_data = np.ones(weight_shape, dtype=np_float_type) + weight_0 = onnx.numpy_helper.from_array(weight_0_data, "weight_0") + + weight_1_data = np.full(weight_shape, 0.5, dtype=np_float_type) + weight_1 = onnx.numpy_helper.from_array(weight_1_data, "weight_1") + + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + bias_shared = onnx.numpy_helper.from_array(bias_data, "bias_shared") + + conv_0_node = onnx.helper.make_node("Conv", ["input_0", "weight_0", "bias_shared"], ["output_0"], name="Conv0") + conv_1_node = onnx.helper.make_node("Conv", ["input_0", "weight_1", "bias_shared"], ["output_1"], name="Conv1") + graph = onnx.helper.make_graph( + [conv_0_node, conv_1_node], + "ConvWithSharedBiasToDup", + [input_0], + [output_0, output_1], + initializer=[weight_0, weight_1, bias_shared], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_dup_shared_bias(self): + """ + Test duplicating a bias that is shared by two nodes that want to quantize their bias to int32. + """ + float_model_path = os.path.join(self._tmp_dir_path, "convs_share_bias.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "convs_share_bias.qdq.onnx") + + # Create float model with a Convs that share a bias input. The QDQ quantizer should add a + # duplicate bias so that each node has its own. + input0_shape = [1, 2, 4, 4] + weight_shape = [2, 2, 2, 2] + float_model = self.build_model_convs_share_bias(input0_shape, weight_shape, onnx.TensorProto.FLOAT) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + input0_rmin = 0.0 + input0_scale = 0.05 + input0_rmax = (input0_scale * 255.0) + input0_rmin + input_data_list = [ + {"input_0": np.full(input0_shape, input0_rmin, dtype=np.float32)}, + {"input_0": np.full(input0_shape, (input0_rmax - input0_rmin) / 2.0, dtype=np.float32)}, + {"input_0": np.full(input0_shape, input0_rmax, dtype=np.float32)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + ) + + qdq_model = onnx.load_model(qdq_model_path) + bias_names = set() + + for node in qdq_model.graph.node: + if node.op_type == "DequantizeLinear" and node.input[0].startswith("bias_shared"): + bias_names.add(node.input[0]) + + self.assertEqual(len(bias_names), 2) + + +class TestQDQPrequantWeights(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.prequant_weight") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_model( + self, + inp_shape: list[int], + weight_quant_data: np.ndarray, + weight_scale_data: np.ndarray, + weight_zp_data: np.ndarray, + bias_data: np.ndarray, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + ): + """ + Builds a model with a Conv that has a pre-quantized constant weight input. + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, None) + weight_quant = onnx.numpy_helper.from_array(weight_quant_data, "weight_quant") + weight_scale = onnx.numpy_helper.from_array(weight_scale_data, "weight_scale") + weight_zp = onnx.numpy_helper.from_array(weight_zp_data, "weight_zp") + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + dq_node = onnx.helper.make_node( + "DequantizeLinear", ["weight_quant", "weight_scale", "weight_zp"], ["weight_dequant"], name="DQ0" + ) + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight_dequant", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [dq_node, conv_node], + "ConvPreQuantWeight", + [input_0], + [output_0], + initializer=[weight_quant, weight_scale, weight_zp, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + return onnx.shape_inference.infer_shapes(model) + + def build_conv_dynamic_weight_model( + self, + input_quant_data: np.ndarray, + input_scale_data: np.ndarray, + input_zp_data: np.ndarray, + weight_shape: list[int], + bias_data: np.ndarray, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, + ): + """ + Builds a model with a Conv that has a dynamic float weight input, but a constant + pre-quantized input[0]. + """ + dyn_weight = onnx.helper.make_tensor_value_info("dyn_weight", float_type, weight_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, None) + input_quant = onnx.numpy_helper.from_array(input_quant_data, "input_quant") + input_scale = onnx.numpy_helper.from_array(input_scale_data, "input_scale") + input_zp = onnx.numpy_helper.from_array(input_zp_data, "input_zp") + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + dq_node = onnx.helper.make_node( + "DequantizeLinear", ["input_quant", "input_scale", "input_zp"], ["input_dequant"], name="DQ0" + ) + conv_node = onnx.helper.make_node("Conv", ["input_dequant", "dyn_weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [dq_node, conv_node], + "ConvPreQuantInput_DynamicWeight", + [dyn_weight], + [output_0], + initializer=[input_quant, input_scale, input_zp, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + return onnx.shape_inference.infer_shapes(model) + + def test_quantize_with_prequantized_weights(self): + """ + Test quantization of Conv with pre-quantized weights. + """ + rng = np.random.default_rng(123) + test_configs = [onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16] + + for float_type in test_configs: + with self.subTest(float_type=float_type): + label = f"_{onnx.TensorProto.DataType.Name(float_type)}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv.f32.prequant_weight{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"conv.prequant_weight{label}.qdq.onnx") + + inp_shape = [1, 2, 100, 100] + weight_shape = [2, 2, 20, 20] + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type) + + # range = 2.0, scale = 2/254, zp = 0 + weight_scale_data = np.array(2 / 254, dtype=np_dtype) + weight_zp_data = np.array(0, dtype=np.int8) + weight_data = np.linspace(-1.0, 1.0, num=1600, dtype=np_dtype).reshape(weight_shape) + weight_quant_data = quantize_nparray( + onnx.TensorProto.INT8, weight_data, weight_scale_data, weight_zp_data + ) + + bias_data = np.array([-10.0, 10.0], dtype=np_dtype) + float_model = self.build_conv_model( + inp_shape, weight_quant_data, weight_scale_data, weight_zp_data, bias_data, float_type + ) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + # Check that the input model only has a pre-quantized weight and save its scale/zero-point + # to check that it doesn't change after quantization. + float_node_counts = {"QuantizeLinear": 0, "DequantizeLinear": 1} + check_op_type_count(self, float_model_path, **float_node_counts) + conv_node_original = next((node for node in float_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node_original, None) + + _, producers_original = get_tensor_consumers_and_producers(float_model) + weight_dq_node_original = producers_original.get(conv_node_original.input[1], None) + initializers_original = {initializer.name: initializer for initializer in float_model.graph.initializer} + scale_name_original = weight_dq_node_original.input[1] + scale_val_original = onnx.numpy_helper.to_array(initializers_original[scale_name_original]) + zp_name_original = weight_dq_node_original.input[2] + zp_val_original = onnx.numpy_helper.to_array(initializers_original[zp_name_original]) + + input_data_list = [ + {"input_0": rng.uniform(-10.0, 10.0, inp_shape).astype(np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + op_types_to_quantize=["Conv"], + ) + + # The final model should have everything quantized + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + # Check that the pre-quantized weight still has the same scale/zp after quantization + qdq_model = onnx.load_model(qdq_model_path) + conv_node = next((node for node in qdq_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node, None) + + _, producers = get_tensor_consumers_and_producers(qdq_model) + weight_dq_node = producers.get(conv_node.input[1], None) + initializers = {initializer.name: initializer for initializer in qdq_model.graph.initializer} + + scale_name = weight_dq_node.input[1] + self.assertEqual(scale_name, scale_name_original) + scale_val = onnx.numpy_helper.to_array(initializers[scale_name]) + self.assertEqual(scale_val, scale_val_original) + + zp_name = weight_dq_node.input[2] + self.assertEqual(zp_name, zp_name_original) + zp_val = onnx.numpy_helper.to_array(initializers[zp_name]) + self.assertEqual(zp_val, zp_val_original) + + def test_quantize_with_prequantized_input(self): + """ + Test quantization of Conv with pre-quantized input and dynamic weight. + """ + rng = np.random.default_rng(123) + test_configs = [ + (onnx.TensorProto.FLOAT, False), + (onnx.TensorProto.FLOAT16, False), + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT16, True), + ] + + for float_type, convert_weight_qtype in test_configs: + with self.subTest(float_type=float_type): + convert_label = "_convert_qtype" if convert_weight_qtype else "" + label = f"_{onnx.TensorProto.DataType.Name(float_type)}{convert_label}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv.f32.prequant_input{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"conv.prequant_input{label}.qdq.onnx") + + inp_shape = [1, 2, 40, 40] + weight_shape = [2, 2, 20, 20] + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type) + + # range = 3.0, scale = 3/255, zp = 127 + input_scale_data = np.array(3 / 255, dtype=np_dtype) + input_zp_data = np.array(127, dtype=np.uint8) + input_data = np.linspace(-1.5, 1.5, num=3200, dtype=np_dtype).reshape(inp_shape) + input_quant_data = quantize_nparray(onnx.TensorProto.UINT8, input_data, input_scale_data, input_zp_data) + + bias_data = np.array([-10.0, 10.0], dtype=np_dtype) + float_model = self.build_conv_dynamic_weight_model( + input_quant_data, input_scale_data, input_zp_data, weight_shape, bias_data, float_type + ) + + onnx.checker.check_model(float_model, True) + onnx.save_model(float_model, float_model_path) + + # Check that the input model only has a pre-quantized input and save its scale/zero-point + # to check that it doesn't change after quantization. + float_node_counts = {"QuantizeLinear": 0, "DequantizeLinear": 1} + check_op_type_count(self, float_model_path, **float_node_counts) + conv_node_original = next((node for node in float_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node_original, None) + + _, producers_original = get_tensor_consumers_and_producers(float_model) + input_dq_node_original = producers_original.get(conv_node_original.input[0], None) + initializers_original = {initializer.name: initializer for initializer in float_model.graph.initializer} + scale_name_original = input_dq_node_original.input[1] + scale_val_original = onnx.numpy_helper.to_array(initializers_original[scale_name_original]) + zp_name_original = input_dq_node_original.input[2] + zp_val_original = onnx.numpy_helper.to_array(initializers_original[zp_name_original]) + + # Create data reader with random input calibration data. + dyn_weight_data_list = [ + {"dyn_weight": rng.uniform(-10.0, 10.0, weight_shape).astype(np_dtype)}, + ] + data_reader = TestDataFeeds(dyn_weight_data_list) + + extra_options = {} + if convert_weight_qtype: + # Test converting the dynamic weight's quantization type, which results in + # dyn_weight -> Q(u16) -> DQ(f32) -> Q(u8) -> DQ(f32) -> Conv + extra_options["TensorQuantOverrides"] = { + "dyn_weight": [{"quant_type": QuantType.QUInt16, "convert": {"quant_type": QuantType.QUInt8}}], + } + + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + op_types_to_quantize=["Conv"], + extra_options=extra_options, + ) + + # The final model should have everything quantized + qdq_node_counts = {"QuantizeLinear": 2, "DequantizeLinear": 4} + if convert_weight_qtype: + qdq_node_counts["QuantizeLinear"] += 1 + qdq_node_counts["DequantizeLinear"] += 1 + + check_op_type_count(self, qdq_model_path, **qdq_node_counts) + + # Check that the pre-quantized input still has the same scale/zp after quantization + qdq_model = onnx.load_model(qdq_model_path) + conv_node = next((node for node in qdq_model.graph.node if node.op_type == "Conv"), None) + self.assertNotEqual(conv_node, None) + + _, producers = get_tensor_consumers_and_producers(qdq_model) + input_dq_node = producers.get(conv_node.input[0], None) + initializers = {initializer.name: initializer for initializer in qdq_model.graph.initializer} + + scale_name = input_dq_node.input[1] + self.assertEqual(scale_name, scale_name_original) + scale_val = onnx.numpy_helper.to_array(initializers[scale_name]) + self.assertEqual(scale_val, scale_val_original) + + zp_name = input_dq_node.input[2] + self.assertEqual(zp_name, zp_name_original) + zp_val = onnx.numpy_helper.to_array(initializers[zp_name]) + self.assertEqual(zp_val, zp_val_original) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quant_util.py b/onnxruntime/test/python/quantization/test_quant_util.py index 96d841654adbd..b23d53f2a04e8 100644 --- a/onnxruntime/test/python/quantization/test_quant_util.py +++ b/onnxruntime/test/python/quantization/test_quant_util.py @@ -145,7 +145,7 @@ def test_quantize_data_4bit(self): for onnx_type, symmetric in subtest_configs: with self.subTest(onnx_type=onnx_type, symmetric=symmetric): - _, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) + zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) is_signed = onnx_type == onnx.TensorProto.INT4 np_int_type = numpy.int8 if is_signed else numpy.uint8 qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type) diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 21a772c5f56c7..41dae04f1c6ff 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -36,7 +36,7 @@ def setUp(self): self.bias = np.array([0.0, 1.0], dtype=np.float32) self.default_act_qtype = onnx.TensorProto.UINT8 self.default_wgt_qtype = onnx.TensorProto.UINT8 - self.default_wgt_qtype_per_channel = onnx.TensorProto.INT8 + self.default_wgt_qtype_per_channel = onnx.TensorProto.UINT8 self.default_bias_qtype = onnx.TensorProto.INT32 self.default_zp_scales = { @@ -49,7 +49,8 @@ def setUp(self): self.default_zp_scales_per_channel = { "INP": (0, np.float32(0.0235294122248888)), "SIG_OUT": (0, np.float32(0.003911871928721666)), - "WGT": ([0, 0], [np.float32(0.015748031437397003), np.float32(0.011811023578047752)]), + # per-channel weights are always symmetric (ie. zp = (qmin + qmax) / 2) + "WGT": ([127, 127], [np.float32(0.015748031437397003), np.float32(0.011811023578047752)]), "BIAS": ([0, 0], [np.float32(0.00006160428165458143), np.float32(0.00004620321124093607)]), "OUT": (0, np.float32(0.005075461231172085)), } @@ -420,12 +421,17 @@ def test_qdq_overrides_per_channel2(self): self.assertEqual(wgt_zp.data_type, quant_type.tensor_type) for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data)): - wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=reduce_range) + wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType( + wgt_zp.data_type, + symmetric=True, # per-channel is always symmetric + reduce_range=reduce_range, + ) expected_zp, expected_scale = compute_scale_zp( np.array(rmin_vals[index], dtype=np.float32), np.array(rmax_vals[index], dtype=np.float32), wgt_qmin, wgt_qmax, + symmetric=True, # per-channel is always symmetric ) self.assertEqual(zp, expected_zp) self.assertEqual(scale, np.float32(expected_scale)) diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 46ab905977f48..a74d5389e9047 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -24,7 +24,7 @@ from parameterized import parameterized from test_gqa_cpu import smooth_softmax_ref -from onnxruntime import InferenceSession, OrtValue, SessionOptions +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers torch.manual_seed(0) @@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff( def has_flash_attention(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, _ = torch.cuda.get_device_capability() return major >= 8 and ( platform.system() == "Linux" @@ -2009,6 +2011,8 @@ def has_flash_attention(): def has_memory_efficient(): if not torch.cuda.is_available(): return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False major, minor = torch.cuda.get_device_capability() if major < 5 or (major == 5 and minor < 3): return False @@ -2047,8 +2051,8 @@ def mha_test_cases(): (2048, 2048), ] ) - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [3] if pipeline_mode else [1, 6, 16] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for s, s2 in seqs: @@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), (2000, 2000), - (200, 200), - (240, 240), ] if pipeline_mode else [ @@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), (240, 240), ] if pipeline_mode @@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases(): (240, 240), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] torch.manual_seed(69) for b in batches: @@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases(): def gqa_past_memory_efficient_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 1024)] if pipeline_mode else [ (1, 128), @@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases(): def gqa_past_flash_attention_test_cases(): batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 2048)] if pipeline_mode else [ (1, 128), @@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases(): def gqa_interactive_one_batch_flash_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(128, 2048)] if pipeline_mode else [ (1, 128), @@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(32, 128)] if pipeline_mode else [ (1, 128), @@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: @@ -2326,120 +2322,114 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): ) -class TestGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") +@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") +class TestFlashGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): + print("------- FLASH ATTENTION (PROMPT CASE) --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_prompt( config, - rtol=5e-3, - atol=5e-3, + local=local, past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=False, + use_smooth_softmax=True, ) parity_check_gqa_prompt_no_buff( config, - rtol=5e-3, - atol=5e-3, + local=local, past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=True, + use_smooth_softmax=False, ) - @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (PROMPT CASE) --------") + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): + print("------- FLASH ATTENTION (TOKEN GEN) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_prompt( + parity_check_gqa_past( config, local=local, past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=True, + use_smooth_softmax=False, ) - parity_check_gqa_prompt_no_buff( + parity_check_gqa_past_no_buff( config, local=local, past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, - use_smooth_softmax=False, + use_smooth_softmax=True, ) - @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): - if not has_memory_efficient(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") + @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) + def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + print("------- FLASH ATTENTION (INTERACTIVE) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_past( config, + local=local, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, + rtol=5e-3, + atol=5e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - softcap=softcap, - use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, + local=local, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, + rtol=5e-3, + atol=5e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - softcap=softcap, - use_smooth_softmax=False, ) - @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (TOKEN GEN) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - parity_check_gqa_past( +@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.") +class TestMemoryEfficientGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) + def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") + + parity_check_gqa_prompt( config, - local=local, + rtol=5e-3, + atol=5e-3, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, softcap=softcap, use_smooth_softmax=False, ) - parity_check_gqa_past_no_buff( + parity_check_gqa_prompt_no_buff( config, - local=local, + rtol=5e-3, + atol=5e-3, past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2447,38 +2437,36 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle use_smooth_softmax=True, ) - @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) - def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - if not has_flash_attention(): - return - print("------- FLASH ATTENTION (INTERACTIVE) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + @parameterized.expand(gqa_past_memory_efficient_test_cases()) + def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") parity_check_gqa_past( config, - local=local, past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, + use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, - local=local, past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, + rtol=1e-3, + atol=1e-3, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, + use_smooth_softmax=False, ) @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): - if not has_memory_efficient(): - return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index 99460722c2469..a5910c28c2975 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -16,16 +16,16 @@ import onnxruntime -class TestGQA(unittest.TestCase): +@unittest.skipIf( + (not torch.cuda.is_available()) + or (platform.system() != "Linux") + or ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()), + reason="ROCm is not available, skipping tests.", +) +class TestRocmGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_flash_attention_test_cases()) def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (PROMPT CASE) --------") parity_check_gqa_prompt( @@ -52,12 +52,6 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte @parameterized.expand(gqa_past_flash_attention_test_cases()) def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" - if not torch.cuda.is_available(): - return - if platform.system() != "Linux": - return - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - return print("------- FLASH ATTENTION (TOKEN GEN) -------") parity_check_gqa_past( diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 08ec5de328b9d..77b4b326bf645 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -1900,7 +1900,7 @@ class TestGQA(unittest.TestCase): def test_gqa_no_past(self): torch.manual_seed(69) print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( [ (127, 127), @@ -1916,8 +1916,8 @@ def test_gqa_no_past(self): (8000, 8000), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: for sq, skv in seqs: for n, n2 in num_h: @@ -1954,9 +1954,9 @@ def test_gqa_no_past(self): def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") - batches = [1, 3] if pipeline_mode else [1, 3, 5] + batches = [1] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (1, 1024), (1, 2048)] + [(1, 128)] if pipeline_mode else [ (1, 128), @@ -1972,8 +1972,8 @@ def test_gqa_past(self): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: @@ -2018,7 +2018,7 @@ def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") batches = [1] seqs = ( - [(2, 128), (128, 129), (32, 128), (256, 2048)] + [(256, 2048)] if pipeline_mode else [ (1, 128), @@ -2034,8 +2034,8 @@ def test_gqa_interactive_one_batch(self): # (128, 128), ] ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: diff --git a/onnxruntime/test/python/transformers/test_optimizer.py b/onnxruntime/test/python/transformers/test_optimizer.py index c7db636a2f11f..058b1d2c9e0fa 100644 --- a/onnxruntime/test/python/transformers/test_optimizer.py +++ b/onnxruntime/test/python/transformers/test_optimizer.py @@ -5,30 +5,21 @@ # license information. # -------------------------------------------------------------------------- -# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG +# For live logging, use the following command: +# pytest -o log_cli=true --log-cli-level=DEBUG test_optimizer.py -import shutil import unittest -import pytest -import torch from model_loader import get_fusion_test_model, get_test_data_path from onnx import TensorProto, load_model from parity_utilities import find_transformers_source -from transformers import is_tf_available if find_transformers_source(): - from benchmark_helper import ConfigModifier, OptimizerInfo, Precision from fusion_options import FusionOptions - from huggingface_models import MODELS - from onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf from onnx_model import OnnxModel from optimizer import optimize_model else: - from onnxruntime.transformers.benchmark_helper import ConfigModifier, OptimizerInfo, Precision from onnxruntime.transformers.fusion_options import FusionOptions - from onnxruntime.transformers.huggingface_models import MODELS - from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf from onnxruntime.transformers.onnx_model import OnnxModel from onnxruntime.transformers.optimizer import optimize_model @@ -66,70 +57,6 @@ def verify_node_count(self, onnx_model, expected_node_count, test_name): self.assertEqual(len(onnx_model.get_nodes_by_op_type(op_type)), count) - # test huggingface pytorch model - def _test_optimizer_on_huggingface_model( - self, - model_name, - expected_fusion_result_list, - inputs_count=1, - validate_model=True, - ): - # Remove cached model so that CI machine has enough space. Do not remove cache models in dev machine. - if not find_transformers_source(): - shutil.rmtree("./cache_models", ignore_errors=True) - shutil.rmtree("./onnx_models", ignore_errors=True) - - # expect fusion result list have the following keys - # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization - model_fusion_statistics = {} - - input_names = MODELS[model_name][0] - - config_modifier = ConfigModifier(None) - fusion_options = None - model_class = "AutoModel" - with torch.no_grad(): - _, is_valid_onnx_model, _, _ = export_onnx_model_from_pt( - model_name, - MODELS[model_name][1], # opset version - MODELS[model_name][2], # use_external_data_format - MODELS[model_name][3], # optimization model type - model_class, - config_modifier, - "./cache_models", - "./onnx_models", - input_names[:inputs_count], - False, - Precision.FLOAT32, - OptimizerInfo.BYSCRIPT, - True, - True, - True, - model_fusion_statistics, - fusion_options, - ) - - if validate_model: - self.assertEqual(is_valid_onnx_model, True) - - expected_node_count = { - "EmbedLayerNormalization": expected_fusion_result_list[0], - "Attention": expected_fusion_result_list[1], - "Gelu": expected_fusion_result_list[2], - "FastGelu": expected_fusion_result_list[3], - "BiasGelu": expected_fusion_result_list[4], - "LayerNormalization": expected_fusion_result_list[5], - "SkipLayerNormalization": expected_fusion_result_list[6], - } - - for value in model_fusion_statistics.values(): - actual_node_count = value - - for op_type, count in expected_node_count.items(): - if op_type not in actual_node_count or actual_node_count[op_type] != count: - print(f"expected: {expected_node_count} got {actual_node_count}") - self.assertTrue(False) - def test_gpt2_past(self): for enable_skip_layer_norm_fusion in [False, True]: input_path = _get_test_model_path("gpt2_past") @@ -227,176 +154,6 @@ def test_embed_layer_norm_fusion(self): } self.verify_node_count(model, expected_node_count, file) - @pytest.mark.slow - def test_huggingface_bert_fusion_1(self): - self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1) - - @pytest.mark.slow - def test_huggingface_bert_fusion_2(self): - self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2) - - @pytest.mark.slow - def test_huggingface_bert_fusion_3(self): - self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=3) - - @pytest.mark.slow - def test_huggingface_openaigpt_fusion(self): - self._test_optimizer_on_huggingface_model("openai-gpt", [0, 12, 0, 12, 0, 0, 24]) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of gpt-2 on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_gpt2_fusion(self): - self._test_optimizer_on_huggingface_model("gpt2", [0, 12, 0, 12, 0, 25, 0]) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of xlm on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_xlm_fusion(self): - self._test_optimizer_on_huggingface_model("xlm-mlm-ende-1024", [0, 6, 0, 0, 6, 0, 13]) - - @pytest.mark.slow - def test_huggingface_roberta_fusion(self): - self._test_optimizer_on_huggingface_model("roberta-base", [0, 12, 0, 0, 12, 1, 24]) - - @pytest.mark.slow - def test_huggingface_distillbert_fusion(self): - self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=1) - self._test_optimizer_on_huggingface_model("distilbert-base-uncased", [1, 6, 0, 0, 6, 0, 12], inputs_count=2) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of camembert on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_camembert_fusion(self): - self._test_optimizer_on_huggingface_model("camembert-base", [0, 12, 0, 0, 12, 1, 24], validate_model=False) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of albert on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_albert_fusion(self): - self._test_optimizer_on_huggingface_model("albert-base-v1", [0, 12, 0, 0, 12, 1, 24]) - - @pytest.mark.slow - @unittest.skip("skip fusion test of t5 since it is not implemented yet") - def test_huggingface_t5_fusion(self): - self._test_optimizer_on_huggingface_model("t5-small", [0, 0, 0, 0, 0, 0, 0]) - - @pytest.mark.slow - def test_huggingface_xlmroberta_fusion(self): - self._test_optimizer_on_huggingface_model("xlm-roberta-base", [0, 12, 0, 0, 12, 1, 24]) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of flaubert on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_flaubert_fusion(self): - self._test_optimizer_on_huggingface_model( - "flaubert/flaubert_base_cased", - [0, 12, 0, 0, 12, 0, 25], - validate_model=False, - ) - self._test_optimizer_on_huggingface_model( - "flaubert/flaubert_small_cased", - [0, 6, 0, 0, 6, 12, 1], - validate_model=False, - ) - - @pytest.mark.slow - @unittest.skip("skip failed fusion test of dialogpt on PyTorch 1.12 and transformers 4.18. TODO: fix it") - def test_huggingface_dialogpt_fusion(self): - self._test_optimizer_on_huggingface_model("microsoft/DialoGPT-small", [0, 12, 0, 12, 0, 25, 0]) - - @pytest.mark.slow - def test_huggingface_bart_fusion(self): - self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30]) - - @pytest.mark.slow - def test_huggingface_vit_fusion(self): - self._test_optimizer_on_huggingface_model("google/vit-base-patch16-224", [0, 11, 0, 0, 12, 1, 24]) - - -@unittest.skipUnless(is_tf_available(), "skip TestBertOptimizationTF since tensorflow is not available") -class TestTensorflowModelOptimization(unittest.TestCase): - def setUp(self): - try: - import tf2onnx # noqa: F401 - except ImportError: - self.skipTest("skip TestBertOptimizationTF since tf2onnx not installed") - - def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, inputs_count, validate_model=True): - # Remove cached model so that CI machine has enough space. Do not remove cache models in dev machine. - if not find_transformers_source(): - shutil.rmtree("./cache_models", ignore_errors=True) - shutil.rmtree("./onnx_models", ignore_errors=True) - - # expect fusion result list have the following keys - # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization - model_fusion_statistics = {} - print("testing mode ", model_name) - print("testing input number = ", inputs_count) - input_names = MODELS[model_name][0] - - config_modifier = ConfigModifier(None) - fusion_options = None - model_class = "AutoModel" - with torch.no_grad(): - _, is_valid_onnx_model, _, _ = export_onnx_model_from_tf( - model_name, - MODELS[model_name][1], # opset version - MODELS[model_name][2], # use_external_data_format - MODELS[model_name][3], # optimization model - model_class, - config_modifier, - "./cache_models", - "./onnx_models", - input_names[:inputs_count], - False, - Precision.FLOAT32, - True, - True, - True, - True, - model_fusion_statistics, - fusion_options, - ) - - onnx_model = next(iter(model_fusion_statistics.keys())) - fusion_result_list = list(model_fusion_statistics[onnx_model].values()) - - if validate_model: - self.assertEqual(is_valid_onnx_model, True) - self.assertEqual(fusion_result_list, expected_fusion_result_list) - - @pytest.mark.slow - def test_huggingface_bert_base_cased_from_tf2onnx_1(self): - self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1) - - @pytest.mark.slow - def test_huggingface_bert_base_cased_from_tf2onnx_2(self): - self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2) - - @pytest.mark.slow - def test_huggingface_bert_base_cased_from_tf2onnx_3(self): - self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3) - - @pytest.mark.slow - def test_huggingface_distilgpt2_from_tf2onnx(self): - self._test_optimizer_on_tf_model("distilgpt2", [0, 0, 0, 0, 0, 12, 1], 1) - - @pytest.mark.slow - def test_huggingface_albert_from_tf2onnx(self): - self._test_optimizer_on_tf_model("albert-base-v1", [0, 0, 0, 0, 0, 0, 25], 1) - - @pytest.mark.slow - def test_huggingface_gpt2_from_tf2onnx(self): - self._test_optimizer_on_tf_model("gpt2", [0, 0, 0, 0, 0, 24, 1], 1, validate_model=False) - - @pytest.mark.slow - def test_huggingface_roberta_from_tf2onnx(self): - self._test_optimizer_on_tf_model("roberta-base", [0, 12, 0, 0, 0, 0, 25], 1, validate_model=False) - - @pytest.mark.slow - def test_huggingface_distilbert_from_tf2onnx(self): - self._test_optimizer_on_tf_model("distilbert-base-uncased", [0, 0, 0, 0, 0, 0, 13], 1, validate_model=False) - - @pytest.mark.slow - def test_huggingface_xlm_from_tf2onnx(self): - self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False) - if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_optimizer_huggingface_bert.py b/onnxruntime/test/python/transformers/test_optimizer_huggingface_bert.py new file mode 100644 index 0000000000000..e4f883dc8b45c --- /dev/null +++ b/onnxruntime/test/python/transformers/test_optimizer_huggingface_bert.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# For live logging, use the following command: +# pytest -o log_cli=true --log-cli-level=DEBUG test_optimizer_huggingface_bert.py + +import shutil +import unittest +from pathlib import Path + +import torch +from parity_utilities import find_transformers_source +from transformers.utils import default_cache_path + +if find_transformers_source(): + from benchmark_helper import ConfigModifier, OptimizerInfo, Precision + from compare_bert_results import run_test as bert_parity_test + from onnx_exporter import export_onnx_model_from_pt +else: + from onnxruntime.transformers.benchmark_helper import ConfigModifier, OptimizerInfo, Precision + from onnxruntime.transformers.compare_bert_results import run_test as bert_parity_test + from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt + + +class TestHuggingfaceBertModelOptimization(unittest.TestCase): + def run_optimizer_on_model( + self, + model_name, + expected_fusion_result_list, + inputs_count=1, + validate_model=True, + opset_version=16, + use_external_data_format=False, + model_type="bert", + ): + onnx_dir = Path(".") / "onnx_models" / model_name + shutil.rmtree(onnx_dir, ignore_errors=True) + + Path(onnx_dir).mkdir(parents=True, exist_ok=True) + + model_fusion_statistics = {} + + input_names = ["input_ids", "attention_mask", "token_type_ids"] + + config_modifier = ConfigModifier(None) + fusion_options = None + model_class = "AutoModel" + with torch.no_grad(): + optimized_model_path, is_valid_onnx_model, _, _ = export_onnx_model_from_pt( + model_name=model_name, + opset_version=opset_version, + use_external_data_format=use_external_data_format, + model_type=model_type, + model_class=model_class, + config_modifier=config_modifier, + cache_dir=default_cache_path, + onnx_dir=str(onnx_dir), + input_names=input_names[:inputs_count], + use_gpu=False, + precision=Precision.FLOAT32, + optimizer_info=OptimizerInfo.BYSCRIPT, + validate_onnx=True, + use_raw_attention_mask=True, + overwrite=True, + model_fusion_statistics=model_fusion_statistics, + fusion_options=fusion_options, + ) + + if validate_model: + self.assertEqual(is_valid_onnx_model, True) + + expected_node_count = { + "EmbedLayerNormalization": expected_fusion_result_list[0], + "Attention": expected_fusion_result_list[1], + "Gelu": expected_fusion_result_list[2], + "FastGelu": expected_fusion_result_list[3], + "BiasGelu": expected_fusion_result_list[4], + "LayerNormalization": expected_fusion_result_list[5], + "SkipLayerNormalization": expected_fusion_result_list[6], + } + + node_count = None + for value in model_fusion_statistics.values(): + node_count = value + self.assertIsNotNone(node_count) + + actual_node_count = {} + for op_type in expected_node_count: + actual_node_count[op_type] = node_count.get(op_type, 0) + + expected = ", ".join(f"{key}: {value}" for key, value in sorted(expected_node_count.items())) + actual = ", ".join(f"{key}: {value}" for key, value in sorted(actual_node_count.items())) + self.assertEqual(expected, actual) + + suffix = "_fp32_cpu.onnx" + assert optimized_model_path.endswith(suffix) + baseline_model_path = optimized_model_path[: -len(suffix)] + ".onnx" + for batch_size in [1, 2]: + for sequence_length in [1, 8]: + max_abs_diff, case_passed = bert_parity_test( + baseline_model_path, + optimized_model_path, + output_dir=None, + batch_size=batch_size, + sequence_length=sequence_length, + use_gpu=False, + test_cases=1, + seed=123, + verbose=False, + rtol=1e-4, + atol=1e-4, + input_ids_name=input_names[0], + segment_ids_name=input_names[2] if inputs_count > 2 else None, + input_mask_name=input_names[1] if inputs_count > 1 else None, + mask_type=2, + dictionary_size=1024, + ) + self.assertTrue( + case_passed, f"bert parity test failed: {batch_size=} {sequence_length=} {max_abs_diff=}" + ) + + def test_bert(self): + model_name = "hf-internal-testing/tiny-random-bert" + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=1) + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=2) + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=3) + + def test_roberta(self): + model_name = "hf-internal-testing/tiny-random-roberta" + # TODO: EmbedLayerNormalization fusion. + self.run_optimizer_on_model(model_name, [0, 5, 0, 0, 5, 1, 10], inputs_count=1) + self.run_optimizer_on_model(model_name, [0, 5, 0, 0, 5, 1, 10], inputs_count=2) + + def test_distillbert(self): + model_name = "hf-internal-testing/tiny-random-distilbert" + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=1) + self.run_optimizer_on_model(model_name, [1, 5, 0, 0, 5, 0, 10], inputs_count=2) + + def test_xlm_roberta(self): + model_name = "hf-internal-testing/tiny-xlm-roberta" + # TODO: EmbedLayerNormalization fusion. + self.run_optimizer_on_model(model_name, [0, 2, 0, 0, 2, 1, 4], inputs_count=1) + self.run_optimizer_on_model(model_name, [0, 2, 0, 0, 2, 1, 4], inputs_count=2) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index 1e7940e38335f..baaaeaa766db9 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -651,7 +651,6 @@ def parity_check(self): torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) if ort_output is not None: - assert torch.allclose(torch_output, ort_output.to(torch.float32), rtol=THRESHOLD, atol=THRESHOLD) print( "name:", self.__class__.__name__, @@ -661,8 +660,8 @@ def parity_check(self): self.sequence_length, " max_diff:", (torch_output - ort_output).abs().max(), - " parity: OK", ) + torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD) def benchmark_ort(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) @@ -996,6 +995,13 @@ def small_test_cases(): yield batch_size, sequence_length +def phi3_test_cases(): + # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation. + for batch_size in [1, 4, 16]: + for sequence_length in [128]: + yield batch_size, sequence_length + + class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_switch_moe_parity(self, batch_size, sequence_length): @@ -1023,7 +1029,7 @@ def test_mixtral_moe_parity(self, batch_size, sequence_length): class TestPhiMoE(unittest.TestCase): - @parameterized.expand(small_test_cases()) + @parameterized.expand(phi3_test_cases()) def test_phi3_moe_parity(self, batch_size, sequence_length): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 5b3720992c542..24c343c7b9541 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -50,6 +50,7 @@ namespace qnnctxgen { "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -146,7 +147,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } } else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing" || - key == "offload_graph_io_quantization") { + key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -158,7 +159,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'vtcm_mb', 'htp_performance_mode', 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing', - 'offload_graph_io_quantization'])"); + 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } test_config.run_config.qnn_options[key] = value; diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc index d568d5e78688a..3be0bd253c8a4 100644 --- a/onnxruntime/test/qnn_ctx_gen/main.cc +++ b/onnxruntime/test/qnn_ctx_gen/main.cc @@ -33,8 +33,11 @@ static void CheckStatus(const Status& status) { // from the last context cache Onnx model, find the EPContext node with main_context=1, // and get the QNN context binary file name, this context binary contains all graphs from all Onnx models +// get the max spill fill buffer size static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, - std::string& last_ctx_bin_file) { + std::string& last_ctx_bin_file, + int64_t& max_size) { + max_size = 0; std::shared_ptr ctx_model; CheckStatus(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); @@ -43,6 +46,7 @@ static void GetLastContextBinaryFileName(const std::basic_string last if (node.OpType() == "EPContext") { NodeAttrHelper node_helper(node); int64_t is_main_context = node_helper.Get("main_context", static_cast(0)); + max_size = node_helper.Get("max_size", static_cast(0)); if (1 == is_main_context) { last_ctx_bin_file = node_helper.Get("ep_cache_context", ""); return; @@ -55,7 +59,8 @@ static void GetLastContextBinaryFileName(const std::basic_string last // the last QNN context binary file // Remove not used QNN context binary file, only keep the last one which contains all graphs static void UpdateEpContextModel(const std::vector>& ep_ctx_files, - const std::string& last_qnn_ctx_binary_file_name) { + const std::string& last_qnn_ctx_binary_file_name, + int64_t max_size) { for (auto ep_ctx_file : ep_ctx_files) { std::shared_ptr ctx_model; auto path_str = ToPathString(ep_ctx_file); @@ -75,6 +80,8 @@ static void UpdateEpContextModel(const std::vector> std::remove(file_path.string().c_str()); node.ClearAttribute("ep_cache_context"); node.AddAttribute("ep_cache_context", last_qnn_ctx_binary_file_name); + node.ClearAttribute("max_size"); + node.AddAttribute("max_size", max_size); } } } @@ -181,7 +188,8 @@ int real_main(int argc, char* argv[]) { // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name); + int64_t max_size = 0; + GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; if (last_qnn_ctx_binary_file_name.empty()) { throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); @@ -191,7 +199,7 @@ int real_main(int argc, char* argv[]) { // Update generated context cache Onnx model to make the main EPContext node point to // the last QNN context binary file // Remove not used QNN context binary file, only keep the last one which contains all graphs - UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name); + UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); } ORT_CATCH(const Ort::Exception& e) { fprintf(stderr, "Failed to generate context cache file: %s \n", e.what()); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index e19362e0ec32d..0be1c0b1965ac 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4600,86 +4600,3 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) { ASSERT_EQ(len, static_cast(2)); mock_gqa.ReleaseAliasMap(input_index, output_index); } - -TEST(CApiTest, Serialize_PrePack_Initializers) { - std::string model_name = "model_with_matmul_nbits"; - - const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; - const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; - std::string external_initializer_file_name = model_name + "_opt.onnx.data"; - - // Generate optimized with prepacked weights serialized - Ort::SessionOptions session_options_opt; - session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, external_initializer_file_name.c_str()); - session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "0"); - session_options_opt.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1"); - -#if defined(_WIN32) || defined(_WIN64) - std::wstring test_model_wide = onnxruntime::ToWideString(test_model); - session_options_opt.SetOptimizedModelFilePath(onnxruntime::ToWideString(optimized_model).c_str()); - Ort::Session session_opt_model(*ort_env, test_model_wide.c_str(), session_options_opt); -#else - session_options_opt.SetOptimizedModelFilePath(optimized_model.c_str()); - Ort::Session session_opt_model(*ort_env, test_model.c_str(), session_options_opt); -#endif - - // Do inference with original model and optimized model and check output is identical - // set inputs and session options - Ort::SessionOptions session_options; - const char* input_names[] = {"A"}; - const char* const output_names[] = {"Y"}; - Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - - std::vector ort_inputs; - std::vector input_0_data = {1.3f}; - std::vector input_0_dims = {1, 1}; - ort_inputs.emplace_back( - Ort::Value::CreateTensor(info, const_cast(input_0_data.data()), - input_0_data.size(), input_0_dims.data(), input_0_dims.size())); - - // run inference with original model - // Convert std::string to std::wstring -#if defined(_WIN32) || defined(_WIN64) - Ort::Session session(*ort_env, test_model_wide.c_str(), session_options); -#else - Ort::Session session(*ort_env, test_model.c_str(), session_options); -#endif - auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), - output_names, 1); - - // run inference with optimized model which load serialized prepack initializer -#if defined(_WIN32) || defined(_WIN64) - std::wstring optimized_model_wide = onnxruntime::ToWideString(optimized_model); - Ort::Session session_opt(*ort_env, optimized_model_wide.c_str(), session_options); -#else - Ort::Session session_opt(*ort_env, optimized_model.c_str(), session_options); -#endif - auto ort_outputs_opt = session_opt.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), - output_names, 1); - - // check output of original model and optimized model are equal - ASSERT_EQ(ort_outputs.size(), ort_outputs_opt.size()); - - for (size_t i = 0; i < ort_outputs.size(); ++i) { - const auto& sequences = ort_outputs[i]; - ASSERT_TRUE(sequences.IsTensor()); - - const auto& sequences_opt = ort_outputs_opt[i]; - ASSERT_TRUE(sequences_opt.IsTensor()); - - auto result_ts = sequences.GetTensorTypeAndShapeInfo(); - auto result_ts_opt = sequences_opt.GetTensorTypeAndShapeInfo(); - - ASSERT_EQ(result_ts.GetElementType(), result_ts_opt.GetElementType()); - - ASSERT_EQ(result_ts.GetShape(), result_ts_opt.GetShape()); - - const auto* result_vals = sequences.GetTensorData(); - auto result_span = gsl::make_span(result_vals, ort_outputs.size()); - - const auto* result_vals_opt = sequences_opt.GetTensorData(); - auto result_span_opt = gsl::make_span(result_vals_opt, ort_outputs_opt.size()); - - ASSERT_TRUE(std::equal(result_span_opt.begin(), result_span_opt.end(), result_span.begin(), result_span.end())); - } -} \ No newline at end of file diff --git a/onnxruntime/test/shared_lib/test_nontensor_types.cc b/onnxruntime/test/shared_lib/test_nontensor_types.cc index 8171a6eecc91d..ba16bd6c9888f 100644 --- a/onnxruntime/test/shared_lib/test_nontensor_types.cc +++ b/onnxruntime/test/shared_lib/test_nontensor_types.cc @@ -987,6 +987,32 @@ TEST(CApiTest, SparseTensorFillSparseTensorFormatAPI) { } } +TEST(CApi, TestResize) { + std::vector values; + values.resize(10); + + std::vector sts; + sts.resize(5); + + std::vector domains; + domains.resize(5); + + std::vector type_and_shape; + type_and_shape.resize(5); + + std::vector seq_type_info; + seq_type_info.resize(5); + + std::vector map_type_info; + map_type_info.resize(5); + + std::vector type_info; + type_info.resize(5); + + std::vector op_attr; + op_attr.resize(5); +} + TEST(CApiTest, SparseTensorFillSparseFormatStringsAPI) { auto allocator = Ort::AllocatorWithDefaultOptions(); Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); diff --git a/onnxruntime/test/testdata/dummy_t5.onnx b/onnxruntime/test/testdata/dummy_t5.onnx new file mode 100644 index 0000000000000..3a3bbf4767523 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5.onnx differ diff --git a/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx new file mode 100644 index 0000000000000..4b36cc9b6eca0 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx differ diff --git a/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000..5a5c302914890 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx differ diff --git a/onnxruntime/test/testdata/model_with_external_initializers.onnx b/onnxruntime/test/testdata/model_with_external_initializers.onnx index 3538f01b53c18..f815b4000f98f 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.onnx +++ b/onnxruntime/test/testdata/model_with_external_initializers.onnx @@ -1,8 +1,7 @@ - - onnx-example:� -, + onnx-example:� +& X -PadsYpad0"Pad* +PadsY"Pad* mode"constant� test-model*"BPadsj locationPads.binpZ @@ -17,4 +16,4 @@ test-model*"BPadsj Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/model_with_external_initializers.py b/onnxruntime/test/testdata/model_with_external_initializers.py index dc64d4a41424a..8d2589a9e6564 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.py +++ b/onnxruntime/test/testdata/model_with_external_initializers.py @@ -35,10 +35,9 @@ def GenerateModel(model_name, external_data_name): # noqa: N802 # Create a node (NodeProto) node_def = helper.make_node( - "Pad", # op type + "Pad", # node name ["X", external_data_name], # inputs ["Y"], # outputs - "pad0", # node name mode="constant", # Attributes ) diff --git a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx index 47d0c68235099..6f9cce0bc5b4f 100644 --- a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx +++ b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx @@ -1,8 +1,7 @@ - - onnx-example:� -@ +  onnx-example:� +: X -model_with_orig_ext_dataYpad0"Pad* +model_with_orig_ext_dataY"Pad* mode"constant� test-model*JBmodel_with_orig_ext_dataj( locationmodel_with_orig_ext_data.binpZ @@ -17,4 +16,4 @@ test-model*JBmodel_with_orig_ext_dataj( Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 4b14d50127aa9..7ecaab6fedb02 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -306,6 +306,11 @@ "^test_qlinearmatmul_3D_int8_float32_cuda", "^test_qlinearmatmul_3D_uint8_float16_cuda", "^test_qlinearmatmul_3D_uint8_float32_cuda", + // Tests that failed on CUDA 12.2. + "^test_Conv3d_dilated_cuda", + "^test_Conv3d_dilated_strided_cuda", + "^test_Conv3d_stride_cuda", + "^test_Conv3d_stride_padding_cuda", // Size(21) from ONNX 1.16.0 is not implemented in cuda. "^test_size_cuda", "^test_size_example_cuda", @@ -672,6 +677,30 @@ "^test_nonmaxsuppression_flipped_coordinates_cpu", "^test_nonmaxsuppression_center_point_box_format_cpu" ], + "current_failing_tests_WEBGPU": [ + "^test_layer_normalization_2d_axis0_cpu", + "^test_layer_normalization_2d_axis1_cpu", + "^test_layer_normalization_2d_axis_negative_1_cpu", + "^test_layer_normalization_2d_axis_negative_2_cpu", + "^test_layer_normalization_3d_axis0_epsilon_cpu", + "^test_layer_normalization_3d_axis1_epsilon_cpu", + "^test_layer_normalization_3d_axis2_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_1_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_2_epsilon_cpu", + "^test_layer_normalization_3d_axis_negative_3_epsilon_cpu", + "^test_layer_normalization_4d_axis0_cpu", + "^test_layer_normalization_4d_axis1_cpu", + "^test_layer_normalization_4d_axis2_cpu", + "^test_layer_normalization_4d_axis3_cpu", + "^test_layer_normalization_4d_axis_negative_1_cpu", + "^test_layer_normalization_4d_axis_negative_2_cpu", + "^test_layer_normalization_4d_axis_negative_3_cpu", + "^test_layer_normalization_4d_axis_negative_4_cpu", + "^test_layer_normalization_default_axis_cpu", + "^test_gelu_tanh_1_expanded_cpu", + "^test_gelu_tanh_2_expanded_cpu", + "^test_dynamicquantizelinear_expanded_cpu" + ], "current_failing_tests_pure_DML": [ "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu", "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu", @@ -721,6 +750,13 @@ "^test_reduce_log_sum_empty_set_cpu", "^test_reduce_log_sum_exp_empty_set_cpu", "^test_reduce_prod_empty_set_cpu", + // Bug: DML EP some how executes these CUDA tests and failed + // TODO: Remove these tests when DML EP is fixed + "^test_convtranspose_autopad_same_cuda", + "^test_asin_example_cuda", + "^test_dynamicquantizelinear_cuda", + "^test_dynamicquantizelinear_expanded_cuda", + "^test_reduce_min_empty_set_cuda", //Bug: DML EP does not execute operators with an empty input tensor //TODO: Resolve as a graph implementation that returns a constant inf tensor with appropriate strides "^test_reduce_min_empty_set_cpu" diff --git a/onnxruntime/test/testdata/prepack/MatMul.Weight.bin b/onnxruntime/test/testdata/prepack/MatMul.Weight.bin deleted file mode 100644 index 0f8a571589c10..0000000000000 Binary files a/onnxruntime/test/testdata/prepack/MatMul.Weight.bin and /dev/null differ diff --git a/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py b/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py deleted file mode 100644 index 86af461edc2c4..0000000000000 --- a/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import numpy as np -import onnx -from onnx import TensorProto, helper -from onnx.external_data_helper import set_external_data -from onnx.numpy_helper import from_array - -M = 1 -K = 1 -N = 1 -q_cols = 1 -q_rows = 1 -q_scale_size = 1 - - -def create_external_data_tensor(value, tensor_name, data_type): - tensor = from_array(np.array(value)) - tensor.name = tensor_name - tensor_filename = f"{tensor_name}.bin" - set_external_data(tensor, location=tensor_filename) - - with open(os.path.join(tensor_filename), "wb") as data_file: - data_file.write(tensor.raw_data) - tensor.ClearField("raw_data") - tensor.data_location = onnx.TensorProto.EXTERNAL - tensor.data_type = data_type - return tensor - - -def create_internal_data_tensor(value, tensor_name, data_type): - tensor = helper.make_tensor(name=tensor_name, data_type=data_type, dims=value.shape, vals=value.flatten().tolist()) - print(tensor) - tensor.data_location = onnx.TensorProto.DEFAULT - return tensor - - -def GenerateMatmulNBitsModel(model_name, external_data_name): # noqa: N802 - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [M, K]) # noqa: N806 - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [M, N]) # noqa: N806 - - # Create a node (NodeProto) - node_def = helper.make_node( - op_type="MatMulNBits", # op type - inputs=["A", external_data_name, "scales"], # inputs - outputs=["Y"], # outputs - name="MatMul_0", # node name - domain="com.microsoft", # Custom domain for this operator - accuracy_level=4, # Attributes - bits=4, # Attributes - block_size=32, # Attributes - K=K, # Attributes - N=N, # Attributes - ) - - # Create the graph (GraphProto) - graph_def = helper.make_graph( - [node_def], - "test-model-matmul4bits", - [A], - [Y], - [ - create_external_data_tensor([[171]], external_data_name, TensorProto.UINT8), - create_internal_data_tensor(np.array([1.5], dtype=np.float32), "scales", TensorProto.FLOAT), - ], - ) - - # Create the model - model_def = helper.make_model( - graph_def, - producer_name="onnx-example", - opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], - ) - - print(f"The ir_version in model: {model_def.ir_version}\n") - print(f"The producer_name in model: {model_def.producer_name}\n") - print(f"The graph in model:\n{model_def.graph}") - onnx.checker.check_model(model_def) - print("The model is checked!") - with open(model_name, "wb") as model_file: - model_file.write(model_def.SerializeToString()) - - -if __name__ == "__main__": - GenerateMatmulNBitsModel("model_with_matmul_nbits.onnx", "MatMul.Weight") diff --git a/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx b/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx deleted file mode 100644 index 0e06a75a5a7e8..0000000000000 Binary files a/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx and /dev/null differ diff --git a/onnxruntime/test/testdata/relu_with_optional.onnx b/onnxruntime/test/testdata/relu_with_optional.onnx new file mode 100644 index 0000000000000..b52c6927527bd Binary files /dev/null and b/onnxruntime/test/testdata/relu_with_optional.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx new file mode 100644 index 0000000000000..8ca8282572db8 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.onnx new file mode 100644 index 0000000000000..2521a89b7bb56 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.onnx @@ -0,0 +1,41 @@ + :� +R +inputA a_quantizeda_scalea_zpDynamicQuantizeLinear"DynamicQuantizeLinear +Y + a_quantized +inputB +a_zp +inputBZPmatmulinteger_output MatMulInteger" MatMulInteger +- +a_scale + inputBScalemul_1 mul_right"Mul +: +matmulinteger_output cast_outputcast"Cast* +to� +- +mul_1 + cast_outputoutput +mul_bottom"Mul+matmul_integer_to_float_large_tensor_fusionZ" +inputA + + + +� + +� +Z +inputB + + +� + +� +Z +inputBZP + + +Z + inputBScale + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py new file mode 100644 index 0000000000000..543517cc015ef --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py @@ -0,0 +1,49 @@ +from enum import Enum # noqa: F401 + +import onnx +from onnx import TensorProto, helper + + +def GenerateModel(model_name): # noqa: N802 + inputs = [] + outputs = [] + initializers = [] + nodes = [] + + inputs.append(helper.make_tensor_value_info("inputA", TensorProto.FLOAT, [16, 32, 1280, 1280])) + inputs.append(helper.make_tensor_value_info("inputB", TensorProto.INT8, [1280, 1280])) + inputs.append(helper.make_tensor_value_info("inputBZP", TensorProto.INT8, [1])) + inputs.append(helper.make_tensor_value_info("inputBScale", TensorProto.FLOAT, [1])) + + nodes = [ # construct graph + helper.make_node( + "DynamicQuantizeLinear", + ["inputA"], + ["a_quantized", "a_scale", "a_zp"], + "DynamicQuantizeLinear", + ), + helper.make_node( + "MatMulInteger", + ["a_quantized", "inputB", "a_zp", "inputBZP"], + ["matmulinteger_output"], + "MatMulInteger", + ), + helper.make_node("Mul", ["a_scale", "inputBScale"], ["mul_1"], "mul_right"), + helper.make_node("Cast", ["matmulinteger_output"], ["cast_output"], "cast", to=1), + helper.make_node("Mul", ["mul_1", "cast_output"], ["output"], "mul_bottom"), + ] + + graph = helper.make_graph( + nodes, + "matmul_integer_to_float_large_tensor_fusion", # name + inputs, + outputs, + initializers, + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + + +if __name__ == "__main__": + GenerateModel("matmul_integer_to_float_large_tensor.onnx") diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index d57a22f024d5f..59926bbcd1c6f 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -122,6 +122,12 @@ std::unique_ptr DefaultOpenVINOExecutionProvider() { std::unique_ptr DefaultCudaExecutionProvider() { #ifdef USE_CUDA +#ifdef USE_DML + const std::string no_cuda_ep_test = Env::Default().GetEnvironmentVar("NO_CUDA_TEST"); + if (no_cuda_ep_test == "1") { + return nullptr; + } +#endif OrtCUDAProviderOptionsV2 provider_options{}; provider_options.do_copy_in_default_stream = true; provider_options.use_tf32 = false; @@ -134,6 +140,12 @@ std::unique_ptr DefaultCudaExecutionProvider() { #ifdef ENABLE_CUDA_NHWC_OPS std::unique_ptr DefaultCudaNHWCExecutionProvider() { #if defined(USE_CUDA) +#ifdef USE_DML + const std::string no_cuda_ep_test = Env::Default().GetEnvironmentVar("NO_CUDA_TEST"); + if (no_cuda_ep_test == "1") { + return nullptr; + } +#endif OrtCUDAProviderOptionsV2 provider_options{}; provider_options.do_copy_in_default_stream = true; provider_options.use_tf32 = false; @@ -176,14 +188,6 @@ std::unique_ptr DnnlExecutionProviderWithOptions(const OrtDn return nullptr; } -// std::unique_ptr DefaultTvmExecutionProvider() { -// #ifdef USE_TVM -// return TVMProviderFactoryCreator::Create("")->CreateProvider(); -// #else -// return nullptr; -// #endif -// } - std::unique_ptr DefaultNnapiExecutionProvider() { // The NNAPI EP uses a stub implementation on non-Android platforms so cannot be used to execute a model. // Manually append an NNAPI EP instance to the session to unit test the GetCapability and Compile implementation. @@ -247,14 +251,14 @@ std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlpr // The test will create a model but execution of it will obviously fail. #if defined(USE_COREML) && defined(__APPLE__) // We want to run UT on CPU only to get output value without losing precision - uint32_t coreml_flags = 0; - coreml_flags |= COREML_FLAG_USE_CPU_ONLY; + auto option = ProviderOptions(); + option[kCoremlProviderOption_MLComputeUnits] = "CPUOnly"; if (use_mlprogram) { - coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; + option[kCoremlProviderOption_ModelFormat] = "MLProgram"; } - return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); + return CoreMLProviderFactoryCreator::Create(option)->CreateProvider(); #else ORT_UNUSED_PARAMETER(use_mlprogram); return nullptr; @@ -307,6 +311,10 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { std::unique_ptr DefaultWebGpuExecutionProvider() { #ifdef USE_WEBGPU ConfigOptions config_options{}; + // Disable storage buffer cache + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + webgpu::options::kBufferCacheMode_Disabled) + .IsOK()); return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); #else return nullptr; @@ -324,6 +332,12 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML +#ifdef USE_CUDA + const std::string no_dml_ep_test = Env::Default().GetEnvironmentVar("NO_DML_TEST"); + if (no_dml_ep_test == "1") { + return nullptr; + } +#endif ConfigOptions config_options{}; if (auto factory = DMLProviderFactoryCreator::CreateFromDeviceOptions(config_options, nullptr, false, false)) { return factory->CreateProvider(); diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index ed95bf67f1ffb..9b44150d972db 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -20,7 +20,6 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); -// std::shared_ptr CreateExecutionProviderFactory_Tvm(const char*); std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_Rocm(const OrtROCMProviderOptions* provider_options); diff --git a/onnxruntime/test/util/include/providers.h b/onnxruntime/test/util/include/providers.h index a73b237ae10df..01be1a444646b 100644 --- a/onnxruntime/test/util/include/providers.h +++ b/onnxruntime/test/util/include/providers.h @@ -7,9 +7,6 @@ #ifdef USE_DNNL #include "core/providers/dnnl/dnnl_provider_factory.h" #endif -#ifdef USE_TVM -#include "core/providers/tvm/tvm_provider_factory.h" -#endif #ifdef USE_OPENVINO #include "core/providers/openvino/openvino_provider_factory.h" #endif diff --git a/onnxruntime/test/wasm/package-lock.json b/onnxruntime/test/wasm/package-lock.json index 522e96fc3188a..3bd5d173dbe79 100644 --- a/onnxruntime/test/wasm/package-lock.json +++ b/onnxruntime/test/wasm/package-lock.json @@ -27,9 +27,9 @@ } }, "node_modules/@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "node_modules/@types/cookie": { @@ -39,19 +39,22 @@ "dev": true }, "node_modules/@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "dependencies": { "@types/node": "*" } }, "node_modules/@types/node": { - "version": "18.13.0", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.13.0.tgz", - "integrity": "sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==", - "dev": true + "version": "22.10.1", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.10.1.tgz", + "integrity": "sha512-qKgsUwfHZV2WCWLAnVP1JqnpE6Im6h3Y0+fYgMTasNQ7V++CBX5OT1as0g0f+OyubbFqhf6XVNIsmN4IIhEgGQ==", + "dev": true, + "dependencies": { + "undici-types": "~6.20.0" + } }, "node_modules/accepts": { "version": "1.3.8", @@ -162,12 +165,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -288,9 +291,9 @@ } }, "node_modules/cookie": { - "version": "0.4.2", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.2.tgz", - "integrity": "sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==", + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", "dev": true, "engines": { "node": ">= 0.6" @@ -409,9 +412,9 @@ } }, "node_modules/engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.6.2.tgz", + "integrity": "sha512-gmNvsYi9C8iErnZdVcJnvCpSKbWTt1E8+JZo8b+daLninywUWi5NQ5STSHZ9rFjFO7imNcvb8Pc5pe/wMR5xEw==", "dev": true, "dependencies": { "@types/cookie": "^0.4.1", @@ -419,32 +422,32 @@ "@types/node": ">=10.0.0", "accepts": "~1.3.4", "base64id": "2.0.0", - "cookie": "~0.4.1", + "cookie": "~0.7.2", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true, "engines": { "node": ">=10.0.0" } }, "node_modules/engine.io/node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "dependencies": { - "ms": "2.1.2" + "ms": "^2.1.3" }, "engines": { "node": ">=6.0" @@ -456,9 +459,9 @@ } }, "node_modules/engine.io/node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true }, "node_modules/ent": { @@ -516,9 +519,9 @@ "dev": true }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "dependencies": { "to-regex-range": "^5.0.1" @@ -1304,35 +1307,60 @@ } }, "node_modules/socket.io": { - "version": "4.6.0", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.0.tgz", - "integrity": "sha512-b65bp6INPk/BMMrIgVvX12x3Q+NqlGqSlTuvKQWt0BUJ3Hyy3JangBl7fEoWZTXbOKlCqNPbQ6MbWgok/km28w==", + "version": "4.8.1", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.8.1.tgz", + "integrity": "sha512-oZ7iUCxph8WYRHHcjBEc9unw3adt5CmSNlppj/5Q4k2RIrhl8Z5yY2Xr4j9zj0+wzVZ0bxmYoGSzKJnRl6A4yg==", "dev": true, "dependencies": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.0", + "engine.io": "~6.6.0", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", + "dev": true, + "dependencies": { + "debug": "~4.3.4", + "ws": "~8.17.1" + } + }, + "node_modules/socket.io-adapter/node_modules/debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "dependencies": { - "ws": "~8.11.0" + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } } }, + "node_modules/socket.io-adapter/node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true + }, "node_modules/socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "dependencies": { "@socket.io/component-emitter": "~3.1.0", @@ -1343,12 +1371,12 @@ } }, "node_modules/socket.io-parser/node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "dependencies": { - "ms": "2.1.2" + "ms": "^2.1.3" }, "engines": { "node": ">=6.0" @@ -1360,9 +1388,9 @@ } }, "node_modules/socket.io-parser/node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true }, "node_modules/socket.io/node_modules/debug": { @@ -1534,6 +1562,12 @@ "node": "*" } }, + "node_modules/undici-types": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.20.0.tgz", + "integrity": "sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==", + "dev": true + }, "node_modules/universalify": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.1.2.tgz", @@ -1615,16 +1649,16 @@ "dev": true }, "node_modules/ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "engines": { "node": ">=10.0.0" }, "peerDependencies": { "bufferutil": "^4.0.1", - "utf-8-validate": "^5.0.2" + "utf-8-validate": ">=5.0.2" }, "peerDependenciesMeta": { "bufferutil": { @@ -1686,9 +1720,9 @@ "dev": true }, "@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "@types/cookie": { @@ -1698,19 +1732,22 @@ "dev": true }, "@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "requires": { "@types/node": "*" } }, "@types/node": { - "version": "18.13.0", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.13.0.tgz", - "integrity": "sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==", - "dev": true + "version": "22.10.1", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.10.1.tgz", + "integrity": "sha512-qKgsUwfHZV2WCWLAnVP1JqnpE6Im6h3Y0+fYgMTasNQ7V++CBX5OT1as0g0f+OyubbFqhf6XVNIsmN4IIhEgGQ==", + "dev": true, + "requires": { + "undici-types": "~6.20.0" + } }, "accepts": { "version": "1.3.8", @@ -1796,12 +1833,12 @@ } }, "braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "requires": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" } }, "bytes": { @@ -1890,9 +1927,9 @@ "dev": true }, "cookie": { - "version": "0.4.2", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.2.tgz", - "integrity": "sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==", + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", "dev": true }, "cors": { @@ -1986,9 +2023,9 @@ "dev": true }, "engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.6.2.tgz", + "integrity": "sha512-gmNvsYi9C8iErnZdVcJnvCpSKbWTt1E8+JZo8b+daLninywUWi5NQ5STSHZ9rFjFO7imNcvb8Pc5pe/wMR5xEw==", "dev": true, "requires": { "@types/cookie": "^0.4.1", @@ -1996,34 +2033,34 @@ "@types/node": ">=10.0.0", "accepts": "~1.3.4", "base64id": "2.0.0", - "cookie": "~0.4.1", + "cookie": "~0.7.2", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" }, "dependencies": { "debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "requires": { - "ms": "2.1.2" + "ms": "^2.1.3" } }, "ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true } } }, "engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true }, "ent": { @@ -2072,9 +2109,9 @@ "dev": true }, "fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "requires": { "to-regex-range": "^5.0.1" @@ -2651,17 +2688,18 @@ } }, "socket.io": { - "version": "4.6.0", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.0.tgz", - "integrity": "sha512-b65bp6INPk/BMMrIgVvX12x3Q+NqlGqSlTuvKQWt0BUJ3Hyy3JangBl7fEoWZTXbOKlCqNPbQ6MbWgok/km28w==", + "version": "4.8.1", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.8.1.tgz", + "integrity": "sha512-oZ7iUCxph8WYRHHcjBEc9unw3adt5CmSNlppj/5Q4k2RIrhl8Z5yY2Xr4j9zj0+wzVZ0bxmYoGSzKJnRl6A4yg==", "dev": true, "requires": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.0", + "engine.io": "~6.6.0", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" }, "dependencies": { "debug": { @@ -2682,18 +2720,36 @@ } }, "socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", "dev": true, "requires": { - "ws": "~8.11.0" + "debug": "~4.3.4", + "ws": "~8.17.1" + }, + "dependencies": { + "debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", + "dev": true, + "requires": { + "ms": "^2.1.3" + } + }, + "ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true + } } }, "socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "requires": { "@socket.io/component-emitter": "~3.1.0", @@ -2701,18 +2757,18 @@ }, "dependencies": { "debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", "dev": true, "requires": { - "ms": "2.1.2" + "ms": "^2.1.3" } }, "ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true } } @@ -2817,6 +2873,12 @@ "integrity": "sha512-s8ax/CeZdK9R/56Sui0WM6y9OFREJarMRHqLB2EwkovemBxNQ+Bqu8GAsUnVcXKgphb++ghr/B2BZx4mahujPw==", "dev": true }, + "undici-types": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.20.0.tgz", + "integrity": "sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==", + "dev": true + }, "universalify": { "version": "0.1.2", "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.1.2.tgz", @@ -2874,9 +2936,9 @@ "dev": true }, "ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "requires": {} }, diff --git a/onnxruntime/test/webgpu/external_dawn/main.cc b/onnxruntime/test/webgpu/external_dawn/main.cc new file mode 100644 index 0000000000000..ed8d2eab94ce9 --- /dev/null +++ b/onnxruntime/test/webgpu/external_dawn/main.cc @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. + +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include + +#include "dawn/native/DawnNative.h" + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + bool no_proc_table = argc > 0 && +#ifdef _WIN32 + wcscmp(L"--no_proc_table", argv[argc - 1]) == 0; +#else + strcmp("--no_proc_table", argv[argc - 1]) == 0; +#endif + + int retval = 0; + Ort::Env env{nullptr}; + try { + env = Ort::Env{ORT_LOGGING_LEVEL_WARNING, "Default"}; + + // model is https://github.com/onnx/onnx/blob/v1.15.0/onnx/backend/test/data/node/test_abs/model.onnx + constexpr uint8_t MODEL_DATA[] = {8, 7, 18, 12, 98, 97, 99, 107, 101, 110, + 100, 45, 116, 101, 115, 116, 58, 73, 10, 11, + 10, 1, 120, 18, 1, 121, 34, 3, 65, 98, + 115, 18, 8, 116, 101, 115, 116, 95, 97, 98, + 115, 90, 23, 10, 1, 120, 18, 18, 10, 16, + 8, 1, 18, 12, 10, 2, 8, 3, 10, 2, + 8, 4, 10, 2, 8, 5, 98, 23, 10, 1, + 121, 18, 18, 10, 16, 8, 1, 18, 12, 10, + 2, 8, 3, 10, 2, 8, 4, 10, 2, 8, + 5, 66, 4, 10, 0, 16, 13}; + + Ort::SessionOptions session_options; + session_options.DisableMemPattern(); + std::unordered_map provider_options; + if (!no_proc_table) { + provider_options["dawnProcTable"] = std::to_string(reinterpret_cast(&dawn::native::GetProcs())); + } + session_options.AppendExecutionProvider("WebGPU", provider_options); + Ort::Session session{env, MODEL_DATA, sizeof(MODEL_DATA), session_options}; + + if (no_proc_table) { + std::cerr << "DawnProcTable is not passing to ONNX Runtime, but no exception is thrown." << std::endl; + retval = -1; + } else { + // successfully initialized + std::cout << "Successfully initialized WebGPU EP." << std::endl; + retval = 0; + } + } catch (const std::exception& ex) { + std::cerr << ex.what() << std::endl; + + if (no_proc_table) { + std::cout << "DawnProcTable is not passing to ONNX Runtime, so an exception is thrown as expected." << std::endl; + retval = 0; + } else { + std::cerr << "Unexpected exception." << std::endl; + retval = -1; + } + } + + ::google::protobuf::ShutdownProtobufLibrary(); + return retval; +} diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 0efbcab3a3238..45e2475548df5 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -237,11 +237,13 @@ Module['jsepInit'] = (name, params) => { } Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); - } - - Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => { + }; + Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { + return backend['createMLContext'](optionsOrGpuDevice); + }; + Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { return backend['registerMLConstant']( externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); - } + }; } }; diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 87a7cbc0375a4..f1545e96481fa 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -758,7 +758,8 @@ Status TrainingSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/, + const logging::Logger& /*logger*/) const { ORT_RETURN_IF_NOT( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations, "Only applying full build optimizations is supported by TrainingSession."); diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 765f88e1c992e..58492dc62400f 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -489,7 +489,8 @@ class TrainingSession : public InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const override; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const override; /** Perform auto-diff to add backward graph into the model. @param weights_to_train a set of weights to be training. diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index ec7a458237c77..c4c7a98ba116a 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -42,7 +42,6 @@ static SessionOptions session_options = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling - false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index 0e40d04ddac8c..1b7d6b9ea26f6 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -89,7 +89,6 @@ int main(int argc, char* argv[]) { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::DEFAULT, // execution_order false, // enable_profiling - false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 5a2f1cd13683e..dae6f613f4329 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -37,7 +37,6 @@ static SessionOptions SESSION_OPTION = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling - false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 0944e46ff8eaf..58c173ed90277 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -139,7 +139,8 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci; - auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, &kci); + auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, + DefaultLoggingManager().DefaultLogger(), &kci); if (!st.IsOK()) { // The goal here is unclear. It seems best to leave it to the Session // creation to figure out whether the model can be executed using some diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 548f39bb0150c..1b8699d1de497 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -23,8 +23,10 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -47,8 +49,8 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch b/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch index 3a408e2265fe7..29b8812c979e4 100644 --- a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch +++ b/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch @@ -46,7 +46,7 @@ RUN cd MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 && \ rm -r MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 ENV PATH=${OLD_PATH} -ENV unset OLD_PATH +ENV unset=OLD_PATH # python env RUN pip3 install --upgrade setuptools diff --git a/packages.config b/packages.config index 597ca77a321c5..877e2a17fd83e 100644 --- a/packages.config +++ b/packages.config @@ -1,6 +1,6 @@  - + diff --git a/tools/android_custom_build/Dockerfile b/tools/android_custom_build/Dockerfile index 0ff365dd5ff74..fcaffd9ef5e78 100644 --- a/tools/android_custom_build/Dockerfile +++ b/tools/android_custom_build/Dockerfile @@ -15,7 +15,7 @@ RUN apt-get update && apt-get install --yes --no-install-recommends \ ca-certificates \ git \ ninja-build \ - openjdk-11-jdk-headless \ + openjdk-17-jdk-headless \ python3-dev \ python3-numpy \ python3-pip \ diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index ff246503e82b6..6a8154681ed97 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -21,6 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn") s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut") s = s.replace("kTotalCudaStreams", "kTotalHipStreams") + + # in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want + s = s.replace("rocblas_half", "__half") + # these should be "hip" but it's easier to just use rocm to avoid complicated file renaming s = s.replace("CudaGraph", "RocmGraph") s = s.replace("CUDAGraph", "ROCMGraph") diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 85eb3ddad3c56..3527a89ca7a7b 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -13,6 +13,7 @@ import shutil import subprocess import sys +import warnings from pathlib import Path @@ -253,7 +254,12 @@ def convert_arg_line_to_args(self, arg_line): "--cudnn_home is not specified.", ) parser.add_argument("--enable_cuda_line_info", action="store_true", help="Enable CUDA line info.") - parser.add_argument("--enable_cuda_nhwc_ops", action="store_true", help="Enable CUDA NHWC ops in build.") + + parser.add_argument( + "--enable_cuda_nhwc_ops", action="store_true", help="Deprecated; default to enable CUDA NHWC ops in build." + ) + + parser.add_argument("--disable_cuda_nhwc_ops", action="store_true", help="Disable CUDA NHWC ops in build.") # Python bindings parser.add_argument("--enable_pybind", action="store_true", help="Enable Python Bindings.") @@ -571,19 +577,14 @@ def convert_arg_line_to_args(self, arg_line): ) parser.add_argument("--use_jsep", action="store_true", help="Build with JavaScript kernels.") parser.add_argument("--use_webgpu", action="store_true", help="Build with WebGPU support.") + parser.add_argument("--use_external_dawn", action="store_true", help="Treat Dawn as an external dependency.") parser.add_argument("--use_qnn", action="store_true", help="Build with QNN support.") parser.add_argument("--qnn_home", help="Path to QNN SDK dir.") parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") parser.add_argument("--use_preinstalled_eigen", action="store_true", help="Use pre-installed Eigen.") parser.add_argument("--eigen_path", help="Path to pre-installed Eigen.") parser.add_argument("--enable_msinternal", action="store_true", help="Enable for Microsoft internal builds only.") - parser.add_argument("--llvm_path", help="Path to llvm dir") parser.add_argument("--use_vitisai", action="store_true", help="Build with Vitis-AI") - parser.add_argument("--use_tvm", action="store_true", help="Build with TVM") - parser.add_argument("--tvm_cuda_runtime", action="store_true", default=False, help="Build TVM with CUDA support") - parser.add_argument( - "--use_tvm_hash", action="store_true", help="Build ipp-crypto for hash generation. It is used by TVM EP only" - ) parser.add_argument("--use_tensorrt", action="store_true", help="Build with TensorRT") parser.add_argument( "--use_tensorrt_builtin_parser", action="store_true", default=True, help="Use TensorRT builtin parser" @@ -595,12 +596,6 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--migraphx_home", help="Path to MIGraphX installation dir") parser.add_argument("--use_full_protobuf", action="store_true", help="Use the full protobuf library") - parser.add_argument( - "--llvm_config", - type=str, - default="", - help="Path to llvm-config.exe for LLVM built from sources. It is strongly needed for build on Windows", - ) parser.add_argument( "--skip_onnx_tests", action="store_true", @@ -792,6 +787,11 @@ def convert_arg_line_to_args(self, arg_line): if args.cmake_generator is None and is_windows(): args.cmake_generator = "Ninja" if args.build_wasm else "Visual Studio 17 2022" + if args.enable_cuda_nhwc_ops: + warnings.warn( + "The argument '--enable_cuda_nhwc_ops' is deprecated and is default to True. ", DeprecationWarning + ) + return args @@ -1019,16 +1019,11 @@ def generate_build_tree( "-Donnxruntime_USE_NNAPI_BUILTIN=" + ("ON" if args.use_nnapi else "OFF"), "-Donnxruntime_USE_VSINPU=" + ("ON" if args.use_vsinpu else "OFF"), "-Donnxruntime_USE_RKNPU=" + ("ON" if args.use_rknpu else "OFF"), - "-Donnxruntime_USE_LLVM=" + ("ON" if args.use_tvm else "OFF"), "-Donnxruntime_ENABLE_MICROSOFT_INTERNAL=" + ("ON" if args.enable_msinternal else "OFF"), "-Donnxruntime_USE_VITISAI=" + ("ON" if args.use_vitisai else "OFF"), "-Donnxruntime_USE_TENSORRT=" + ("ON" if args.use_tensorrt else "OFF"), "-Donnxruntime_USE_TENSORRT_BUILTIN_PARSER=" + ("ON" if args.use_tensorrt_builtin_parser and not args.use_tensorrt_oss_parser else "OFF"), - # set vars for TVM - "-Donnxruntime_USE_TVM=" + ("ON" if args.use_tvm else "OFF"), - "-Donnxruntime_TVM_CUDA_RUNTIME=" + ("ON" if args.use_tvm and args.tvm_cuda_runtime else "OFF"), - "-Donnxruntime_TVM_USE_HASH=" + ("ON" if args.use_tvm_hash else "OFF"), # set vars for migraphx "-Donnxruntime_USE_MIGRAPHX=" + ("ON" if args.use_migraphx else "OFF"), "-Donnxruntime_DISABLE_CONTRIB_OPS=" + ("ON" if args.disable_contrib_ops else "OFF"), @@ -1058,6 +1053,7 @@ def generate_build_tree( "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), "-Donnxruntime_USE_JSEP=" + ("ON" if args.use_jsep else "OFF"), "-Donnxruntime_USE_WEBGPU=" + ("ON" if args.use_webgpu else "OFF"), + "-Donnxruntime_USE_EXTERNAL_DAWN=" + ("ON" if args.use_external_dawn else "OFF"), # Training related flags "-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"), "-Donnxruntime_ENABLE_TRAINING=" + ("ON" if args.enable_training else "OFF"), @@ -1072,7 +1068,7 @@ def generate_build_tree( "-Donnxruntime_USE_MPI=" + ("ON" if args.use_mpi else "OFF"), "-Donnxruntime_ENABLE_MEMORY_PROFILE=" + ("ON" if args.enable_memory_profile else "OFF"), "-Donnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO=" + ("ON" if args.enable_cuda_line_info else "OFF"), - "-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.enable_cuda_nhwc_ops else "OFF"), + "-Donnxruntime_USE_CUDA_NHWC_OPS=" + ("ON" if args.use_cuda and not args.disable_cuda_nhwc_ops else "OFF"), "-Donnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB=" + ("ON" if args.build_wasm_static_lib else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=" + ("OFF" if args.disable_wasm_exception_catching else "ON"), @@ -1159,8 +1155,6 @@ def generate_build_tree( cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) - if args.llvm_config: - cmake_args.append("-Donnxruntime_TVM_USE_LLVM=" + args.llvm_config) if args.use_cuda: add_default_definition(cmake_extra_defines, "onnxruntime_USE_CUDA", "ON") @@ -1243,9 +1237,6 @@ def generate_build_tree( if args.use_full_protobuf or args.use_openvino or args.use_vitisai or args.gen_doc: cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] - if args.use_tvm and args.llvm_path is not None: - cmake_args += [f"-DLLVM_DIR={args.llvm_path}"] - if args.use_cuda and not is_windows(): nvml_stub_path = cuda_home + "/lib64/stubs" cmake_args += ["-DCUDA_CUDA_LIBRARY=" + nvml_stub_path] @@ -1321,6 +1312,9 @@ def generate_build_tree( if args.use_jsep and args.use_webgpu: raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + if args.use_external_dawn and not args.use_webgpu: + raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).") + if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] @@ -1567,8 +1561,7 @@ def generate_build_tree( ldflags = ["/profile", "/DYNAMICBASE"] # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. if not args.enable_address_sanitizer: - # Also enable a special perf patch that was made for Intel Meteor Lake mobile CPUs - cflags += ["/Qspectre", "/DONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH"] + cflags += ["/Qspectre"] if config == "Release": cflags += ["/O2", "/Ob2", "/DNDEBUG"] elif config == "RelWithDebInfo": @@ -1643,16 +1636,6 @@ def generate_build_tree( cxxflags = cflags.copy() config_build_dir = get_config_build_dir(build_dir, config) os.makedirs(config_build_dir, exist_ok=True) - if args.use_tvm: - os.environ["PATH"] = ( - os.path.join(config_build_dir, "_deps", "tvm-build") - + os.pathsep - + os.path.join(config_build_dir, "_deps", "tvm-src") - + os.pathsep - + os.path.dirname(sys.executable) - + os.pathsep - + os.environ["PATH"] - ) preinstalled_dir = Path(build_dir) / config temp_cmake_args = cmake_args.copy() if cflags is not None and cxxflags is not None and len(cflags) != 0 and len(cxxflags) != 0: @@ -2081,8 +2064,6 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if args.enable_pybind: python_path = None - if args.use_tvm: - python_path = str((Path(build_dir) / config / "_deps" / "tvm-src" / "python").resolve()) # Disable python tests in a reduced build as we don't know which ops have been included and which # models can run. @@ -2092,6 +2073,17 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if is_windows(): cwd = os.path.join(cwd, config) + if args.enable_transformers_tool_test and not args.disable_contrib_ops and not args.use_rocm: + # PyTorch is required for transformers tests, and optional for some python tests. + # Install cpu only version of torch when cuda is not enabled in Linux. + extra = [] if args.use_cuda and is_linux() else ["--index-url", "https://download.pytorch.org/whl/cpu"] + run_subprocess( + [sys.executable, "-m", "pip", "install", "torch", *extra], + cwd=cwd, + dll_path=dll_path, + python_path=python_path, + ) + run_subprocess( [sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path ) @@ -2120,10 +2112,10 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if not args.disable_ml_ops and not args.use_tensorrt: run_subprocess([sys.executable, "onnxruntime_test_python_mlops.py"], cwd=cwd, dll_path=dll_path) - # if args.use_tensorrt: - # run_subprocess( - # [sys.executable, "onnxruntime_test_python_nested_control_flow_op.py"], cwd=cwd, dll_path=dll_path - # ) + if args.use_tensorrt: + run_subprocess( + [sys.executable, "onnxruntime_test_python_nested_control_flow_op.py"], cwd=cwd, dll_path=dll_path + ) try: import onnx # noqa: F401 @@ -2146,6 +2138,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): dll_path=dll_path, python_path=python_path, ) + if not args.disable_contrib_ops: run_subprocess( [sys.executable, "-m", "unittest", "discover", "-s", "quantization"], cwd=cwd, dll_path=dll_path @@ -2167,7 +2160,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): ], cwd=SCRIPT_DIR, ) - run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) + run_subprocess([sys.executable, "-m", "pytest", "--durations=0", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it run_subprocess([sys.executable, "-m", "pip", "install", "numpy==" + numpy_init_version]) run_subprocess([sys.executable, "-m", "pip", "install", "protobuf==" + pb_init_version]) @@ -2205,17 +2198,6 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): run_subprocess([sys.executable, "onnxruntime_test_python_keras.py"], cwd=cwd, dll_path=dll_path) -def tvm_run_python_tests(build_dir, configs): - for config in configs: - cwd = get_config_build_dir(build_dir, config) - if is_windows(): - cwd = os.path.join(cwd, config) - python_path = os.path.join(build_dir, config, "_deps", "tvm-src", "python") - run_subprocess( - [sys.executable, "onnxruntime_test_python_tvm.py"], cwd=cwd, python_path=os.path.abspath(python_path) - ) - - def run_nodejs_tests(nodejs_binding_dir): args = ["npm", "test", "--", "--timeout=90000"] if is_windows(): @@ -2235,7 +2217,6 @@ def build_python_wheel( use_dnnl, use_tensorrt, use_openvino, - use_tvm, use_vitisai, use_acl, use_armnn, @@ -2288,8 +2269,6 @@ def build_python_wheel( args.append("--use_openvino") elif use_dnnl: args.append("--use_dnnl") - elif use_tvm: - args.append("--use_tvm") elif use_vitisai: args.append("--use_vitisai") elif use_acl: @@ -2318,7 +2297,6 @@ def build_nuget_package( use_openvino, use_tensorrt, use_dnnl, - use_tvm, use_winml, use_qnn, enable_training_apis, @@ -2354,7 +2332,7 @@ def build_nuget_package( target_name = "/t:CreateWindowsAIPackage" elif use_openvino: execution_provider = "/p:ExecutionProvider=openvino" - package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.OpenVino" + package_name = "/p:OrtPackageId=Intel.ML.OnnxRuntime.OpenVino" elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" @@ -2365,9 +2343,6 @@ def build_nuget_package( package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu" elif use_rocm: package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm" - elif use_tvm: - execution_provider = "/p:ExecutionProvider=tvm" - package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.Tvm" elif use_qnn: execution_provider = "/p:ExecutionProvider=qnn" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.QNN" @@ -2609,7 +2584,7 @@ def main(): if args.use_tensorrt: args.use_cuda = True - if args.build_wheel or args.gen_doc or args.use_tvm or args.enable_training: + if args.build_wheel or args.gen_doc or args.enable_training: args.enable_pybind = True if ( @@ -2891,12 +2866,6 @@ def main(): run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs) - # TODO(agladyshev): - # to support Windows, we need to update .github/workflows/windows.yml - # and add to the PATH variable the following value: C:Program Files\LLVM\bin - if args.enable_pybind and args.use_tvm and not is_windows(): - tvm_run_python_tests(build_dir, configs) - # run node.js binding tests if args.build_nodejs and not args.skip_nodejs_tests: nodejs_binding_dir = os.path.normpath(os.path.join(source_dir, "js", "node")) @@ -2924,7 +2893,6 @@ def main(): args.use_dnnl, args.use_tensorrt, args.use_openvino, - args.use_tvm, args.use_vitisai, args.use_acl, args.use_armnn, @@ -2952,7 +2920,6 @@ def main(): args.use_openvino, args.use_tensorrt, args.use_dnnl, - args.use_tvm, args.use_winml, args.use_qnn, args.enable_training_apis, diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py index 19f66245a45e2..1b34b3d302e57 100644 --- a/tools/ci_build/github/android/build_aar_package.py +++ b/tools/ci_build/github/android/build_aar_package.py @@ -23,11 +23,11 @@ # Onnx Runtime native library is built against NDK API 21 by default # It is possible to build from source for Android API levels below 21, but it is not guaranteed -DEFAULT_ANDROID_MIN_SDK_VER = 21 +DEFAULT_ANDROID_MIN_SDK_VER = 24 # Android API 24 is the default target API version for Android builds, based on Microsoft 1CS requirements # It is possible to build from source using API level 21 and higher as the target SDK version -DEFAULT_ANDROID_TARGET_SDK_VER = 24 +DEFAULT_ANDROID_TARGET_SDK_VER = 34 def _parse_build_settings(args): diff --git a/tools/ci_build/github/android/default_full_aar_build_settings.json b/tools/ci_build/github/android/default_full_aar_build_settings.json index b0eff75812673..1c7769c623d41 100644 --- a/tools/ci_build/github/android/default_full_aar_build_settings.json +++ b/tools/ci_build/github/android/default_full_aar_build_settings.json @@ -5,8 +5,8 @@ "x86", "x86_64" ], - "android_min_sdk_version": 21, - "android_target_sdk_version": 24, + "android_min_sdk_version": 24, + "android_target_sdk_version": 34, "build_params": [ "--enable_lto", "--android", diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 0c28b272f7fa3..4991b4329646f 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -4,30 +4,47 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |Operator|Note| |--------|------| |ai.onnx:Add|| +|ai.onnx:Argmax|| |ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| +|ai.onnx:Cast|| |ai.onnx:Clip|| |ai.onnx:Concat|| |ai.onnx:Conv|Only 1D/2D Conv is supported.
Bias if provided must be constant.| |ai.onnx:ConvTranspose|Weight and bias must be constant.
padding_type of SAME_UPPER/SAME_LOWER is not supported.
kernel_shape must have default values.
output_shape is not supported.
output_padding must have default values.| |ai.onnx:DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.| |ai.onnx:Div|| +|ai.onnx:Erf|| |ai.onnx:Gemm|Input B must be constant.| +|ai.onnx:Gelu|| |ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GridSample|4D input.
'mode' of 'linear' or 'zeros'.
(mode==linear && padding_mode==reflection && align_corners==0) is not supported.| +|ai.onnx:GroupNormalization|| +|ai.onnx:InstanceNormalization|| +|ai.onnx:LayerNormalization|| |ai.onnx:LeakyRelu|| |ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.| |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| +|ai.onnx:Max|| |ai.onnx:Mul|| |ai.onnx:Pow|Only supports cases when both inputs are fp32.| +|ai.onnx:PRelu|| |ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| +|ai.onnx:ReduceSum|| +|ai.onnx:ReduceMean|| +|ai.onnx:ReduceMax|| |ai.onnx:Relu|| |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| +|ai.onnx:Round|| +|ai.onnx:Shape|| |ai.onnx:Slice|starts/ends/axes/steps must be constant initializers.| |ai.onnx:Split|If provided, `splits` must be constant.| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| +|ai.onnx:Softmax|| |ai.onnx:Sqrt|| +|ai.onnx:Squeeze|| |ai.onnx:Tanh|| |ai.onnx:Transpose|| +|ai.onnx:Unsqueeze|| diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 9362a8b0ee18c..c3dbee336b69d 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 3ee4375329069..aca06c320d1d3 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -40,9 +40,8 @@ parameters: default: 0 variables: - - template: templates/common-variables.yml - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 - name: linux_trt_version value: 10.3.0.26-1.cuda11.8 - name: Repository @@ -116,15 +115,15 @@ stages: set -ex; \ env; \ ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build --cmake_generator Ninja \ --config Release --update --build \ --skip_submodule_sync \ --build_shared_lib \ --parallel \ --build_wheel \ - --enable_onnx_tests --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda-${{variables.common_cuda_version}} --cudnn_home=/usr/local/cuda-${{variables.common_cuda_version}} \ - --enable_cuda_profiling --enable_cuda_nhwc_ops \ + --enable_onnx_tests --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \ + --enable_cuda_profiling \ --enable_pybind --build_java \ --use_cache \ --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \ @@ -181,6 +180,17 @@ stages: TargetPath: '$(Build.BinariesDirectory)/Release' SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv + Context: tools/ci_build/github/linux/docker/ + ScriptName: tools/ci_build/get_docker_image.py + DockerBuildArgs: " + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimeubuntupackagestest_cuda11 + UseImageCacheContainerRegistry: false + UpdateDepsTxt: false - task: Cache@2 inputs: @@ -197,18 +207,15 @@ stages: -v $(Build.BinariesDirectory)/Release:/Release \ -v $(STABLE_DIFFUSION_MODEL_CACHE):/model_cache:rw \ -v $(GenerateImage_DIR):/images:rw \ - nvcr.io/nvidia/pytorch:22.11-py3 \ + onnxruntimeubuntupackagestest_cuda11 \ bash -c ' \ set -ex; \ - pip uninstall -y $(pip list --format=freeze | grep opencv); \ - rm -rf /usr/local/lib/python3.8/dist-packages/cv2/; \ - apt-get update; \ - DEBIAN_FRONTEND="noninteractive" apt-get install --yes python3-opencv; \ python3 --version; \ python3 -m pip install --upgrade pip; \ python3 -m pip install /Release/*.whl; \ pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion; \ python3 -m pip install -r requirements/cuda11/requirements.txt; \ + python3 -m pip install numpy==1.22.2; \ python3 -m pip install --upgrade polygraphy onnx-graphsurgeon ; \ echo Generate an image guided by a text prompt; \ python3 demo_txt2img.py --framework-model-dir /model_cache --seed 1 --deterministic "astronaut riding a horse on mars" ; \ @@ -239,7 +246,7 @@ stages: - script: | docker run --rm --gpus all -v $PWD:/workspace \ -v $(CLIP_MODEL_CACHE):/model_cache:rw \ - nvcr.io/nvidia/pytorch:22.11-py3 \ + onnxruntimeubuntupackagestest_cuda11 \ bash -c ' set -x; \ python3 --version; \ @@ -266,7 +273,7 @@ stages: - script: | docker run --rm --gpus all -v $PWD:/workspace \ -v $(CLIP_MODEL_CACHE):/model_cache:rw \ - nvcr.io/nvidia/pytorch:22.11-py3 \ + onnxruntimeubuntupackagestest_cuda11 \ bash -c ' set -ex; \ python3 --version; \ @@ -274,6 +281,7 @@ stages: pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion/; \ image2=$(find $(pwd) -name "astronaut_riding_a_h*.png") ; \ pushd test; \ + python3 -m pip install numpy==1.22.2; \ python3 -m pip install -r requirements.txt; \ echo check demo_txt2image.py generate image; \ python3 -u check_image.py --image1 astronaut_riding_txt2image-DDIM-50.png --image2 $image2 --cache_dir /model_cache ; \ @@ -439,7 +447,7 @@ stages: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg Context: tools/ci_build/github/linux/docker/ ScriptName: tools/ci_build/get_docker_image.py DockerBuildArgs: '--build-arg BUILD_UID=$( id -u )' diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index b12360d2710d0..798868f3b957e 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.27.0.240926 + default: 2.28.0.241029 resources: repositories: @@ -77,13 +77,14 @@ resources: ref: 5eda9aded5462201e6310105728d33016e637ea7 variables: +- template: templates/common-variables.yml - name: ReleaseVersionSuffix value: '' - name: win_trt_version value: 11.8 - name: win_trt_home - value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda11 }} - name: win_cuda_home value: $(Agent.TempDirectory)\v11.8 @@ -111,6 +112,7 @@ stages: BuildVariant: 'default' SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + QnnSDKVersion: ${{ parameters.QnnSdk }} - template: stages/java-cuda-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 7118e85e9ea4b..bc33aba57ec93 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -59,13 +59,14 @@ parameters: - 12.2 variables: + - template: templates/common-variables.yml - name: ReleaseVersionSuffix value: '' - name: win_trt_home ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: $(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda12 }} - name: win_cuda_home ${{ if eq(parameters.CudaVersion, '11.8') }}: value: $(Agent.TempDirectory)\v11.8 @@ -97,7 +98,6 @@ stages: jobs: - template: templates/c-api-linux-cpu.yml parameters: - BaseImage: 'registry.access.redhat.com/ubi8/ubi' OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index e938b8a61335f..2eb2839cdac02 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -54,7 +54,7 @@ stages: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=registry.access.redhat.com/ubi8/ubi" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuildcentos8x64 - template: templates/linux-build-step-with-cache.yml @@ -149,7 +149,7 @@ stages: parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu Context: tools/ci_build/github/linux/docker/ - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=registry.access.redhat.com/ubi8/ubi" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuild - task: PythonScript@0 @@ -217,7 +217,7 @@ stages: /bin/bash -c " set -ex; \ ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build --cmake_generator 'Ninja' \ --config Release \ --skip_submodule_sync \ @@ -299,6 +299,7 @@ stages: machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' with_cache: true cmake_build_type: Release + python_exe_path: '/opt/python/cp310-cp310/bin/python3.10' - stage: arm64_test dependsOn: ['arm64_build'] @@ -306,4 +307,27 @@ stages: - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' + ep: 'cpu' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' + +- stage: arm64_build_xnnpack + dependsOn: [] + jobs: + - template: templates/py-linux.yml + parameters: + arch: 'aarch64' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' + with_cache: true + cmake_build_type: Release + ep: 'XNNPack' + extra_build_arg: '--use_xnnpack' + python_exe_path: '/opt/python/cp310-cp310/bin/python3.10' + +- stage: arm64_test_xnnpack + dependsOn: ['arm64_build_xnnpack'] + jobs: + - template: templates/py-packaging-linux-test-cpu.yml + parameters: + arch: 'aarch64' + ep: 'XNNPack' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml index 2d3260a13f13a..4964d33067092 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml @@ -128,7 +128,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/2 --cmake_generator Ninja \ --config Debug \ --skip_submodule_sync \ @@ -210,7 +210,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/5 --cmake_generator Ninja \ --config Debug \ --skip_submodule_sync \ @@ -231,7 +231,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/6a \ --cmake_generator Ninja \ --config MinSizeRel \ @@ -258,7 +258,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/6b \ --cmake_generator Ninja \ --config MinSizeRel \ @@ -287,7 +287,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/6c \ --cmake_generator Ninja \ --config MinSizeRel \ @@ -317,7 +317,7 @@ jobs: -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ -e NIGHTLY_BUILD \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build/7 \ --cmake_generator Ninja \ --config MinSizeRel \ diff --git a/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml index 7311c6e526d57..0391ecf4f5869 100644 --- a/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-dnnl-ci-pipeline.yml @@ -49,6 +49,7 @@ jobs: Repository: onnxruntimecpubuild - task: CmdLine@2 + displayName: 'Build and test' inputs: script: | mkdir -p $HOME/.onnx @@ -61,7 +62,7 @@ jobs: -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3.8 /onnxruntime_src/tools/ci_build/build.py \ + /opt/python/cp310-cp310/bin/python3.10 /onnxruntime_src/tools/ci_build/build.py \ --build_dir /build \ --config Debug Release \ --skip_submodule_sync \ diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index b0f40429c1a1e..7bb1deb60c6ba 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -49,9 +49,9 @@ parameters: variables: - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 - name: Repository ${{ if eq(parameters.CudaVersion, '11.8') }}: diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 87d5c7bd824d2..9d60c9ea17cd8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -37,16 +37,17 @@ parameters: - 12.2 variables: + - template: templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} jobs: - job: Linux_Build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index fb2c86dbf68e3..83cf26614a285 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -8,12 +8,12 @@ parameters: - name: TrtVersion displayName: TensorRT Version type: string - default: 10.4.cuda_12_5_cudnn_9 + default: 10.5.cuda_12_5_cudnn_9 values: - 8.6.cuda_11_8_cudnn_8 - 8.6.cuda_12_3_cudnn_9 - - 10.4.cuda_11_8_cudnn_8 - - 10.4.cuda_12_5_cudnn_9 + - 10.5.cuda_11_8_cudnn_8 + - 10.5.cuda_12_5_cudnn_9 - BIN - name: UseTensorrtOssParser diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index d8c0120fc9ee5..c7b814f3dd52c 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -33,5 +33,5 @@ jobs: parameters: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' - RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2024.3.0 -x "--use_openvino CPU --build_wheel"' + RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2024.5.0 -x "--use_openvino CPU --build_wheel"' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 41f6b6a8d6d80..d3826d90f9073 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml index 1fdc5098579b2..c6ab33164035c 100644 --- a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml +++ b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml @@ -17,8 +17,13 @@ extends: name: onnxruntime-Win-CPU-2022 os: windows sdl: + git: + submodules: false tsa: enabled: true + codeSignValidation: + enabled: true + break: true policheck: enabled: true exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index b1e5816fb748e..f9ecfb7cf7938 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -23,6 +23,7 @@ stages: pool: ${{ parameters.AgentPool }} variables: + - template: ../../templates/common-variables.yml - name: OnnxRuntimeBuildDirectory value: '$(Build.BinariesDirectory)' @@ -52,7 +53,7 @@ stages: inputs: script: | ln -sf /data/models $(Build.BinariesDirectory) - + # As for Debian installation, replace '-1.' by '-1+' when assigning trt version below - ${{if contains(parameters.StageSuffix , 'GPU') }}: - template: ../../templates/get-docker-image-steps.yml parameters: @@ -61,7 +62,7 @@ stages: ${{ if eq(parameters.CudaVersion, '12.2') }}: DockerBuildArgs: " --build-arg BASEIMAGE=nvidia/cuda:12.2.2-devel-ubuntu20.04 - --build-arg TRT_VERSION=10.4.0.26-1+cuda12.6 + --build-arg TRT_VERSION=${{ replace(variables.linux_trt_version_cuda12, '-1.', '-1+') }} --build-arg BUILD_UID=$( id -u ) " ${{ else }}: diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 7f131590c900b..3eafd7350b25b 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -1,3 +1,20 @@ +parameters: +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +variables: + - template: templates/common-variables.yml + - name: win_trt_folder + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: ${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: ${{ variables.win_trt_folder_cuda12 }} + stages: - ${{ if or(startsWith(variables['System.CollectionUri'], 'https://dev.azure.com/aiinfra/'),startsWith(variables['System.CollectionUri'], 'https://aiinfra.visualstudio.com/')) }}: - template: templates/web-ci.yml @@ -206,7 +223,7 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env_cuda.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo @@ -226,7 +243,7 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env_trt.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --enable_cuda_profiling --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder }}" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml index 0a8fe2f50a29f..9296928ad97e0 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 - trt_version: '10.4.0.26-1.cuda11.8' + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 + trt_version: '10.6.0.26-1.cuda11.8' cuda_version: '11.8' diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml index 844991c475ff7..93a38b212d934 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-packaging-pipeline.yml @@ -28,7 +28,15 @@ extends: # For productions pipelines, use "Official". template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - # Update the pool with your team's 1ES hosted pool. + sdl: + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' pool: name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index 5094c56956978..307415b7be16f 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -18,7 +18,6 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 - trt_version: '10.4.0.26-1.cuda12.6' + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 cuda_version: '12.2' diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml index 7e6b1889687a3..2e040698fad2a 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -24,13 +24,22 @@ parameters: - RelWithDebInfo - MinSizeRel + extends: # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. # For non-production pipelines, use "Unofficial" as defined below. # For productions pipelines, use "Official". template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - # Update the pool with your team's 1ES hosted pool. + sdl: + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' pool: name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS diff --git a/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml index 280b54e4b9c2d..371d233897c8d 100644 --- a/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-dml-packaging-pipeline.yml @@ -20,7 +20,16 @@ extends: # For productions pipelines, use "Official". template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - # Update the pool with your team's 1ES hosted pool. + sdl: + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + pool: name: 'onnxruntime-Win-CPU-2022' # Name of your hosted pool os: windows # OS of the image. This value cannot be a variable. Allowed values: windows, linux, macOS diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index bb9ada7d6cb4b..bd33282fd494e 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.27.0.240926 + default: 2.28.2.241116 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index fd3f31da4ab7e..d54b8018c232a 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index a38486995478d..716383fd61dbb 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -148,9 +148,9 @@ stages: value: false - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 timeoutInMinutes: 60 steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index a33f757c24408..47092393e0039 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -42,16 +42,17 @@ jobs: dependsOn: [ ] timeoutInMinutes: ${{ parameters.timeout }} variables: + - template: ../../templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} pool: ${{ parameters.machine_pool }} steps: - checkout: self diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/steps/py_packaging_test_step.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/steps/py_packaging_test_step.yml new file mode 100644 index 0000000000000..9a721c65de332 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/steps/py_packaging_test_step.yml @@ -0,0 +1,21 @@ +parameters: +- name: EP_NAME + type: string + default: CPU + +- name: PYTHON_VERSION + type: string + +steps: +- powershell: | + python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq + Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*cp${{ replace(parameters.PYTHON_VERSION,'.','') }}*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} + mkdir -p $(Agent.TempDirectory)\ort_test_data + Copy-Item -Path $(Build.sourcesDirectory)/onnxruntime/test/python/onnx_backend_test_series.py -Destination $(Agent.TempDirectory)\ort_test_data + Copy-Item -Recurse -Path $(Build.sourcesDirectory)/onnxruntime/test/testdata -Destination $(Agent.TempDirectory)\ort_test_data + cd $(Agent.TempDirectory)\ort_test_data + python onnx_backend_test_series.py --devices ${{ parameters.EP_NAME }} -v + cd $(Agent.TempDirectory) + Remove-Item -Path $(Agent.TempDirectory)\ort_test_data -Recurse -Force + workingDirectory: '$(Build.sourcesDirectory)' + displayName: 'Run Python Tests with ${{ parameters.EP_NAME }} EP' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 9c7fbc24ab1b6..0b3eac0110abc 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -50,6 +50,8 @@ stages: win_trt_home: ${{ parameters.win_trt_home }} win_cuda_home: ${{ parameters.win_cuda_home }} buildJava: ${{ parameters.buildJava }} + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} - template: nuget-cuda-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index fc6da88917f62..d331c76bc264e 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -68,6 +68,7 @@ stages: timeoutInMinutes: 180 pool: 'onnxruntime-Ubuntu2204-AMD-CPU' variables: + - template: ../templates/common-variables.yml - name: CUDA_VERSION_MAJOR ${{ if eq(parameters.CudaVersion, '11.8') }}: value: '11' @@ -75,12 +76,11 @@ stages: value: '12' - name: CUDA_VERSION value: ${{ parameters.CudaVersion }} - - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} steps: - checkout: self clean: true @@ -140,6 +140,7 @@ stages: clean: all pool: 'Onnxruntime-Linux-GPU' variables: + - template: ../templates/common-variables.yml - name: CUDA_VERSION_MAJOR ${{ if eq(parameters.CudaVersion, '11.8') }}: value: '11' @@ -147,9 +148,9 @@ stages: value: '12' - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.4.0.26-1.cuda11.8 + value: ${{ variables.linux_trt_version_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.4.0.26-1.cuda12.6 + value: ${{ variables.linux_trt_version_cuda12 }} steps: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime submodules: false diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index 445066f08995a..d6b25c98936f0 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -34,7 +34,7 @@ parameters: displayName: Specific Artifact's BuildId type: string default: '0' - + - name: buildJava type: boolean @@ -50,13 +50,14 @@ stages: msbuildPlatform: x64 packageName: x64-cuda CudaVersion: ${{ parameters.CudaVersion }} - buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" + buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" --use_dml --build_csharp --parallel runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + ComboTests: true # Windows CUDA with TensorRT Packaging - template: ../templates/win-ci.yml parameters: @@ -68,7 +69,7 @@ stages: msbuildPlatform: x64 CudaVersion: ${{ parameters.CudaVersion }} packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" + buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" --parallel runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index e92761e20d9e3..72df94c9ea672 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.27.0.240926 + default: 2.28.2.241116 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index 0160fdd6ddd95..f7235e3ad2076 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -56,10 +56,8 @@ stages: PYTHON_VERSION: ${{ python_version }} EP_NAME: gpu CudaVersion: ${{ parameters.cuda_version }} - ${{ if eq(parameters.cuda_version, '11.8') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" - ${{ if eq(parameters.cuda_version, '12.2') }}: - EP_BUILD_FLAGS: --enable_lto --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 --cuda_home=$(Agent.TempDirectory)\v12.2 --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_BUILD_FLAGS: --use_dml --enable_lto --cuda_home=$(Agent.TempDirectory)\v${{ parameters.cuda_version }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + use_tensorrt: True - ${{ if eq(parameters.enable_linux_cuda, true) }}: - template: py-linux-gpu-stage.yml @@ -70,11 +68,9 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} ${{ if eq(parameters.cuda_version, '11.8') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241020.1 - trt_version: 10.4.0.26-1.cuda11.8 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20241120.3 ${{ if eq(parameters.cuda_version, '12.2') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241020.1 - trt_version: 10.4.0.26-1.cuda12.6 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20241120.3 - ${{ if eq(parameters.enable_windows_dml, true) }}: - ${{ each python_version in parameters.PythonVersions }}: @@ -83,6 +79,5 @@ stages: MACHINE_POOL: 'onnxruntime-Win2022-GPU-dml-A10' PYTHON_VERSION: ${{ python_version }} EP_BUILD_FLAGS: --use_dml --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0 --enable_wcos - ENV_SETUP_SCRIPT: setup_env.bat EP_NAME: directml cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml index 83b863f18fbc4..3f26d2d5aeca3 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml @@ -20,12 +20,6 @@ parameters: - name: docker_base_image type: string -- name: trt_version - type: string - default: '10.4.0.26-1.cuda11.8' - values: - - 10.4.0.26-1.cuda11.8 - - 10.4.0.26-1.cuda12.6 - name: cuda_version type: string default: '11.8' @@ -71,6 +65,12 @@ stages: value: -x ${{ parameters.extra_build_arg }} ${{ if eq(parameters.extra_build_arg, '') }}: value: '' + - template: ../templates/common-variables.yml + - name: trt_version + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: ${{ variables.linux_trt_version_cuda12 }} steps: - checkout: self clean: true @@ -82,7 +82,7 @@ stages: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda/Dockerfile Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda - DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ parameters.trt_version }} --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ variables.trt_version }} --build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 88937cc2e154d..dd0539f751c89 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -12,10 +12,6 @@ parameters: - name: EP_BUILD_FLAGS type: string -- name: ENV_SETUP_SCRIPT - type: string - default: '' - - name: BUILD_PY_PARAMETERS displayName: > Extra parameters to pass to build.py. Don't put newlines in here. @@ -38,6 +34,10 @@ parameters: - RelWithDebInfo - MinSizeRel +- name: use_tensorrt + type: boolean + default: false + stages: - stage: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Build dependsOn: [] @@ -65,9 +65,23 @@ stages: targetPath: $(Build.ArtifactStagingDirectory) artifactName: win_${{ parameters.EP_NAME }}_wheel_${{ parameters.PYTHON_VERSION }} variables: - GRADLE_OPTS: '-Dorg.gradle.daemon=false' - VSGenerator: 'Visual Studio 17 2022' - CUDA_MODULE_LOADING: 'LAZY' + - template: ../templates/common-variables.yml + - name: GRADLE_OPTS + value: '-Dorg.gradle.daemon=false' + - name: VSGenerator + value: 'Visual Studio 17 2022' + - name: CUDA_MODULE_LOADING + value: 'LAZY' + - name: win_trt_folder + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: ${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: ${{ variables.win_trt_folder_cuda12 }} + - name: trt_build_flag + ${{ if eq(parameters.use_tensorrt, true) }}: + value: '--use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder }}"' + ${{ if eq(parameters.use_tensorrt, false) }}: + value: '' steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -85,25 +99,20 @@ stages: addToPath: true architecture: 'x64' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - template: ../templates/download-deps.yml - - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: ../templates/jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: - DownloadCUDA: true - ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}: - DownloadTRT: true - - - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}: - - template: ../templates/jobs/download_win_gpu_library.yml - parameters: - CudaVersion: ${{ parameters.CudaVersion }} - ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: - DownloadCUDA: true - ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}: - DownloadTRT: true + - template: ../templates/jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), eq(parameters.use_tensorrt, true)) }}: + DownloadCUDA: true + DownloadTRT: ${{ parameters.use_tensorrt }} - task: PythonScript@0 displayName: 'Update deps.txt' @@ -125,8 +134,8 @@ stages: --cmake_generator "$(VSGenerator)" --enable_pybind --enable_onnx_tests - --parallel --use_binskim_compliant_compile_flags --update --build - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} + --parallel 4 --use_binskim_compliant_compile_flags --update --build + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} ${{ variables.trt_build_flag }} workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing @@ -189,20 +198,28 @@ stages: StepName: 'Download Pipeline Artifact - Windows GPU Build' TargetPath: '$(Build.ArtifactStagingDirectory)' + - template: ../templates/jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), eq(parameters.use_tensorrt, true)) }}: + DownloadCUDA: true + DownloadTRT: ${{ parameters.use_tensorrt }} + - task: PowerShell@2 - displayName: 'Install ONNX' + displayName: 'Install Third Party Dependencies' inputs: filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' workingDirectory: '$(Build.BinariesDirectory)' arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\installed -build_config ${{ parameters.cmake_build_type }} - - powershell: | - python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq - Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*cp${{ replace(parameters.PYTHON_VERSION,'.','') }}*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} - mkdir -p $(Agent.TempDirectory)\ort_test_data - Copy-Item -Path $(Build.sourcesDirectory)/onnxruntime/test/python/onnx_backend_test_series.py -Destination $(Agent.TempDirectory)\ort_test_data - Copy-Item -Recurse -Path $(Build.sourcesDirectory)/onnxruntime/test/testdata -Destination $(Agent.TempDirectory)\ort_test_data - cd $(Agent.TempDirectory)\ort_test_data - python onnx_backend_test_series.py - workingDirectory: '$(Build.sourcesDirectory)' - displayName: 'Run Python Tests' + - template: jobs/steps/py_packaging_test_step.yml + parameters: + EP_NAME: DML + PYTHON_VERSION: ${{ parameters.PYTHON_VERSION }} + + - template: jobs/steps/py_packaging_test_step.yml + parameters: + EP_NAME: CUDA + PYTHON_VERSION: ${{ parameters.PYTHON_VERSION }} + + diff --git a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml index acce2a4098ed0..4d9606d82ced2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml @@ -91,7 +91,7 @@ stages: -e BUILD_REASON=$(Build.Reason) \ -e BUILD_BRANCH=$(Build.SourceBranch) \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ --build_dir /build/1a \ ${BINARY_SIZE_THRESHOLD_ARGS} \ "/onnxruntime_src/${{ parameters.BuildConfigFile }}" @@ -147,7 +147,7 @@ stages: -e BUILD_REASON=$(Build.Reason) \ -e BUILD_BRANCH=$(Build.SourceBranch) \ onnxruntimecpubuild \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ + /opt/python/cp310-cp310/bin/python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ --build_dir /build/1b \ --with_debug_info \ "/onnxruntime_src/${{ parameters.BuildConfigFile }}" diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index e162365c40ce7..29caa7fa4955a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,6 +19,11 @@ parameters: type: string default: '' +- name: QnnSDKVersion + displayName: QNN SDK Version + type: string + default: '2.28.0.241029' + jobs: - job: Final_AAR_Testing_Android_${{ parameters.job_name_suffix }} workspace: @@ -50,36 +55,61 @@ jobs: - template: use-android-ndk.yml - - template: use-android-emulator.yml - parameters: - create: true - start: true - - script: | - set -e -x - mkdir android_test - cd android_test - cp -av $(Build.SourcesDirectory)/java/src/test/android ./ - cd ./android - mkdir -p app/libs - cp $(Build.BinariesDirectory)/final-android-aar/${{parameters.packageName}}-$(OnnxRuntimeVersion)${{parameters.ReleaseVersionSuffix}}.aar app/libs/onnxruntime-android.aar - $(Build.SourcesDirectory)/java/gradlew --no-daemon clean connectedDebugAndroidTest --stacktrace - displayName: Run E2E test using Emulator + set -e -x + mkdir -p android_test/android/app/libs + cd android_test/android + cp -av $(Build.SourcesDirectory)/java/src/test/android/* ./ + cp $(Build.BinariesDirectory)/final-android-aar/${{parameters.packageName}}-$(OnnxRuntimeVersion)${{parameters.ReleaseVersionSuffix}}.aar app/libs/${{parameters.packageName}}.aar + displayName: Copy Android test files and AAR to android_test directory workingDirectory: $(Build.BinariesDirectory) - - template: use-android-emulator.yml - parameters: - stop: true + # skip emulator tests for qnn package as there are no arm64-v8a emulators and no qnn libraries for x86 + - ${{ if not(contains(parameters.packageName, 'qnn')) }}: + - template: use-android-emulator.yml + parameters: + create: true + start: true + + - script: | + set -e -x + cd android_test/android + $(Build.SourcesDirectory)/java/gradlew --no-daemon clean connectedDebugAndroidTest --stacktrace + displayName: Run E2E test using Emulator + workingDirectory: $(Build.BinariesDirectory) + + - template: use-android-emulator.yml + parameters: + stop: true + + - ${{ else }}: + - script: | + # QNN SDK version string, expected format: 2.28.0.241029 + # Extract the first three parts of the version string to get the Maven package version (e.g., 2.28.0) + QnnMavenPackageVersion=$(echo ${{ parameters.QnnSDKVersion }} | cut -d'.' -f1-3) + echo "QnnMavenPackageVersion: $QnnMavenPackageVersion" + echo "##vso[task.setvariable variable=QnnMavenPackageVersion]$QnnMavenPackageVersion" + displayName: Trim QNN SDK version to major.minor.patch + + - script: | + set -e -x + # build apks for qnn package as they are not built in the emulator test step + $(Build.SourcesDirectory)/java/gradlew --no-daemon clean assembleDebug assembleAndroidTest -DqnnVersion=$(QnnMavenPackageVersion) --stacktrace + displayName: Build QNN APK + workingDirectory: $(Build.BinariesDirectory)/android_test/android # we run e2e tests on one older device (Pixel 3) and one newer device (Galaxy 23) - script: | set -e -x pip install requests + python $(Build.SourcesDirectory)/tools/python/upload_and_run_browserstack_tests.py \ --test_platform espresso \ - --app_apk_path "debug/app-debug.apk" \ - --test_apk_path "androidTest/debug/app-debug-androidTest.apk" \ - --devices "Samsung Galaxy S23-13.0" "Google Pixel 3-9.0" + --app_path "debug/app-debug.apk" \ + --test_path "androidTest/debug/app-debug-androidTest.apk" \ + --devices "Samsung Galaxy S23-13.0" "Google Pixel 3-9.0" \ + --build_tag "${{ parameters.packageName }}" + displayName: Run E2E tests using Browserstack workingDirectory: $(Build.BinariesDirectory)/android_test/android/app/build/outputs/apk timeoutInMinutes: 15 diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 51e47fde74bb2..c38736edd58f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -48,6 +48,11 @@ parameters: type: string default: '' +- name: QnnSDKVersion + displayName: QNN SDK Version + type: string + default: '2.28.0.241029' + jobs: - job: Android_Java_API_AAR_Packaging_${{ parameters.job_name_suffix }} timeoutInMinutes: 120 @@ -85,6 +90,8 @@ jobs: - ${{ if contains(parameters.packageName, 'qnn') }}: - template: jobs/download_linux_qnn_sdk.yml + parameters: + QnnSDKVersion: '${{parameters.QnnSDKVersion}}' - task: CmdLine@2 displayName: Build Android AAR Packages diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index a98efa8f3fc92..b105e919c5b12 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -48,6 +48,11 @@ parameters: type: string default: '0' +- name: QnnSDKVersion + displayName: QNN SDK Version + type: string + default: 2.28.0.241029 + stages: - template: linux-cpu-packaging-pipeline.yml parameters: @@ -98,7 +103,14 @@ stages: enable_code_sign: ${{ parameters.DoEsrp }} packageName: 'onnxruntime-android-qnn' ReleaseVersionSuffix: $(ReleaseVersionSuffix) - #TODO: Add test job for QNN Android AAR + QnnSDKVersion: ${{ parameters.QnnSDKVersion }} + + - template: android-java-api-aar-test.yml + parameters: + artifactName: 'onnxruntime-android-qnn-aar' + job_name_suffix: 'QNN' + packageName: 'onnxruntime-android-qnn' + QnnSDKVersion: ${{ parameters.QnnSDKVersion }} - stage: iOS_Full_xcframework dependsOn: [] diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index c5bd4b93db947..d3b3315ebb04c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -5,9 +5,6 @@ parameters: type: string default: '' -- name: BaseImage - type: string - - name: OnnxruntimeArch type: string @@ -50,7 +47,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{parameters.BaseImage}}" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging - ${{ if eq(parameters.OnnxruntimeArch, 'aarch64') }}: @@ -58,7 +55,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/aarch64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{parameters.BaseImage}}" + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}}_packaging UpdateDepsTxt: false diff --git a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml index e7f703fa592a3..d35bed69ee409 100644 --- a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml +++ b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml @@ -1,3 +1,7 @@ variables: - common_cuda_version: '11.8' - common_cuda_baseimg: 'nvidia/cuda:11.8.0-cudnn8-devel-ubi8' + common_trt_version: '10.6.0.26' + # As for Debian installation, replace '-1.' by '-1+' when assigning trt version below + linux_trt_version_cuda11: ${{ variables.common_trt_version }}-1.cuda11.8 + linux_trt_version_cuda12: ${{ variables.common_trt_version }}-1.cuda12.6 + win_trt_folder_cuda11: TensorRT-${{ variables.common_trt_version }}.Windows10.x86_64.cuda-11.8 + win_trt_folder_cuda12: TensorRT-${{ variables.common_trt_version }}.Windows10.x86_64.cuda-12.6 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index b3a47039005a9..949479fb8b5e4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.193 + version: 1.0.201 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.193 + version: 1.0.201 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/install-appcenter.yml b/tools/ci_build/github/azure-pipelines/templates/install-appcenter.yml deleted file mode 100644 index 51be73d4c658a..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/install-appcenter.yml +++ /dev/null @@ -1,12 +0,0 @@ -# Install appcenter CLI - -parameters: -- name: appcenterVersion - type: string - default: "2.13.7" - -steps: -- bash: | - set -e -x - npm install -g appcenter-cli@${{ parameters.appcenterVersion }} - displayName: Install appcenter CLI ${{ parameters.appcenterVersion }} diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index f749f32456b25..179a846509cc1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.27.0.240926' + default: '2.28.2.241116' steps: - script: | @@ -16,6 +16,29 @@ steps: echo $(QnnSDKRootDir) displayName: 'Print QnnSDKRootDir after downloading QNN SDK' + - script: | + set -x + sdk_file="$(QnnSDKRootDir)/sdk.yaml" + # Parse the sdk.yaml file to get the QNN SDK version downloaded + downloaded_qnn_sdk_version=$(grep '^version:' "$sdk_file" | head -n 1 | cut -d':' -f2 | xargs | cut -d'.' -f1-3 | tr -d '\r') + + # Extract major.minor.patch part from QnnSDKVersion passed as parameter + expected_qnn_sdk_version=$(echo ${{ parameters.QnnSDKVersion }} | cut -d'.' -f1-3) + + if [[ -z "$downloaded_qnn_sdk_version" ]]; then + echo "QNN version not found in sdk.yaml." + exit 1 + fi + + # Compare provided version with version from sdk.yaml + if [[ "$downloaded_qnn_sdk_version" == "$expected_qnn_sdk_version" ]]; then + echo "Success: QnnSDKVersion matches sdk.yaml version ($downloaded_qnn_sdk_version)." + else + echo "Error: QnnSDKVersion ($expected_qnn_sdk_version) does not match sdk.yaml version ($downloaded_qnn_sdk_version) in the QNN SDK directory" + exit 1 + fi + displayName: "Sanity Check: QnnSDKVersion vs sdk.yaml version" + - script: | azcopy cp --recursive 'https://lotusscus.blob.core.windows.net/models/qnnsdk/Qualcomm AI Hub Proprietary License.pdf' $(QnnSDKRootDir) displayName: 'Download Qualcomm AI Hub license' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index e196ecb312f96..ae54b3849a862 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -13,10 +13,10 @@ parameters: - 12.2 - name: TrtVersion type: string - default: '10.4.0.26' + default: '10.6.0.26' values: - 8.6.1.6 - - 10.4.0.26 + - 10.6.0.26 steps: - ${{ if eq(parameters.DownloadCUDA, true) }}: @@ -42,7 +42,7 @@ steps: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.0" displayName: Set trtCudaVersion - - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.4.0.26')) }}: + - ${{ if and(eq(parameters.CudaVersion, '12.2'), eq(parameters.TrtVersion, '10.6.0.26')) }}: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.6" displayName: Set trtCudaVersion diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index c56d81aefbec1..9df8b249f681e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.27.0.240926' + default: '2.28.2.241116' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index 6a2b7f4566b61..dfaf237a711fe 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -13,6 +13,12 @@ parameters: - name: SecondaryCUDAVersion type: string default: '11.8' + - name: win_trt_folder_cuda11 + type: string + default: 'TensorRT-10.6.0.26.Windows10.x86_64.cuda-11.8' + - name: win_trt_folder_cuda12 + type: string + default: 'TensorRT-10.6.0.26.Windows10.x86_64.cuda-12.6' steps: - ${{ if eq(parameters.DownloadCUDA, 'true') }}: @@ -24,11 +30,11 @@ steps: displayName: 'Download Secondary CUDA SDK v${{ parameters.SecondaryCUDAVersion }}' - ${{ if eq(parameters.DownloadTRT, 'true') }}: - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda11 }}" $(Agent.TempDirectory) + displayName: 'Download ${{ parameters.win_trt_folder_cuda11 }}' - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6" $(Agent.TempDirectory) - displayName: 'Download TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda12 }}" $(Agent.TempDirectory) + displayName: 'Download ${{ variables.win_trt_folder_cuda12 }}' - task: BatchScript@1 displayName: 'setup env' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index b39d7edb8fb22..7bdd069de711b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -109,7 +109,7 @@ jobs: - ${{ if eq(parameters.buildArch, 'x64') }}: - task: JavaToolInstaller@0 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' @@ -218,16 +218,32 @@ jobs: - powershell: | python3 -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} - workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' displayName: 'Install onnxruntime wheel' - ${{ if eq(parameters.RunOnnxRuntimeTests, true) }}: - - powershell: | - python $(Build.SourcesDirectory)\tools\ci_build\build.py --config ${{ parameters.BuildConfig }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests ${{ parameters.additionalBuildFlags }} - - workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' - displayName: 'Run tests' + - ${{ if and(contains(parameters.additionalBuildFlags, 'use_cuda'), contains(parameters.additionalBuildFlags, 'use_dml')) }}: + - powershell: | + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config ${{ parameters.BuildConfig }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests ${{ parameters.additionalBuildFlags }} + workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' + displayName: 'Run tests excluding CUDA tests' + env: + NO_CUDA_TEST: '1' + GTEST_FILTER: '-CudaEp*:CudaNhwcTypedTest*:*cpu_*models*' # Exclude CUDA EP tests under providers/cuda/ and cpu models test + PATH: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }};$(PATH)' # For onnxruntime4j_test to find dependent dlls + - powershell: | + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config ${{ parameters.BuildConfig }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests ${{ parameters.additionalBuildFlags }} + workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' + displayName: 'Run tests excluding DML tests' + env: + NO_DML_TEST: '1' + GTEST_FILTER: '-*cpu_*models*' + PATH: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }};$(PATH)' + - ${{ else }}: + - powershell: | + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config ${{ parameters.BuildConfig }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --enable_onnx_tests ${{ parameters.additionalBuildFlags }} + workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' + displayName: 'Run tests' - ${{ if eq(parameters.GenerateDocumentation, true) }}: - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index 8972d55f6e190..7ac2e3a8addb6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -29,7 +29,6 @@ stages: - template: c-api-linux-cpu.yml parameters: AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - BaseImage: 'registry.access.redhat.com/ubi8/ubi' OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' @@ -40,7 +39,6 @@ stages: - template: c-api-linux-cpu.yml parameters: AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - BaseImage: 'arm64v8/almalinux:8' OnnxruntimeArch: 'aarch64' OnnxruntimeNodejsBindingArch: 'arm64' PoolName: 'onnxruntime-linux-ARM64-CPU-2019' diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 080079388a76c..ab31e592d7d71 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -68,9 +68,6 @@ stages: jobs: - job: MacOS_C_API_Package_Publish pool: - ${{ if eq(parameters.DoESRP, true)}}: - vmImage: 'macOS-12' - ${{ else }}: vmImage: 'macOS-13' steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index 8a278b57e4aee..7a1addffee0e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -45,16 +45,17 @@ jobs: submodules: none - task: UsePythonVersion@0 - displayName: Use Python 3.11 + displayName: Use Python 3.10 inputs: - versionSpec: 3.11 + versionSpec: 3.10 + - task: NodeTool@0 inputs: versionSpec: '20.x' - task: JavaToolInstaller@0 inputs: - versionSpec: "11" + versionSpec: "17" jdkArchitectureOption: "x64" jdkSourceOption: 'PreInstalled' @@ -83,7 +84,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" BuildJava: false BuildNodejs: false WithCache: ${{ parameters.WithCache }} @@ -95,7 +96,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} @@ -107,7 +108,7 @@ jobs: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu BuildJava: true BuildNodejs: true WithCache: ${{ parameters.WithCache }} diff --git a/tools/ci_build/github/azure-pipelines/templates/orttraining-linux-gpu-test-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/orttraining-linux-gpu-test-ci-pipeline.yml deleted file mode 100644 index 5f073433265fa..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/orttraining-linux-gpu-test-ci-pipeline.yml +++ /dev/null @@ -1,41 +0,0 @@ -parameters: -- name: DockerImageTag - type: string -- name: BuildConfig - type: string - -steps: - -- template: jobs/download_training_test_data.yml - - # Entry point for all ORTModule tests - # The onnxruntime folder is deleted in the build directory - # to enforce use of the onnxruntime wheel - # Uninstall orttraining requirements.txt and install ortmodule requirements.txt before running tests. -- script: | - docker run \ - --gpus all \ - --shm-size=1024m \ - --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory)/${{ parameters.BuildConfig }}:/build \ - --volume $(Agent.TempDirectory)/mnist:/mnist \ - ${{ parameters.DockerImageTag }} \ - bash -c "rm -rf /build/onnxruntime/ && python3 -m pip show torch && python3 -m pip install torch==2.3.1+cu118 --index-url https://download.pytorch.org/whl/cu118 && python3 -m pip install /build/dist/onnxruntime*.whl && python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install && /build/launch_test.py --cmd_line_with_args 'python orttraining_ortmodule_tests.py --mnist /mnist --bert_data /bert_data/hf_data/glue_data/CoLA/original/raw' --cwd /build" \ - displayName: 'Run orttraining_ortmodule_tests.py' - condition: succeededOrFailed() - timeoutInMinutes: 60 - -# Entry point for all ort training api tests -- script: | - docker run \ - --gpus all \ - --shm-size=1024m \ - --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory)/${{ parameters.BuildConfig }}:/build \ - ${{ parameters.DockerImageTag }} \ - bash -c "rm -rf /build/onnxruntime/ && python3 -m pip install /build/dist/onnxruntime*.whl && python3 -m pip install torch==2.3.1+cu118 --index-url https://download.pytorch.org/whl/cu118 && /build/launch_test.py --cmd_line_with_args 'python orttraining_test_ort_apis.py --cwd /build' --cwd /build" \ - displayName: 'Run ORT Training APIs Tests' - condition: succeededOrFailed() - timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index e663afb49dd99..b1cec2284df65 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 jobs: - job: Linux_py_qnn_Wheels_x64 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index dd9d2412f8f91..c7becac763e28 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -26,8 +26,16 @@ parameters: type: string default: '' +- name: ep + type: string + default: 'cpu' + +- name: python_exe_path + type: string + default: '' + jobs: -- job: Linux_py_Wheels_${{ parameters.arch }} +- job: Linux_py_Wheels_${{ parameters.arch }}_${{parameters.ep}} timeoutInMinutes: 240 workspace: clean: all @@ -42,9 +50,15 @@ jobs: value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - name: extra_build_args ${{ if ne(parameters.extra_build_arg, '') }}: - value: -x ${{ parameters.extra_build_arg }} + value: '-x ${{ parameters.extra_build_arg }}' ${{ if eq(parameters.extra_build_arg, '') }}: value: '' + - name: python_exe_path + ${{ if ne(parameters.python_exe_path, '') }}: + value: '-p ${{ parameters.python_exe_path }}' + ${{ if eq(parameters.python_exe_path, '') }}: + value: '' + steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -78,7 +92,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) ${{ if eq(parameters.with_cache, 'true') }}: env: ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" @@ -87,14 +101,14 @@ jobs: displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime + ArtifactName: onnxruntime-${{ parameters.ep }} - task: PublishPipelineArtifact@0 displayName: 'Publish Test Binaries' inputs: - artifactName: 'drop-linux-cpu-${{ parameters.arch }}' + artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' - template: component-governance-component-detection-steps.yml parameters : - condition : 'succeeded' \ No newline at end of file + condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml index 0c7c356393b54..bfa6b0d32cab5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -28,6 +28,10 @@ parameters: type: number default: 120 +- name: ep + type: string + default: 'cpu' + jobs: - job: Linux_Test_CPU${{ parameters.extra_job_id }}_${{ parameters.arch }} timeoutInMinutes: ${{ parameters.timeout }} @@ -43,30 +47,30 @@ jobs: # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - download: current # pipeline resource identifier. - artifact: 'drop-linux-cpu-${{ parameters.arch }}' + artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: current # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' - bash: | set -e -x - mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. - artifact: 'drop-linux-cpu-${{ parameters.arch }}' + artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: build # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' - bash: | set -e -x ls $(Pipeline.Workspace)/build - mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 6a74d0e7befd3..0473fc199a991 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -16,12 +16,6 @@ parameters: - name: docker_base_image type: string -- name: trt_version - type: string - default: '10.4.0.26-1.cuda11.8' - values: - - 10.4.0.26-1.cuda11.8 - - 10.4.0.26-1.cuda12.6 - name: cuda_version type: string default: '11.8' @@ -47,7 +41,14 @@ jobs: - job: Linux_Test_GPU${{ parameters.extra_job_id }}_${{ parameters.arch }} timeoutInMinutes: ${{ parameters.timeout }} variables: - skipComponentGovernanceDetection: true + - template: common-variables.yml + - name: skipComponentGovernanceDetection + value: true + - name: trt_version + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: ${{ variables.linux_trt_version_cuda12 }} workspace: clean: all pool: ${{ parameters.machine_pool }} @@ -92,7 +93,7 @@ jobs: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda/Dockerfile Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cuda - DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ parameters.trt_version }} --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ variables.trt_version }} --build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} - task: Bash@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml index 4310c7f7800fa..7f3a61997b2f8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml @@ -379,9 +379,10 @@ stages: pool: 'onnxruntime-Win2022-GPU-A10' timeoutInMinutes: 300 variables: + - template: common-variables.yml CUDA_VERSION: '11.8' buildArch: x64 - EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" + EpBuildFlags: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda11 }}" --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80" EnvSetupScript: setup_env_gpu.bat EP_NAME: gpu VSGenerator: 'Visual Studio 17 2022' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml deleted file mode 100644 index fc163d17e44a9..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml +++ /dev/null @@ -1,209 +0,0 @@ -parameters: - build_py_parameters: '' - torch_version: '' - opset_version: '' - cuda_version: '' - cmake_cuda_architectures: '' - docker_file: '' - upload_wheel: '' - debug_build: '' - python_version: '' - stage_name: '' - SpecificArtifact: false - BuildId: '0' - build_pool_name: '' - -stages: - - stage: Build_${{ parameters.stage_name }} - variables: - - name: isMain - value: ${{ or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')) }} - - name: finalStorage - ${{ if eq(variables['isMain'], 'true') }}: - value: '--final_storage' - ${{ else }}: - value: '' - - name: buildConfig - ${{ if eq(parameters['debug_build'], 'true') }}: - value: 'Debug' - ${{ else }}: - value: 'Release' - - name: PythonVersion - value: ${{ parameters.python_version }} - - name: Repository - value: onnxruntimetraininggpubuild_cu${{ replace(parameters.cuda_version, '.', '') }}_py${{ replace(parameters.python_version, '.', '') }} - dependsOn: [] - - jobs: - - job: Build - pool: ${{ parameters.build_pool_name }} - timeoutInMinutes: 180 - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - task: CmdLine@2 - displayName: 'check variables' - inputs: - script: | - echo "Branch is "${{ variables['Build.SourceBranch'] }} && \ - echo "isMain is "${{ variables['isMain'] }} && \ - echo "final_storage is "${{ variables['finalStorage'] }} - - - checkout: self - clean: true - submodules: recursive - - - template: set-python-manylinux-variables-step.yml - - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/${{ parameters.docker_file }} - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg TORCH_VERSION=${{ parameters.torch_version }} - --build-arg OPSET_VERSION=${{ parameters.opset_version }} - --build-arg PYTHON_VERSION=${{ parameters.python_version }} - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu - --build-arg BUILD_UID=$(id -u) - Repository: $(Repository) - - - task: CmdLine@2 - displayName: 'build onnxruntime' - inputs: - script: | - set -e -x - mkdir -p $HOME/.onnx - docker run --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - -e NIGHTLY_BUILD \ - -e DEFAULT_TRAINING_PACKAGE_DEVICE \ - -e BUILD_BUILDNUMBER \ - -e ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION \ - $(Repository) \ - $(PythonManylinuxDir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build \ - --config ${{ variables['buildConfig'] }} \ - --skip_submodule_sync \ - --parallel --use_binskim_compliant_compile_flags \ - --build_wheel \ - --enable_onnx_tests \ - ${{ parameters.build_py_parameters }} \ - --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=${{ parameters.cmake_cuda_architectures }}' onnxruntime_BUILD_UNIT_TESTS=OFF \ - --use_cuda --cuda_version=${{ parameters.cuda_version }} --cuda_home=/usr/local/cuda-${{ parameters.cuda_version }} --cudnn_home=/usr/local/cuda-${{ parameters.cuda_version }}; - workingDirectory: $(Build.SourcesDirectory) - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)' - Contents: "${{ variables['buildConfig'] }}/dist/*.whl" - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel and documentation' - inputs: - ArtifactName: "onnxruntime_gpu_${{ variables['buildConfig'] }}_${{ parameters.python_version }}" - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - - template: clean-agent-build-directory-step.yml - - - stage: Test_${{ parameters.stage_name }} - variables: - - name: isMain - value: ${{ or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')) }} - - name: finalStorage - ${{ if eq(variables['isMain'], 'true') }}: - value: '--final_storage' - ${{ else }}: - value: '' - - name: buildConfig - ${{ if eq(parameters['debug_build'], 'true') }}: - value: 'Debug' - ${{ else }}: - value: 'Release' - - name: PythonVersion - value: ${{ parameters.python_version }} - - name: Repository - value: onnxruntimetraininggpubuild_cu${{ replace(parameters.cuda_version, '.', '') }}_py${{ replace(parameters.python_version, '.', '') }} - - name: UploadWheel - value: ${{ parameters.upload_wheel }} - dependsOn: Build_${{ parameters.stage_name }} - jobs: - - job: Test_GPU - pool: Onnxruntime-Linux-GPU - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: none - - - template: jobs/download_training_test_data.yml - - - template: set-python-manylinux-variables-step.yml - - - template: flex-downloadPipelineArtifact.yml - parameters: - ArtifactName: "onnxruntime_gpu_${{ variables['buildConfig'] }}_${{ parameters.python_version }}" - StepName: 'Download Pipeline Artifact - Linux Training Build' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - script: | - set -e -x - whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \ - echo $whlfilename ; du -sh $whlfilename ; \ - (( $(wc -c < "$whlfilename") - 400*1024*1024 < 0 )) || ( echo 'Wheel size bigger than 400M'; exit 1) - displayName: 'Check wheel size' - continueOnError: true - - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/${{ parameters.docker_file }} - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg TORCH_VERSION=${{ parameters.torch_version }} - --build-arg OPSET_VERSION=${{ parameters.opset_version }} - --build-arg PYTHON_VERSION=${{ parameters.python_version }} - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu - --build-arg BUILD_UID=$(id -u) - Repository: $(Repository) - - - task: CmdLine@2 - displayName: 'test ortmodule' - inputs: - script: | - set -ex ; \ - whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \ - echo $whlfilename ; \ - basefilename=$(basename $whlfilename) ; \ - docker run --rm \ - --gpus all \ - -e NVIDIA_VISIBLE_DEVICES=all \ - --volume $(Build.ArtifactStagingDirectory):/build \ - --volume $(Agent.TempDirectory)/MNIST:/mnist \ - $(Repository) \ - bash -c " $(PythonManylinuxDir)/bin/python3 -m pip install /build/Release/dist/$basefilename && $(PythonManylinuxDir)/bin/python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install " ; - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 'Upload wheel' - condition: and(succeeded(), and(eq(variables['UploadWheel'], 'yes'), ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true'))) - inputs: - script: | - set -e -x - whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \ - python3 tools/ci_build/upload_python_package_to_azure_storage.py \ - --python_wheel_path $whlfilename ${{ variables['finalStorage'] }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index f47108a2a48cd..e07f0afa6109c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string @@ -59,6 +59,11 @@ jobs: addToPath: true architecture: 'arm64' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: onebranch.pipeline.tsaoptions@1 displayName: 'OneBranch TSAOptions' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 5839ee273c1fe..8cc647c2464f3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string @@ -50,6 +50,11 @@ jobs: addToPath: true architecture: 'x64' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: onebranch.pipeline.tsaoptions@1 displayName: 'OneBranch TSAOptions' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 9e01f4116b602..466fee92d0d5e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string @@ -50,6 +50,11 @@ jobs: addToPath: true architecture: 'x64' + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: onebranch.pipeline.tsaoptions@1 displayName: 'OneBranch TSAOptions' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 7ec84453321ef..aa0b6bf6d391e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.27.0.240926' + QnnSdk: '2.28.2.241116' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index d8ea1c35c89c4..29c5f6bb34d7a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -261,8 +261,6 @@ stages: publishJUnitResults: true testResultsFiles: '**/TEST-*.xml' testRunTitle: 'React Native Android Instrumented Test results' - javaHomeOption: 'path' - jdkDirectory: '$(JAVA_HOME_11_X64)' sonarQubeRunAnalysis: false spotBugsAnalysis: false displayName: Run React Native Android Instrumented Tests diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index a3b6bc1025267..5d7ea5e7b2727 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -70,8 +70,6 @@ stages: parameters: xcodeVersion: $(xcodeVersion) - - template: ../install-appcenter.yml - - script: | pip install -r tools/ci_build/github/apple/ios_packaging/requirements.txt displayName: "Install Python requirements" @@ -100,6 +98,8 @@ stages: --prepare_test_project_only displayName: "Assemble test project for App Center" + # Xcode tasks require absolute paths because it searches for the paths and files relative to + # the root directory and not relative to the working directory - task: Xcode@5 inputs: actions: 'build-for-testing' @@ -107,8 +107,6 @@ stages: xcWorkspacePath: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/apple_package_test.xcworkspace' sdk: 'iphoneos' scheme: 'ios_package_test' - xcodeVersion: 'specifyPath' - xcodeDeveloperDir: '/Applications/Xcode_${{ variables.xcodeVersion }}.app/Contents/Developer' signingOption: 'manual' signingIdentity: '$(APPLE_CERTIFICATE_SIGNING_IDENTITY)' provisioningProfileUuid: '$(APPLE_PROV_PROFILE_UUID)' @@ -117,16 +115,69 @@ stages: useXcpretty: false # xcpretty can hide useful error output so we will disable it displayName: 'Build App Center iPhone arm64 tests' + - script: | + zip -r --symlinks $(Build.ArtifactStagingDirectory)/package_tests.zip ios_package_testUITests-Runner.app + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData/Build/Products/Debug-iphoneos' + displayName: "Create .zip file of the tests" + + - script: | + python $(Build.SourcesDirectory)/onnxruntime/test/platform/apple/generate_ipa_export_options_plist.py \ + --dest_file "exportOptions.plist" \ + --apple_team_id $(APPLE_TEAM_ID) \ + --provisioning_profile_uuid $(APPLE_PROV_PROFILE_UUID) + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + displayName: "Generate .plist file for the .ipa file" + + # Task only generates an .xcarchive file if the plist export options are included, but does + # not produce an IPA file. + # Source code: https://github.com/microsoft/azure-pipelines-tasks/blob/master/Tasks/XcodeV5/xcode.ts + - task: Xcode@5 + inputs: + actions: 'archive' + xcWorkspacePath: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/apple_package_test.xcworkspace' + packageApp: true + archivePath: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + exportOptions: 'plist' + exportOptionsPlist: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/exportOptions.plist' + configuration: 'Debug' + sdk: 'iphoneos' + scheme: 'ios_package_test' + args: '-derivedDataPath $(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData' + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + useXcpretty: false + displayName: 'Create archive for the .ipa file' + + # Use script step because exporting the .ipa file using the Xcode@5 task was too brittle (Xcode@5 is designed + # to handle both the .xcarchive step and the .ipa step in the same step -- ran into countless issues with signing + # and the .plist file) + - script: | + xcodebuild -exportArchive \ + -archivePath ios_package_test.xcarchive \ + -exportOptionsPlist exportOptions.plist \ + -exportPath $(Build.ArtifactStagingDirectory)/test_ipa + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' + displayName: "Create .ipa file" + + # Publish the BrowserStack artifacts first so that if the next step fails, the artifacts will still be published + # so that users can attempt to locally debug + - publish: "$(Build.ArtifactStagingDirectory)" + artifact: "browserstack_test_artifacts_${{ lower(parameters.packageVariant) }}" + displayName: "Publish BrowserStack test artifacts" + - script: | set -e -x - appcenter test run xcuitest \ - --app "AI-Frameworks/ORT-Mobile-iOS" \ - --devices $(app_center_test_devices) \ - --test-series "master" \ - --locale "en_US" \ - --build-dir $(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData/Build/Products/Debug-iphoneos \ - --token $(app_center_api_token) - displayName: "Run E2E tests on App Center" + pip install requests + python $(Build.SourcesDirectory)/tools/python/upload_and_run_browserstack_tests.py \ + --test_platform xcuitest \ + --app_path "$(Build.ArtifactStagingDirectory)/test_ipa/ios_package_test.ipa" \ + --test_path "$(Build.ArtifactStagingDirectory)/package_tests.zip" \ + --devices "iPhone 15-17" + displayName: Run E2E tests using Browserstack + workingDirectory: $(Build.BinariesDirectory)/app_center_test/apple_package_test + timeoutInMinutes: 15 + env: + BROWSERSTACK_ID: $(browserstack_username) + BROWSERSTACK_TOKEN: $(browserstack_access_key) - script: | set -e -x diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 6a0c34daa9bb9..e046997b4f49a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -25,7 +25,7 @@ parameters: - name: runTests type: boolean - default: true + default: false - name: buildJava type: boolean @@ -71,6 +71,10 @@ parameters: - 11.8 - 12.2 +- name: ComboTests + type: boolean + default: false + - name: SpecificArtifact displayName: Use Specific Artifact type: boolean @@ -118,20 +122,28 @@ stages: clean: true submodules: none + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: ${{ parameters.buildArch }} + - template: telemetry-steps.yml + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - ${{ if eq(parameters['buildJava'], 'true') }}: - task: JavaToolInstaller@0 inputs: - versionSpec: "11" + versionSpec: "17" jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: ${{ parameters.buildArch }} - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) @@ -214,7 +226,7 @@ stages: condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }}' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --test --skip_submodule_sync --build_shared_lib --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }}' workingDirectory: '$(Build.BinariesDirectory)' - ${{ else }}: - powershell: | @@ -326,6 +338,10 @@ stages: displayName: 'Clean Agent Directories' condition: always() + - script: + echo ${{ parameters.SpecificArtifact }} + displayName: 'Print Specific Artifact' + - checkout: self clean: true submodules: none @@ -354,7 +370,7 @@ stages: - ${{ if eq(parameters['buildJava'], 'true') }}: - task: JavaToolInstaller@0 inputs: - versionSpec: "11" + versionSpec: "17" jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' @@ -364,6 +380,13 @@ stages: addToPath: true architecture: ${{ parameters.buildArch }} + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: @@ -384,13 +407,35 @@ stages: displayName: 'Append dotnet x86 Directory to PATH' condition: and(succeeded(), eq('${{ parameters.buildArch}}', 'x86')) - - task: PythonScript@0 - displayName: 'test' - condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' - workingDirectory: '$(Build.BinariesDirectory)' + - ${{ if eq(parameters.ComboTests, 'true') }}: + - task: PythonScript@0 + displayName: 'test excludes CUDA' + condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + workingDirectory: '$(Build.BinariesDirectory)' + env: + NO_CUDA_TEST: '1' + GTEST_FILTER: '-CudaEp*:CudaNhwcTypedTest*' # Exclude CUDA EP tests under providers/cuda/ + - task: PythonScript@0 + displayName: 'test excludes DML' + condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + workingDirectory: '$(Build.BinariesDirectory)' + env: + NO_DML_TEST: '1' + - ${{ else }}: + - task: PythonScript@0 + displayName: 'test' + condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + workingDirectory: '$(Build.BinariesDirectory)' + # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - ${{ if eq(parameters.buildJava, 'true') }}: - template: make_java_win_binaries.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml index 9d47ae65a37d0..fb3ebdc760a7b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml @@ -11,6 +11,7 @@ parameters: jobs: - job: Windows_Packaging_${{ parameters.BuildArch }}_${{ parameters.Runtime }} + timeoutInMinutes: 180 templateContext: outputs: - output: pipelineArtifact @@ -32,7 +33,7 @@ jobs: - task: PipAuthenticate@1 displayName: 'Pip Authenticate' inputs: - artifactFeeds: 'PublicPackages/ORT-Nightly' + artifactFeeds: 'Lotus' - template: telemetry-steps.yml @@ -87,10 +88,8 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | - curl -O -L https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-windows-x86_64.zip - 7z x cmake-3.28.3-windows-x86_64.zip python --version - python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" --cmake_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\ctest.exe + python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml index 47ece37e66e09..67fd47c3150af 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-cuda-ci-pipeline.yml @@ -62,4 +62,28 @@ stages: RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} ORT_EP_NAME: CUDA WITH_CACHE: true - MachinePool: onnxruntime-Win2022-GPU-A10 \ No newline at end of file + MachinePool: onnxruntime-Win2022-GPU-A10 + +- stage: cuda_dml + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env_cuda.bat + buildArch: x64 + additionalBuildFlags: >- + --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" + --enable_cuda_profiling --enable_transformers_tool_test + --use_dml + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON + --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + ORT_EP_NAME: CUDA + EnablePython: false + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-GPU-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml index 94b0aa680d54d..911d99cd2adf3 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-dml-ci-pipeline.yml @@ -43,11 +43,11 @@ stages: BuildConfig: 'RelWithDebInfo' EnvSetupScript: setup_env.bat buildArch: x64 - additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml + additionalBuildFlags: --enable_pybind --use_dml --enable_wcos --use_winml msbuildPlatform: x64 isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} ORT_EP_NAME: DML WITH_CACHE: false - MachinePool: onnxruntime-Win2022-GPU-dml-A10 \ No newline at end of file + MachinePool: onnxruntime-Win2022-GPU-dml-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index 7c04d6aa2e739..f4ab9ee5b4a5c 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -37,11 +37,12 @@ parameters: - 12.2 variables: + - template: templates/common-variables.yml - name: win_trt_folder ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8 + value: ${{ variables.win_trt_folder_cuda11 }} ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6 + value: ${{ variables.win_trt_folder_cuda12 }} jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml index c4db7735aaf2f..06f374afca57a 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml @@ -41,10 +41,11 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat + EnvSetupScript: setup_env.bat buildArch: x64 - # add --enable_pybind and --build_java if necessary + # add --build_java if necessary additionalBuildFlags: >- + --enable_pybind --build_nodejs --use_webgpu --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON @@ -56,3 +57,52 @@ stages: EnablePython: false WITH_CACHE: true MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 + +- stage: webgpu_external_dawn + dependsOn: [] + jobs: + - job: build_x64_RelWithDebInfo + variables: + DEPS_CACHE_DIR: $(Agent.TempDirectory)/deps_ccache + ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all + pool: onnxruntime-Win2022-VS2022-webgpu-A10 + timeoutInMinutes: 300 + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/jobs/win-ci-prebuild-steps.yml + parameters: + EnvSetupScript: setup_env.bat + DownloadCUDA: false + DownloadTRT: false + BuildArch: x64 + BuildConfig: RelWithDebInfo + MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 + WithCache: true + Today: $(Today) + + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: true + Today: $(TODAY) + CacheDir: $(ORT_CACHE_DIR) + AdditionalKey: " $(System.StageName) | RelWithDebInfo " + BuildPyArguments: '--config RelWithDebInfo --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --update --parallel --cmake_generator "Visual Studio 17 2022" --use_webgpu --use_external_dawn --skip_tests --target onnxruntime_webgpu_external_dawn_test' + MsbuildArguments: '-maxcpucount' + BuildArch: x64 + Platform: x64 + BuildConfig: RelWithDebInfo + + - script: | + onnxruntime_webgpu_external_dawn_test.exe + displayName: Run tests (onnxruntime_webgpu_external_dawn_test) + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + - script: | + onnxruntime_webgpu_external_dawn_test.exe --no_proc_table + displayName: Run tests (onnxruntime_webgpu_external_dawn_test) + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 8f971612dbc6d..5c013fae6be0b 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index f55f476f70d30..53700c58c7e7d 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.27.0.240926 + default: 2.28.2.241116 jobs: - job: 'build' diff --git a/tools/ci_build/github/linux/build_cuda_ci.sh b/tools/ci_build/github/linux/build_cuda_ci.sh index 6b155da02030b..0533b7b394492 100755 --- a/tools/ci_build/github/linux/build_cuda_ci.sh +++ b/tools/ci_build/github/linux/build_cuda_ci.sh @@ -3,28 +3,31 @@ set -ex #Every cuda container has this $CUDA_VERSION env var set. SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') -BUILD_ARGS=('--config' 'Release' '--update' '--build' - '--skip_submodule_sync' - '--build_shared_lib' - '--parallel' '--use_binskim_compliant_compile_flags' - '--build_wheel' - '--enable_onnx_tests' - '--use_cuda' - "--cuda_version=$SHORT_CUDA_VERSION" - "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" - "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" - "--enable_cuda_profiling" - "--enable_cuda_nhwc_ops" - "--enable_pybind" - "--build_java" - "--cmake_extra_defines" - "CMAKE_CUDA_ARCHITECTURES=75" - "onnxruntime_BUILD_UNIT_TESTS=ON" - "onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON") +BUILD_ARGS=('--config' + 'Release' + '--update' + '--build' + '--skip_submodule_sync' + '--build_shared_lib' + '--parallel' + '--use_binskim_compliant_compile_flags' + '--build_wheel' + '--enable_onnx_tests' + '--use_cuda' + "--cuda_version=$SHORT_CUDA_VERSION" + "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" + "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" + "--enable_cuda_profiling" + "--enable_pybind" + "--build_java" + "--cmake_extra_defines" + "CMAKE_CUDA_ARCHITECTURES=75" + "onnxruntime_BUILD_UNIT_TESTS=ON" + "onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON") if [ -x "$(command -v ninja)" ]; then BUILD_ARGS+=('--cmake_generator' 'Ninja') fi - + if [ -d /build ]; then BUILD_ARGS+=('--build_dir' '/build') else @@ -33,14 +36,14 @@ fi if [ -x "$(command -v ccache)" ]; then ccache -s; - BUILD_ARGS+=("--use_cache") + #BUILD_ARGS+=("--use_cache") fi if [ -f /opt/python/cp312-cp312/bin/python3 ]; then /opt/python/cp312-cp312/bin/python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" else python3 tools/ci_build/build.py "${BUILD_ARGS[@]}" fi -if [ -x "$(command -v ccache)" ]; then - ccache -sv +if [ -x "$(command -v ccache)" ]; then + ccache -sv ccache -z fi diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh index 22095f3f898b6..e2e0cea69efb5 100755 --- a/tools/ci_build/github/linux/build_linux_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -7,15 +7,29 @@ mkdir -p /build/dist EXTRA_ARG="" ENABLE_CACHE=false -# Put 3.10 at the last because Ubuntu 22.04 use python 3.10 and we will upload the intermediate build files of this +# Put 3.10 at the last because Ubuntu 22.04 use python 3.10 and we will upload the intermediate build files of this # config to Azure DevOps Artifacts and download them to a Ubuntu 22.04 machine to run the tests. -PYTHON_EXES=("/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12" "/opt/python/cp313-cp313/bin/python3.13" "/opt/python/cp313-cp313t/bin/python3.13t" "/opt/python/cp310-cp310/bin/python3.10") +PYTHON_EXES=( + "/opt/python/cp311-cp311/bin/python3.11" + "/opt/python/cp312-cp312/bin/python3.12" + "/opt/python/cp313-cp313/bin/python3.13" + "/opt/python/cp313-cp313t/bin/python3.13t" + "/opt/python/cp310-cp310/bin/python3.10" + ) while getopts "d:p:x:c:e" parameter_Option do case "${parameter_Option}" in #GPU|CPU|NPU. d) BUILD_DEVICE=${OPTARG};; -p) PYTHON_EXES=${OPTARG};; +p) + # Check if OPTARG is empty or starts with a hyphen, indicating a missing or invalid argument for -p + if [[ -z "${OPTARG}" || "${OPTARG}" == -* ]]; then + echo "ERROR: Option -p requires a valid argument, not another option." + exit 1 + else + PYTHON_EXES=("${OPTARG}") # Use the provided argument for -p + fi + ;; x) EXTRA_ARG=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; e) ENABLE_CACHE=true;; @@ -89,9 +103,11 @@ export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" for PYTHON_EXE in "${PYTHON_EXES[@]}" do rm -rf /build/"$BUILD_CONFIG" - ${PYTHON_EXE} -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt - ${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}" - + # that's a workaround for the issue that there's no python3 in the docker image + # like xnnpack's cmakefile, it uses pythone3 to run a external command + python3_dir=$(dirname "$PYTHON_EXE") + ${PYTHON_EXE} -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt + PATH=$python3_dir:$PATH ${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}" cp /build/"$BUILD_CONFIG"/dist/*.whl /build/dist done diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index 3ff213b16f3d1..d2d3aa1675c2e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,6 +1,6 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241020.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241120.3 -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 07885ba65af8a..c42042b0ec639 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -32,7 +32,7 @@ else \ echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ fi -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts @@ -42,5 +42,5 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH -ENV CUDA_MODULE_LOADING "LAZY" \ No newline at end of file +ENV PATH=/usr/local/dotnet:$PATH +ENV CUDA_MODULE_LOADING="LAZY" \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index e4c3af05053ba..9a265b4249f0b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -201,5 +201,5 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH +ENV PATH=/usr/local/dotnet:$PATH ENV ORTMODULE_ONNX_OPSET_VERSION=$OPSET_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 index 8ef8e05b8ac77..9de88d1664b82 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -9,19 +9,19 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG TRT_VERSION=8.6.1.6-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache # Install python3 RUN dnf install -y \ - python3.8 \ - python38-pip \ - python38-wheel &&\ + python3.10 \ + python310-pip \ + python310-wheel &&\ cd /usr/local/bin &&\ - ln -s /usr/bin/python3 python3.8 &&\ - ln -s /usr/bin/pip3 pip3.8; + ln -s /usr/bin/python3 python3.10 &&\ + ln -s /usr/bin/pip3 pip3.10; RUN pip3 install --upgrade pip RUN pip3 install setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 index c1a445e29fc89..c2bae5fd7ee59 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 @@ -6,10 +6,10 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 -ARG TRT_VERSION=10.4.0.26-1.cuda12.6 +ARG TRT_VERSION=10.6.0.26-1.cuda12.6 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/opt/python/cp310-cp310/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch index a228ebed165eb..2ecc6d1918b1a 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch @@ -6,10 +6,10 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG TRT_VERSION=10.4.0.26-1.cuda11.8 +ARG TRT_VERSION=10.6.0.26-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/opt/python/cp310-cp310/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index 6a4244b7aad0d..81aeada6a4a46 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -6,11 +6,11 @@ # Build base image with required system packages ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.4.0.26-1+cuda11.8 +ARG TRT_VERSION=10.6.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg similarity index 93% rename from tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg rename to tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg index 418c551ab38b4..4298dd53e4c66 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_ffmpeg @@ -5,12 +5,12 @@ # Dockerfile to run ONNXRuntime with TensorRT integration # Build base image with required system packages -ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG TRT_VERSION=10.4.0.26-1+cuda11.8 +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +ARG TRT_VERSION=10.6.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv new file mode 100644 index 0000000000000..1312475ceca3a --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2204_gpu_opencv @@ -0,0 +1,64 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with TensorRT integration + +# Build base image with required system packages +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +ARG TRT_VERSION=10.6.0.26-1+cuda11.8 +ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 +FROM $BASEIMAGE AS base +ARG TRT_VERSION +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV DEBIAN_FRONTEND=noninteractive + +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} + +RUN apt-get update &&\ + apt-get install -y git bash wget diffutils + +RUN DEBIAN_FRONTEND="noninteractive" apt-get install --yes python3-opencv + +# Install python3 +RUN apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + python3-wheel + +RUN pip install --upgrade pip + +# Install TensorRT +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ + apt-get update &&\ + apt-get install -y \ + libnvinfer-dev=${TRT_VERSION} \ + libnvinfer-dispatch-dev=${TRT_VERSION} \ + libnvinfer-dispatch10=${TRT_VERSION} \ + libnvinfer-headers-dev=${TRT_VERSION} \ + libnvinfer-headers-plugin-dev=${TRT_VERSION} \ + libnvinfer-lean-dev=${TRT_VERSION} \ + libnvinfer-lean10=${TRT_VERSION} \ + libnvinfer-plugin-dev=${TRT_VERSION} \ + libnvinfer-plugin10=${TRT_VERSION} \ + libnvinfer-vc-plugin-dev=${TRT_VERSION} \ + libnvinfer-vc-plugin10=${TRT_VERSION} \ + libnvinfer10=${TRT_VERSION} \ + libnvonnxparsers-dev=${TRT_VERSION} \ + libnvonnxparsers10=${TRT_VERSION} \ + tensorrt-dev=${TRT_VERSION} \ + libnvinfer-bin=${TRT_VERSION} &&\ + if [ $(echo $CUDA_VERSION | cut -d"." -f1) -ge 12 ]; then apt-get install -y cudnn9-cuda-12 ; fi +# ^^^^^^^^^^^If cuda version is 12 or higher, install cudnn 9 for cuda 12 + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index dfc057b129f91..3b4d36a9a8fd8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -82,7 +82,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 index a7d8f220ea9b3..22d5e3b0248a8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -31,26 +31,26 @@ RUN pip install --upgrade pip RUN pip install psutil setuptools>=68.2.2 # Install TensorRT -RUN version="10.4.0.26-1+cuda11.8" &&\ +RUN TRT_VERSION="10.6.0.26-1+cuda11.8" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ - libnvinfer-dev=${version} \ - libnvinfer-dispatch-dev=${version} \ - libnvinfer-dispatch10=${version} \ - libnvinfer-headers-dev=${version} \ - libnvinfer-headers-plugin-dev=${version} \ - libnvinfer-lean-dev=${version} \ - libnvinfer-lean10=${version} \ - libnvinfer-plugin-dev=${version} \ - libnvinfer-plugin10=${version} \ - libnvinfer-vc-plugin-dev=${version} \ - libnvinfer-vc-plugin10=${version} \ - libnvinfer10=${version} \ - libnvonnxparsers-dev=${version} \ - libnvonnxparsers10=${version} \ - tensorrt-dev=${version} \ - libnvinfer-bin=${version} + libnvinfer-dev=${TRT_VERSION} \ + libnvinfer-dispatch-dev=${TRT_VERSION} \ + libnvinfer-dispatch10=${TRT_VERSION} \ + libnvinfer-headers-dev=${TRT_VERSION} \ + libnvinfer-headers-plugin-dev=${TRT_VERSION} \ + libnvinfer-lean-dev=${TRT_VERSION} \ + libnvinfer-lean10=${TRT_VERSION} \ + libnvinfer-plugin-dev=${TRT_VERSION} \ + libnvinfer-plugin10=${TRT_VERSION} \ + libnvinfer-vc-plugin-dev=${TRT_VERSION} \ + libnvinfer-vc-plugin10=${TRT_VERSION} \ + libnvinfer10=${TRT_VERSION} \ + libnvonnxparsers-dev=${TRT_VERSION} \ + libnvonnxparsers10=${TRT_VERSION} \ + tensorrt-dev=${TRT_VERSION} \ + libnvinfer-bin=${TRT_VERSION} # Compile trtexec if not installed RUN if [ ! -d /usr/src/tensorrt/bin ] || [ ! -f /usr/src/tensorrt/bin/trtexec ]; then \ @@ -98,7 +98,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 index f63112039fe8e..6d35df72894d8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -85,7 +85,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 index 523318f09aba6..819d9bab7be75 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -31,26 +31,26 @@ RUN pip install --upgrade pip RUN pip install setuptools>=68.2.2 psutil # Install TensorRT -RUN version="10.4.0.26-1+cuda12.6" &&\ +RUN TRT_VERSION="10.6.0.26-1+cuda12.6" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ apt-get install -y \ - libnvinfer-dev=${version} \ - libnvinfer-dispatch-dev=${version} \ - libnvinfer-dispatch10=${version} \ - libnvinfer-headers-dev=${version} \ - libnvinfer-headers-plugin-dev=${version} \ - libnvinfer-lean-dev=${version} \ - libnvinfer-lean10=${version} \ - libnvinfer-plugin-dev=${version} \ - libnvinfer-plugin10=${version} \ - libnvinfer-vc-plugin-dev=${version} \ - libnvinfer-vc-plugin10=${version} \ - libnvinfer10=${version} \ - libnvonnxparsers-dev=${version} \ - libnvonnxparsers10=${version} \ - tensorrt-dev=${version} \ - libnvinfer-bin=${version} + libnvinfer-dev=${TRT_VERSION} \ + libnvinfer-dispatch-dev=${TRT_VERSION} \ + libnvinfer-dispatch10=${TRT_VERSION} \ + libnvinfer-headers-dev=${TRT_VERSION} \ + libnvinfer-headers-plugin-dev=${TRT_VERSION} \ + libnvinfer-lean-dev=${TRT_VERSION} \ + libnvinfer-lean10=${TRT_VERSION} \ + libnvinfer-plugin-dev=${TRT_VERSION} \ + libnvinfer-plugin10=${TRT_VERSION} \ + libnvinfer-vc-plugin-dev=${TRT_VERSION} \ + libnvinfer-vc-plugin10=${TRT_VERSION} \ + libnvinfer10=${TRT_VERSION} \ + libnvonnxparsers-dev=${TRT_VERSION} \ + libnvonnxparsers10=${TRT_VERSION} \ + tensorrt-dev=${TRT_VERSION} \ + libnvinfer-bin=${TRT_VERSION} # Compile trtexec if not installed RUN if [ ! -d /usr/src/tensorrt/bin ] || [ ! -f /usr/src/tensorrt/bin/trtexec ]; then \ @@ -98,7 +98,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino index 5f525c1310412..643c0d66d01f5 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino @@ -1,7 +1,7 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:${UBUNTU_VERSION} -ARG OPENVINO_VERSION=2024.3.0 +ARG OPENVINO_VERSION=2024.5.0 ARG PYTHON_VERSION=3.10 ADD scripts /tmp/scripts @@ -12,16 +12,16 @@ RUN /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d EdgeDevice RUN apt update && apt install -y libnuma1 ocl-icd-libopencl1 && \ rm -rf /var/lib/apt/lists/* /tmp/scripts -ENV INTEL_OPENVINO_DIR /opt/intel/openvino_${OPENVINO_VERSION} -ENV LD_LIBRARY_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH -ENV OpenVINO_DIR $INTEL_OPENVINO_DIR/runtime/cmake -ENV IE_PLUGINS_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64 +ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_${OPENVINO_VERSION} +ENV LD_LIBRARY_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH +ENV OpenVINO_DIR=$INTEL_OPENVINO_DIR/runtime/cmake +ENV IE_PLUGINS_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64 ENV DEBIAN_FRONTEND=noninteractive RUN cd /opt && mkdir -p intel && cd intel && \ - wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.3/linux/l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64.tgz && \ - tar xzf l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64.tgz && \ - mv l_openvino_toolkit_ubuntu22_2024.3.0.16041.1e3b88e4e3f_x86_64 openvino_2024.3.0 && \ + wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.5/linux/l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && \ + tar xzf l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && \ + mv l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64 openvino_2024.5.0 && \ cd $INTEL_OPENVINO_DIR/install_dependencies && ./install_openvino_dependencies.sh -y WORKDIR /root diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index e8d8dc0a64feb..4f58dc89333ba 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -44,7 +44,7 @@ COPY ${TRT_BINS_DIR}/TensorRT-${TAR_TRT_VERSION}.Linux.x86_64-gnu.cuda-${TAR_CUD RUN tar -xzvf /TensorRT-${TAR_TRT_VERSION}.tar.gz RUN cd /TensorRT-${TAR_TRT_VERSION}/python &&\ - python3 -m pip install tensorrt*cp38*.whl + python3 -m pip install tensorrt*cp310*.whl RUN cp -r /TensorRT-${TAR_TRT_VERSION}/lib/* /usr/lib/x86_64-linux-gnu/ RUN cp /TensorRT-${TAR_TRT_VERSION}/include/* /usr/local/include/ @@ -92,7 +92,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index ca00050121d67..246ef09f7be25 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,15 +2,14 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -ARG BASEIMAGE=arm64v8/almalinux:8 -FROM $BASEIMAGE +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12_dotnet:20241120.3 ENV PATH=/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh deleted file mode 100755 index 596a5ce436c57..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -set -e -x - -os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) - -echo "installing for CentOS version : $os_major_version" -dnf install -y python3.12-pip python3.12-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh index bf08a853fe7f4..70bb373efb23f 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh @@ -39,9 +39,6 @@ mkdir -p /tmp/src cd /tmp/src CPU_ARCH=$(uname -m) -echo "Installing cmake" -GetFile "https://github.com/Kitware/CMake/releases/download/v3.31.0-rc2/cmake-3.31.0-rc2-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" -tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 3f42b28497c7a..43dd3badef387 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12:20241020.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc12:20241120.3 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh index c81e57c60c9da..d0b58ed28b8c9 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh @@ -7,8 +7,6 @@ echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget -# export PATH=/opt/python/cp38-cp38/bin:$PATH - echo "installing rapidjson for AzureEP" wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz tar zxvf v1.1.0.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index ef28dde67617f..fffe92d2583a2 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,15 +2,14 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -ARG BASEIMAGE=amd64/almalinux:8 -FROM $BASEIMAGE +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12_dotnet:20241120.3 -ENV PATH=/usr/lib/jvm/msopenjdk-11/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 + ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_centos.sh deleted file mode 100755 index 03534d8a2f447..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_centos.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -e -x -if [ ! -f /etc/yum.repos.d/microsoft-prod.repo ]; then - os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) - echo "installing for CentOS version : $os_major_version" - rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm -fi -dnf install -y python3.12-pip python3.12-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran gcc-toolset-12-libasan-devel libasan.x86_64 -locale diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh index 0cc48a720b8f4..be906bf21a4fb 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh @@ -38,9 +38,6 @@ mkdir -p /tmp/src cd /tmp/src CPU_ARCH=$(uname -m) -echo "Installing cmake" -GetFile "https://github.com/Kitware/CMake/releases/download/v3.31.0-rc2/cmake-3.31.0-rc2-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" -tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile index 6702474d75801..d386db7ab7bd8 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile @@ -2,10 +2,9 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20241020.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20241120.3 ARG TRT_VERSION -RUN rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm && dnf install -y msopenjdk-11 #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "$TRT_VERSION" ]; then \ echo "TRT_VERSION is $TRT_VERSION" && \ @@ -31,11 +30,11 @@ else \ echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ fi -ENV PATH /usr/lib/jvm/msopenjdk-11/bin:$PATH +ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:$PATH ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -ENV CUDAHOSTCXX /opt/rh/gcc-toolset-11/root/usr/bin/g++ +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 +ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-11/root/usr/bin/g++ ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 4059de23b2480..ba6f28be4636c 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20241020.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20241120.3 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty @@ -35,12 +35,12 @@ fi ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV CUDAHOSTCXX /opt/rh/gcc-toolset-12/root/usr/bin/g++ +ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-12/root/usr/bin/g++ ADD scripts /tmp/scripts RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/ubi.repo && \ - rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm && dnf install -y msopenjdk-11 && cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts -ENV PATH /usr/lib/jvm/msopenjdk-11/bin:$PATH -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 + cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:$PATH +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index 76b31e71a7dea..857fc445ef74a 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241020.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc12:20241120.3 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh index c81e57c60c9da..d0b58ed28b8c9 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh @@ -7,8 +7,6 @@ echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget -# export PATH=/opt/python/cp38-cp38/bin:$PATH - echo "installing rapidjson for AzureEP" wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz tar zxvf v1.1.0.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile index 54bdbebbd1319..a69b98f86ba1b 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile @@ -5,7 +5,7 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 FROM $BASEIMAGE -ARG TRT_VERSION=10.4.0.26-1.cuda11.8 +ARG TRT_VERSION=10.6.0.26-1.cuda11.8 #Install TensorRT only if TRT_VERSION is not empty RUN if [ -n "${TRT_VERSION}" ]; then \ @@ -32,8 +32,8 @@ else \ echo "TRT_VERSION is x${TRT_VERSION} skipping Tensor RT Installation" ; \ fi -ENV PATH /usr/local/cuda/bin:$PATH -ENV CUDA_MODULE_LOADING "LAZY" +ENV PATH=/usr/local/cuda/bin:$PATH +ENV CUDA_MODULE_LOADING="LAZY" ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh index c81e57c60c9da..d0b58ed28b8c9 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh @@ -7,8 +7,6 @@ echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget -# export PATH=/opt/python/cp38-cp38/bin:$PATH - echo "installing rapidjson for AzureEP" wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz tar zxvf v1.1.0.tar.gz diff --git a/tools/ci_build/github/linux/docker/scripts/install_java.sh b/tools/ci_build/github/linux/docker/scripts/install_java.sh index d11e29f693b8b..f4ea49963f115 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_java.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_java.sh @@ -5,7 +5,7 @@ if [ -f /etc/redhat-release ]; then dnf install -y java-11-openjdk-devel \ && dnf clean dbcache elif [ -f /etc/os-release ]; then - apt-get update && apt-get install -y openjdk-11-jdk + apt-get update && apt-get install -y openjdk-17-jdk else echo "Unsupported OS" exit 1 diff --git a/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh index 7f3160371aa24..87b9b960b7ebc 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_os_deps.sh @@ -12,7 +12,6 @@ d) DEVICE_TYPE=${OPTARG};; v) echo "Cuda version is no longer accepted as an input to this script. Ignoring the input argument -v.";; t) echo "Installing python training dependencies argument is no longer accepted as an input to this script. Ignoring the input argument -t.";; m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; -u) echo "Installing ortmodule python dependencies argument is no longer accepted as an input to this script. Ignoring the input argument -u.";; r) echo "Installing ROCM python dependencies argument is no longer accepted as an input to this script. Ignoring the input argument -r.";; esac done diff --git a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh index 1ac1d226deec6..2d7acd1f701ff 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh @@ -3,7 +3,6 @@ set -e -x INSTALL_DEPS_TRAINING=false INSTALL_DEPS_DISTRIBUTED_SETUP=false -ORTMODULE_BUILD=false TARGET_ROCM=false CU_VER="11.8" TORCH_VERSION='2.0.0' @@ -18,7 +17,6 @@ d) DEVICE_TYPE=${OPTARG};; v) CU_VER=${OPTARG};; t) INSTALL_DEPS_TRAINING=true;; m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; -u) ORTMODULE_BUILD=true;; r) TARGET_ROCM=true;; c) USE_CONDA=true;; esac @@ -55,17 +53,3 @@ fi export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps\.sh/requirements\.txt} -if [ $DEVICE_TYPE = "gpu" ]; then - if [[ $INSTALL_DEPS_TRAINING = true ]]; then - if [[ $ORTMODULE_BUILD = false ]]; then - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/requirements.txt} - else - if [[ $TARGET_ROCM = false ]]; then - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/ortmodule\/stage1\/requirements_torch${TORCH_VERSION}_cu${CU_VER}\/requirements.txt} - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/ortmodule\/stage2\/requirements.txt} - else - ${PYTHON_EXE} -m pip install -r ${0/%install_python_deps.sh/training\/ortmodule\/stage1\/requirements_rocm\/requirements.txt} - fi - fi - fi -fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index a980963429034..4bc609fc0badb 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -5,6 +5,7 @@ do case "${parameter_Option}" in p) PYTHON_VER=${OPTARG};; d) DEVICE_TYPE=${OPTARG};; +*) echo "Usage: $0 -p PYTHON_VER -d DEVICE_TYPE";; esac done @@ -20,54 +21,65 @@ apt-get update && apt-get install -y software-properties-common lsb-release OS_VERSION=$(lsb_release -r -s) -PACKAGE_LIST="autotools-dev \ - automake \ - build-essential \ - git apt-transport-https apt-utils \ - ca-certificates \ - pkg-config \ - wget \ - zlib1g \ - zlib1g-dev \ - libssl-dev \ - curl libcurl4-openssl-dev \ - autoconf \ - sudo \ - gfortran \ - python3-dev \ - language-pack-en \ - liblttng-ust-dev \ - libcurl4 \ - libkrb5-3 \ - libtinfo-dev \ - libtinfo5 \ - libtool \ - openssh-server \ - aria2 \ - bzip2 \ - unzip \ - zip \ - rsync libunwind8 libpng-dev libexpat1-dev \ - python3-setuptools python3-numpy python3-wheel python3-pip python3-pytest python3-distutils \ - openjdk-11-jdk \ - graphviz" - - -if [ $DEVICE_TYPE = "Normal" ]; then - PACKAGE_LIST="$PACKAGE_LIST libedit-dev libxml2-dev python3-packaging" +PACKAGE_LIST=( + "apt-transport-https" + "apt-utils" + "aria2" + "autoconf" + "automake" + "autotools-dev" + "build-essential" + "bzip2" + "ca-certificates" + "curl" + "gfortran" + "git" + "graphviz" + "language-pack-en" + "libcurl4" + "libcurl4-openssl-dev" + "libexpat1-dev" + "libkrb5-3" + "liblttng-ust-dev" + "libpng-dev" + "libssl-dev" + "libtinfo-dev" + "libtinfo5" + "libtool" + "libunwind8" + "openjdk-17-jdk" + "openssh-server" + "pkg-config" + "python3-dev" + "python3-distutils" + "python3-numpy" + "python3-pip" + "python3-pytest" + "python3-setuptools" + "python3-wheel" + "rsync" + "sudo" + "unzip" + "wget" + "zip" + "zlib1g" + "zlib1g-dev" +) +if [ "$DEVICE_TYPE" = "Normal" ]; then + PACKAGE_LIST+=("libedit-dev" "libxml2-dev" "python3-packaging") fi -PACKAGE_LIST="$PACKAGE_LIST libicu-dev" +PACKAGE_LIST+=("libicu-dev") -apt-get install -y --no-install-recommends $PACKAGE_LIST +apt-get install -y --no-install-recommends "${PACKAGE_LIST[@]}" locale-gen en_US.UTF-8 update-locale LANG=en_US.UTF-8 if [ "$OS_VERSION" = "20.04" ]; then # The defaul version of python is 3.8 - major=$(echo $PYTHON_VER | cut -d. -f1) - minor=$(echo $PYTHON_VER | cut -d. -f2) + major=$(echo "$PYTHON_VER" | cut -d. -f1) + minor=$(echo "$PYTHON_VER" | cut -d. -f2) if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 8 ]; then PYTHON_VER="3.8" fi @@ -75,19 +87,19 @@ if [ "$OS_VERSION" = "20.04" ]; then add-apt-repository -y ppa:deadsnakes/ppa apt-get update apt-get install -y --no-install-recommends \ - python${PYTHON_VER} \ - python${PYTHON_VER}-dev - update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VER} 1 + python"${PYTHON_VER}" \ + python"${PYTHON_VER}-"dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python"${PYTHON_VER}" 1 update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 2 - update-alternatives --set python3 /usr/bin/python${PYTHON_VER} + update-alternatives --set python3 /usr/bin/python"${PYTHON_VER}" #TODO: the old one(/usr/bin/pip3) should be uninstalled first. Because the one will be #put at /usr/local/. Then there will be two pips. - /usr/bin/python${PYTHON_VER} -m pip install --upgrade --force-reinstall pip==19.0.3 + /usr/bin/python"${PYTHON_VER}" -m pip install --upgrade --force-reinstall pip==19.0.3 fi elif [ "$OS_VERSION" = "22.04" ] ; then # The defaul version of python is 3.10 - major=$(echo $PYTHON_VER | cut -d. -f1) - minor=$(echo $PYTHON_VER | cut -d. -f2) + major=$(echo "$PYTHON_VER" | cut -d. -f1) + minor=$(echo "$PYTHON_VER" | cut -d. -f2) if [ "$major" -lt 3 ] || [ "$major" -eq 3 ] && [ "$minor" -lt 10 ]; then PYTHON_VER="3.10" fi @@ -95,11 +107,11 @@ elif [ "$OS_VERSION" = "22.04" ] ; then add-apt-repository -y ppa:deadsnakes/ppa apt-get update apt-get install -y --no-install-recommends \ - python${PYTHON_VER} \ - python${PYTHON_VER}-dev - update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VER} 1 + python"${PYTHON_VER}" \ + python"${PYTHON_VER}"-dev + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python"${PYTHON_VER}" 1 update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2 - update-alternatives --set python3 /usr/bin/python${PYTHON_VER} + update-alternatives --set python3 /usr/bin/python"${PYTHON_VER}" fi else exit 1 diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh index dfda5ec73fdbe..a487bf7f91507 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh @@ -8,9 +8,6 @@ if [ "$os_major_version" -gt 7 ]; then PACKAGE_MANAGER="dnf" $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget fi -if [ ! -f /etc/yum.repos.d/microsoft-prod.repo ]; then - rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm -fi -# Install Java + # Install automatic documentation generation dependencies -$PACKAGE_MANAGER install -y msopenjdk-11 graphviz +$PACKAGE_MANAGER install -y graphviz diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh index 2f69435dc316e..69b0ea1321235 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_ubuntuos.sh @@ -12,4 +12,4 @@ apt-get install -y gdb build-essential tar unzip make aria2 bzip2 # Install Java # Install automatic documentation generation dependencies apt-get update -apt-get install -y openjdk-11-jdk graphviz +apt-get install -y openjdk-17-jdk graphviz diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index 12db3bd132bb7..2d714e3058da4 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -6,7 +6,8 @@ setuptools>=68.2.2 wheel onnx==1.16.1 protobuf==4.21.12 -sympy==1.12 +sympy==1.12 ; python_version < '3.9' +sympy==1.13 ; python_version >= '3.9' flatbuffers neural-compressor>=2.2.1 triton diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt deleted file mode 100644 index 051f42dac335d..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_rocm/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy==1.21.6 ; python_version < '3.9' -numpy==2.1.2 ; python_version >= '3.9' diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt deleted file mode 100644 index b3b2651c8d26d..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ ---pre --f https://download.pytorch.org/whl/torch_stable.html -torch==2.0.0+cu118 -torchvision==0.15.1+cu118 -torchtext==0.15.1 -# TODO(bmeswani): packaging 22.0 removes support for LegacyVersion leading to errors because transformers 4.4.2 uses LegacyVersion -packaging==21.3 -setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt deleted file mode 100644 index 152a17db90366..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ ---pre --f https://download.pytorch.org/whl/torch_stable.html -torch==2.1.0+cu121 -torchvision==0.16.0+cu121 -torchtext==0.16.0 -packaging==23.1 -setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt deleted file mode 100644 index 846f8c15b257d..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ --f https://download.pytorch.org/whl/torch_stable.html -torch==2.3.0+cpu -setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt deleted file mode 100644 index 01fa7b0ff956e..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -scikit-learn -packaging==21.3 -transformers==v4.36.0 -accelerate==0.25.0 -wget diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt deleted file mode 100644 index 6346c54decf9c..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ ---pre --f https://download.pytorch.org/whl/torch_stable.html -torch==2.2.0 -setuptools>=68.2.2 -cerberus -h5py -scikit-learn -numpy==1.21.6 ; python_version < '3.9' -numpy==2.1.2 ; python_version >= '3.9' -pandas -parameterized diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt deleted file mode 100644 index dd86b32f88c76..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -pandas -scikit-learn -numpy==1.21.6 ; python_version < '3.9' -numpy==2.1.2 ; python_version >= '3.9' -transformers==v4.36.0 -accelerate==0.25.0 -rsa==4.9 -tensorboard==2.13.0 -h5py -wget -pytorch-lightning==2.3.3 -deepspeed==0.9.0 -fairscale==0.4.6 -parameterized>=0.8.1 -pydantic<2.0.0 diff --git a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh index 640028ee7678c..aef9793f696b6 100755 --- a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh +++ b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh @@ -5,7 +5,7 @@ set -e set -x -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH BUILD_DIR=${1:?"usage: $0 "} @@ -26,7 +26,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \ --build_wheel \ --skip_tests \ --enable_training_ops \ - --enable_pybind --cmake_extra_defines PYTHON_INCLUDE_DIR=/opt/python/cp38-cp38/include/python3.8 PYTHON_LIBRARY=/usr/lib64/librt.so \ + --enable_pybind --cmake_extra_defines PYTHON_INCLUDE_DIR=/opt/python/cp310-cp310/include/python3.10 PYTHON_LIBRARY=/usr/lib64/librt.so \ --use_nnapi \ --use_coreml diff --git a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh b/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh index 58d493086ece9..c857d3f1036bc 100755 --- a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh +++ b/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh @@ -7,7 +7,7 @@ set -e set -x -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH USAGE_TEXT="Usage: -b|--build-directory Specifies the build directory. Required. diff --git a/tools/ci_build/github/linux/run_dockerbuild.sh b/tools/ci_build/github/linux/run_dockerbuild.sh index 9944861f519f4..6618810c77f6d 100755 --- a/tools/ci_build/github/linux/run_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_dockerbuild.sh @@ -15,10 +15,6 @@ BUILD_DIR=$BUILD_BINARIESDIRECTORY YOCTO_VERSION="4.19" #Training only INSTALL_DEPS_DISTRIBUTED_SETUP=false -#Training only -ORTMODULE_BUILD=false -#Training only -USE_CONDA=false ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV="ALLOW_RELEASED_ONNX_OPSET_ONLY="$ALLOW_RELEASED_ONNX_OPSET_ONLY echo "ALLOW_RELEASED_ONNX_OPSET_ONLY environment variable is set as $ALLOW_RELEASED_ONNX_OPSET_ONLY_ENV" @@ -44,10 +40,6 @@ t) EXTRA_IMAGE_TAG=${OPTARG};; i) IMAGE_CACHE_CONTAINER_REGISTRY_NAME=${OPTARG};; # install distributed setup dependencies m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; -# install ortmodule specific dependencies -u) ORTMODULE_BUILD=true;; -# install and use conda -e) USE_CONDA=true;; *) echo "Invalid option";; esac done @@ -82,24 +74,6 @@ if [ $BUILD_OS = "yocto" ]; then $GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \ --docker-build-args="--build-arg TOOL_CHAIN=$TOOL_CHAIN_SCRIPT --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER}" \ --dockerfile $DOCKER_FILE --context . -elif [ $BUILD_DEVICE = "gpu" ]; then - # This code path is only for training. Inferecing pipeline uses CentOS - IMAGE="$BUILD_OS-gpu_training" - # Current build script doesn't support building shared lib with Python dependency. To enable building with PythonOp, - # We need to avoid `--no-undefined` when building shared lib (Otherwise, CIs will report `undefined symbols`), but removing that would bring some other concerns. - # Plus the fact training did not need build shared library, we disable the --build_shared_lib for training CIs. - NEED_BUILD_SHARED_LIB=false - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -t" - if [[ $INSTALL_DEPS_DISTRIBUTED_SETUP = true ]]; then - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -m" - fi - if [[ $ORTMODULE_BUILD = true ]]; then - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -u" - fi - INSTALL_DEPS_EXTRA_ARGS="${INSTALL_DEPS_EXTRA_ARGS} -v 11.8" - $GET_DOCKER_IMAGE_CMD --repository "onnxruntime-$IMAGE" \ - --docker-build-args="--build-arg BASEIMAGE=nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-${BUILD_OS} --build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg INSTALL_DEPS_EXTRA_ARGS=\"${INSTALL_DEPS_EXTRA_ARGS}\" --build-arg USE_CONDA=${USE_CONDA} --network=host" \ - --dockerfile Dockerfile.ubuntu_gpu_training --context . elif [[ $BUILD_DEVICE = "openvino"* ]]; then BUILD_ARGS="--build-arg BUILD_USER=onnxruntimedev --build-arg BUILD_UID=$(id -u) --build-arg PYTHON_VERSION=${PYTHON_VER} --build-arg OPENVINO_VERSION=${OPENVINO_VERSION} --build-arg UBUNTU_VERSION=${UBUNTU_VERSION}" IMAGE="$BUILD_OS-openvino" diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index eb3a0132f8aba..2fec98e569919 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -2,14 +2,15 @@ set -e -x BUILD_CONFIG="Release" -while getopts "i:d:x:c:" parameter_Option +while getopts "i:d:x:c:p:" parameter_Option do case "${parameter_Option}" in i) DOCKER_IMAGE=${OPTARG};; d) DEVICE=${OPTARG};; x) BUILD_EXTR_PAR=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; -*) echo "Usage: $0 -i -d [-x ] [-c ]" +p) PYTHON_EXES=${OPTARG};; +*) echo "Usage: $0 -i -d [-x ] [-c ] [-p ]" exit 1;; esac done @@ -17,6 +18,10 @@ done mkdir -p "${HOME}/.onnx" DOCKER_SCRIPT_OPTIONS="-d ${DEVICE} -c ${BUILD_CONFIG}" +if [ "${PYTHON_EXES}" != "" ] ; then + DOCKER_SCRIPT_OPTIONS+=" -p ${PYTHON_EXES}" +fi + if [ "${BUILD_EXTR_PAR}" != "" ] ; then DOCKER_SCRIPT_OPTIONS+=" -x ${BUILD_EXTR_PAR}" fi diff --git a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh index 9cd1222cabfa6..835f83e2b8bed 100755 --- a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh +++ b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh @@ -5,7 +5,7 @@ pip3 install --user --upgrade pip pip3 install --user numpy torch pytest pip3 install --user /build/Release/dist/*.whl -export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.8/site-packages:$PYTHONPATH +export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.10/site-packages:$PYTHONPATH python3 -m pytest -v /onnxruntime_src/tools/test/test_custom_ops_pytorch_exporter.py || exit 1 diff --git a/tools/ci_build/github/pai/pai_huggingface_bert_large_test.sh b/tools/ci_build/github/pai/pai_huggingface_bert_large_test.sh deleted file mode 100755 index fb4dbeb2e73d3..0000000000000 --- a/tools/ci_build/github/pai/pai_huggingface_bert_large_test.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -set -ex - -usage() { echo "Usage: $0 [-v ]" 1>&2; exit 1; } - -while getopts "v:" parameter_Option -do case "${parameter_Option}" -in -v) ROCM_VERSION=${OPTARG};; -*) usage ;; -esac -done - -MI200_DEVICE_NUMBERS=$(rocm-smi --showproductname | grep -c "MI250" | xargs) - -if [ "$MI200_DEVICE_NUMBERS" -gt "0" ]; then - RESULT_FILE=ci-mi200.huggingface.bert-large-rocm${ROCM_VERSION}.json -else - RESULT_FILE=ci-mi100.huggingface.bert-large-rocm${ROCM_VERSION}.json -fi - -python \ - /stage/huggingface-transformers/examples/pytorch/language-modeling/run_mlm.py \ - --model_name_or_path bert-large-uncased \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --do_train \ - --max_steps 260 \ - --logging_steps 20 \ - --output_dir ./test-mlm-bbu \ - --overwrite_output_dir \ - --per_device_train_batch_size 8 \ - --fp16 \ - --dataloader_num_workers 1 \ - --ort \ - --skip_memory_metrics - -cat ci-pipeline-actual.json - -python /onnxruntime_src/orttraining/tools/ci_test/compare_huggingface.py \ - ci-pipeline-actual.json \ - /onnxruntime_src/orttraining/tools/ci_test/results/"$RESULT_FILE" diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index 6a660ecaa40d2..34ddd75da16fc 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -6,10 +6,10 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( ) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64;%PATH% ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.6.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% @REM The default version is still cuda v12.2, because set cuda v11.8 after it -set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-11.8\lib +set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.6.0.26.Windows10.x86_64.cuda-11.8\lib if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 ) else ( diff --git a/tools/ci_build/github/windows/setup_env_trt.bat b/tools/ci_build/github/windows/setup_env_trt.bat index 4f2272e306570..03734293be5c4 100644 --- a/tools/ci_build/github/windows/setup_env_trt.bat +++ b/tools/ci_build/github/windows/setup_env_trt.bat @@ -6,6 +6,6 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( ) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 ) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.4.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% +set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.6.0.26.Windows10.x86_64.cuda-12.6\lib;%PATH% set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index 32c5ce7dd08d1..14aeff3df9c62 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -1,11 +1,12 @@ -# packages used by transformers python unittest (only enabled in Linux CPU CI Pipeline) +# packages used by transformers python unittest packaging -protobuf==3.20.2 -numpy==1.24.0 ; python_version < '3.12' -numpy==1.26.0 ; python_version >= '3.12' +# protobuf and numpy is same as tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +protobuf==4.21.12 +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' torch coloredlogs==15.0 -transformers==4.38.0 +transformers==4.46.3 parameterized>=0.8.1 psutil einops diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 0e9cd514d8aa5..b46d1e2559e46 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -30,14 +30,9 @@ "mac-ios-ci-pipeline.yml", "mac-ios-packaging-pipeline.yml", "mac-react-native-ci-pipeline.yml", - "orttraining-linux-ci-pipeline.yml", - "orttraining-linux-gpu-ci-pipeline.yml", - "orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml", - "orttraining-mac-ci-pipeline.yml", "win-ci-pipeline.yml", "win-gpu-dml-ci-pipeline.yml", "win-gpu-cuda-ci-pipeline.yml", - "win-gpu-training-ci-pipeline.yml", "win-gpu-doc-gen-ci-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", "win-gpu-webgpu-ci-pipeline.yml", diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 683d7b6be2aa8..11842f34ce45b 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -138,7 +138,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "tvm", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +182,8 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Intel.ML.OnnxRuntime" in package_name: + description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package description = ( "This package contains native shared library artifacts for all supported platforms of ONNX Runtime." @@ -225,7 +227,7 @@ def add_common_dependencies(xml_text, package_name, version): def generate_dependencies(xml_text, package_name, version): - dml_dependency = '' + dml_dependency = '' if package_name == "Microsoft.AI.MachineLearning": xml_text.append("") @@ -375,13 +377,11 @@ def generate_files(line_list, args): "mklml": "mklml.dll", "openmp": "libiomp5md.dll", "dnnl": "dnnl.dll", - "tvm": "tvm.dll", "providers_shared_lib": "onnxruntime_providers_shared.dll", "dnnl_ep_shared_lib": "onnxruntime_providers_dnnl.dll", "tensorrt_ep_shared_lib": "onnxruntime_providers_tensorrt.dll", "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", - "tvm_ep_shared_lib": "onnxruntime_providers_tvm.lib", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -394,7 +394,6 @@ def generate_files(line_list, args): "mklml_1": "libmklml_gnu.so", "openmp": "libiomp5.so", "dnnl": "libdnnl.so.1", - "tvm": "libtvm.so.0.5.1", "providers_shared_lib": "libonnxruntime_providers_shared.so", "dnnl_ep_shared_lib": "libonnxruntime_providers_dnnl.so", "tensorrt_ep_shared_lib": "libonnxruntime_providers_tensorrt.so", @@ -456,14 +455,6 @@ def generate_files(line_list, args): + '" target="build\\native\\include" />' ) - if args.execution_provider == "tvm": - files_list.append( - "' - ) - if args.execution_provider == "openvino": files_list.append( "' ) - if args.execution_provider == "tvm": - files_list.append( - "' - ) - files_list.append( - "' - ) - - tvm_build_path = os.path.join(args.ort_build_path, args.build_config, "_deps", "tvm-build") - if is_windows(): - files_list.append( - "' - ) - else: - # TODO(agladyshev): Add support for Linux. - raise RuntimeError("Now only Windows is supported for TVM EP.") - if args.execution_provider == "rocm" or is_rocm_gpu_package and not is_ado_packaging_build: files_list.append( "' ) + if is_windows(): + dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") + tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") + for dll_element in os.listdir(dll_list_path): + if dll_element.endswith("dll"): + files_list.append( + "' + ) + for tbb_element in os.listdir(tbb_list_path): + if tbb_element.endswith("dll"): + files_list.append( + "' + ) + if args.execution_provider == "cuda" or is_cuda_gpu_win_sub_package and not is_ado_packaging_build: files_list.append( "" ) - # Process tvm dependency - if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["tvm"])): - files_list.append( - "" - ) - # Some tools to be packaged in nightly debug build only, should not be released # These are copied to the runtimes folder for convenience of loading with the dlls # NOTE: nuget gives a spurious error on linux if these aren't in a separate directory to the library so diff --git a/tools/python/upload_and_run_browserstack_tests.py b/tools/python/upload_and_run_browserstack_tests.py index 8751368e1b2fc..a4da87e4fe435 100644 --- a/tools/python/upload_and_run_browserstack_tests.py +++ b/tools/python/upload_and_run_browserstack_tests.py @@ -29,13 +29,16 @@ def upload_apk_parse_json(post_url, apk_path, id, token): return response_to_json(response) -def browserstack_build_request(devices, app_url, test_suite_url, test_platform, id, token): +def browserstack_build_request(devices, app_url, test_suite_url, test_platform, id, token, project, build_tag): headers = {} json_data = { "devices": devices, "app": app_url, "testSuite": test_suite_url, + "project": project, + "buildTag": build_tag, + "deviceLogs": True, } build_response = requests.post( @@ -78,22 +81,24 @@ def build_query_loop(build_id, test_platform, id, token): "--test_platform", type=str, help="Testing platform", choices=["espresso", "xcuitest"], required=True ) parser.add_argument( - "--app_apk_path", + "--app_path", type=Path, help=( - "Path to the app APK. " - "Typically, the app APK is in " + "Path to the app file. " + "For Android, typically, the app file (the APK) is in " "{build_output_dir}/android_test/android/app/build/outputs/apk/debug/app-debug.apk" + ". For iOS, you will have to build an IPA file from the test app, which is built from the .xcarchive path" ), required=True, ) parser.add_argument( - "--test_apk_path", + "--test_path", type=Path, help=( - "Path to the test APK. " + "Path to the test suite file. " "Typically, the test APK is in " "{build_output_dir}/android_test/android/app/build/outputs/apk/androidTest/debug/app-debug-androidTest.apk" + ". For iOS, you will have to create a .zip of the tests. After manually building the tests, the tests that you need to zip will be in {{Xcode DerivedData Folder Path}}/Build/Products" ), required=True, ) @@ -102,10 +107,17 @@ def build_query_loop(build_id, test_platform, id, token): type=str, nargs="+", help="List of devices to run the tests on. For more info, " - "see https://www.browserstack.com/docs/app-automate/espresso/specify-devices", + "see https://www.browserstack.com/docs/app-automate/espresso/specify-devices (Android) or https://www.browserstack.com/docs/app-automate/xcuitest/specify-devices (iOS)", required=True, ) + parser.add_argument( + "--project", + type=str, + help="Identifier to logically group multiple builds together", + default="ONNXRuntime tests", + ) + parser.add_argument("--build_tag", type=str, help="Identifier to tag the build with a unique name", default="") args = parser.parse_args() try: @@ -121,13 +133,13 @@ def build_query_loop(build_id, test_platform, id, token): # Upload the app and test suites upload_app_json = upload_apk_parse_json( f"https://api-cloud.browserstack.com/app-automate/{args.test_platform}/v2/app", - args.app_apk_path, + args.app_path, browserstack_id, browserstack_token, ) upload_test_json = upload_apk_parse_json( f"https://api-cloud.browserstack.com/app-automate/{args.test_platform}/v2/test-suite", - args.test_apk_path, + args.test_path, browserstack_id, browserstack_token, ) @@ -140,6 +152,8 @@ def build_query_loop(build_id, test_platform, id, token): args.test_platform, browserstack_id, browserstack_token, + args.project, + args.build_tag, ) # Get build status until the tests are no longer running diff --git a/tools/python/util/android/android.py b/tools/python/util/android/android.py index dd2dcce01bf4a..24004d6be761d 100644 --- a/tools/python/util/android/android.py +++ b/tools/python/util/android/android.py @@ -4,6 +4,7 @@ import collections import contextlib import datetime +import os import signal import subprocess import time @@ -105,8 +106,15 @@ def _stop_process_with_pid(pid: int): def start_emulator( - sdk_tool_paths: SdkToolPaths, avd_name: str, extra_args: typing.Optional[typing.Sequence[str]] = None + sdk_tool_paths: SdkToolPaths, + avd_name: str, + extra_args: typing.Optional[typing.Sequence[str]] = None, + timeout_minutes: int = 20, ) -> subprocess.Popen: + if check_emulator_running_using_avd_name(avd_name=avd_name): + raise RuntimeError( + f"An emulator with avd_name{avd_name} is already running. Please close it before starting a new one." + ) with contextlib.ExitStack() as emulator_stack, contextlib.ExitStack() as waiter_stack: emulator_args = [ sdk_tool_paths.emulator, @@ -122,6 +130,7 @@ def start_emulator( "-gpu", "guest", "-delay-adb", + "-verbose", ] # For Linux CIs we must use "-no-window" otherwise you'll get @@ -155,9 +164,9 @@ def start_emulator( waiter_stack.callback(_stop_process, waiter_process) # poll subprocesses. - # allow 20 minutes for startup as some CIs are slow. TODO: Make timeout configurable if needed. + # allow 20 minutes for startup as some CIs are slow. sleep_interval_seconds = 10 - end_time = datetime.datetime.now() + datetime.timedelta(minutes=20) + end_time = datetime.datetime.now() + datetime.timedelta(minutes=timeout_minutes) while True: waiter_ret, emulator_ret = waiter_process.poll(), emulator_process.poll() @@ -205,13 +214,127 @@ def start_emulator( _log.debug(f"sys.boot_completed='{getprop_value}'. Sleeping for {sleep_interval_seconds} before retrying.") time.sleep(sleep_interval_seconds) + # Verify if the emulator is now running + if not check_emulator_running_using_avd_name(avd_name=avd_name): + raise RuntimeError("Emulator failed to start.") return emulator_process -def stop_emulator(emulator_proc_or_pid: typing.Union[subprocess.Popen, int]): +def check_emulator_running_using_avd_name(avd_name: str) -> bool: + """ + Check if an emulator is running based on the provided AVD name. + :param avd_name: Name of the Android Virtual Device (AVD) to check. + :return: True if an emulator with the given AVD name is running, False otherwise. + """ + try: + # Step 1: List running devices + result = subprocess.check_output(["adb", "devices"], text=True).strip() + _log.info(f"adb devices output:\n{result}") + running_emulators = [line.split("\t")[0] for line in result.splitlines()[1:] if "emulator" in line] + + if not running_emulators: + _log.debug("No emulators running.") + return False # No emulators running + + # Step 2: Check each running emulator's AVD name + for emulator in running_emulators: + try: + avd_info = ( + subprocess.check_output(["adb", "-s", emulator, "emu", "avd", "name"], text=True) + .strip() + .split("\n")[0] + ) + _log.debug(f"AVD name for emulator {emulator}: {avd_info}") + if avd_info == avd_name: + return True + except subprocess.SubprocessError: + _log.warning(f"Error checking AVD name for emulator: {emulator}") + continue # Skip if there's an issue querying a specific emulator + + _log.warning(f"No emulator running with AVD name: {avd_name}") + return False # No matching AVD name found + except subprocess.SubprocessError as e: + _log.warning(f"Error checking emulator status: {e}") + return False + + +def check_emulator_running_using_process(emulator_proc: subprocess.Popen) -> bool: + """Check if the emulator process is running based on a Popen instance.""" + return emulator_proc.poll() is None + + +def check_emulator_running_using_pid(emulator_pid: int) -> bool: + """Check if the emulator process is running based on PID.""" + try: + os.kill(emulator_pid, 0) # Signal 0 checks process existence + return True + except OSError: + return False + + +def stop_emulator_by_proc(emulator_proc: subprocess.Popen, timeout_seconds: int = 120): + """ + Stops the emulator process using a subprocess.Popen instance. + :param emulator_proc: The emulator process as a subprocess.Popen instance. + :param timeout_seconds: Maximum time (in seconds) to wait for the emulator to stop. + """ + if not check_emulator_running_using_process(emulator_proc): + _log.warning("The specified emulator process is not running.") + return + + _log.info("Stopping emulator using subprocess.Popen instance.") + _stop_process(emulator_proc) + + # Wait for the process to stop + interval = 5 + end_time = datetime.datetime.now() + datetime.timedelta(seconds=timeout_seconds) + + while check_emulator_running_using_process(emulator_proc): + if datetime.datetime.now() > end_time: + raise RuntimeError(f"Failed to stop the emulator within the specified timeout = {timeout_seconds} seconds.") + _log.debug("Emulator still running. Checking again in 5 seconds...") + time.sleep(interval) + + _log.info("Emulator stopped successfully.") + + +def stop_emulator_by_pid(emulator_pid: int, timeout_seconds: int = 120): + """ + Stops the emulator process using a PID. + :param emulator_pid: The emulator process PID. + :param timeout_seconds: Maximum time (in seconds) to wait for the emulator to stop. + """ + if not check_emulator_running_using_pid(emulator_pid): + _log.warning(f"No emulator process with PID {emulator_pid} is currently running.") + return + + _log.info(f"Stopping emulator with PID: {emulator_pid}") + _stop_process_with_pid(emulator_pid) + + # Wait for the process to stop + interval = 5 + end_time = datetime.datetime.now() + datetime.timedelta(seconds=timeout_seconds) + + while check_emulator_running_using_pid(emulator_pid): + if datetime.datetime.now() > end_time: + raise RuntimeError( + f"Failed to stop the emulator with PID {emulator_pid} within the specified timeout = {timeout_seconds} seconds." + ) + _log.debug("Emulator still running. Checking again in 5 seconds...") + time.sleep(interval) + + _log.info("Emulator stopped successfully.") + + +def stop_emulator(emulator_proc_or_pid: typing.Union[subprocess.Popen, int], timeout_seconds: int = 120): + """ + Stops the emulator process, checking its running status before and after stopping. + :param emulator_proc_or_pid: The emulator process (subprocess.Popen) or PID (int). + :param timeout_seconds: Maximum time (in seconds) to wait for the emulator to stop. + """ if isinstance(emulator_proc_or_pid, subprocess.Popen): - _stop_process(emulator_proc_or_pid) + stop_emulator_by_proc(emulator_proc_or_pid, timeout_seconds) elif isinstance(emulator_proc_or_pid, int): - _stop_process_with_pid(emulator_proc_or_pid) + stop_emulator_by_pid(emulator_proc_or_pid, timeout_seconds) else: raise ValueError("Expected either a PID or subprocess.Popen instance.") diff --git a/tools/scripts/python_test.sh b/tools/scripts/python_test.sh index 39d9ed432a1dc..53d350cf30611 100755 --- a/tools/scripts/python_test.sh +++ b/tools/scripts/python_test.sh @@ -7,15 +7,12 @@ export build_dir=$2 export config=$3 # it's for manylinux image -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH echo Install Python Deps cp $src_dir/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $build_dir/requirements.txt python3 -m pip install -r $build_dir/requirements.txt -mkdir -p $build_dir/requirements_torch_cpu/ -cp $src_dir/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $build_dir/requirements_torch_cpu/requirements.txt -python3 -m pip install -r $build_dir/requirements_torch_cpu/requirements.txt python3 -m pip list | grep onnx echo Install $config python package @@ -23,6 +20,5 @@ rm -rf $build_dir/$config/onnxruntime $build_dir/$config/pybind11 python3 -m pip install $build_dir/$config/dist/*.whl echo Run $config unit tests -pushd $build_dir/$config/ -python3 $src_dir/tools/ci_build/build.py --build_dir $build_dir --cmake_generator Ninja --config $config --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests --enable_transformers_tool_test --ctest_path "" -popd +cd $build_dir/$config/ +python3 $src_dir/tools/ci_build/build.py --build_dir $build_dir --cmake_generator Ninja --config $config --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests --enable_transformers_tool_test diff --git a/tools/scripts/symbolic_shape_infer_test.sh b/tools/scripts/symbolic_shape_infer_test.sh index d8d50c5e3fa91..6717c1d5a9f59 100755 --- a/tools/scripts/symbolic_shape_infer_test.sh +++ b/tools/scripts/symbolic_shape_infer_test.sh @@ -5,7 +5,7 @@ set -ex export build_dir=$1 # it's for manylinux image -export PATH=/opt/python/cp38-cp38/bin:$PATH +export PATH=/opt/python/cp310-cp310/bin:$PATH echo Run symbolic shape infer test pushd $build_dir/Release/ diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index 1763290718a8f..f1272fc1b8626 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -1,8 +1,8 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. + +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "lib/Api/pch/pch.h" - #include "HardwareCoreEnumerator.h" namespace WINMLP { @@ -88,22 +88,33 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { #if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + bool isIntelSpecifiedPlatform = false; + const int kVendorID_IntelSpecifiedPlatformIDs[3] = { + // ExtendedModel,ExtendedFamily,Family Code, and Model Number + 0xa06a, // MTL + 0xc065, // ARL-H + 0xb065 // ARL-U + }; + int regs_leaf0[4]; - int regs_leaf7[4]; + int regs_leaf1[4]; __cpuid(regs_leaf0, 0); - __cpuid(regs_leaf7, 0x7); + __cpuid(regs_leaf1, 0x1); auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]); - auto isHybrid = (regs_leaf7[3] & (1 << 15)); + for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) { + if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) { + isIntelSpecifiedPlatform = true; + } + } - if (isIntel && isHybrid) { + if (isIntel && isIntelSpecifiedPlatform) { // We want to use the number of physical cores, but exclude cores without an LLC return cores.LLCCores; } #endif - return cores.PhysicalCores; }