diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml
index 2587e0fe3..a0682abe7 100644
--- a/.github/workflows/documentation.yml
+++ b/.github/workflows/documentation.yml
@@ -10,7 +10,6 @@ jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
- repo_owner: xenova
commit_sha: ${{ github.sha }}
package: transformers.js
path_to_docs: transformers.js/docs/source
diff --git a/.github/workflows/pr-documentation.yml b/.github/workflows/pr-documentation.yml
index 5ac60b4fb..0e6415b4d 100644
--- a/.github/workflows/pr-documentation.yml
+++ b/.github/workflows/pr-documentation.yml
@@ -11,7 +11,6 @@ jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
- repo_owner: xenova
commit_sha: ${{ github.sha }}
pr_number: ${{ github.event.number }}
package: transformers.js
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 11427c54b..3b87f8b39 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -7,17 +7,20 @@ on:
pull_request:
branches:
- main
-
-env:
- TESTING_REMOTELY: true
+ types:
+ - opened
+ - reopened
+ - synchronize
+ - ready_for_review
jobs:
build:
+ if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
strategy:
matrix:
- node-version: [18.x, latest, node]
+ node-version: [18, 20, 22]
steps:
- uses: actions/checkout@v4
@@ -27,11 +30,9 @@ jobs:
node-version: ${{ matrix.node-version }}
- run: npm ci
- run: npm run build
- - run: pip install -r tests/requirements.txt
# Setup the testing environment
- - run: npm run generate-tests
- - run: git lfs install && GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/Xenova/t5-small ./models/t5-small
+ - run: git lfs install && GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/hf-internal-testing/tiny-random-T5ForConditionalGeneration ./models/hf-internal-testing/tiny-random-T5ForConditionalGeneration
# Actually run tests
- run: npm run test
diff --git a/.prettierignore b/.prettierignore
new file mode 100644
index 000000000..bd1927ab2
--- /dev/null
+++ b/.prettierignore
@@ -0,0 +1,8 @@
+# Ignore artifacts:
+.github
+dist
+docs
+examples
+scripts
+types
+*.md
diff --git a/.prettierrc b/.prettierrc
new file mode 100644
index 000000000..57d5ce89a
--- /dev/null
+++ b/.prettierrc
@@ -0,0 +1,10 @@
+{
+ "overrides": [
+ {
+ "files": ["tests/**/*.js"],
+ "options": {
+ "printWidth": 10000000
+ }
+ }
+ ]
+}
diff --git a/README.md b/README.md
index 52e449516..49776b05d 100644
--- a/README.md
+++ b/README.md
@@ -3,19 +3,29 @@
-
-
-
+
+
+
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -23,9 +33,9 @@ State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in
Transformers.js is designed to be functionally equivalent to Hugging Face's [transformers](https://github.com/huggingface/transformers) python library, meaning you can run the same pretrained models using a very similar API. These models support common tasks in different modalities, such as:
- 📝 **Natural Language Processing**: text classification, named entity recognition, question answering, language modeling, summarization, translation, multiple choice, and text generation.
- - 🖼️ **Computer Vision**: image classification, object detection, and segmentation.
- - 🗣️ **Audio**: automatic speech recognition and audio classification.
- - 🐙 **Multimodal**: zero-shot image classification.
+ - 🖼️ **Computer Vision**: image classification, object detection, segmentation, and depth estimation.
+ - 🗣️ **Audio**: automatic speech recognition, audio classification, and text-to-speech.
+ - 🐙 **Multimodal**: embeddings, zero-shot audio classification, zero-shot image classification, and zero-shot object detection.
Transformers.js uses [ONNX Runtime](https://onnxruntime.ai/) to run models in the browser. The best part about it, is that you can easily [convert](#convert-your-models-to-onnx) your pretrained PyTorch, TensorFlow, or JAX models to ONNX using [🤗 Optimum](https://github.com/huggingface/optimum#onnx--onnx-runtime).
@@ -59,7 +69,7 @@ out = pipe('I love transformers!')
```javascript
-import { pipeline } from '@xenova/transformers';
+import { pipeline } from '@huggingface/transformers';
// Allocate a pipeline for sentiment-analysis
let pipe = await pipeline('sentiment-analysis');
@@ -83,15 +93,15 @@ let pipe = await pipeline('sentiment-analysis', 'Xenova/bert-base-multilingual-u
## Installation
-To install via [NPM](https://www.npmjs.com/package/@xenova/transformers), run:
+To install via [NPM](https://www.npmjs.com/package/@huggingface/transformers), run:
```bash
-npm i @xenova/transformers
+npm i @huggingface/transformers
```
Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with:
```html
```
@@ -104,18 +114,18 @@ Want to jump straight in? Get started with one of our sample applications/templa
|-------------------|----------------------------------|-------------------------------|
| Whisper Web | Speech recognition w/ Whisper | [code](https://github.com/xenova/whisper-web), [demo](https://huggingface.co/spaces/Xenova/whisper-web) |
| Doodle Dash | Real-time sketch-recognition game | [blog](https://huggingface.co/blog/ml-web-games), [code](https://github.com/xenova/doodle-dash), [demo](https://huggingface.co/spaces/Xenova/doodle-dash) |
-| Code Playground | In-browser code completion website | [code](https://github.com/xenova/transformers.js/tree/main/examples/code-completion/), [demo](https://huggingface.co/spaces/Xenova/ai-code-playground) |
-| Semantic Image Search (client-side) | Search for images with text | [code](https://github.com/xenova/transformers.js/tree/main/examples/semantic-image-search-client/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search-client) |
-| Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](https://github.com/xenova/transformers.js/tree/main/examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) |
-| Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](https://github.com/xenova/transformers.js/tree/main/examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) |
-| React | Multilingual translation website | [code](https://github.com/xenova/transformers.js/tree/main/examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) |
-| Text to speech (client-side) | In-browser speech synthesis | [code](https://github.com/xenova/transformers.js/tree/main/examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) |
-| Browser extension | Text classification extension | [code](https://github.com/xenova/transformers.js/tree/main/examples/extension/) |
-| Electron | Text classification application | [code](https://github.com/xenova/transformers.js/tree/main/examples/electron/) |
-| Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](https://github.com/xenova/transformers.js/tree/main/examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) |
-| Next.js (server-side) | Sentiment analysis (Node.js inference) | [code](https://github.com/xenova/transformers.js/tree/main/examples/next-server/), [demo](https://huggingface.co/spaces/Xenova/next-server-example-app) |
-| Node.js | Sentiment analysis API | [code](https://github.com/xenova/transformers.js/tree/main/examples/node/) |
-| Demo site | A collection of demos | [code](https://github.com/xenova/transformers.js/tree/main/examples/demo-site/), [demo](https://xenova.github.io/transformers.js/) |
+| Code Playground | In-browser code completion website | [code](https://github.com/huggingface/transformers.js/tree/main/examples/code-completion/), [demo](https://huggingface.co/spaces/Xenova/ai-code-playground) |
+| Semantic Image Search (client-side) | Search for images with text | [code](https://github.com/huggingface/transformers.js/tree/main/examples/semantic-image-search-client/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search-client) |
+| Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](https://github.com/huggingface/transformers.js/tree/main/examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) |
+| Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](https://github.com/huggingface/transformers.js/tree/main/examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) |
+| React | Multilingual translation website | [code](https://github.com/huggingface/transformers.js/tree/main/examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) |
+| Text to speech (client-side) | In-browser speech synthesis | [code](https://github.com/huggingface/transformers.js/tree/main/examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) |
+| Browser extension | Text classification extension | [code](https://github.com/huggingface/transformers.js/tree/main/examples/extension/) |
+| Electron | Text classification application | [code](https://github.com/huggingface/transformers.js/tree/main/examples/electron/) |
+| Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](https://github.com/huggingface/transformers.js/tree/main/examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) |
+| Next.js (server-side) | Sentiment analysis (Node.js inference) | [code](https://github.com/huggingface/transformers.js/tree/main/examples/next-server/), [demo](https://huggingface.co/spaces/Xenova/next-server-example-app) |
+| Node.js | Sentiment analysis API | [code](https://github.com/huggingface/transformers.js/tree/main/examples/node/) |
+| Demo site | A collection of demos | [code](https://github.com/huggingface/transformers.js/tree/main/examples/demo-site/), [demo](https://xenova.github.io/transformers.js/) |
Check out the Transformers.js [template](https://huggingface.co/new-space?template=static-templates%2Ftransformers.js) on Hugging Face to get started in one click!
@@ -124,13 +134,12 @@ Check out the Transformers.js [template](https://huggingface.co/new-space?templa
-By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2/dist/), which should work out-of-the-box. You can customize this as follows:
-
+By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0/dist/), which should work out-of-the-box. You can customize this as follows:
### Settings
```javascript
-import { env } from '@xenova/transformers';
+import { env } from '@huggingface/transformers';
// Specify a custom location for models (defaults to '/models/').
env.localModelPath = '/path/to/models/';
@@ -146,7 +155,7 @@ For a full list of available settings, check out the [API Reference](https://hug
### Convert your models to ONNX
-We recommend using our [conversion script](https://github.com/xenova/transformers.js/blob/main/scripts/convert.py) to convert your PyTorch, TensorFlow, or JAX models to ONNX in a single command. Behind the scenes, it uses [🤗 Optimum](https://huggingface.co/docs/optimum) to perform conversion and quantization of your model.
+We recommend using our [conversion script](https://github.com/huggingface/transformers.js/blob/main/scripts/convert.py) to convert your PyTorch, TensorFlow, or JAX models to ONNX in a single command. Behind the scenes, it uses [🤗 Optimum](https://huggingface.co/docs/optimum) to perform conversion and quantization of your model.
```bash
python -m scripts.convert --quantize --model_id
@@ -176,7 +185,7 @@ For the full list of supported architectures, see the [Optimum documentation](ht
Here is the list of all tasks and architectures currently supported by Transformers.js.
If you don't see your task/model listed here or it is not yet supported, feel free
-to open up a feature request [here](https://github.com/xenova/transformers.js/issues/new/choose).
+to open up a feature request [here](https://github.com/huggingface/transformers.js/issues/new/choose).
To find compatible models on the Hub, select the "transformers.js" library tag in the filter menu (or visit [this link](https://huggingface.co/models?library=transformers.js)).
You can refine your search by selecting the task you're interested in (e.g., [text-classification](https://huggingface.co/models?pipeline_tag=text-classification&library=transformers.js)).
@@ -271,6 +280,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[CLIPSeg](https://huggingface.co/docs/transformers/model_doc/clipseg)** (from University of Göttingen) released with the paper [Image Segmentation Using Text and Image Prompts](https://arxiv.org/abs/2112.10003) by Timo Lüddecke and Alexander Ecker.
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve.
+1. **[Cohere](https://huggingface.co/docs/transformers/main/model_doc/cohere)** (from Cohere) released with the paper [Command-R: Retrieval Augmented Generation at Production Scale]( ) by Cohere.
1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan.
1. **[ConvNeXT](https://huggingface.co/docs/transformers/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie.
1. **[ConvNeXTV2](https://huggingface.co/docs/transformers/model_doc/convnextv2)** (from Facebook AI) released with the paper [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) by Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon, Saining Xie.
@@ -279,6 +289,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
1. **[Depth Anything](https://huggingface.co/docs/transformers/main/model_doc/depth_anything)** (from University of Hong Kong and TikTok) released with the paper [Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data](https://arxiv.org/abs/2401.10891) by Lihe Yang, Bingyi Kang, Zilong Huang, Xiaogang Xu, Jiashi Feng, Hengshuang Zhao.
+1. **Depth Pro** (from Apple) released with the paper [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/abs/2410.02073) by Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, Vladlen Koltun.
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1. **[DINOv2](https://huggingface.co/docs/transformers/model_doc/dinov2)** (from Meta AI) released with the paper [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) by Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy Vo, Marc Szafraniec, Vasil Khalidov, Pierre Fernandez, Daniel Haziza, Francisco Massa, Alaaeldin El-Nouby, Mahmoud Assran, Nicolas Ballas, Wojciech Galuba, Russell Howes, Po-Yao Huang, Shang-Wen Li, Ishan Misra, Michael Rabbat, Vasu Sharma, Gabriel Synnaeve, Hu Xu, Hervé Jegou, Julien Mairal, Patrick Labatut, Armand Joulin, Piotr Bojanowski.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
@@ -291,39 +302,61 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
1. **FastViT** (from Apple) released with the paper [FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189) by Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel and Anurag Ranjan.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
+1. **Florence2** (from Microsoft) released with the paper [Florence-2: Advancing a Unified Representation for a Variety of Vision Tasks](https://arxiv.org/abs/2311.06242) by Bin Xiao, Haiping Wu, Weijian Xu, Xiyang Dai, Houdong Hu, Yumao Lu, Michael Zeng, Ce Liu, Lu Yuan.
+1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (from Google) released with the paper [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by the Gemma Google team.
+1. **[Gemma2](https://huggingface.co/docs/transformers/main/model_doc/gemma2)** (from Google) released with the paper [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by the Gemma Google team.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach
1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki.
1. **[GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode)** (from BigCode) released with the paper [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988) by Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra.
+1. **[Granite](https://huggingface.co/docs/transformers/main/model_doc/granite)** (from IBM) released with the paper [Power Scheduler: A Batch Size and Token Number Agnostic Learning Rate Scheduler](https://arxiv.org/abs/2408.13359) by Yikang Shen, Matthew Stallone, Mayank Mishra, Gaoyuan Zhang, Shawn Tan, Aditya Prasad, Adriana Meza Soria, David D. Cox, Rameswar Panda.
+1. **[GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit)** (from UCSD, NVIDIA) released with the paper [GroupViT: Semantic Segmentation Emerges from Text Supervision](https://arxiv.org/abs/2202.11094) by Jiarui Xu, Shalini De Mello, Sifei Liu, Wonmin Byeon, Thomas Breuel, Jan Kautz, Xiaolong Wang.
1. **[HerBERT](https://huggingface.co/docs/transformers/model_doc/herbert)** (from Allegro.pl, AGH University of Science and Technology) released with the paper [KLEJ: Comprehensive Benchmark for Polish Language Understanding](https://www.aclweb.org/anthology/2020.acl-main.111.pdf) by Piotr Rybak, Robert Mroczkowski, Janusz Tracz, Ireneusz Gawlik.
+1. **[Hiera](https://huggingface.co/docs/transformers/model_doc/hiera)** (from Meta) released with the paper [Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/pdf/2306.00989) by Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan, Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed, Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
+1. **JAIS** (from Core42) released with the paper [Jais and Jais-chat: Arabic-Centric Foundation and Instruction-Tuned Open Generative Large Language Models](https://arxiv.org/pdf/2308.16149) by Neha Sengupta, Sunil Kumar Sahu, Bokang Jia, Satheesh Katipomu, Haonan Li, Fajri Koto, William Marshall, Gurpreet Gosal, Cynthia Liu, Zhiming Chen, Osama Mohammed Afzal, Samta Kamboj, Onkar Pandit, Rahul Pal, Lalit Pradhan, Zain Muhammad Mujahid, Massa Baali, Xudong Han, Sondos Mahmoud Bsharat, Alham Fikri Aji, Zhiqiang Shen, Zhengzhong Liu, Natalia Vassilieva, Joel Hestness, Andy Hock, Andrew Feldman, Jonathan Lee, Andrew Jackson, Hector Xuguang Ren, Preslav Nakov, Timothy Baldwin, Eric Xing.
1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.
+1. **[LLaVa](https://huggingface.co/docs/transformers/model_doc/llava)** (from Microsoft Research & University of Wisconsin-Madison) released with the paper [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
+1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov.
1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
+1. **[MusicGen](https://huggingface.co/docs/transformers/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli.
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou.
+1. **MobileCLIP** (from Apple) released with the paper [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049) by Pavan Kumar Anasosalu Vasu, Hadi Pouransari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel.
+1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (from Google Inc.) released with the paper [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) by Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam.
+1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (from Google Inc.) released with the paper [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) by Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen.
+1. **MobileNetV3** (from Google Inc.) released with the paper [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244) by Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam.
+1. **MobileNetV4** (from Google Inc.) released with the paper [MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) by Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard.
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
+1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic.
+1. **OpenELM** (from Apple) released with the paper [OpenELM: An Efficient Language Model Family with Open-source Training and Inference Framework](https://arxiv.org/abs/2404.14619) by Sachin Mehta, Mohammad Hossein Sekhavat, Qingqing Cao, Maxwell Horton, Yanzi Jin, Chenfan Sun, Iman Mirzadeh, Mahyar Najibi, Dmitry Belenko, Peter Zatloukal, Mohammad Rastegari.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee.
+1. **[Phi3](https://huggingface.co/docs/transformers/main/model_doc/phi3)** (from Microsoft) released with the paper [Phi-3 Technical Report: A Highly Capable Language Model Locally on Your Phone](https://arxiv.org/abs/2404.14219) by Marah Abdin, Sam Ade Jacobs, Ammar Ahmad Awan, Jyoti Aneja, Ahmed Awadallah, Hany Awadalla, Nguyen Bach, Amit Bahree, Arash Bakhtiari, Harkirat Behl, Alon Benhaim, Misha Bilenko, Johan Bjorck, Sébastien Bubeck, Martin Cai, Caio César Teodoro Mendes, Weizhu Chen, Vishrav Chaudhary, Parul Chopra, Allie Del Giorno, Gustavo de Rosa, Matthew Dixon, Ronen Eldan, Dan Iter, Amit Garg, Abhishek Goswami, Suriya Gunasekar, Emman Haider, Junheng Hao, Russell J. Hewett, Jamie Huynh, Mojan Javaheripi, Xin Jin, Piero Kauffmann, Nikos Karampatziakis, Dongwoo Kim, Mahoud Khademi, Lev Kurilenko, James R. Lee, Yin Tat Lee, Yuanzhi Li, Chen Liang, Weishung Liu, Eric Lin, Zeqi Lin, Piyush Madan, Arindam Mitra, Hardik Modi, Anh Nguyen, Brandon Norick, Barun Patra, Daniel Perez-Becker, Thomas Portet, Reid Pryzant, Heyang Qin, Marko Radmilac, Corby Rosset, Sambudha Roy, Olatunji Ruwase, Olli Saarikivi, Amin Saied, Adil Salim, Michael Santacroce, Shital Shah, Ning Shang, Hiteshi Sharma, Xia Song, Masahiro Tanaka, Xin Wang, Rachel Ward, Guanhua Wang, Philipp Witte, Michael Wyatt, Can Xu, Jiahang Xu, Sonali Yadav, Fan Yang, Ziyi Yang, Donghan Yu, Chengruidong Zhang, Cyril Zhang, Jianwen Zhang, Li Lyna Zhang, Yi Zhang, Yue Zhang, Yunan Zhang, Xiren Zhou.
+1. **[PVT](https://huggingface.co/docs/transformers/main/model_doc/pvt)** (from Nanjing University, The University of Hong Kong etc.) released with the paper [Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions](https://arxiv.org/pdf/2102.12122.pdf) by Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao.
+1. **PyAnnote** released in the repository [pyannote/pyannote-audio](https://github.com/pyannote/pyannote-audio) by Hervé Bredin.
1. **[Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2)** (from the Qwen team, Alibaba Group) released with the paper [Qwen Technical Report](https://arxiv.org/abs/2309.16609) by Jinze Bai, Shuai Bai, Yunfei Chu, Zeyu Cui, Kai Dang, Xiaodong Deng, Yang Fan, Wenbin Ge, Yu Han, Fei Huang, Binyuan Hui, Luo Ji, Mei Li, Junyang Lin, Runji Lin, Dayiheng Liu, Gao Liu, Chengqiang Lu, Keming Lu, Jianxin Ma, Rui Men, Xingzhang Ren, Xuancheng Ren, Chuanqi Tan, Sinan Tan, Jianhong Tu, Peng Wang, Shijie Wang, Wei Wang, Shengguang Wu, Benfeng Xu, Jin Xu, An Yang, Hao Yang, Jian Yang, Shusheng Yang, Yang Yao, Bowen Yu, Hongyi Yuan, Zheng Yuan, Jianwei Zhang, Xingxuan Zhang, Yichang Zhang, Zhenru Zhang, Chang Zhou, Jingren Zhou, Xiaohuan Zhou and Tianhang Zhu.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
+1. **[RT-DETR](https://huggingface.co/docs/transformers/model_doc/rt_detr)** (from Baidu), released together with the paper [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069) by Yian Zhao, Wenyu Lv, Shangliang Xu, Jinman Wei, Guanzhong Wang, Qingqing Dang, Yi Liu, Jie Chen.
+1. **Sapiens** (from Meta AI) released with the paper [Sapiens: Foundation for Human Vision Models](https://arxiv.org/pdf/2408.12569) by Rawal Khirodkar, Timur Bagautdinov, Julieta Martinez, Su Zhaoen, Austin James, Peter Selednik, Stuart Anderson, Shunsuke Saito.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
@@ -340,7 +373,9 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
+1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang.
+1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas.
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
diff --git a/docs/scripts/build_readme.py b/docs/scripts/build_readme.py
index 44faf1a77..611c5b3f6 100644
--- a/docs/scripts/build_readme.py
+++ b/docs/scripts/build_readme.py
@@ -5,19 +5,29 @@
-
-
-
+
+
+
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
{intro}
@@ -42,7 +52,7 @@
Here is the list of all tasks and architectures currently supported by Transformers.js.
If you don't see your task/model listed here or it is not yet supported, feel free
-to open up a feature request [here](https://github.com/xenova/transformers.js/issues/new/choose).
+to open up a feature request [here](https://github.com/huggingface/transformers.js/issues/new/choose).
To find compatible models on the Hub, select the "transformers.js" library tag in the filter menu (or visit [this link](https://huggingface.co/models?library=transformers.js)).
You can refine your search by selecting the task you're interested in (e.g., [text-classification](https://huggingface.co/models?pipeline_tag=text-classification&library=transformers.js)).
diff --git a/docs/snippets/0_introduction.snippet b/docs/snippets/0_introduction.snippet
index a0ede3821..d25a0e513 100644
--- a/docs/snippets/0_introduction.snippet
+++ b/docs/snippets/0_introduction.snippet
@@ -3,9 +3,9 @@ State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in
Transformers.js is designed to be functionally equivalent to Hugging Face's [transformers](https://github.com/huggingface/transformers) python library, meaning you can run the same pretrained models using a very similar API. These models support common tasks in different modalities, such as:
- 📝 **Natural Language Processing**: text classification, named entity recognition, question answering, language modeling, summarization, translation, multiple choice, and text generation.
- - 🖼️ **Computer Vision**: image classification, object detection, and segmentation.
- - 🗣️ **Audio**: automatic speech recognition and audio classification.
- - 🐙 **Multimodal**: zero-shot image classification.
+ - 🖼️ **Computer Vision**: image classification, object detection, segmentation, and depth estimation.
+ - 🗣️ **Audio**: automatic speech recognition, audio classification, and text-to-speech.
+ - 🐙 **Multimodal**: embeddings, zero-shot audio classification, zero-shot image classification, and zero-shot object detection.
Transformers.js uses [ONNX Runtime](https://onnxruntime.ai/) to run models in the browser. The best part about it, is that you can easily [convert](#convert-your-models-to-onnx) your pretrained PyTorch, TensorFlow, or JAX models to ONNX using [🤗 Optimum](https://github.com/huggingface/optimum#onnx--onnx-runtime).
diff --git a/docs/snippets/1_quick-tour.snippet b/docs/snippets/1_quick-tour.snippet
index dec6b341f..2e906a0f1 100644
--- a/docs/snippets/1_quick-tour.snippet
+++ b/docs/snippets/1_quick-tour.snippet
@@ -23,7 +23,7 @@ out = pipe('I love transformers!')
```javascript
-import { pipeline } from '@xenova/transformers';
+import { pipeline } from '@huggingface/transformers';
// Allocate a pipeline for sentiment-analysis
let pipe = await pipeline('sentiment-analysis');
diff --git a/docs/snippets/2_installation.snippet b/docs/snippets/2_installation.snippet
index 5f739c98f..6c8b6146e 100644
--- a/docs/snippets/2_installation.snippet
+++ b/docs/snippets/2_installation.snippet
@@ -1,12 +1,12 @@
-To install via [NPM](https://www.npmjs.com/package/@xenova/transformers), run:
+To install via [NPM](https://www.npmjs.com/package/@huggingface/transformers), run:
```bash
-npm i @xenova/transformers
+npm i @huggingface/transformers
```
Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with:
```html
```
diff --git a/docs/snippets/3_examples.snippet b/docs/snippets/3_examples.snippet
index 1ee5cc49a..f8bf7ed1c 100644
--- a/docs/snippets/3_examples.snippet
+++ b/docs/snippets/3_examples.snippet
@@ -4,17 +4,17 @@ Want to jump straight in? Get started with one of our sample applications/templa
|-------------------|----------------------------------|-------------------------------|
| Whisper Web | Speech recognition w/ Whisper | [code](https://github.com/xenova/whisper-web), [demo](https://huggingface.co/spaces/Xenova/whisper-web) |
| Doodle Dash | Real-time sketch-recognition game | [blog](https://huggingface.co/blog/ml-web-games), [code](https://github.com/xenova/doodle-dash), [demo](https://huggingface.co/spaces/Xenova/doodle-dash) |
-| Code Playground | In-browser code completion website | [code](https://github.com/xenova/transformers.js/tree/main/examples/code-completion/), [demo](https://huggingface.co/spaces/Xenova/ai-code-playground) |
-| Semantic Image Search (client-side) | Search for images with text | [code](https://github.com/xenova/transformers.js/tree/main/examples/semantic-image-search-client/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search-client) |
-| Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](https://github.com/xenova/transformers.js/tree/main/examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) |
-| Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](https://github.com/xenova/transformers.js/tree/main/examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) |
-| React | Multilingual translation website | [code](https://github.com/xenova/transformers.js/tree/main/examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) |
-| Text to speech (client-side) | In-browser speech synthesis | [code](https://github.com/xenova/transformers.js/tree/main/examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) |
-| Browser extension | Text classification extension | [code](https://github.com/xenova/transformers.js/tree/main/examples/extension/) |
-| Electron | Text classification application | [code](https://github.com/xenova/transformers.js/tree/main/examples/electron/) |
-| Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](https://github.com/xenova/transformers.js/tree/main/examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) |
-| Next.js (server-side) | Sentiment analysis (Node.js inference) | [code](https://github.com/xenova/transformers.js/tree/main/examples/next-server/), [demo](https://huggingface.co/spaces/Xenova/next-server-example-app) |
-| Node.js | Sentiment analysis API | [code](https://github.com/xenova/transformers.js/tree/main/examples/node/) |
-| Demo site | A collection of demos | [code](https://github.com/xenova/transformers.js/tree/main/examples/demo-site/), [demo](https://xenova.github.io/transformers.js/) |
+| Code Playground | In-browser code completion website | [code](https://github.com/huggingface/transformers.js/tree/main/examples/code-completion/), [demo](https://huggingface.co/spaces/Xenova/ai-code-playground) |
+| Semantic Image Search (client-side) | Search for images with text | [code](https://github.com/huggingface/transformers.js/tree/main/examples/semantic-image-search-client/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search-client) |
+| Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](https://github.com/huggingface/transformers.js/tree/main/examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) |
+| Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](https://github.com/huggingface/transformers.js/tree/main/examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) |
+| React | Multilingual translation website | [code](https://github.com/huggingface/transformers.js/tree/main/examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) |
+| Text to speech (client-side) | In-browser speech synthesis | [code](https://github.com/huggingface/transformers.js/tree/main/examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) |
+| Browser extension | Text classification extension | [code](https://github.com/huggingface/transformers.js/tree/main/examples/extension/) |
+| Electron | Text classification application | [code](https://github.com/huggingface/transformers.js/tree/main/examples/electron/) |
+| Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](https://github.com/huggingface/transformers.js/tree/main/examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) |
+| Next.js (server-side) | Sentiment analysis (Node.js inference) | [code](https://github.com/huggingface/transformers.js/tree/main/examples/next-server/), [demo](https://huggingface.co/spaces/Xenova/next-server-example-app) |
+| Node.js | Sentiment analysis API | [code](https://github.com/huggingface/transformers.js/tree/main/examples/node/) |
+| Demo site | A collection of demos | [code](https://github.com/huggingface/transformers.js/tree/main/examples/demo-site/), [demo](https://xenova.github.io/transformers.js/) |
Check out the Transformers.js [template](https://huggingface.co/new-space?template=static-templates%2Ftransformers.js) on Hugging Face to get started in one click!
diff --git a/docs/snippets/4_custom-usage.snippet b/docs/snippets/4_custom-usage.snippet
index 787c8f579..d272c7617 100644
--- a/docs/snippets/4_custom-usage.snippet
+++ b/docs/snippets/4_custom-usage.snippet
@@ -1,12 +1,11 @@
-By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2/dist/), which should work out-of-the-box. You can customize this as follows:
-
+By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0/dist/), which should work out-of-the-box. You can customize this as follows:
### Settings
```javascript
-import { env } from '@xenova/transformers';
+import { env } from '@huggingface/transformers';
// Specify a custom location for models (defaults to '/models/').
env.localModelPath = '/path/to/models/';
@@ -22,7 +21,7 @@ For a full list of available settings, check out the [API Reference](./api/env).
### Convert your models to ONNX
-We recommend using our [conversion script](https://github.com/xenova/transformers.js/blob/main/scripts/convert.py) to convert your PyTorch, TensorFlow, or JAX models to ONNX in a single command. Behind the scenes, it uses [🤗 Optimum](https://huggingface.co/docs/optimum) to perform conversion and quantization of your model.
+We recommend using our [conversion script](https://github.com/huggingface/transformers.js/blob/main/scripts/convert.py) to convert your PyTorch, TensorFlow, or JAX models to ONNX in a single command. Behind the scenes, it uses [🤗 Optimum](https://huggingface.co/docs/optimum) to perform conversion and quantization of your model.
```bash
python -m scripts.convert --quantize --model_id
diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet
index f8ad89ae0..f1bcdad44 100644
--- a/docs/snippets/6_supported-models.snippet
+++ b/docs/snippets/6_supported-models.snippet
@@ -16,6 +16,7 @@
1. **[CLIPSeg](https://huggingface.co/docs/transformers/model_doc/clipseg)** (from University of Göttingen) released with the paper [Image Segmentation Using Text and Image Prompts](https://arxiv.org/abs/2112.10003) by Timo Lüddecke and Alexander Ecker.
1. **[CodeGen](https://huggingface.co/docs/transformers/model_doc/codegen)** (from Salesforce) released with the paper [A Conversational Paradigm for Program Synthesis](https://arxiv.org/abs/2203.13474) by Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, Caiming Xiong.
1. **[CodeLlama](https://huggingface.co/docs/transformers/model_doc/llama_code)** (from MetaAI) released with the paper [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve.
+1. **[Cohere](https://huggingface.co/docs/transformers/main/model_doc/cohere)** (from Cohere) released with the paper [Command-R: Retrieval Augmented Generation at Production Scale]( ) by Cohere.
1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan.
1. **[ConvNeXT](https://huggingface.co/docs/transformers/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie.
1. **[ConvNeXTV2](https://huggingface.co/docs/transformers/model_doc/convnextv2)** (from Facebook AI) released with the paper [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) by Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon, Saining Xie.
@@ -24,6 +25,7 @@
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
1. **[Depth Anything](https://huggingface.co/docs/transformers/main/model_doc/depth_anything)** (from University of Hong Kong and TikTok) released with the paper [Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data](https://arxiv.org/abs/2401.10891) by Lihe Yang, Bingyi Kang, Zilong Huang, Xiaogang Xu, Jiashi Feng, Hengshuang Zhao.
+1. **Depth Pro** (from Apple) released with the paper [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/abs/2410.02073) by Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, Vladlen Koltun.
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
1. **[DINOv2](https://huggingface.co/docs/transformers/model_doc/dinov2)** (from Meta AI) released with the paper [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) by Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy Vo, Marc Szafraniec, Vasil Khalidov, Pierre Fernandez, Daniel Haziza, Francisco Massa, Alaaeldin El-Nouby, Mahmoud Assran, Nicolas Ballas, Wojciech Galuba, Russell Howes, Po-Yao Huang, Shang-Wen Li, Ishan Misra, Michael Rabbat, Vasu Sharma, Gabriel Synnaeve, Hu Xu, Hervé Jegou, Julien Mairal, Patrick Labatut, Armand Joulin, Piotr Bojanowski.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
@@ -36,39 +38,61 @@
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme.
1. **FastViT** (from Apple) released with the paper [FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189) by Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel and Anurag Ranjan.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
+1. **Florence2** (from Microsoft) released with the paper [Florence-2: Advancing a Unified Representation for a Variety of Vision Tasks](https://arxiv.org/abs/2311.06242) by Bin Xiao, Haiping Wu, Weijian Xu, Xiyang Dai, Houdong Hu, Yumao Lu, Michael Zeng, Ce Liu, Lu Yuan.
+1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (from Google) released with the paper [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by the Gemma Google team.
+1. **[Gemma2](https://huggingface.co/docs/transformers/main/model_doc/gemma2)** (from Google) released with the paper [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by the Gemma Google team.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
1. **[GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach
1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki.
1. **[GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode)** (from BigCode) released with the paper [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988) by Loubna Ben Allal, Raymond Li, Denis Kocetkov, Chenghao Mou, Christopher Akiki, Carlos Munoz Ferrandis, Niklas Muennighoff, Mayank Mishra, Alex Gu, Manan Dey, Logesh Kumar Umapathi, Carolyn Jane Anderson, Yangtian Zi, Joel Lamy Poirier, Hailey Schoelkopf, Sergey Troshin, Dmitry Abulkhanov, Manuel Romero, Michael Lappert, Francesco De Toni, Bernardo García del Río, Qian Liu, Shamik Bose, Urvashi Bhattacharyya, Terry Yue Zhuo, Ian Yu, Paulo Villegas, Marco Zocca, Sourab Mangrulkar, David Lansky, Huu Nguyen, Danish Contractor, Luis Villa, Jia Li, Dzmitry Bahdanau, Yacine Jernite, Sean Hughes, Daniel Fried, Arjun Guha, Harm de Vries, Leandro von Werra.
+1. **[Granite](https://huggingface.co/docs/transformers/main/model_doc/granite)** (from IBM) released with the paper [Power Scheduler: A Batch Size and Token Number Agnostic Learning Rate Scheduler](https://arxiv.org/abs/2408.13359) by Yikang Shen, Matthew Stallone, Mayank Mishra, Gaoyuan Zhang, Shawn Tan, Aditya Prasad, Adriana Meza Soria, David D. Cox, Rameswar Panda.
+1. **[GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit)** (from UCSD, NVIDIA) released with the paper [GroupViT: Semantic Segmentation Emerges from Text Supervision](https://arxiv.org/abs/2202.11094) by Jiarui Xu, Shalini De Mello, Sifei Liu, Wonmin Byeon, Thomas Breuel, Jan Kautz, Xiaolong Wang.
1. **[HerBERT](https://huggingface.co/docs/transformers/model_doc/herbert)** (from Allegro.pl, AGH University of Science and Technology) released with the paper [KLEJ: Comprehensive Benchmark for Polish Language Understanding](https://www.aclweb.org/anthology/2020.acl-main.111.pdf) by Piotr Rybak, Robert Mroczkowski, Janusz Tracz, Ireneusz Gawlik.
+1. **[Hiera](https://huggingface.co/docs/transformers/model_doc/hiera)** (from Meta) released with the paper [Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/pdf/2306.00989) by Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan, Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed, Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
+1. **JAIS** (from Core42) released with the paper [Jais and Jais-chat: Arabic-Centric Foundation and Instruction-Tuned Open Generative Large Language Models](https://arxiv.org/pdf/2308.16149) by Neha Sengupta, Sunil Kumar Sahu, Bokang Jia, Satheesh Katipomu, Haonan Li, Fajri Koto, William Marshall, Gurpreet Gosal, Cynthia Liu, Zhiming Chen, Osama Mohammed Afzal, Samta Kamboj, Onkar Pandit, Rahul Pal, Lalit Pradhan, Zain Muhammad Mujahid, Massa Baali, Xudong Han, Sondos Mahmoud Bsharat, Alham Fikri Aji, Zhiqiang Shen, Zhengzhong Liu, Natalia Vassilieva, Joel Hestness, Andy Hock, Andrew Feldman, Jonathan Lee, Andrew Jackson, Hector Xuguang Ren, Preslav Nakov, Timothy Baldwin, Eric Xing.
1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom.
+1. **[LLaVa](https://huggingface.co/docs/transformers/model_doc/llava)** (from Microsoft Research & University of Wisconsin-Madison) released with the paper [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
+1. **[MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov.
1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
+1. **[MusicGen](https://huggingface.co/docs/transformers/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli.
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou.
+1. **MobileCLIP** (from Apple) released with the paper [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049) by Pavan Kumar Anasosalu Vasu, Hadi Pouransari, Fartash Faghri, Raviteja Vemulapalli, Oncel Tuzel.
+1. **[MobileNetV1](https://huggingface.co/docs/transformers/model_doc/mobilenet_v1)** (from Google Inc.) released with the paper [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) by Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam.
+1. **[MobileNetV2](https://huggingface.co/docs/transformers/model_doc/mobilenet_v2)** (from Google Inc.) released with the paper [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) by Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen.
+1. **MobileNetV3** (from Google Inc.) released with the paper [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244) by Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam.
+1. **MobileNetV4** (from Google Inc.) released with the paper [MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) by Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard.
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
+1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MPT](https://huggingface.co/docs/transformers/model_doc/mpt)** (from MosaiML) released with the repository [llm-foundry](https://github.com/mosaicml/llm-foundry/) by the MosaicML NLP Team.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)** (from Meta AI) released with the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Lukas Blecher, Guillem Cucurull, Thomas Scialom, Robert Stojnic.
+1. **OpenELM** (from Apple) released with the paper [OpenELM: An Efficient Language Model Family with Open-source Training and Inference Framework](https://arxiv.org/abs/2404.14619) by Sachin Mehta, Mohammad Hossein Sekhavat, Qingqing Cao, Maxwell Horton, Yanzi Jin, Chenfan Sun, Iman Mirzadeh, Mahyar Najibi, Dmitry Belenko, Peter Zatloukal, Mohammad Rastegari.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[OWLv2](https://huggingface.co/docs/transformers/model_doc/owlv2)** (from Google AI) released with the paper [Scaling Open-Vocabulary Object Detection](https://arxiv.org/abs/2306.09683) by Matthias Minderer, Alexey Gritsenko, Neil Houlsby.
1. **[Phi](https://huggingface.co/docs/transformers/main/model_doc/phi)** (from Microsoft) released with the papers - [Textbooks Are All You Need](https://arxiv.org/abs/2306.11644) by Suriya Gunasekar, Yi Zhang, Jyoti Aneja, Caio César Teodoro Mendes, Allie Del Giorno, Sivakanth Gopi, Mojan Javaheripi, Piero Kauffmann, Gustavo de Rosa, Olli Saarikivi, Adil Salim, Shital Shah, Harkirat Singh Behl, Xin Wang, Sébastien Bubeck, Ronen Eldan, Adam Tauman Kalai, Yin Tat Lee and Yuanzhi Li, [Textbooks Are All You Need II: phi-1.5 technical report](https://arxiv.org/abs/2309.05463) by Yuanzhi Li, Sébastien Bubeck, Ronen Eldan, Allie Del Giorno, Suriya Gunasekar and Yin Tat Lee.
+1. **[Phi3](https://huggingface.co/docs/transformers/main/model_doc/phi3)** (from Microsoft) released with the paper [Phi-3 Technical Report: A Highly Capable Language Model Locally on Your Phone](https://arxiv.org/abs/2404.14219) by Marah Abdin, Sam Ade Jacobs, Ammar Ahmad Awan, Jyoti Aneja, Ahmed Awadallah, Hany Awadalla, Nguyen Bach, Amit Bahree, Arash Bakhtiari, Harkirat Behl, Alon Benhaim, Misha Bilenko, Johan Bjorck, Sébastien Bubeck, Martin Cai, Caio César Teodoro Mendes, Weizhu Chen, Vishrav Chaudhary, Parul Chopra, Allie Del Giorno, Gustavo de Rosa, Matthew Dixon, Ronen Eldan, Dan Iter, Amit Garg, Abhishek Goswami, Suriya Gunasekar, Emman Haider, Junheng Hao, Russell J. Hewett, Jamie Huynh, Mojan Javaheripi, Xin Jin, Piero Kauffmann, Nikos Karampatziakis, Dongwoo Kim, Mahoud Khademi, Lev Kurilenko, James R. Lee, Yin Tat Lee, Yuanzhi Li, Chen Liang, Weishung Liu, Eric Lin, Zeqi Lin, Piyush Madan, Arindam Mitra, Hardik Modi, Anh Nguyen, Brandon Norick, Barun Patra, Daniel Perez-Becker, Thomas Portet, Reid Pryzant, Heyang Qin, Marko Radmilac, Corby Rosset, Sambudha Roy, Olatunji Ruwase, Olli Saarikivi, Amin Saied, Adil Salim, Michael Santacroce, Shital Shah, Ning Shang, Hiteshi Sharma, Xia Song, Masahiro Tanaka, Xin Wang, Rachel Ward, Guanhua Wang, Philipp Witte, Michael Wyatt, Can Xu, Jiahang Xu, Sonali Yadav, Fan Yang, Ziyi Yang, Donghan Yu, Chengruidong Zhang, Cyril Zhang, Jianwen Zhang, Li Lyna Zhang, Yi Zhang, Yue Zhang, Yunan Zhang, Xiren Zhou.
+1. **[PVT](https://huggingface.co/docs/transformers/main/model_doc/pvt)** (from Nanjing University, The University of Hong Kong etc.) released with the paper [Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions](https://arxiv.org/pdf/2102.12122.pdf) by Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao.
+1. **PyAnnote** released in the repository [pyannote/pyannote-audio](https://github.com/pyannote/pyannote-audio) by Hervé Bredin.
1. **[Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2)** (from the Qwen team, Alibaba Group) released with the paper [Qwen Technical Report](https://arxiv.org/abs/2309.16609) by Jinze Bai, Shuai Bai, Yunfei Chu, Zeyu Cui, Kai Dang, Xiaodong Deng, Yang Fan, Wenbin Ge, Yu Han, Fei Huang, Binyuan Hui, Luo Ji, Mei Li, Junyang Lin, Runji Lin, Dayiheng Liu, Gao Liu, Chengqiang Lu, Keming Lu, Jianxin Ma, Rui Men, Xingzhang Ren, Xuancheng Ren, Chuanqi Tan, Sinan Tan, Jianhong Tu, Peng Wang, Shijie Wang, Wei Wang, Shengguang Wu, Benfeng Xu, Jin Xu, An Yang, Hao Yang, Jian Yang, Shusheng Yang, Yang Yao, Bowen Yu, Hongyi Yuan, Zheng Yuan, Jianwei Zhang, Xingxuan Zhang, Yichang Zhang, Zhenru Zhang, Chang Zhou, Jingren Zhou, Xiaohuan Zhou and Tianhang Zhu.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
+1. **[RT-DETR](https://huggingface.co/docs/transformers/model_doc/rt_detr)** (from Baidu), released together with the paper [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069) by Yian Zhao, Wenyu Lv, Shangliang Xu, Jinman Wei, Guanzhong Wang, Qingqing Dang, Yi Liu, Jie Chen.
+1. **Sapiens** (from Meta AI) released with the paper [Sapiens: Foundation for Human Vision Models](https://arxiv.org/pdf/2408.12569) by Rawal Khirodkar, Timur Bagautdinov, Julieta Martinez, Su Zhaoen, Austin James, Peter Selednik, Stuart Anderson, Shunsuke Saito.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
@@ -85,7 +109,9 @@
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
+1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang.
+1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas.
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2-BERT](https://huggingface.co/docs/transformers/main/model_doc/wav2vec2-bert)** (from Meta AI) released with the paper [Seamless: Multilingual Expressive and Streaming Speech Translation](https://ai.meta.com/research/publications/seamless-multilingual-expressive-and-streaming-speech-translation/) by the Seamless Communication team.
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 1fe9150f6..4458c049b 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -48,6 +48,21 @@
title: ONNX
title: Backends
isExpanded: false
+ - sections:
+ - local: api/generation/parameters
+ title: Parameters
+ - local: api/generation/configuration_utils
+ title: Configuration
+ - local: api/generation/logits_process
+ title: Logits Processors
+ - local: api/generation/logits_sampler
+ title: Logits Samplers
+ - local: api/generation/stopping_criteria
+ title: Stopping Criteria
+ - local: api/generation/streamers
+ title: Streamers
+ title: Generation
+ isExpanded: false
- sections:
- local: api/utils/core
title: Core
@@ -61,8 +76,6 @@
title: Tensor
- local: api/utils/maths
title: Maths
- - local: api/utils/generation
- title: Generation
- local: api/utils/data-structures
title: Data Structures
title: Utilities
diff --git a/docs/source/guides/node-audio-processing.md b/docs/source/guides/node-audio-processing.md
index 88d93df2d..1b9e3cfea 100644
--- a/docs/source/guides/node-audio-processing.md
+++ b/docs/source/guides/node-audio-processing.md
@@ -13,7 +13,7 @@ This tutorial will be written as an ES module, but you can easily adapt it to us
**Useful links:**
-- [Source code](https://github.com/xenova/transformers.js/tree/main/examples/node-audio-processing)
+- [Source code](https://github.com/huggingface/transformers.js/tree/main/examples/node-audio-processing)
- [Documentation](https://huggingface.co/docs/transformers.js)
@@ -26,11 +26,11 @@ This tutorial will be written as an ES module, but you can easily adapt it to us
## Getting started
-Let's start by creating a new Node.js project and installing Transformers.js via [NPM](https://www.npmjs.com/package/@xenova/transformers):
+Let's start by creating a new Node.js project and installing Transformers.js via [NPM](https://www.npmjs.com/package/@huggingface/transformers):
```bash
npm init -y
-npm i @xenova/transformers
+npm i @huggingface/transformers
```
@@ -52,7 +52,7 @@ npm i wavefile
Start by creating a new file called `index.js`, which will be the entry point for our application. Let's also import the necessary modules:
```js
-import { pipeline } from '@xenova/transformers';
+import { pipeline } from '@huggingface/transformers';
import wavefile from 'wavefile';
```
diff --git a/docs/source/guides/private.md b/docs/source/guides/private.md
index a687e1789..6715f0d4e 100644
--- a/docs/source/guides/private.md
+++ b/docs/source/guides/private.md
@@ -28,7 +28,7 @@ Transformers.js will attach an Authorization header to requests made to the Hugg
One way to do this is to call your program with the environment variable set. For example, let's say you have a file called `llama.js` with the following code:
```js
-import { AutoTokenizer } from '@xenova/transformers';
+import { AutoTokenizer } from '@huggingface/transformers';
// Load tokenizer for a gated repository.
const tokenizer = await AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf');
diff --git a/docs/source/index.md b/docs/source/index.md
index 1b94c115f..6551e303f 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -35,7 +35,7 @@ The documentation is organized into 4 sections:
Here is the list of all tasks and architectures currently supported by Transformers.js.
If you don't see your task/model listed here or it is not yet supported, feel free
-to open up a feature request [here](https://github.com/xenova/transformers.js/issues/new/choose).
+to open up a feature request [here](https://github.com/huggingface/transformers.js/issues/new/choose).
To find compatible models on the Hub, select the "transformers.js" library tag in the filter menu (or visit [this link](https://huggingface.co/models?library=transformers.js)).
You can refine your search by selecting the task you're interested in (e.g., [text-classification](https://huggingface.co/models?pipeline_tag=text-classification&library=transformers.js)).
diff --git a/docs/source/pipelines.md b/docs/source/pipelines.md
index 93f4ee216..0c1b3d584 100644
--- a/docs/source/pipelines.md
+++ b/docs/source/pipelines.md
@@ -14,7 +14,7 @@ For the full list of available tasks/pipelines, check out [this table](#availabl
Start by creating an instance of `pipeline()` and specifying a task you want to use it for. For example, to create a sentiment analysis pipeline, you can do:
```javascript
-import { pipeline } from '@xenova/transformers';
+import { pipeline } from '@huggingface/transformers';
let classifier = await pipeline('sentiment-analysis');
```
diff --git a/docs/source/tutorials/browser-extension.md b/docs/source/tutorials/browser-extension.md
index a5fd391bd..a8853c5d4 100644
--- a/docs/source/tutorials/browser-extension.md
+++ b/docs/source/tutorials/browser-extension.md
@@ -1,4 +1,4 @@
# Building a browser extension
-*Full tutorial coming soon...* In the meantime, check out the example application: https://github.com/xenova/transformers.js/tree/main/examples/extension
+*Full tutorial coming soon...* In the meantime, check out the example application: https://github.com/huggingface/transformers.js/tree/main/examples/extension
diff --git a/docs/source/tutorials/electron.md b/docs/source/tutorials/electron.md
index 6962e4b7e..5fb3650c1 100644
--- a/docs/source/tutorials/electron.md
+++ b/docs/source/tutorials/electron.md
@@ -1,3 +1,3 @@
# Building an Electron application
-*Full tutorial coming soon...* In the meantime, check out the example application: https://github.com/xenova/transformers.js/tree/main/examples/electron
+*Full tutorial coming soon...* In the meantime, check out the example application: https://github.com/huggingface/transformers.js/tree/main/examples/electron
diff --git a/docs/source/tutorials/next.md b/docs/source/tutorials/next.md
index b3bcff659..0c8c70279 100644
--- a/docs/source/tutorials/next.md
+++ b/docs/source/tutorials/next.md
@@ -9,7 +9,7 @@ The final product will look something like this:
Useful links:
- Demo site: [client-side](https://huggingface.co/spaces/Xenova/next-example-app) or [server-side](https://huggingface.co/spaces/Xenova/next-server-example-app)
-- Source code: [client-side](https://github.com/xenova/transformers.js/tree/main/examples/next-client) or [server-side](https://github.com/xenova/transformers.js/tree/main/examples/next-server)
+- Source code: [client-side](https://github.com/huggingface/transformers.js/tree/main/examples/next-client) or [server-side](https://github.com/huggingface/transformers.js/tree/main/examples/next-server)
## Prerequisites
@@ -42,11 +42,11 @@ On installation, you'll see various prompts. For this demo, we'll be selecting t
### Step 2: Install and configure Transformers.js
-You can install Transformers.js from [NPM](https://www.npmjs.com/package/@xenova/transformers) with the following command:
+You can install Transformers.js from [NPM](https://www.npmjs.com/package/@huggingface/transformers) with the following command:
```bash
-npm i @xenova/transformers
+npm i @huggingface/transformers
```
We also need to update the `next.config.js` file to ignore node-specific modules when bundling for the browser:
@@ -76,7 +76,7 @@ module.exports = nextConfig
Next, we'll create a new [Web Worker](https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Using_web_workers) script where we'll place all ML-related code. This is to ensure that the main thread is not blocked while the model is loading and performing inference. For this application, we'll be using [`Xenova/distilbert-base-uncased-finetuned-sst-2-english`](https://huggingface.co/Xenova/distilbert-base-uncased-finetuned-sst-2-english), a ~67M parameter model finetuned on the [Stanford Sentiment Treebank](https://huggingface.co/datasets/sst) dataset. Add the following code to `./src/app/worker.js`:
```js
-import { pipeline, env } from "@xenova/transformers";
+import { pipeline, env } from "@huggingface/transformers";
// Skip local model check
env.allowLocalModels = false;
@@ -264,11 +264,11 @@ On installation, you'll see various prompts. For this demo, we'll be selecting t
### Step 2: Install and configure Transformers.js
-You can install Transformers.js from [NPM](https://www.npmjs.com/package/@xenova/transformers) with the following command:
+You can install Transformers.js from [NPM](https://www.npmjs.com/package/@huggingface/transformers) with the following command:
```bash
-npm i @xenova/transformers
+npm i @huggingface/transformers
```
We also need to update the `next.config.js` file to prevent Webpack from bundling certain packages:
@@ -294,7 +294,7 @@ Next, let's set up our Route Handler. We can do this by creating two files in a
1. `pipeline.js` - to handle the construction of our pipeline.
```js
- import { pipeline } from "@xenova/transformers";
+ import { pipeline } from "@huggingface/transformers";
// Use the Singleton pattern to enable lazy construction of the pipeline.
// NOTE: We wrap the class in a function to prevent code duplication (see below).
@@ -413,7 +413,7 @@ Visit the URL shown in the terminal (e.g., [http://localhost:3000/](http://local
For this demo, we will build and deploy our application to [Hugging Face Spaces](https://huggingface.co/docs/hub/spaces). If you haven't already, you can create a free Hugging Face account [here](https://huggingface.co/join).
-1. Create a new `Dockerfile` in your project's root folder. You can use our [example Dockerfile](https://github.com/xenova/transformers.js/blob/main/examples/next-server/Dockerfile) as a template.
+1. Create a new `Dockerfile` in your project's root folder. You can use our [example Dockerfile](https://github.com/huggingface/transformers.js/blob/main/examples/next-server/Dockerfile) as a template.
2. Visit [https://huggingface.co/new-space](https://huggingface.co/new-space) and fill in the form. Remember to select "Docker" as the space type (you can choose the "Blank" Docker template).
3. Click the "Create space" button at the bottom of the page.
4. Go to "Files" → "Add file" → "Upload files". Drag the files from your project folder (excluding `node_modules` and `.next`, if present) into the upload box and click "Upload". After they have uploaded, scroll down to the button and click "Commit changes to main".
diff --git a/docs/source/tutorials/node.md b/docs/source/tutorials/node.md
index 2d9e3adf4..7cc5cc6be 100644
--- a/docs/source/tutorials/node.md
+++ b/docs/source/tutorials/node.md
@@ -19,7 +19,7 @@ Although you can always use the [Python library](https://github.com/huggingface/
**Useful links:**
-- Source code ([ESM](https://github.com/xenova/transformers.js/tree/main/examples/node/esm/app.js) or [CommonJS](https://github.com/xenova/transformers.js/tree/main/examples/node/commonjs/app.js))
+- Source code ([ESM](https://github.com/huggingface/transformers.js/tree/main/examples/node/esm/app.js) or [CommonJS](https://github.com/huggingface/transformers.js/tree/main/examples/node/commonjs/app.js))
- [Documentation](https://huggingface.co/docs/transformers.js)
@@ -31,11 +31,11 @@ Although you can always use the [Python library](https://github.com/huggingface/
## Getting started
-Let's start by creating a new Node.js project and installing Transformers.js via [NPM](https://www.npmjs.com/package/@xenova/transformers):
+Let's start by creating a new Node.js project and installing Transformers.js via [NPM](https://www.npmjs.com/package/@huggingface/transformers):
```bash
npm init -y
-npm i @xenova/transformers
+npm i @huggingface/transformers
```
Next, create a new file called `app.js`, which will be the entry point for our application. Depending on whether you're using [ECMAScript modules](#ecmascript-modules-esm) or [CommonJS](#commonjs), you will need to do some things differently (see below).
@@ -66,7 +66,7 @@ import url from 'url';
Following that, let's import Transformers.js and define the `MyClassificationPipeline` class.
```javascript
-import { pipeline, env } from '@xenova/transformers';
+import { pipeline, env } from '@huggingface/transformers';
class MyClassificationPipeline {
static task = 'text-classification';
@@ -107,7 +107,7 @@ class MyClassificationPipeline {
static async getInstance(progress_callback = null) {
if (this.instance === null) {
// Dynamically import the Transformers.js library
- let { pipeline, env } = await import('@xenova/transformers');
+ let { pipeline, env } = await import('@huggingface/transformers');
// NOTE: Uncomment this to change the cache directory
// env.cacheDir = './.cache';
@@ -195,7 +195,7 @@ Great! We've successfully created a basic HTTP server that uses Transformers.js
### Model caching
-By default, the first time you run the application, it will download the model files and cache them on your file system (in `./node_modules/@xenova/transformers/.cache/`). All subsequent requests will then use this model. You can change the location of the cache by setting `env.cacheDir`. For example, to cache the model in the `.cache` directory in the current working directory, you can add:
+By default, the first time you run the application, it will download the model files and cache them on your file system (in `./node_modules/@huggingface/transformers/.cache/`). All subsequent requests will then use this model. You can change the location of the cache by setting `env.cacheDir`. For example, to cache the model in the `.cache` directory in the current working directory, you can add:
```javascript
env.cacheDir = './.cache';
diff --git a/docs/source/tutorials/react.md b/docs/source/tutorials/react.md
index ab50d4de9..e617d8a05 100644
--- a/docs/source/tutorials/react.md
+++ b/docs/source/tutorials/react.md
@@ -7,7 +7,7 @@ In this tutorial, we'll be building a simple React application that performs mul
Useful links:
- [Demo site](https://huggingface.co/spaces/Xenova/react-translator)
-- [Source code](https://github.com/xenova/transformers.js/tree/main/examples/react-translator)
+- [Source code](https://github.com/huggingface/transformers.js/tree/main/examples/react-translator)
## Prerequisites
@@ -44,10 +44,10 @@ You can stop the development server by pressing Ctrl + C i
## Step 2: Install and configure Transformers.js
-Now we get to the fun part: adding machine learning to our application! First, install Transformers.js from [NPM](https://www.npmjs.com/package/@xenova/transformers) with the following command:
+Now we get to the fun part: adding machine learning to our application! First, install Transformers.js from [NPM](https://www.npmjs.com/package/@huggingface/transformers) with the following command:
```bash
-npm install @xenova/transformers
+npm install @huggingface/transformers
```
For this application, we will use the [Xenova/nllb-200-distilled-600M](https://huggingface.co/Xenova/nllb-200-distilled-600M) model, which can perform multilingual translation among 200 languages. Before we start, there are 2 things we need to take note of:
@@ -58,7 +58,7 @@ We can achieve both of these goals by using a [Web Worker](https://developer.moz
1. Create a file called `worker.js` in the `src` directory. This script will do all the heavy-lifing for us, including loading and running of the translation pipeline. To ensure the model is only loaded once, we will create the `MyTranslationPipeline` class which use the [singleton pattern](https://en.wikipedia.org/wiki/Singleton_pattern) to lazily create a single instance of the pipeline when `getInstance` is first called, and use this pipeline for all subsequent calls:
```javascript
- import { pipeline } from '@xenova/transformers';
+ import { pipeline } from '@huggingface/transformers';
class MyTranslationPipeline {
static task = 'translation';
@@ -127,7 +127,7 @@ We recommend starting the development server again with `npm run dev`
First, let's define our components. Create a folder called `components` in the `src` directory, and create the following files:
-1. `LanguageSelector.jsx`: This component will allow the user to select the input and output languages. Check out the full list of languages [here](https://github.com/xenova/transformers.js/blob/main/examples/react-translator/src/components/LanguageSelector.jsx).
+1. `LanguageSelector.jsx`: This component will allow the user to select the input and output languages. Check out the full list of languages [here](https://github.com/huggingface/transformers.js/blob/main/examples/react-translator/src/components/LanguageSelector.jsx).
```jsx
const LANGUAGES = {
"Acehnese (Arabic script)": "ace_Arab",
diff --git a/docs/source/tutorials/vanilla-js.md b/docs/source/tutorials/vanilla-js.md
index 7bc503006..58e336f12 100644
--- a/docs/source/tutorials/vanilla-js.md
+++ b/docs/source/tutorials/vanilla-js.md
@@ -10,7 +10,7 @@ Useful links:
- [Demo site](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector)
- [Interactive code walk-through (scrim)](https://scrimba.com/scrim/cKm9bDAg)
-- [Source code](https://github.com/xenova/transformers.js/tree/main/examples/vanilla-js)
+- [Source code](https://github.com/huggingface/transformers.js/tree/main/examples/vanilla-js)
## Step 1: HTML and CSS setup
@@ -104,7 +104,7 @@ The `type="module"` attribute is important, as it turns our file into a [JavaScr
Moving into `index.js`, let's import Transformers.js by adding the following line to the top of the file:
```js
-import { pipeline, env } from "https://cdn.jsdelivr.net/npm/@xenova/transformers@2.6.0";
+import { pipeline, env } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers";
```
Since we will be downloading the model from the Hugging Face Hub, we can skip the local model check by setting:
diff --git a/examples/code-completion/src/App.jsx b/examples/code-completion/src/App.jsx
index a532f7299..7fc84f538 100644
--- a/examples/code-completion/src/App.jsx
+++ b/examples/code-completion/src/App.jsx
@@ -162,7 +162,7 @@ function App() {
diff --git a/examples/demo-site/src/index.html b/examples/demo-site/src/index.html
index 9613acf63..49344c159 100644
--- a/examples/demo-site/src/index.html
+++ b/examples/demo-site/src/index.html
@@ -65,7 +65,7 @@ Transformers.js
diff --git a/examples/electron/README.md b/examples/electron/README.md
index 3d40e0c20..898801a12 100644
--- a/examples/electron/README.md
+++ b/examples/electron/README.md
@@ -6,7 +6,7 @@ An example project to show how to run 🤗 Transformers in an [Electron](https:/
## Getting Started
1. Clone the repo and enter the project directory:
```bash
- git clone https://github.com/xenova/transformers.js.git
+ git clone https://github.com/huggingface/transformers.js.git
cd transformers.js/examples/electron/
```
1. Install the necessary dependencies:
diff --git a/examples/extension/README.md b/examples/extension/README.md
index dfc81946f..4c4e0bceb 100644
--- a/examples/extension/README.md
+++ b/examples/extension/README.md
@@ -6,7 +6,7 @@ An example project to show how to run 🤗 Transformers in a browser extension.
## Getting Started
1. Clone the repo and enter the project directory:
```bash
- git clone https://github.com/xenova/transformers.js.git
+ git clone https://github.com/huggingface/transformers.js.git
cd transformers.js/examples/extension/
```
1. Install the necessary dependencies:
diff --git a/examples/florence2-webgpu/.eslintrc.cjs b/examples/florence2-webgpu/.eslintrc.cjs
new file mode 100644
index 000000000..3e212e1d4
--- /dev/null
+++ b/examples/florence2-webgpu/.eslintrc.cjs
@@ -0,0 +1,21 @@
+module.exports = {
+ root: true,
+ env: { browser: true, es2020: true },
+ extends: [
+ 'eslint:recommended',
+ 'plugin:react/recommended',
+ 'plugin:react/jsx-runtime',
+ 'plugin:react-hooks/recommended',
+ ],
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
+ parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
+ settings: { react: { version: '18.2' } },
+ plugins: ['react-refresh'],
+ rules: {
+ 'react/jsx-no-target-blank': 'off',
+ 'react-refresh/only-export-components': [
+ 'warn',
+ { allowConstantExport: true },
+ ],
+ },
+}
diff --git a/examples/florence2-webgpu/.gitignore b/examples/florence2-webgpu/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/florence2-webgpu/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/florence2-webgpu/README.md b/examples/florence2-webgpu/README.md
new file mode 100644
index 000000000..f768e33fc
--- /dev/null
+++ b/examples/florence2-webgpu/README.md
@@ -0,0 +1,8 @@
+# React + Vite
+
+This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
+
+Currently, two official plugins are available:
+
+- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
+- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
diff --git a/examples/florence2-webgpu/index.html b/examples/florence2-webgpu/index.html
new file mode 100644
index 000000000..77f8f0a0c
--- /dev/null
+++ b/examples/florence2-webgpu/index.html
@@ -0,0 +1,12 @@
+
+
+
+
+
+ Florence2 WebGPU
+
+
+
+
+
+
diff --git a/examples/florence2-webgpu/package.json b/examples/florence2-webgpu/package.json
new file mode 100644
index 000000000..490ad589f
--- /dev/null
+++ b/examples/florence2-webgpu/package.json
@@ -0,0 +1,30 @@
+{
+ "name": "florence2-webgpu",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@xenova/transformers": "github:xenova/transformers.js#v3",
+ "react": "^18.3.1",
+ "react-dom": "^18.3.1"
+ },
+ "devDependencies": {
+ "@types/react": "^18.3.3",
+ "@types/react-dom": "^18.3.0",
+ "@vitejs/plugin-react": "^4.3.1",
+ "autoprefixer": "^10.4.19",
+ "eslint": "^8.57.0",
+ "eslint-plugin-react": "^7.34.2",
+ "eslint-plugin-react-hooks": "^4.6.2",
+ "eslint-plugin-react-refresh": "^0.4.7",
+ "postcss": "^8.4.38",
+ "tailwindcss": "^3.4.4",
+ "vite": "^5.3.1"
+ }
+}
diff --git a/examples/florence2-webgpu/postcss.config.js b/examples/florence2-webgpu/postcss.config.js
new file mode 100644
index 000000000..2e7af2b7f
--- /dev/null
+++ b/examples/florence2-webgpu/postcss.config.js
@@ -0,0 +1,6 @@
+export default {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/examples/florence2-webgpu/src/App.jsx b/examples/florence2-webgpu/src/App.jsx
new file mode 100644
index 000000000..36ac67e0f
--- /dev/null
+++ b/examples/florence2-webgpu/src/App.jsx
@@ -0,0 +1,218 @@
+import { useEffect, useState, useRef, useCallback } from 'react';
+
+import Progress from './components/Progress';
+import ImageInput from './components/ImageInput';
+
+const IS_WEBGPU_AVAILABLE = !!navigator.gpu;
+
+function App() {
+
+ // Create a reference to the worker object.
+ const worker = useRef(null);
+
+ // Model loading and progress
+ const [status, setStatus] = useState(null);
+ const [loadingMessage, setLoadingMessage] = useState('');
+ const [progressItems, setProgressItems] = useState([]);
+
+ const [task, setTask] = useState('');
+ const [text, setText] = useState('');
+ const [image, setImage] = useState(null);
+ const [result, setResult] = useState(null);
+ const [time, setTime] = useState(null);
+
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
+ useEffect(() => {
+ if (!worker.current) {
+ // Create the worker if it does not yet exist.
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
+ type: 'module'
+ });
+ }
+
+ // Create a callback function for messages from the worker thread.
+ const onMessageReceived = (e) => {
+ switch (e.data.status) {
+ case 'loading':
+ // Model file start load: add a new progress item to the list.
+ setStatus('loading');
+ setLoadingMessage(e.data.data);
+ break;
+
+ case 'initiate':
+ setProgressItems(prev => [...prev, e.data]);
+ break;
+
+ case 'progress':
+ // Model file progress: update one of the progress items.
+ setProgressItems(
+ prev => prev.map(item => {
+ if (item.file === e.data.file) {
+ return { ...item, ...e.data }
+ }
+ return item;
+ })
+ );
+ break;
+
+ case 'done':
+ // Model file loaded: remove the progress item from the list.
+ setProgressItems(
+ prev => prev.filter(item => item.file !== e.data.file)
+ );
+ break;
+
+ case 'ready':
+ // Pipeline ready: the worker is ready to accept messages.
+ setStatus('ready');
+ break;
+
+ case 'complete':
+ setResult(e.data.result);
+ setTime(e.data.time);
+ setStatus('ready');
+ break;
+ }
+ };
+
+ // Attach the callback function as an event listener.
+ worker.current.addEventListener('message', onMessageReceived);
+
+ // Define a cleanup function for when the component is unmounted.
+ return () => {
+ worker.current.removeEventListener('message', onMessageReceived);
+ };
+ }, []);
+
+ const handleClick = useCallback(() => {
+ if (status === null) {
+ setStatus('loading');
+ worker.current.postMessage({ type: 'load' });
+ } else {
+ setStatus('running');
+ worker.current.postMessage({
+ type: 'run', data: { text, url: image, task }
+ });
+ }
+ }, [status, task, image, text]);
+
+ return (
+ IS_WEBGPU_AVAILABLE
+ ? (
+
+ {status === 'loading' && (
+
+
+
{loadingMessage}
+ {progressItems.map(({ file, progress, total }, i) => (
+
+ ))}
+
+
+ )}
+
+
+
Florence2 WebGPU
+ Powerful vision foundation model running locally in your browser.
+
+
+
+
+
+ You are about to download Florence-2-base-ft ,
+ a 230 million parameter vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks like captioning, object detection, and segmentation.
+ Once loaded, the model (340 MB) will be cached and reused when you revisit the page.
+
+ Everything runs locally in your browser using 🤗 Transformers.js and ONNX Runtime Web,
+ meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
+
+
+
+
+
+ Task
+ setTask(e.target.value)}
+ >
+ Caption
+ Detailed Caption
+ More Detailed Caption
+ OCR
+ OCR with Region
+ Object Detection
+ Dense Region Caption
+ Caption to Phrase Grounding
+ {/* Referring Expression Segmentation */}
+ {/* Region to Segmentation */}
+ {/* Open Vocabulary Detection */}
+ {/* Region to Category */}
+ {/* Region to Description */}
+ {/* Region to OCR */}
+ {/* Region Proposal */}
+
+
+
+ Input Image
+ {
+ worker.current.postMessage({ type: 'reset' }); // Reset image cache
+ setResult(null);
+ setImage(result);
+ }} />
+
+
+
+ {
+ task === '
'
+ && (
+ Text input
+ setText(e.target.value)}
+ />
+
)
+ }
+
+
+
Output
+
+ {result?.[task] && (<>
+ {
+ typeof result[task] === 'string'
+ ?
{result[task]}
+ :
+ {JSON.stringify(result[task], null, 2)}
+
+ }
+ {
+ time &&
Execution time: {time.toFixed(2)} ms
+ }
+ >)
+ }
+
+
+
+
+
+
+
+ {status === null ? 'Load model' :
+ status === 'running'
+ ? 'Running...'
+ : 'Run model'
+ }
+
+
+
+
+
)
+ : (WebGPU is not supported by this browser :(
)
+ )
+}
+
+export default App
diff --git a/examples/florence2-webgpu/src/components/ImageInput.jsx b/examples/florence2-webgpu/src/components/ImageInput.jsx
new file mode 100644
index 000000000..9f24d9d5b
--- /dev/null
+++ b/examples/florence2-webgpu/src/components/ImageInput.jsx
@@ -0,0 +1,68 @@
+import { useState, useRef } from 'react';
+
+const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beetle.png';
+
+const ImageInput = ({ onImageChange, ...props }) => {
+ const [imagePreview, setImagePreview] = useState(null);
+ const fileInputRef = useRef(null);
+
+ const readFile = (file) => {
+ if (!file) return;
+ const reader = new FileReader();
+ reader.onloadend = () => {
+ setImagePreview(reader.result);
+ if (onImageChange) {
+ onImageChange(file, reader.result);
+ }
+ };
+ reader.readAsDataURL(file);
+ }
+
+ const handleImageChange = (event) => {
+ readFile(event.target.files[0]);
+ };
+
+ const handleDragOver = (event) => {
+ event.preventDefault();
+ };
+
+ const handleDrop = (event) => {
+ event.preventDefault();
+ readFile(event.dataTransfer.files[0]);
+ };
+
+ const handleClick = () => {
+ fileInputRef.current.click();
+ };
+
+ return (
+
+
+ {imagePreview ? (
+
+ ) : (
+
+ Drag & drop or click to select an image
+ {
+ e.stopPropagation();
+ setImagePreview(EXAMPLE_URL);
+ onImageChange(null, EXAMPLE_URL);
+ }}>(or try an example )
+
+ )}
+
+ );
+};
+
+export default ImageInput;
diff --git a/examples/florence2-webgpu/src/components/Progress.jsx b/examples/florence2-webgpu/src/components/Progress.jsx
new file mode 100644
index 000000000..9ce024cc8
--- /dev/null
+++ b/examples/florence2-webgpu/src/components/Progress.jsx
@@ -0,0 +1,15 @@
+function formatBytes(size) {
+ const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
+ return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
+}
+
+export default function Progress({ text, percentage, total }) {
+ percentage ??= 0;
+ return (
+
+
+ {text} ({percentage.toFixed(2)}%{isNaN(total) ? '' : ` of ${formatBytes(total)}`})
+
+
+ );
+}
diff --git a/examples/florence2-webgpu/src/index.css b/examples/florence2-webgpu/src/index.css
new file mode 100644
index 000000000..c4a1285e0
--- /dev/null
+++ b/examples/florence2-webgpu/src/index.css
@@ -0,0 +1,21 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
+
+@layer utilities {
+ .scrollbar-thin::-webkit-scrollbar {
+ @apply w-2;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-track {
+ @apply rounded-full bg-gray-100 dark:bg-gray-700;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb {
+ @apply rounded-full bg-gray-300 dark:bg-gray-600;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb:hover {
+ @apply bg-gray-500;
+ }
+}
diff --git a/examples/florence2-webgpu/src/main.jsx b/examples/florence2-webgpu/src/main.jsx
new file mode 100644
index 000000000..54b39dd1d
--- /dev/null
+++ b/examples/florence2-webgpu/src/main.jsx
@@ -0,0 +1,10 @@
+import React from 'react'
+import ReactDOM from 'react-dom/client'
+import App from './App.jsx'
+import './index.css'
+
+ReactDOM.createRoot(document.getElementById('root')).render(
+
+
+ ,
+)
diff --git a/examples/florence2-webgpu/src/worker.js b/examples/florence2-webgpu/src/worker.js
new file mode 100644
index 000000000..92c1732f4
--- /dev/null
+++ b/examples/florence2-webgpu/src/worker.js
@@ -0,0 +1,140 @@
+
+import {
+ Florence2ForConditionalGeneration,
+ AutoProcessor,
+ AutoTokenizer,
+ RawImage,
+ full,
+} from '@xenova/transformers';
+
+async function hasFp16() {
+ try {
+ const adapter = await navigator.gpu.requestAdapter();
+ return adapter.features.has('shader-f16');
+ } catch (e) {
+ return false;
+ }
+}
+
+/**
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
+ */
+class Florence2Singleton {
+ static model_id = 'onnx-community/Florence-2-base-ft';
+
+ static async getInstance(progress_callback = null) {
+ this.processor ??= AutoProcessor.from_pretrained(this.model_id);
+ this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id);
+
+ this.supports_fp16 ??= await hasFp16();
+ this.model ??= Florence2ForConditionalGeneration.from_pretrained(this.model_id, {
+ dtype: {
+ embed_tokens: this.supports_fp16 ? 'fp16' : 'fp32',
+ vision_encoder: this.supports_fp16 ? 'fp16' : 'fp32',
+ encoder_model: 'q4', // or 'fp16' or 'fp32'
+ decoder_model_merged: 'q4', // or 'fp16' or 'fp32'
+ },
+ device: 'webgpu',
+ progress_callback,
+ });
+
+ return Promise.all([this.model, this.tokenizer, this.processor]);
+ }
+}
+
+
+async function load() {
+ self.postMessage({
+ status: 'loading',
+ data: 'Loading model...'
+ });
+
+ // Load the pipeline and save it for future use.
+ const [model, tokenizer, processor] = await Florence2Singleton.getInstance(x => {
+ // We also add a progress callback to the pipeline so that we can
+ // track model loading.
+ self.postMessage(x);
+ });
+
+ self.postMessage({
+ status: 'loading',
+ data: 'Compiling shaders and warming up model...'
+ });
+
+ // Dummy text and vision inputs
+ const text_inputs = tokenizer('a');
+ const pixel_values = full([1, 3, 768, 768], 0.0);
+
+ // Run model with dummy input to compile shaders
+ await model.generate({
+ ...text_inputs,
+ pixel_values,
+ max_new_tokens: 1,
+ });
+
+ self.postMessage({ status: 'ready' });
+}
+
+const TASKS_WITH_INPUTS = [
+ '',
+]
+
+let vision_inputs;
+let image_size;
+async function run({ text, url, task }) {
+ const [model, tokenizer, processor] = await Florence2Singleton.getInstance();
+
+ // Read and preprocess image
+ const start = performance.now();
+ if (!vision_inputs) {
+ // Cache vision inputs when possible
+ const image = await RawImage.fromURL(url);
+ image_size = image.size;
+ vision_inputs = await processor(image);
+ }
+
+ let user_input = task;
+ if (TASKS_WITH_INPUTS.includes(task) && text) {
+ user_input += text;
+ }
+ const prompts = processor.construct_prompts(user_input);
+ const text_inputs = tokenizer(prompts);
+
+ // Generate text
+ const generated_ids = await model.generate({
+ ...text_inputs,
+ ...vision_inputs,
+ max_new_tokens: 128,
+ num_beams: 1,
+ do_sample: false,
+ });
+
+ // Decode generated text
+ const generated_text = tokenizer.batch_decode(generated_ids, { skip_special_tokens: false })[0];
+
+ // Post-process the generated text
+ const result = processor.post_process_generation(generated_text, task, image_size);
+
+ const end = performance.now();
+
+ self.postMessage({ status: 'complete', result, time: end - start });
+}
+
+// Listen for messages from the main thread
+self.addEventListener('message', async (e) => {
+ const { type, data } = e.data;
+
+ switch (type) {
+ case 'load':
+ load();
+ break;
+
+ case 'run':
+ run(data);
+ break;
+
+ case 'reset':
+ vision_inputs = image_size = null;
+ break;
+ }
+});
diff --git a/examples/florence2-webgpu/tailwind.config.js b/examples/florence2-webgpu/tailwind.config.js
new file mode 100644
index 000000000..d37737fc0
--- /dev/null
+++ b/examples/florence2-webgpu/tailwind.config.js
@@ -0,0 +1,12 @@
+/** @type {import('tailwindcss').Config} */
+export default {
+ content: [
+ "./index.html",
+ "./src/**/*.{js,ts,jsx,tsx}",
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/examples/florence2-webgpu/vite.config.js b/examples/florence2-webgpu/vite.config.js
new file mode 100644
index 000000000..5a33944a9
--- /dev/null
+++ b/examples/florence2-webgpu/vite.config.js
@@ -0,0 +1,7 @@
+import { defineConfig } from 'vite'
+import react from '@vitejs/plugin-react'
+
+// https://vitejs.dev/config/
+export default defineConfig({
+ plugins: [react()],
+})
diff --git a/examples/musicgen-web/.eslintrc.cjs b/examples/musicgen-web/.eslintrc.cjs
new file mode 100644
index 000000000..3e212e1d4
--- /dev/null
+++ b/examples/musicgen-web/.eslintrc.cjs
@@ -0,0 +1,21 @@
+module.exports = {
+ root: true,
+ env: { browser: true, es2020: true },
+ extends: [
+ 'eslint:recommended',
+ 'plugin:react/recommended',
+ 'plugin:react/jsx-runtime',
+ 'plugin:react-hooks/recommended',
+ ],
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
+ parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
+ settings: { react: { version: '18.2' } },
+ plugins: ['react-refresh'],
+ rules: {
+ 'react/jsx-no-target-blank': 'off',
+ 'react-refresh/only-export-components': [
+ 'warn',
+ { allowConstantExport: true },
+ ],
+ },
+}
diff --git a/examples/musicgen-web/.gitignore b/examples/musicgen-web/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/musicgen-web/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/musicgen-web/README.md b/examples/musicgen-web/README.md
new file mode 100644
index 000000000..f768e33fc
--- /dev/null
+++ b/examples/musicgen-web/README.md
@@ -0,0 +1,8 @@
+# React + Vite
+
+This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
+
+Currently, two official plugins are available:
+
+- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
+- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
diff --git a/examples/musicgen-web/index.html b/examples/musicgen-web/index.html
new file mode 100644
index 000000000..cad1bcd1a
--- /dev/null
+++ b/examples/musicgen-web/index.html
@@ -0,0 +1,12 @@
+
+
+
+
+
+ MusicGen Web | In-browser text-to-music w/ 🤗 Transformers.js!
+
+
+
+
+
+
diff --git a/examples/musicgen-web/package.json b/examples/musicgen-web/package.json
new file mode 100644
index 000000000..0175494d7
--- /dev/null
+++ b/examples/musicgen-web/package.json
@@ -0,0 +1,30 @@
+{
+ "name": "musicgen-web",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@xenova/transformers": "github:xenova/transformers.js#v3",
+ "react": "^18.2.0",
+ "react-dom": "^18.2.0"
+ },
+ "devDependencies": {
+ "@types/react": "^18.2.66",
+ "@types/react-dom": "^18.2.22",
+ "@vitejs/plugin-react": "^4.2.1",
+ "autoprefixer": "^10.4.19",
+ "eslint": "^8.57.0",
+ "eslint-plugin-react": "^7.34.1",
+ "eslint-plugin-react-hooks": "^4.6.0",
+ "eslint-plugin-react-refresh": "^0.4.6",
+ "postcss": "^8.4.38",
+ "tailwindcss": "^3.4.3",
+ "vite": "^5.2.0"
+ }
+}
diff --git a/examples/musicgen-web/postcss.config.js b/examples/musicgen-web/postcss.config.js
new file mode 100644
index 000000000..2e7af2b7f
--- /dev/null
+++ b/examples/musicgen-web/postcss.config.js
@@ -0,0 +1,6 @@
+export default {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/examples/musicgen-web/src/App.css b/examples/musicgen-web/src/App.css
new file mode 100644
index 000000000..91ab868f6
--- /dev/null
+++ b/examples/musicgen-web/src/App.css
@@ -0,0 +1,9 @@
+#root {
+ max-width: 960px;
+ height: 100vh;
+ margin: 0 auto;
+ text-align: center;
+ display: flex;
+ justify-content: center;
+ align-items: center;
+}
diff --git a/examples/musicgen-web/src/App.jsx b/examples/musicgen-web/src/App.jsx
new file mode 100644
index 000000000..a64e8b655
--- /dev/null
+++ b/examples/musicgen-web/src/App.jsx
@@ -0,0 +1,229 @@
+import { useEffect, useState, useRef } from 'react';
+import { AutoTokenizer, MusicgenForConditionalGeneration, BaseStreamer } from '@xenova/transformers';
+import { encodeWAV, share } from './utils.js';
+
+import './App.css';
+
+const MODEL_ID = 'Xenova/musicgen-small';
+
+// Adapted from https://huggingface.co/spaces/facebook/MusicGen
+const EXAMPLES = [
+ '80s pop track with bassy drums and synth',
+ '90s rock song with loud guitars and heavy drums',
+ 'a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130',
+ 'A cheerful country song with acoustic guitars',
+ 'lofi slow bpm electro chill with organic samples',
+];
+
+// Enable sharing if running on Hugging Face Spaces
+const SHARING_ENABLED = window.location.host.endsWith('.hf.space');
+
+// Streamer to update progress
+class CallbackStreamer extends BaseStreamer {
+ constructor(callback_fn) {
+ super();
+ this.callback_fn = callback_fn;
+ }
+
+ put(value) {
+ return this.callback_fn(value);
+ }
+
+ end() {
+ return this.callback_fn();
+ }
+}
+
+// Main App component
+const App = () => {
+ // Input/output state
+ const [textInput, setTextInput] = useState(EXAMPLES[0]);
+ const [progress, setProgress] = useState(0);
+ const [loadProgress, setLoadProgress] = useState({});
+ const [statusText, setStatusText] = useState('Loading model (656MB)...');
+ const [result, setResult] = useState(null);
+ const audioRef = useRef(null);
+
+ // Model and tokenizer references
+ const modelPromise = useRef(null);
+ const tokenizerPromise = useRef(null);
+
+ // Generation parameters
+ const [guidance_scale, setGuidanceScale] = useState(3);
+ const [temperature, setTemperature] = useState(1);
+ const [duration, setDuration] = useState(10);
+
+ // Load model and tokenizer on first render
+ useEffect(() => {
+ modelPromise.current ??= MusicgenForConditionalGeneration.from_pretrained(MODEL_ID, {
+ progress_callback: (data) => {
+ if (data.status !== 'progress') return;
+ setLoadProgress(prev => ({ ...prev, [data.file]: data }))
+ },
+ dtype: {
+ text_encoder: 'q8',
+ decoder_model_merged: 'q8',
+ encodec_decode: 'fp32',
+ },
+ device: 'wasm',
+ });
+
+ tokenizerPromise.current ??= AutoTokenizer.from_pretrained(MODEL_ID);
+ }, []);
+
+ // Update progress bar based on load progress
+ useEffect(() => {
+ const items = Object.values(loadProgress);
+ if (items.length !== 5) return; // 5 files to load
+ let loaded = 0;
+ let total = 0;
+ for (const data of Object.values(loadProgress)) {
+ loaded += data.loaded;
+ total += data.total;
+ }
+ const progress = loaded / total;
+ setProgress(progress);
+ setStatusText(progress === 1
+ ? 'Ready!'
+ : `Loading model (${(progress * 100).toFixed()}% of 656MB)...`
+ );
+ }, [loadProgress]);
+
+ // Function to handle generating music
+ const generateMusic = async () => {
+ // Reset audio player and result
+ audioRef.current.src = '';
+ setResult(null);
+
+ // Get model and tokenizer
+ const tokenizer = await tokenizerPromise.current;
+ const model = await modelPromise.current;
+
+ // Get number of tokens to match user-specified duration (more intuitive for user)
+ // 503 tokens -> 10 seconds generated => ~50 tokens per second
+ // https://huggingface.co/docs/transformers/model_doc/musicgen#generation
+ const max_length = Math.min(
+ Math.max(Math.floor(duration * 50), 1) + 4,
+ model.generation_config.max_length ?? 1500,
+ );
+
+ // Create a streamer to update progress
+ let num_tokens = 0;
+ const streamer = new CallbackStreamer((value) => {
+ const percent = value === undefined ? 1 : ++num_tokens / max_length;
+ setStatusText(`Generating (${(percent * 100).toFixed()}%)...`);
+ setProgress(percent);
+ });
+
+ // Tokenize input text
+ const inputs = tokenizer(textInput);
+
+ // Generate music
+ const audio_values = await model.generate({
+ // Inputs
+ ...inputs,
+
+ // Generation parameters
+ max_length,
+ guidance_scale,
+ temperature,
+
+ // Outputs
+ streamer,
+ });
+
+ setStatusText('Encoding audio...');
+
+ // Encode audio values to WAV
+ const sampling_rate = model.config.audio_encoder.sampling_rate;
+ const wav = encodeWAV(audio_values.data, sampling_rate);
+ const blob = new Blob([wav], { type: 'audio/wav' });
+ setResult(blob);
+
+ audioRef.current.src = URL.createObjectURL(blob);
+ setStatusText('Done!');
+ };
+
+ return (
+
+
MusicGen Web
+
+
+ {/* Text input for user */}
+
setTextInput(e.target.value)}
+ className="border border-gray-300 p-2 mb-4 w-full rounded"
+ />
+
+ {/* Example buttons */}
+
+ {EXAMPLES.map((example, i) => (
+ setTextInput(e.target.innerText)}>{example}
+ ))}
+
+
+ {/* Generation parameters */}
+
+ {/* Duration */}
+
+
Duration
+
setDuration(e.target.value)} />
+
{`${duration} second${duration > 1 ? 's' : ''}`}
+
+
+ {/* Guidance Scale */}
+
+
Guidance Scale
+
setGuidanceScale(e.target.value)} />
+
{guidance_scale}
+
+
+ {/* Temperature */}
+
+
Temperature
+
setTemperature(e.target.value)} />
+
{temperature}
+
+
+
+ {/* Button to generate music */}
+
Generate Music
+
+ {/* Progress bar */}
+
+
+ {/* Audio player */}
+ {
+
+ {SHARING_ENABLED && result &&
+
{
+ e.target.disabled = true;
+ e.target.innerText = 'Uploading...';
+ await share(result, {
+ prompt: textInput,
+ duration,
+ guidance_scale,
+ temperature,
+ });
+ e.target.disabled = false;
+ e.target.innerText = 'Share';
+ }
+ }>Share
+ }
+
}
+
+ );
+};
+
+export default App;
diff --git a/examples/musicgen-web/src/index.css b/examples/musicgen-web/src/index.css
new file mode 100644
index 000000000..bd6213e1d
--- /dev/null
+++ b/examples/musicgen-web/src/index.css
@@ -0,0 +1,3 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
\ No newline at end of file
diff --git a/examples/musicgen-web/src/main.jsx b/examples/musicgen-web/src/main.jsx
new file mode 100644
index 000000000..54b39dd1d
--- /dev/null
+++ b/examples/musicgen-web/src/main.jsx
@@ -0,0 +1,10 @@
+import React from 'react'
+import ReactDOM from 'react-dom/client'
+import App from './App.jsx'
+import './index.css'
+
+ReactDOM.createRoot(document.getElementById('root')).render(
+
+
+ ,
+)
diff --git a/examples/musicgen-web/src/utils.js b/examples/musicgen-web/src/utils.js
new file mode 100644
index 000000000..436c9daab
--- /dev/null
+++ b/examples/musicgen-web/src/utils.js
@@ -0,0 +1,59 @@
+
+// Adapted from https://www.npmjs.com/package/audiobuffer-to-wav
+export function encodeWAV(samples, sampleRate = 16000) {
+ let offset = 44;
+ const buffer = new ArrayBuffer(offset + samples.length * 4);
+ const view = new DataView(buffer);
+
+ /* RIFF identifier */
+ writeString(view, 0, 'RIFF')
+ /* RIFF chunk length */
+ view.setUint32(4, 36 + samples.length * 4, true)
+ /* RIFF type */
+ writeString(view, 8, 'WAVE')
+ /* format chunk identifier */
+ writeString(view, 12, 'fmt ')
+ /* format chunk length */
+ view.setUint32(16, 16, true)
+ /* sample format (raw) */
+ view.setUint16(20, 3, true)
+ /* channel count */
+ view.setUint16(22, 1, true)
+ /* sample rate */
+ view.setUint32(24, sampleRate, true)
+ /* byte rate (sample rate * block align) */
+ view.setUint32(28, sampleRate * 4, true)
+ /* block align (channel count * bytes per sample) */
+ view.setUint16(32, 4, true)
+ /* bits per sample */
+ view.setUint16(34, 32, true)
+ /* data chunk identifier */
+ writeString(view, 36, 'data')
+ /* data chunk length */
+ view.setUint32(40, samples.length * 4, true)
+
+ for (let i = 0; i < samples.length; ++i, offset += 4) {
+ view.setFloat32(offset, samples[i], true)
+ }
+
+ return buffer
+}
+function writeString(view, offset, string) {
+ for (let i = 0; i < string.length; ++i) {
+ view.setUint8(offset + i, string.charCodeAt(i))
+ }
+}
+
+export async function share(body, settings) {
+ const response = await fetch('https://huggingface.co/uploads', { method: 'POST', body });
+ if (!response.ok) throw new Error(`Failed to upload audio: ${response.statusText}`);
+ const url = await response.text();
+
+ const params = new URLSearchParams({
+ title: `🎵 ${settings.prompt}`,
+ description: ` \n${JSON.stringify(settings, null, 2)}`,
+ });
+
+ const shareURL = `https://huggingface.co/spaces/Xenova/musicgen-web/discussions/new?${params.toString()}`;
+ window.open(shareURL, '_blank');
+}
\ No newline at end of file
diff --git a/examples/musicgen-web/tailwind.config.js b/examples/musicgen-web/tailwind.config.js
new file mode 100644
index 000000000..d37737fc0
--- /dev/null
+++ b/examples/musicgen-web/tailwind.config.js
@@ -0,0 +1,12 @@
+/** @type {import('tailwindcss').Config} */
+export default {
+ content: [
+ "./index.html",
+ "./src/**/*.{js,ts,jsx,tsx}",
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/examples/musicgen-web/vite.config.js b/examples/musicgen-web/vite.config.js
new file mode 100644
index 000000000..5a33944a9
--- /dev/null
+++ b/examples/musicgen-web/vite.config.js
@@ -0,0 +1,7 @@
+import { defineConfig } from 'vite'
+import react from '@vitejs/plugin-react'
+
+// https://vitejs.dev/config/
+export default defineConfig({
+ plugins: [react()],
+})
diff --git a/examples/next-client/package-lock.json b/examples/next-client/package-lock.json
index 2a76aef90..d03dc91e7 100644
--- a/examples/next-client/package-lock.json
+++ b/examples/next-client/package-lock.json
@@ -8,7 +8,7 @@
"name": "next",
"version": "0.1.0",
"dependencies": {
- "@xenova/transformers": "^2.4.2",
+ "@huggingface/transformers": "^3.0.0-alpha.5",
"autoprefixer": "10.4.14",
"eslint": "8.45.0",
"eslint-config-next": "13.4.12",
@@ -49,6 +49,15 @@
"node": ">=6.9.0"
}
},
+ "node_modules/@emnapi/runtime": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.2.0.tgz",
+ "integrity": "sha512-bV21/9LQmcQeCPEg3BDFtvwL6cwiTMksYNWQQ4KOxCZikEGalWtenoZ0wCiukJINlGCIi2KXx01g4FoH/LxpzQ==",
+ "optional": true,
+ "dependencies": {
+ "tslib": "^2.4.0"
+ }
+ },
"node_modules/@eslint-community/eslint-utils": {
"version": "4.4.0",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz",
@@ -57,79 +66,663 @@
"eslint-visitor-keys": "^3.3.0"
},
"engines": {
- "node": "^12.22.0 || ^14.17.0 || >=16.0.0"
+ "node": "^12.22.0 || ^14.17.0 || >=16.0.0"
+ },
+ "peerDependencies": {
+ "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0"
+ }
+ },
+ "node_modules/@eslint-community/regexpp": {
+ "version": "4.6.0",
+ "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.6.0.tgz",
+ "integrity": "sha512-uiPeRISaglZnaZk8vwrjQZ1CxogZeY/4IYft6gBOTqu1WhVXWmCmZMWxUv2Q/pxSvPdp1JPaO62kLOcOkMqWrw==",
+ "engines": {
+ "node": "^12.0.0 || ^14.0.0 || >=16.0.0"
+ }
+ },
+ "node_modules/@eslint/eslintrc": {
+ "version": "2.1.0",
+ "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.1.0.tgz",
+ "integrity": "sha512-Lj7DECXqIVCqnqjjHMPna4vn6GJcMgul/wuS0je9OZ9gsL0zzDpKPVtcG1HaDVc+9y+qgXneTeUMbCqXJNpH1A==",
+ "dependencies": {
+ "ajv": "^6.12.4",
+ "debug": "^4.3.2",
+ "espree": "^9.6.0",
+ "globals": "^13.19.0",
+ "ignore": "^5.2.0",
+ "import-fresh": "^3.2.1",
+ "js-yaml": "^4.1.0",
+ "minimatch": "^3.1.2",
+ "strip-json-comments": "^3.1.1"
+ },
+ "engines": {
+ "node": "^12.22.0 || ^14.17.0 || >=16.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/eslint"
+ }
+ },
+ "node_modules/@eslint/js": {
+ "version": "8.44.0",
+ "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.44.0.tgz",
+ "integrity": "sha512-Ag+9YM4ocKQx9AarydN0KY2j0ErMHNIocPDrVo8zAE44xLTjEtz81OdR68/cydGtk6m6jDb5Za3r2useMzYmSw==",
+ "engines": {
+ "node": "^12.22.0 || ^14.17.0 || >=16.0.0"
+ }
+ },
+ "node_modules/@huggingface/jinja": {
+ "version": "0.3.0",
+ "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.3.0.tgz",
+ "integrity": "sha512-GLJzso0M07ZncFkrJMIXVU4os6GFbPocD4g8fMQPMGJubf48FtGOsUORH2rtFdXPIPelz8SLBMn8ZRmOTwXm9Q==",
+ "engines": {
+ "node": ">=18"
+ }
+ },
+ "node_modules/@huggingface/transformers": {
+ "version": "3.0.0-alpha.5",
+ "resolved": "https://registry.npmjs.org/@huggingface/transformers/-/transformers-3.0.0-alpha.5.tgz",
+ "integrity": "sha512-GFJ3YfOq+Ax1LvDECOhvLay0sqCbkE1q3roloRYrYoflOUY+YX1A5ez+hfmDyN65blC7eFf4UQ9yWHmyKBkBiw==",
+ "dependencies": {
+ "@huggingface/jinja": "^0.3.0",
+ "onnxruntime-node": "1.18.0",
+ "onnxruntime-web": "1.19.0-dev.20240804-ee2fe87e2d",
+ "sharp": "^0.33.2"
+ }
+ },
+ "node_modules/@huggingface/transformers/node_modules/long": {
+ "version": "5.2.3",
+ "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz",
+ "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q=="
+ },
+ "node_modules/@huggingface/transformers/node_modules/onnxruntime-common": {
+ "version": "1.18.0",
+ "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.18.0.tgz",
+ "integrity": "sha512-lufrSzX6QdKrktAELG5x5VkBpapbCeS3dQwrXbN0eD9rHvU0yAWl7Ztju9FvgAKWvwd/teEKJNj3OwM6eTZh3Q=="
+ },
+ "node_modules/@huggingface/transformers/node_modules/onnxruntime-node": {
+ "version": "1.18.0",
+ "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.18.0.tgz",
+ "integrity": "sha512-iTnFcxKpmywCatx8ov4GTbECe3tJk2Bp1OA2mWRJde78q+7tpPYBhKMnwhlaoKy9oKQcy4UoEuuhoy2PSD13ww==",
+ "hasInstallScript": true,
+ "os": [
+ "win32",
+ "darwin",
+ "linux"
+ ],
+ "dependencies": {
+ "onnxruntime-common": "1.18.0",
+ "tar": "^7.0.1"
+ }
+ },
+ "node_modules/@huggingface/transformers/node_modules/onnxruntime-web": {
+ "version": "1.19.0-dev.20240804-ee2fe87e2d",
+ "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.19.0-dev.20240804-ee2fe87e2d.tgz",
+ "integrity": "sha512-uz93GKeBjHHq0150qIAxGGMhf5YLnfh12OChvYyLG2H6LzXymXhorvcxV7sklofw6fVooL3IutMz8nbZLMQxYg==",
+ "dependencies": {
+ "flatbuffers": "^1.12.0",
+ "guid-typescript": "^1.0.9",
+ "long": "^5.2.3",
+ "onnxruntime-common": "1.19.0-dev.20240730-530a2d7b41",
+ "platform": "^1.3.6",
+ "protobufjs": "^7.2.4"
+ }
+ },
+ "node_modules/@huggingface/transformers/node_modules/onnxruntime-web/node_modules/onnxruntime-common": {
+ "version": "1.19.0-dev.20240730-530a2d7b41",
+ "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.19.0-dev.20240730-530a2d7b41.tgz",
+ "integrity": "sha512-fWyg0USjvdHY5JL+3y/fXUDTOl9OLfhrX+sttfM2LW7jT/O8VNxjc16oAjyJHJruOQdrH2qo+KnxjOLA68i2dw=="
+ },
+ "node_modules/@huggingface/transformers/node_modules/sharp": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.33.4.tgz",
+ "integrity": "sha512-7i/dt5kGl7qR4gwPRD2biwD2/SvBn3O04J77XKFgL2OnZtQw+AG9wnuS/csmu80nPRHLYE9E41fyEiG8nhH6/Q==",
+ "hasInstallScript": true,
+ "dependencies": {
+ "color": "^4.2.3",
+ "detect-libc": "^2.0.3",
+ "semver": "^7.6.0"
+ },
+ "engines": {
+ "libvips": ">=8.15.2",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-darwin-arm64": "0.33.4",
+ "@img/sharp-darwin-x64": "0.33.4",
+ "@img/sharp-libvips-darwin-arm64": "1.0.2",
+ "@img/sharp-libvips-darwin-x64": "1.0.2",
+ "@img/sharp-libvips-linux-arm": "1.0.2",
+ "@img/sharp-libvips-linux-arm64": "1.0.2",
+ "@img/sharp-libvips-linux-s390x": "1.0.2",
+ "@img/sharp-libvips-linux-x64": "1.0.2",
+ "@img/sharp-libvips-linuxmusl-arm64": "1.0.2",
+ "@img/sharp-libvips-linuxmusl-x64": "1.0.2",
+ "@img/sharp-linux-arm": "0.33.4",
+ "@img/sharp-linux-arm64": "0.33.4",
+ "@img/sharp-linux-s390x": "0.33.4",
+ "@img/sharp-linux-x64": "0.33.4",
+ "@img/sharp-linuxmusl-arm64": "0.33.4",
+ "@img/sharp-linuxmusl-x64": "0.33.4",
+ "@img/sharp-wasm32": "0.33.4",
+ "@img/sharp-win32-ia32": "0.33.4",
+ "@img/sharp-win32-x64": "0.33.4"
+ }
+ },
+ "node_modules/@humanwhocodes/config-array": {
+ "version": "0.11.10",
+ "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.10.tgz",
+ "integrity": "sha512-KVVjQmNUepDVGXNuoRRdmmEjruj0KfiGSbS8LVc12LMsWDQzRXJ0qdhN8L8uUigKpfEHRhlaQFY0ib1tnUbNeQ==",
+ "dependencies": {
+ "@humanwhocodes/object-schema": "^1.2.1",
+ "debug": "^4.1.1",
+ "minimatch": "^3.0.5"
+ },
+ "engines": {
+ "node": ">=10.10.0"
+ }
+ },
+ "node_modules/@humanwhocodes/module-importer": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz",
+ "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==",
+ "engines": {
+ "node": ">=12.22"
+ },
+ "funding": {
+ "type": "github",
+ "url": "https://github.com/sponsors/nzakas"
+ }
+ },
+ "node_modules/@humanwhocodes/object-schema": {
+ "version": "1.2.1",
+ "resolved": "https://registry.npmjs.org/@humanwhocodes/object-schema/-/object-schema-1.2.1.tgz",
+ "integrity": "sha512-ZnQMnLV4e7hDlUvw8H+U8ASL02SS2Gn6+9Ac3wGGLIe7+je2AeAOxPY+izIPJDfFDb7eDjev0Us8MO1iFRN8hA=="
+ },
+ "node_modules/@img/sharp-darwin-arm64": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-darwin-arm64/-/sharp-darwin-arm64-0.33.4.tgz",
+ "integrity": "sha512-p0suNqXufJs9t3RqLBO6vvrgr5OhgbWp76s5gTRvdmxmuv9E1rcaqGUsl3l4mKVmXPkTkTErXediAui4x+8PSA==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "glibc": ">=2.26",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-darwin-arm64": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-darwin-x64": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-darwin-x64/-/sharp-darwin-x64-0.33.4.tgz",
+ "integrity": "sha512-0l7yRObwtTi82Z6ebVI2PnHT8EB2NxBgpK2MiKJZJ7cz32R4lxd001ecMhzzsZig3Yv9oclvqqdV93jo9hy+Dw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "glibc": ">=2.26",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-darwin-x64": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-libvips-darwin-arm64": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-arm64/-/sharp-libvips-darwin-arm64-1.0.2.tgz",
+ "integrity": "sha512-tcK/41Rq8IKlSaKRCCAuuY3lDJjQnYIW1UXU1kxcEKrfL8WR7N6+rzNoOxoQRJWTAECuKwgAHnPvqXGN8XfkHA==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "macos": ">=11",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-darwin-x64": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-x64/-/sharp-libvips-darwin-x64-1.0.2.tgz",
+ "integrity": "sha512-Ofw+7oaWa0HiiMiKWqqaZbaYV3/UGL2wAPeLuJTx+9cXpCRdvQhCLG0IH8YGwM0yGWGLpsF4Su9vM1o6aer+Fw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "macos": ">=10.13",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-arm": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm/-/sharp-libvips-linux-arm-1.0.2.tgz",
+ "integrity": "sha512-iLWCvrKgeFoglQxdEwzu1eQV04o8YeYGFXtfWU26Zr2wWT3q3MTzC+QTCO3ZQfWd3doKHT4Pm2kRmLbupT+sZw==",
+ "cpu": [
+ "arm"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.28",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-arm64": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm64/-/sharp-libvips-linux-arm64-1.0.2.tgz",
+ "integrity": "sha512-x7kCt3N00ofFmmkkdshwj3vGPCnmiDh7Gwnd4nUwZln2YjqPxV1NlTyZOvoDWdKQVDL911487HOueBvrpflagw==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.26",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-s390x": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-s390x/-/sharp-libvips-linux-s390x-1.0.2.tgz",
+ "integrity": "sha512-cmhQ1J4qVhfmS6szYW7RT+gLJq9dH2i4maq+qyXayUSn9/3iY2ZeWpbAgSpSVbV2E1JUL2Gg7pwnYQ1h8rQIog==",
+ "cpu": [
+ "s390x"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.28",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-x64": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-x64/-/sharp-libvips-linux-x64-1.0.2.tgz",
+ "integrity": "sha512-E441q4Qdb+7yuyiADVi5J+44x8ctlrqn8XgkDTwr4qPJzWkaHwD489iZ4nGDgcuya4iMN3ULV6NwbhRZJ9Z7SQ==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.26",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linuxmusl-arm64": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-arm64/-/sharp-libvips-linuxmusl-arm64-1.0.2.tgz",
+ "integrity": "sha512-3CAkndNpYUrlDqkCM5qhksfE+qSIREVpyoeHIU6jd48SJZViAmznoQQLAv4hVXF7xyUB9zf+G++e2v1ABjCbEQ==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "musl": ">=1.2.2",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linuxmusl-x64": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-x64/-/sharp-libvips-linuxmusl-x64-1.0.2.tgz",
+ "integrity": "sha512-VI94Q6khIHqHWNOh6LLdm9s2Ry4zdjWJwH56WoiJU7NTeDwyApdZZ8c+SADC8OH98KWNQXnE01UdJ9CSfZvwZw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "musl": ">=1.2.2",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-linux-arm": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm/-/sharp-linux-arm-0.33.4.tgz",
+ "integrity": "sha512-RUgBD1c0+gCYZGCCe6mMdTiOFS0Zc/XrN0fYd6hISIKcDUbAW5NtSQW9g/powkrXYm6Vzwd6y+fqmExDuCdHNQ==",
+ "cpu": [
+ "arm"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.28",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-arm": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-linux-arm64": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm64/-/sharp-linux-arm64-0.33.4.tgz",
+ "integrity": "sha512-2800clwVg1ZQtxwSoTlHvtm9ObgAax7V6MTAB/hDT945Tfyy3hVkmiHpeLPCKYqYR1Gcmv1uDZ3a4OFwkdBL7Q==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.26",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-arm64": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-linux-s390x": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-s390x/-/sharp-linux-s390x-0.33.4.tgz",
+ "integrity": "sha512-h3RAL3siQoyzSoH36tUeS0PDmb5wINKGYzcLB5C6DIiAn2F3udeFAum+gj8IbA/82+8RGCTn7XW8WTFnqag4tQ==",
+ "cpu": [
+ "s390x"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.31",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-s390x": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-linux-x64": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-x64/-/sharp-linux-x64-0.33.4.tgz",
+ "integrity": "sha512-GoR++s0XW9DGVi8SUGQ/U4AeIzLdNjHka6jidVwapQ/JebGVQIpi52OdyxCNVRE++n1FCLzjDovJNozif7w/Aw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "glibc": ">=2.26",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-x64": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-linuxmusl-arm64": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-arm64/-/sharp-linuxmusl-arm64-0.33.4.tgz",
+ "integrity": "sha512-nhr1yC3BlVrKDTl6cO12gTpXMl4ITBUZieehFvMntlCXFzH2bvKG76tBL2Y/OqhupZt81pR7R+Q5YhJxW0rGgQ==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "musl": ">=1.2.2",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linuxmusl-arm64": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-linuxmusl-x64": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-x64/-/sharp-linuxmusl-x64-0.33.4.tgz",
+ "integrity": "sha512-uCPTku0zwqDmZEOi4ILyGdmW76tH7dm8kKlOIV1XC5cLyJ71ENAAqarOHQh0RLfpIpbV5KOpXzdU6XkJtS0daw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "musl": ">=1.2.2",
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linuxmusl-x64": "1.0.2"
+ }
+ },
+ "node_modules/@img/sharp-wasm32": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-wasm32/-/sharp-wasm32-0.33.4.tgz",
+ "integrity": "sha512-Bmmauh4sXUsUqkleQahpdNXKvo+wa1V9KhT2pDA4VJGKwnKMJXiSTGphn0gnJrlooda0QxCtXc6RX1XAU6hMnQ==",
+ "cpu": [
+ "wasm32"
+ ],
+ "optional": true,
+ "dependencies": {
+ "@emnapi/runtime": "^1.1.1"
+ },
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-win32-ia32": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-win32-ia32/-/sharp-win32-ia32-0.33.4.tgz",
+ "integrity": "sha512-99SJ91XzUhYHbx7uhK3+9Lf7+LjwMGQZMDlO/E/YVJ7Nc3lyDFZPGhjwiYdctoH2BOzW9+TnfqcaMKt0jHLdqw==",
+ "cpu": [
+ "ia32"
+ ],
+ "optional": true,
+ "os": [
+ "win32"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
},
- "peerDependencies": {
- "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0"
+ "funding": {
+ "url": "https://opencollective.com/libvips"
}
},
- "node_modules/@eslint-community/regexpp": {
- "version": "4.6.0",
- "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.6.0.tgz",
- "integrity": "sha512-uiPeRISaglZnaZk8vwrjQZ1CxogZeY/4IYft6gBOTqu1WhVXWmCmZMWxUv2Q/pxSvPdp1JPaO62kLOcOkMqWrw==",
+ "node_modules/@img/sharp-win32-x64": {
+ "version": "0.33.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-win32-x64/-/sharp-win32-x64-0.33.4.tgz",
+ "integrity": "sha512-3QLocdTRVIrFNye5YocZl+KKpYKP+fksi1QhmOArgx7GyhIbQp/WrJRu176jm8IxromS7RIkzMiMINVdBtC8Aw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "win32"
+ ],
"engines": {
- "node": "^12.0.0 || ^14.0.0 || >=16.0.0"
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0",
+ "npm": ">=9.6.5",
+ "pnpm": ">=7.1.0",
+ "yarn": ">=3.2.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
}
},
- "node_modules/@eslint/eslintrc": {
- "version": "2.1.0",
- "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.1.0.tgz",
- "integrity": "sha512-Lj7DECXqIVCqnqjjHMPna4vn6GJcMgul/wuS0je9OZ9gsL0zzDpKPVtcG1HaDVc+9y+qgXneTeUMbCqXJNpH1A==",
+ "node_modules/@isaacs/cliui": {
+ "version": "8.0.2",
+ "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz",
+ "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==",
"dependencies": {
- "ajv": "^6.12.4",
- "debug": "^4.3.2",
- "espree": "^9.6.0",
- "globals": "^13.19.0",
- "ignore": "^5.2.0",
- "import-fresh": "^3.2.1",
- "js-yaml": "^4.1.0",
- "minimatch": "^3.1.2",
- "strip-json-comments": "^3.1.1"
+ "string-width": "^5.1.2",
+ "string-width-cjs": "npm:string-width@^4.2.0",
+ "strip-ansi": "^7.0.1",
+ "strip-ansi-cjs": "npm:strip-ansi@^6.0.1",
+ "wrap-ansi": "^8.1.0",
+ "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0"
},
"engines": {
- "node": "^12.22.0 || ^14.17.0 || >=16.0.0"
- },
- "funding": {
- "url": "https://opencollective.com/eslint"
+ "node": ">=12"
}
},
- "node_modules/@eslint/js": {
- "version": "8.44.0",
- "resolved": "https://registry.npmjs.org/@eslint/js/-/js-8.44.0.tgz",
- "integrity": "sha512-Ag+9YM4ocKQx9AarydN0KY2j0ErMHNIocPDrVo8zAE44xLTjEtz81OdR68/cydGtk6m6jDb5Za3r2useMzYmSw==",
+ "node_modules/@isaacs/cliui/node_modules/ansi-regex": {
+ "version": "6.0.1",
+ "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.0.1.tgz",
+ "integrity": "sha512-n5M855fKb2SsfMIiFFoVrABHJC8QtHwVx+mHWP3QcEqBHYienj5dHSgjbxtC0WEZXYt4wcD6zrQElDPhFuZgfA==",
"engines": {
- "node": "^12.22.0 || ^14.17.0 || >=16.0.0"
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-regex?sponsor=1"
}
},
- "node_modules/@humanwhocodes/config-array": {
- "version": "0.11.10",
- "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.10.tgz",
- "integrity": "sha512-KVVjQmNUepDVGXNuoRRdmmEjruj0KfiGSbS8LVc12LMsWDQzRXJ0qdhN8L8uUigKpfEHRhlaQFY0ib1tnUbNeQ==",
+ "node_modules/@isaacs/cliui/node_modules/strip-ansi": {
+ "version": "7.1.0",
+ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz",
+ "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==",
"dependencies": {
- "@humanwhocodes/object-schema": "^1.2.1",
- "debug": "^4.1.1",
- "minimatch": "^3.0.5"
+ "ansi-regex": "^6.0.1"
},
"engines": {
- "node": ">=10.10.0"
- }
- },
- "node_modules/@humanwhocodes/module-importer": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz",
- "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==",
- "engines": {
- "node": ">=12.22"
+ "node": ">=12"
},
"funding": {
- "type": "github",
- "url": "https://github.com/sponsors/nzakas"
+ "url": "https://github.com/chalk/strip-ansi?sponsor=1"
}
},
- "node_modules/@humanwhocodes/object-schema": {
- "version": "1.2.1",
- "resolved": "https://registry.npmjs.org/@humanwhocodes/object-schema/-/object-schema-1.2.1.tgz",
- "integrity": "sha512-ZnQMnLV4e7hDlUvw8H+U8ASL02SS2Gn6+9Ac3wGGLIe7+je2AeAOxPY+izIPJDfFDb7eDjev0Us8MO1iFRN8hA=="
+ "node_modules/@isaacs/fs-minipass": {
+ "version": "4.0.1",
+ "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz",
+ "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==",
+ "dependencies": {
+ "minipass": "^7.0.4"
+ },
+ "engines": {
+ "node": ">=18.0.0"
+ }
},
"node_modules/@jridgewell/gen-mapping": {
"version": "0.3.3",
@@ -359,6 +952,15 @@
"node": ">= 8"
}
},
+ "node_modules/@pkgjs/parseargs": {
+ "version": "0.11.0",
+ "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz",
+ "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==",
+ "optional": true,
+ "engines": {
+ "node": ">=14"
+ }
+ },
"node_modules/@pkgr/utils": {
"version": "2.4.2",
"resolved": "https://registry.npmjs.org/@pkgr/utils/-/utils-2.4.2.tgz",
@@ -557,18 +1159,6 @@
"url": "https://opencollective.com/typescript-eslint"
}
},
- "node_modules/@xenova/transformers": {
- "version": "2.4.2",
- "resolved": "https://registry.npmjs.org/@xenova/transformers/-/transformers-2.4.2.tgz",
- "integrity": "sha512-m1QlvNsic/kQJ1F1N02TpYkIBPwB68hZGljO32EM4mHEw4nKlPoQ/9gZ+oUKkavKC/LqgCnmiNQ8jWfa4Zl5AQ==",
- "dependencies": {
- "onnxruntime-web": "1.14.0",
- "sharp": "^0.32.0"
- },
- "optionalDependencies": {
- "onnxruntime-node": "1.14.0"
- }
- },
"node_modules/acorn": {
"version": "8.10.0",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz",
@@ -827,35 +1417,11 @@
"dequal": "^2.0.3"
}
},
- "node_modules/b4a": {
- "version": "1.6.4",
- "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz",
- "integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw=="
- },
"node_modules/balanced-match": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
"integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw=="
},
- "node_modules/base64-js": {
- "version": "1.5.1",
- "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz",
- "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
- },
"node_modules/big-integer": {
"version": "1.6.51",
"resolved": "https://registry.npmjs.org/big-integer/-/big-integer-1.6.51.tgz",
@@ -872,16 +1438,6 @@
"node": ">=8"
}
},
- "node_modules/bl": {
- "version": "4.1.0",
- "resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz",
- "integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==",
- "dependencies": {
- "buffer": "^5.5.0",
- "inherits": "^2.0.4",
- "readable-stream": "^3.4.0"
- }
- },
"node_modules/bplist-parser": {
"version": "0.2.0",
"resolved": "https://registry.npmjs.org/bplist-parser/-/bplist-parser-0.2.0.tgz",
@@ -944,29 +1500,6 @@
"node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7"
}
},
- "node_modules/buffer": {
- "version": "5.7.1",
- "resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz",
- "integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ],
- "dependencies": {
- "base64-js": "^1.3.1",
- "ieee754": "^1.1.13"
- }
- },
"node_modules/bundle-name": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/bundle-name/-/bundle-name-3.0.0.tgz",
@@ -1091,11 +1624,6 @@
"node": ">= 6"
}
},
- "node_modules/chownr": {
- "version": "1.1.4",
- "resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz",
- "integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg=="
- },
"node_modules/client-only": {
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz",
@@ -1196,28 +1724,6 @@
}
}
},
- "node_modules/decompress-response": {
- "version": "6.0.0",
- "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz",
- "integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==",
- "dependencies": {
- "mimic-response": "^3.1.0"
- },
- "engines": {
- "node": ">=10"
- },
- "funding": {
- "url": "https://github.com/sponsors/sindresorhus"
- }
- },
- "node_modules/deep-extend": {
- "version": "0.6.0",
- "resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz",
- "integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==",
- "engines": {
- "node": ">=4.0.0"
- }
- },
"node_modules/deep-is": {
"version": "0.1.4",
"resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz",
@@ -1290,9 +1796,9 @@
}
},
"node_modules/detect-libc": {
- "version": "2.0.2",
- "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz",
- "integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==",
+ "version": "2.0.3",
+ "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.3.tgz",
+ "integrity": "sha512-bwy0MGW55bG41VqxxypOsdSdGqLwXPI/focwgTYCFMbdUiBAxLg9CFzG08sz2aqzknwiX7Hkl0bQENjg8iLByw==",
"engines": {
"node": ">=8"
}
@@ -1329,6 +1835,11 @@
"node": ">=6.0.0"
}
},
+ "node_modules/eastasianwidth": {
+ "version": "0.2.0",
+ "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz",
+ "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA=="
+ },
"node_modules/electron-to-chromium": {
"version": "1.4.468",
"resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.468.tgz",
@@ -1339,14 +1850,6 @@
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz",
"integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg=="
},
- "node_modules/end-of-stream": {
- "version": "1.4.4",
- "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz",
- "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==",
- "dependencies": {
- "once": "^1.4.0"
- }
- },
"node_modules/enhanced-resolve": {
"version": "5.15.0",
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.15.0.tgz",
@@ -1909,24 +2412,11 @@
"url": "https://github.com/sindresorhus/execa?sponsor=1"
}
},
- "node_modules/expand-template": {
- "version": "2.0.3",
- "resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz",
- "integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==",
- "engines": {
- "node": ">=6"
- }
- },
"node_modules/fast-deep-equal": {
"version": "3.1.3",
"resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz",
"integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q=="
},
- "node_modules/fast-fifo": {
- "version": "1.3.0",
- "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.0.tgz",
- "integrity": "sha512-IgfweLvEpwyA4WgiQe9Nx6VV2QkML2NkvZnk1oKnIzXgXdWxuhF7zw4DvLTPZJn6PIUneiAXPF24QmoEqHTjyw=="
- },
"node_modules/fast-glob": {
"version": "3.3.1",
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.1.tgz",
@@ -2038,6 +2528,32 @@
"is-callable": "^1.1.3"
}
},
+ "node_modules/foreground-child": {
+ "version": "3.3.0",
+ "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.0.tgz",
+ "integrity": "sha512-Ld2g8rrAyMYFXBhEqMz8ZAHBi4J4uS1i/CxGMDnjyFWddMXLVcDp051DZfu+t7+ab7Wv6SMqpWmyFIj5UbfFvg==",
+ "dependencies": {
+ "cross-spawn": "^7.0.0",
+ "signal-exit": "^4.0.1"
+ },
+ "engines": {
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/foreground-child/node_modules/signal-exit": {
+ "version": "4.1.0",
+ "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz",
+ "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==",
+ "engines": {
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
"node_modules/fraction.js": {
"version": "4.2.0",
"resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.2.0.tgz",
@@ -2050,11 +2566,6 @@
"url": "https://www.patreon.com/infusion"
}
},
- "node_modules/fs-constants": {
- "version": "1.0.0",
- "resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz",
- "integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow=="
- },
"node_modules/fs.realpath": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz",
@@ -2154,11 +2665,6 @@
"url": "https://github.com/privatenumber/get-tsconfig?sponsor=1"
}
},
- "node_modules/github-from-package": {
- "version": "0.0.0",
- "resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz",
- "integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw=="
- },
"node_modules/glob": {
"version": "7.1.7",
"resolved": "https://registry.npmjs.org/glob/-/glob-7.1.7.tgz",
@@ -2344,25 +2850,6 @@
"node": ">=14.18.0"
}
},
- "node_modules/ieee754": {
- "version": "1.2.1",
- "resolved": "https://registry.npmjs.org/ieee754/-/ieee754-1.2.1.tgz",
- "integrity": "sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
- },
"node_modules/ignore": {
"version": "5.2.4",
"resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz",
@@ -2408,11 +2895,6 @@
"resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz",
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
},
- "node_modules/ini": {
- "version": "1.3.8",
- "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
- "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
- },
"node_modules/internal-slot": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.5.tgz",
@@ -2539,6 +3021,14 @@
"node": ">=0.10.0"
}
},
+ "node_modules/is-fullwidth-code-point": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz",
+ "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==",
+ "engines": {
+ "node": ">=8"
+ }
+ },
"node_modules/is-glob": {
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz",
@@ -2733,6 +3223,20 @@
"resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz",
"integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw=="
},
+ "node_modules/jackspeak": {
+ "version": "3.4.3",
+ "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz",
+ "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==",
+ "dependencies": {
+ "@isaacs/cliui": "^8.0.2"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ },
+ "optionalDependencies": {
+ "@pkgjs/parseargs": "^0.11.0"
+ }
+ },
"node_modules/jiti": {
"version": "1.19.1",
"resolved": "https://registry.npmjs.org/jiti/-/jiti-1.19.1.tgz",
@@ -2849,11 +3353,6 @@
"resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz",
"integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ=="
},
- "node_modules/long": {
- "version": "4.0.0",
- "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz",
- "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA=="
- },
"node_modules/loose-envify": {
"version": "1.4.0",
"resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz",
@@ -2866,15 +3365,9 @@
}
},
"node_modules/lru-cache": {
- "version": "6.0.0",
- "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz",
- "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==",
- "dependencies": {
- "yallist": "^4.0.0"
- },
- "engines": {
- "node": ">=10"
- }
+ "version": "10.4.3",
+ "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz",
+ "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ=="
},
"node_modules/merge-stream": {
"version": "2.0.0",
@@ -2912,17 +3405,6 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
- "node_modules/mimic-response": {
- "version": "3.1.0",
- "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz",
- "integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==",
- "engines": {
- "node": ">=10"
- },
- "funding": {
- "url": "https://github.com/sponsors/sindresorhus"
- }
- },
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
@@ -2939,13 +3421,97 @@
"resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz",
"integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==",
"funding": {
- "url": "https://github.com/sponsors/ljharb"
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/minipass": {
+ "version": "7.1.2",
+ "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz",
+ "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==",
+ "engines": {
+ "node": ">=16 || 14 >=14.17"
+ }
+ },
+ "node_modules/minizlib": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.0.1.tgz",
+ "integrity": "sha512-umcy022ILvb5/3Djuu8LWeqUa8D68JaBzlttKeMWen48SjabqS3iY5w/vzeMzMUNhLDifyhbOwKDSznB1vvrwg==",
+ "dependencies": {
+ "minipass": "^7.0.4",
+ "rimraf": "^5.0.5"
+ },
+ "engines": {
+ "node": ">= 18"
+ }
+ },
+ "node_modules/minizlib/node_modules/brace-expansion": {
+ "version": "2.0.1",
+ "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz",
+ "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==",
+ "dependencies": {
+ "balanced-match": "^1.0.0"
+ }
+ },
+ "node_modules/minizlib/node_modules/glob": {
+ "version": "10.4.5",
+ "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz",
+ "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==",
+ "dependencies": {
+ "foreground-child": "^3.1.0",
+ "jackspeak": "^3.1.2",
+ "minimatch": "^9.0.4",
+ "minipass": "^7.1.2",
+ "package-json-from-dist": "^1.0.0",
+ "path-scurry": "^1.11.1"
+ },
+ "bin": {
+ "glob": "dist/esm/bin.mjs"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/minizlib/node_modules/minimatch": {
+ "version": "9.0.5",
+ "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz",
+ "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==",
+ "dependencies": {
+ "brace-expansion": "^2.0.1"
+ },
+ "engines": {
+ "node": ">=16 || 14 >=14.17"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/minizlib/node_modules/rimraf": {
+ "version": "5.0.10",
+ "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-5.0.10.tgz",
+ "integrity": "sha512-l0OE8wL34P4nJH/H2ffoaniAokM2qSmrtXHmlpvYr5AVVX8msAyW0l8NVJFDxlSK4u3Uh/f41cQheDVdnYijwQ==",
+ "dependencies": {
+ "glob": "^10.3.7"
+ },
+ "bin": {
+ "rimraf": "dist/esm/bin.mjs"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
}
},
- "node_modules/mkdirp-classic": {
- "version": "0.5.3",
- "resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz",
- "integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A=="
+ "node_modules/mkdirp": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-3.0.1.tgz",
+ "integrity": "sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==",
+ "bin": {
+ "mkdirp": "dist/cjs/src/bin.js"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
},
"node_modules/ms": {
"version": "2.1.2",
@@ -2979,11 +3545,6 @@
"node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1"
}
},
- "node_modules/napi-build-utils": {
- "version": "1.0.2",
- "resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-1.0.2.tgz",
- "integrity": "sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg=="
- },
"node_modules/natural-compare": {
"version": "1.4.0",
"resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz",
@@ -3038,22 +3599,6 @@
}
}
},
- "node_modules/node-abi": {
- "version": "3.45.0",
- "resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.45.0.tgz",
- "integrity": "sha512-iwXuFrMAcFVi/ZoZiqq8BzAdsLw9kxDfTC0HMyjXfSL/6CSDAGD5UmR7azrAgWV1zKYq7dUUMj4owusBWKLsiQ==",
- "dependencies": {
- "semver": "^7.3.5"
- },
- "engines": {
- "node": ">=10"
- }
- },
- "node_modules/node-addon-api": {
- "version": "6.1.0",
- "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz",
- "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA=="
- },
"node_modules/node-releases": {
"version": "2.0.13",
"resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz",
@@ -3228,46 +3773,6 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
- "node_modules/onnx-proto": {
- "version": "4.0.4",
- "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-4.0.4.tgz",
- "integrity": "sha512-aldMOB3HRoo6q/phyB6QRQxSt895HNNw82BNyZ2CMh4bjeKv7g/c+VpAFtJuEMVfYLMbRx61hbuqnKceLeDcDA==",
- "dependencies": {
- "protobufjs": "^6.8.8"
- }
- },
- "node_modules/onnxruntime-common": {
- "version": "1.14.0",
- "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz",
- "integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew=="
- },
- "node_modules/onnxruntime-node": {
- "version": "1.14.0",
- "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz",
- "integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==",
- "optional": true,
- "os": [
- "win32",
- "darwin",
- "linux"
- ],
- "dependencies": {
- "onnxruntime-common": "~1.14.0"
- }
- },
- "node_modules/onnxruntime-web": {
- "version": "1.14.0",
- "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz",
- "integrity": "sha512-Kcqf43UMfW8mCydVGcX9OMXI2VN17c0p6XvR7IPSZzBf/6lteBzXHvcEVWDPmCKuGombl997HgLqj91F11DzXw==",
- "dependencies": {
- "flatbuffers": "^1.12.0",
- "guid-typescript": "^1.0.9",
- "long": "^4.0.0",
- "onnx-proto": "^4.0.4",
- "onnxruntime-common": "~1.14.0",
- "platform": "^1.3.6"
- }
- },
"node_modules/open": {
"version": "9.1.0",
"resolved": "https://registry.npmjs.org/open/-/open-9.1.0.tgz",
@@ -3329,6 +3834,11 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
+ "node_modules/package-json-from-dist": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.0.tgz",
+ "integrity": "sha512-dATvCeZN/8wQsGywez1mzHtTlP22H8OEfPrVMLNr4/eGa+ijtLn/6M5f0dY8UKNrC2O9UCU6SSoG3qRKnt7STw=="
+ },
"node_modules/parent-module": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz",
@@ -3369,6 +3879,21 @@
"resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz",
"integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw=="
},
+ "node_modules/path-scurry": {
+ "version": "1.11.1",
+ "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz",
+ "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==",
+ "dependencies": {
+ "lru-cache": "^10.2.0",
+ "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0"
+ },
+ "engines": {
+ "node": ">=16 || 14 >=14.18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
"node_modules/path-type": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz",
@@ -3538,57 +4063,6 @@
"resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz",
"integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ=="
},
- "node_modules/prebuild-install": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz",
- "integrity": "sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==",
- "dependencies": {
- "detect-libc": "^2.0.0",
- "expand-template": "^2.0.3",
- "github-from-package": "0.0.0",
- "minimist": "^1.2.3",
- "mkdirp-classic": "^0.5.3",
- "napi-build-utils": "^1.0.1",
- "node-abi": "^3.3.0",
- "pump": "^3.0.0",
- "rc": "^1.2.7",
- "simple-get": "^4.0.0",
- "tar-fs": "^2.0.0",
- "tunnel-agent": "^0.6.0"
- },
- "bin": {
- "prebuild-install": "bin.js"
- },
- "engines": {
- "node": ">=10"
- }
- },
- "node_modules/prebuild-install/node_modules/tar-fs": {
- "version": "2.1.1",
- "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz",
- "integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==",
- "dependencies": {
- "chownr": "^1.1.1",
- "mkdirp-classic": "^0.5.2",
- "pump": "^3.0.0",
- "tar-stream": "^2.1.4"
- }
- },
- "node_modules/prebuild-install/node_modules/tar-stream": {
- "version": "2.2.0",
- "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz",
- "integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==",
- "dependencies": {
- "bl": "^4.0.3",
- "end-of-stream": "^1.4.1",
- "fs-constants": "^1.0.0",
- "inherits": "^2.0.3",
- "readable-stream": "^3.1.1"
- },
- "engines": {
- "node": ">=6"
- }
- },
"node_modules/prelude-ls": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz",
@@ -3635,15 +4109,6 @@
"resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz",
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q=="
},
- "node_modules/pump": {
- "version": "3.0.0",
- "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz",
- "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==",
- "dependencies": {
- "end-of-stream": "^1.1.0",
- "once": "^1.3.1"
- }
- },
"node_modules/punycode": {
"version": "2.3.0",
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz",
@@ -3671,33 +4136,6 @@
}
]
},
- "node_modules/queue-tick": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz",
- "integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag=="
- },
- "node_modules/rc": {
- "version": "1.2.8",
- "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz",
- "integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==",
- "dependencies": {
- "deep-extend": "^0.6.0",
- "ini": "~1.3.0",
- "minimist": "^1.2.0",
- "strip-json-comments": "~2.0.1"
- },
- "bin": {
- "rc": "cli.js"
- }
- },
- "node_modules/rc/node_modules/strip-json-comments": {
- "version": "2.0.1",
- "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz",
- "integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==",
- "engines": {
- "node": ">=0.10.0"
- }
- },
"node_modules/react": {
"version": "18.2.0",
"resolved": "https://registry.npmjs.org/react/-/react-18.2.0.tgz",
@@ -3734,19 +4172,6 @@
"pify": "^2.3.0"
}
},
- "node_modules/readable-stream": {
- "version": "3.6.2",
- "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz",
- "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==",
- "dependencies": {
- "inherits": "^2.0.3",
- "string_decoder": "^1.1.1",
- "util-deprecate": "^1.0.1"
- },
- "engines": {
- "node": ">= 6"
- }
- },
"node_modules/readdirp": {
"version": "3.6.0",
"resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz",
@@ -3969,25 +4394,6 @@
"url": "https://github.com/sponsors/ljharb"
}
},
- "node_modules/safe-buffer": {
- "version": "5.2.1",
- "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",
- "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
- },
"node_modules/safe-regex-test": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/safe-regex-test/-/safe-regex-test-1.0.0.tgz",
@@ -4010,12 +4416,9 @@
}
},
"node_modules/semver": {
- "version": "7.5.4",
- "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
- "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
- "dependencies": {
- "lru-cache": "^6.0.0"
- },
+ "version": "7.6.3",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-7.6.3.tgz",
+ "integrity": "sha512-oVekP1cKtI+CTDvHWYFUcMtsK/00wmAEfyqKfNdARm8u1wNVhSgaX7A8d4UuIlUI5e84iEwOhs7ZPYRmzU9U6A==",
"bin": {
"semver": "bin/semver.js"
},
@@ -4023,28 +4426,6 @@
"node": ">=10"
}
},
- "node_modules/sharp": {
- "version": "0.32.6",
- "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz",
- "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==",
- "hasInstallScript": true,
- "dependencies": {
- "color": "^4.2.3",
- "detect-libc": "^2.0.2",
- "node-addon-api": "^6.1.0",
- "prebuild-install": "^7.1.1",
- "semver": "^7.5.4",
- "simple-get": "^4.0.1",
- "tar-fs": "^3.0.4",
- "tunnel-agent": "^0.6.0"
- },
- "engines": {
- "node": ">=14.15.0"
- },
- "funding": {
- "url": "https://opencollective.com/libvips"
- }
- },
"node_modules/shebang-command": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz",
@@ -4082,49 +4463,6 @@
"resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz",
"integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ=="
},
- "node_modules/simple-concat": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz",
- "integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
- },
- "node_modules/simple-get": {
- "version": "4.0.1",
- "resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz",
- "integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ],
- "dependencies": {
- "decompress-response": "^6.0.0",
- "once": "^1.3.1",
- "simple-concat": "^1.0.0"
- }
- },
"node_modules/simple-swizzle": {
"version": "0.2.2",
"resolved": "https://registry.npmjs.org/simple-swizzle/-/simple-swizzle-0.2.2.tgz",
@@ -4157,21 +4495,64 @@
"node": ">=10.0.0"
}
},
- "node_modules/streamx": {
- "version": "2.15.0",
- "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.0.tgz",
- "integrity": "sha512-HcxY6ncGjjklGs1xsP1aR71INYcsXFJet5CU1CHqihQ2J5nOsbd4OjgjHO42w/4QNv9gZb3BueV+Vxok5pLEXg==",
+ "node_modules/string-width": {
+ "version": "5.1.2",
+ "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz",
+ "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==",
"dependencies": {
- "fast-fifo": "^1.1.0",
- "queue-tick": "^1.0.1"
+ "eastasianwidth": "^0.2.0",
+ "emoji-regex": "^9.2.2",
+ "strip-ansi": "^7.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
}
},
- "node_modules/string_decoder": {
- "version": "1.3.0",
- "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz",
- "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==",
+ "node_modules/string-width-cjs": {
+ "name": "string-width",
+ "version": "4.2.3",
+ "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
+ "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
+ "dependencies": {
+ "emoji-regex": "^8.0.0",
+ "is-fullwidth-code-point": "^3.0.0",
+ "strip-ansi": "^6.0.1"
+ },
+ "engines": {
+ "node": ">=8"
+ }
+ },
+ "node_modules/string-width-cjs/node_modules/emoji-regex": {
+ "version": "8.0.0",
+ "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
+ "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A=="
+ },
+ "node_modules/string-width/node_modules/ansi-regex": {
+ "version": "6.0.1",
+ "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.0.1.tgz",
+ "integrity": "sha512-n5M855fKb2SsfMIiFFoVrABHJC8QtHwVx+mHWP3QcEqBHYienj5dHSgjbxtC0WEZXYt4wcD6zrQElDPhFuZgfA==",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-regex?sponsor=1"
+ }
+ },
+ "node_modules/string-width/node_modules/strip-ansi": {
+ "version": "7.1.0",
+ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz",
+ "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==",
"dependencies": {
- "safe-buffer": "~5.2.0"
+ "ansi-regex": "^6.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/strip-ansi?sponsor=1"
}
},
"node_modules/string.prototype.matchall": {
@@ -4245,6 +4626,18 @@
"node": ">=8"
}
},
+ "node_modules/strip-ansi-cjs": {
+ "name": "strip-ansi",
+ "version": "6.0.1",
+ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
+ "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
+ "dependencies": {
+ "ansi-regex": "^5.0.1"
+ },
+ "engines": {
+ "node": ">=8"
+ }
+ },
"node_modules/strip-bom": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz",
@@ -4418,24 +4811,36 @@
"node": ">=6"
}
},
- "node_modules/tar-fs": {
- "version": "3.0.4",
- "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz",
- "integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==",
+ "node_modules/tar": {
+ "version": "7.4.3",
+ "resolved": "https://registry.npmjs.org/tar/-/tar-7.4.3.tgz",
+ "integrity": "sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==",
"dependencies": {
- "mkdirp-classic": "^0.5.2",
- "pump": "^3.0.0",
- "tar-stream": "^3.1.5"
+ "@isaacs/fs-minipass": "^4.0.0",
+ "chownr": "^3.0.0",
+ "minipass": "^7.1.2",
+ "minizlib": "^3.0.1",
+ "mkdirp": "^3.0.1",
+ "yallist": "^5.0.0"
+ },
+ "engines": {
+ "node": ">=18"
}
},
- "node_modules/tar-stream": {
- "version": "3.1.6",
- "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz",
- "integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==",
- "dependencies": {
- "b4a": "^1.6.4",
- "fast-fifo": "^1.2.0",
- "streamx": "^2.15.0"
+ "node_modules/tar/node_modules/chownr": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz",
+ "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==",
+ "engines": {
+ "node": ">=18"
+ }
+ },
+ "node_modules/tar/node_modules/yallist": {
+ "version": "5.0.0",
+ "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz",
+ "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==",
+ "engines": {
+ "node": ">=18"
}
},
"node_modules/text-table": {
@@ -4524,17 +4929,6 @@
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
},
- "node_modules/tunnel-agent": {
- "version": "0.6.0",
- "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz",
- "integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==",
- "dependencies": {
- "safe-buffer": "^5.0.1"
- },
- "engines": {
- "node": "*"
- }
- },
"node_modules/type-check": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz",
@@ -4742,16 +5136,98 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/wrap-ansi": {
+ "version": "8.1.0",
+ "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz",
+ "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==",
+ "dependencies": {
+ "ansi-styles": "^6.1.0",
+ "string-width": "^5.0.1",
+ "strip-ansi": "^7.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/wrap-ansi?sponsor=1"
+ }
+ },
+ "node_modules/wrap-ansi-cjs": {
+ "name": "wrap-ansi",
+ "version": "7.0.0",
+ "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz",
+ "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==",
+ "dependencies": {
+ "ansi-styles": "^4.0.0",
+ "string-width": "^4.1.0",
+ "strip-ansi": "^6.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/wrap-ansi?sponsor=1"
+ }
+ },
+ "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": {
+ "version": "8.0.0",
+ "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
+ "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A=="
+ },
+ "node_modules/wrap-ansi-cjs/node_modules/string-width": {
+ "version": "4.2.3",
+ "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
+ "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
+ "dependencies": {
+ "emoji-regex": "^8.0.0",
+ "is-fullwidth-code-point": "^3.0.0",
+ "strip-ansi": "^6.0.1"
+ },
+ "engines": {
+ "node": ">=8"
+ }
+ },
+ "node_modules/wrap-ansi/node_modules/ansi-regex": {
+ "version": "6.0.1",
+ "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.0.1.tgz",
+ "integrity": "sha512-n5M855fKb2SsfMIiFFoVrABHJC8QtHwVx+mHWP3QcEqBHYienj5dHSgjbxtC0WEZXYt4wcD6zrQElDPhFuZgfA==",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-regex?sponsor=1"
+ }
+ },
+ "node_modules/wrap-ansi/node_modules/ansi-styles": {
+ "version": "6.2.1",
+ "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz",
+ "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-styles?sponsor=1"
+ }
+ },
+ "node_modules/wrap-ansi/node_modules/strip-ansi": {
+ "version": "7.1.0",
+ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz",
+ "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==",
+ "dependencies": {
+ "ansi-regex": "^6.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/strip-ansi?sponsor=1"
+ }
+ },
"node_modules/wrappy": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz",
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="
},
- "node_modules/yallist": {
- "version": "4.0.0",
- "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",
- "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A=="
- },
"node_modules/yaml": {
"version": "2.3.1",
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.3.1.tgz",
diff --git a/examples/next-client/package.json b/examples/next-client/package.json
index 814c663c9..7bccaea67 100644
--- a/examples/next-client/package.json
+++ b/examples/next-client/package.json
@@ -9,7 +9,7 @@
"lint": "next lint"
},
"dependencies": {
- "@xenova/transformers": "^2.4.2",
+ "@huggingface/transformers": "^3.0.0-alpha.5",
"autoprefixer": "10.4.14",
"eslint": "8.45.0",
"eslint-config-next": "13.4.12",
diff --git a/examples/next-client/src/app/worker.js b/examples/next-client/src/app/worker.js
index c7704df8a..4b9960009 100644
--- a/examples/next-client/src/app/worker.js
+++ b/examples/next-client/src/app/worker.js
@@ -1,7 +1,4 @@
-import { pipeline, env } from "@xenova/transformers";
-
-// Skip local model check
-env.allowLocalModels = false;
+import { pipeline } from "@huggingface/transformers";
// Use the Singleton pattern to enable lazy construction of the pipeline.
class PipelineSingleton {
@@ -10,9 +7,7 @@ class PipelineSingleton {
static instance = null;
static async getInstance(progress_callback = null) {
- if (this.instance === null) {
- this.instance = pipeline(this.task, this.model, { progress_callback });
- }
+ this.instance ??= pipeline(this.task, this.model, { progress_callback });
return this.instance;
}
}
@@ -21,14 +16,14 @@ class PipelineSingleton {
self.addEventListener('message', async (event) => {
// Retrieve the classification pipeline. When called for the first time,
// this will load the pipeline and save it for future use.
- let classifier = await PipelineSingleton.getInstance(x => {
+ const classifier = await PipelineSingleton.getInstance(x => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
// Actually perform the classification
- let output = await classifier(event.data.text);
+ const output = await classifier(event.data.text);
// Send the output back to the main thread
self.postMessage({
diff --git a/examples/remove-background-client/index.html b/examples/remove-background-client/index.html
index d20f9eaba..a85cef65f 100644
--- a/examples/remove-background-client/index.html
+++ b/examples/remove-background-client/index.html
@@ -8,7 +8,7 @@
- Background Removal w/ 🤗 Transformers.js
+
Runs locally in your browser, powered by the RMBG V1.4 model from BRIA AI
diff --git a/examples/segment-anything-client/.gitignore b/examples/segment-anything-client/.gitignore
new file mode 100644
index 000000000..1521c8b76
--- /dev/null
+++ b/examples/segment-anything-client/.gitignore
@@ -0,0 +1 @@
+dist
diff --git a/examples/segment-anything-client/index.css b/examples/segment-anything-client/index.css
index a896b8846..fc556bcac 100644
--- a/examples/segment-anything-client/index.css
+++ b/examples/segment-anything-client/index.css
@@ -23,7 +23,7 @@ body,
align-items: center;
}
-h1 {
+h1, h3 {
text-align: center;
}
diff --git a/examples/segment-anything-client/index.html b/examples/segment-anything-client/index.html
index 5e8a2e9b9..9dba925fe 100644
--- a/examples/segment-anything-client/index.html
+++ b/examples/segment-anything-client/index.html
@@ -6,11 +6,13 @@
- Transformers.js - Segment Anything
+ Transformers.js - Segment Anything WebGPU
- Segment Anything w/ 🤗 Transformers.js
+ Segment Anything WebGPU
+
diff --git a/examples/segment-anything-client/index.js b/examples/segment-anything-client/index.js
index e01b59c49..979db0582 100644
--- a/examples/segment-anything-client/index.js
+++ b/examples/segment-anything-client/index.js
@@ -23,9 +23,10 @@ const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/re
const EXAMPLE_URL = BASE_URL + 'corgi.jpg';
// Create a web worker so that the main (UI) thread is not blocked during inference.
-const worker = new Worker('worker.js', {
- type: 'module',
-});
+const worker = new Worker(
+ new URL('./worker.js', import.meta.url),
+ { type: 'module' }
+);
// Preload star and cross images to avoid lag on first click
const star = new Image();
diff --git a/examples/segment-anything-client/package.json b/examples/segment-anything-client/package.json
new file mode 100644
index 000000000..aa790ea74
--- /dev/null
+++ b/examples/segment-anything-client/package.json
@@ -0,0 +1,17 @@
+{
+ "name": "segment-anything-client",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@huggingface/transformers": "^3.0.0-alpha.0"
+ },
+ "devDependencies": {
+ "vite": "^5.2.9"
+ }
+}
diff --git a/examples/segment-anything-client/vite.config.js b/examples/segment-anything-client/vite.config.js
new file mode 100644
index 000000000..b2d374d1b
--- /dev/null
+++ b/examples/segment-anything-client/vite.config.js
@@ -0,0 +1,18 @@
+import { defineConfig } from 'vite';
+export default defineConfig(env => {
+ const config = {
+ build: {
+ target: 'esnext'
+ }
+ };
+
+ // TODO: Add this back when .wasm files are served locally
+ // if (env.mode === 'development') {
+ // // The .wasm files are not correctly served using Vite in development mode.
+ // // This is a workaround to exclude the onnxruntime-web package from Vite's optimization.
+ // // See also: https://github.com/vitejs/vite/issues/8427
+ // config.optimizeDeps = { exclude: ["onnxruntime-web"] };
+ // }
+
+ return config;
+});
diff --git a/examples/segment-anything-client/worker.js b/examples/segment-anything-client/worker.js
index 5dd636973..bb783e0b5 100644
--- a/examples/segment-anything-client/worker.js
+++ b/examples/segment-anything-client/worker.js
@@ -1,33 +1,25 @@
-import { env, SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.14.0';
-
-// Since we will download the model from the Hugging Face Hub, we can skip the local model check
-env.allowLocalModels = false;
+import { SamModel, AutoProcessor, RawImage, Tensor } from '@huggingface/transformers';
// We adopt the singleton pattern to enable lazy-loading of the model and processor.
export class SegmentAnythingSingleton {
static model_id = 'Xenova/slimsam-77-uniform';
static model;
static processor;
- static quantized = true;
static getInstance() {
- if (!this.model) {
- this.model = SamModel.from_pretrained(this.model_id, {
- quantized: this.quantized,
- });
- }
- if (!this.processor) {
- this.processor = AutoProcessor.from_pretrained(this.model_id);
- }
+ this.model ??= SamModel.from_pretrained(this.model_id, {
+ dtype: 'fp16',
+ device: 'webgpu',
+ });
+ this.processor ??= AutoProcessor.from_pretrained(this.model_id);
return Promise.all([this.model, this.processor]);
}
}
-
// State variables
-let image_embeddings = null;
-let image_inputs = null;
+let imageEmbeddings = null;
+let imageInputs = null;
let ready = false;
self.onmessage = async (e) => {
@@ -42,8 +34,8 @@ self.onmessage = async (e) => {
const { type, data } = e.data;
if (type === 'reset') {
- image_inputs = null;
- image_embeddings = null;
+ imageInputs = null;
+ imageEmbeddings = null;
} else if (type === 'segment') {
// Indicate that we are starting to segment the image
@@ -54,8 +46,8 @@ self.onmessage = async (e) => {
// Read the image and recompute image embeddings
const image = await RawImage.read(e.data.data);
- image_inputs = await processor(image);
- image_embeddings = await model.get_image_embeddings(image_inputs)
+ imageInputs = await processor(image);
+ imageEmbeddings = await model.get_image_embeddings(imageInputs)
// Indicate that we have computed the image embeddings, and we are ready to accept decoding requests
self.postMessage({
@@ -65,7 +57,7 @@ self.onmessage = async (e) => {
} else if (type === 'decode') {
// Prepare inputs for decoding
- const reshaped = image_inputs.reshaped_input_sizes[0];
+ const reshaped = imageInputs.reshaped_input_sizes[0];
const points = data.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
const labels = data.map(x => BigInt(x.label));
@@ -81,17 +73,17 @@ self.onmessage = async (e) => {
)
// Generate the mask
- const outputs = await model({
- ...image_embeddings,
+ const { pred_masks, iou_scores } = await model({
+ ...imageEmbeddings,
input_points,
input_labels,
})
// Post-process the mask
const masks = await processor.post_process_masks(
- outputs.pred_masks,
- image_inputs.original_sizes,
- image_inputs.reshaped_input_sizes,
+ pred_masks,
+ imageInputs.original_sizes,
+ imageInputs.reshaped_input_sizes,
);
// Send the result back to the main thread
@@ -99,7 +91,7 @@ self.onmessage = async (e) => {
type: 'decode_result',
data: {
mask: RawImage.fromTensor(masks[0][0]),
- scores: outputs.iou_scores.data,
+ scores: iou_scores.data,
},
});
diff --git a/examples/text-to-speech-client/src/worker.js b/examples/text-to-speech-client/src/worker.js
index 76b8f76ef..4644890d3 100644
--- a/examples/text-to-speech-client/src/worker.js
+++ b/examples/text-to-speech-client/src/worker.js
@@ -25,14 +25,14 @@ class MyTextToSpeechPipeline {
if (this.model_instance === null) {
this.model_instance = SpeechT5ForTextToSpeech.from_pretrained(this.model_id, {
- quantized: false,
+ dtype: 'fp32',
progress_callback,
});
}
if (this.vocoder_instance === null) {
this.vocoder_instance = SpeechT5HifiGan.from_pretrained(this.vocoder_id, {
- quantized: false,
+ dtype: 'fp32',
progress_callback,
});
}
diff --git a/examples/tokenizer-playground/src/App.jsx b/examples/tokenizer-playground/src/App.jsx
index 1e1a286c3..1307e41c8 100644
--- a/examples/tokenizer-playground/src/App.jsx
+++ b/examples/tokenizer-playground/src/App.jsx
@@ -98,7 +98,7 @@ function App() {
The Tokenizer Playground
-
Experiment with different tokenizers (running locally in your browser).
+
Experiment with different tokenizers (running locally in your browser).
diff --git a/examples/video-object-detection/index.html b/examples/video-object-detection/index.html
index 680b1e3bf..bd731f32c 100644
--- a/examples/video-object-detection/index.html
+++ b/examples/video-object-detection/index.html
@@ -10,7 +10,7 @@
Runs locally in your browser, powered by
diff --git a/examples/video-object-detection/main.js b/examples/video-object-detection/main.js
index 5eea3aa91..12c3552a4 100644
--- a/examples/video-object-detection/main.js
+++ b/examples/video-object-detection/main.js
@@ -18,6 +18,13 @@ const thresholdSlider = document.getElementById('threshold');
const thresholdLabel = document.getElementById('threshold-value');
const sizeSlider = document.getElementById('size');
const sizeLabel = document.getElementById('size-value');
+const scaleSlider = document.getElementById('scale');
+const scaleLabel = document.getElementById('scale-value');
+
+function setStreamSize(width, height) {
+ video.width = canvas.width = Math.round(width);
+ video.height = canvas.height = Math.round(height);
+}
status.textContent = 'Loading model...';
@@ -27,6 +34,14 @@ const model = await AutoModel.from_pretrained(model_id);
const processor = await AutoProcessor.from_pretrained(model_id);
// Set up controls
+let scale = 0.5;
+scaleSlider.addEventListener('input', () => {
+ scale = Number(scaleSlider.value);
+ setStreamSize(video.videoWidth * scale, video.videoHeight * scale);
+ scaleLabel.textContent = scale;
+});
+scaleSlider.disabled = false;
+
let threshold = 0.25;
thresholdSlider.addEventListener('input', () => {
threshold = Number(thresholdSlider.value);
@@ -130,8 +145,7 @@ navigator.mediaDevices.getUserMedia(
const videoTrack = stream.getVideoTracks()[0];
const { width, height } = videoTrack.getSettings();
- canvas.width = width;
- canvas.height = height;
+ setStreamSize(width * scale, height * scale);
// Set container width and height depending on the image aspect ratio
const ar = width / height;
diff --git a/examples/webgpu-chat/.eslintrc.cjs b/examples/webgpu-chat/.eslintrc.cjs
new file mode 100644
index 000000000..ce8fffe57
--- /dev/null
+++ b/examples/webgpu-chat/.eslintrc.cjs
@@ -0,0 +1,21 @@
+module.exports = {
+ root: true,
+ env: { browser: true, es2020: true },
+ extends: [
+ 'eslint:recommended',
+ 'plugin:react/recommended',
+ 'plugin:react/jsx-runtime',
+ 'plugin:react-hooks/recommended',
+ ],
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
+ parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
+ settings: { react: { version: '18.2' } },
+ plugins: ['react-refresh'],
+ rules: {
+ 'react-refresh/only-export-components': [
+ 'warn',
+ { allowConstantExport: true },
+ ],
+ 'react/prop-types': 'off'
+ },
+}
diff --git a/examples/webgpu-chat/.gitignore b/examples/webgpu-chat/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/webgpu-chat/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/webgpu-chat/README.md b/examples/webgpu-chat/README.md
new file mode 100644
index 000000000..f768e33fc
--- /dev/null
+++ b/examples/webgpu-chat/README.md
@@ -0,0 +1,8 @@
+# React + Vite
+
+This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
+
+Currently, two official plugins are available:
+
+- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
+- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
diff --git a/examples/webgpu-chat/index.html b/examples/webgpu-chat/index.html
new file mode 100644
index 000000000..404c33b9a
--- /dev/null
+++ b/examples/webgpu-chat/index.html
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+ Phi-3 WebGPU
+
+
+
+
+
+
diff --git a/examples/webgpu-chat/package.json b/examples/webgpu-chat/package.json
new file mode 100644
index 000000000..34e6e95e6
--- /dev/null
+++ b/examples/webgpu-chat/package.json
@@ -0,0 +1,32 @@
+{
+ "name": "webgpu-chat",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@xenova/transformers": "github:xenova/transformers.js#v3",
+ "dompurify": "^3.1.2",
+ "marked": "^12.0.2",
+ "react": "^18.2.0",
+ "react-dom": "^18.2.0"
+ },
+ "devDependencies": {
+ "@types/react": "^18.2.43",
+ "@types/react-dom": "^18.2.17",
+ "@vitejs/plugin-react": "^4.2.1",
+ "autoprefixer": "^10.4.19",
+ "eslint": "^8.55.0",
+ "eslint-plugin-react": "^7.33.2",
+ "eslint-plugin-react-hooks": "^4.6.0",
+ "eslint-plugin-react-refresh": "^0.4.5",
+ "postcss": "^8.4.38",
+ "tailwindcss": "^3.4.3",
+ "vite": "^5.2.11"
+ }
+}
diff --git a/examples/webgpu-chat/postcss.config.js b/examples/webgpu-chat/postcss.config.js
new file mode 100644
index 000000000..2e7af2b7f
--- /dev/null
+++ b/examples/webgpu-chat/postcss.config.js
@@ -0,0 +1,6 @@
+export default {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/examples/webgpu-chat/public/logo.png b/examples/webgpu-chat/public/logo.png
new file mode 100644
index 000000000..73ecb940a
Binary files /dev/null and b/examples/webgpu-chat/public/logo.png differ
diff --git a/examples/webgpu-chat/src/App.jsx b/examples/webgpu-chat/src/App.jsx
new file mode 100644
index 000000000..fd5e3124e
--- /dev/null
+++ b/examples/webgpu-chat/src/App.jsx
@@ -0,0 +1,282 @@
+import { useEffect, useState, useRef } from 'react';
+
+import Chat from './components/Chat';
+import ArrowRightIcon from './components/icons/ArrowRightIcon';
+import StopIcon from './components/icons/StopIcon';
+import Progress from './components/Progress';
+
+const IS_WEBGPU_AVAILABLE = !!navigator.gpu;
+const STICKY_SCROLL_THRESHOLD = 120;
+
+function App() {
+
+ // Create a reference to the worker object.
+ const worker = useRef(null);
+
+ const textareaRef = useRef(null);
+ const chatContainerRef = useRef(null);
+
+ // Model loading and progress
+ const [status, setStatus] = useState(null);
+ const [loadingMessage, setLoadingMessage] = useState('');
+ const [progressItems, setProgressItems] = useState([]);
+ const [isRunning, setIsRunning] = useState(false);
+
+ // Inputs and outputs
+ const [input, setInput] = useState('');
+ const [messages, setMessages] = useState([]);
+ const [tps, setTps] = useState(null);
+ const [numTokens, setNumTokens] = useState(null);
+
+ function onEnter(message) {
+ setMessages(prev => [
+ ...prev,
+ { "role": "user", "content": message },
+ ]);
+ setTps(null);
+ setIsRunning(true);
+ setInput('');
+ }
+
+ useEffect(() => {
+ resizeInput();
+ }, [input]);
+
+ function onInterrupt() {
+ // NOTE: We do not set isRunning to false here because the worker
+ // will send a 'complete' message when it is done.
+ worker.current.postMessage({ type: 'interrupt' });
+ }
+
+ function resizeInput() {
+ if (!textareaRef.current) return;
+
+ const target = textareaRef.current;
+ target.style.height = 'auto';
+ const newHeight = Math.min(Math.max(target.scrollHeight, 24), 200);
+ target.style.height = `${newHeight}px`;
+ }
+
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
+ useEffect(() => {
+ if (!worker.current) {
+ // Create the worker if it does not yet exist.
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
+ type: 'module'
+ });
+ }
+
+ // Create a callback function for messages from the worker thread.
+ const onMessageReceived = (e) => {
+ switch (e.data.status) {
+ case 'loading':
+ // Model file start load: add a new progress item to the list.
+ setStatus('loading');
+ setLoadingMessage(e.data.data);
+ break;
+
+ case 'initiate':
+ setProgressItems(prev => [...prev, e.data]);
+ break;
+
+ case 'progress':
+ // Model file progress: update one of the progress items.
+ setProgressItems(
+ prev => prev.map(item => {
+ if (item.file === e.data.file) {
+ return { ...item, ...e.data }
+ }
+ return item;
+ })
+ );
+ break;
+
+ case 'done':
+ // Model file loaded: remove the progress item from the list.
+ setProgressItems(
+ prev => prev.filter(item => item.file !== e.data.file)
+ );
+ break;
+
+ case 'ready':
+ // Pipeline ready: the worker is ready to accept messages.
+ setStatus('ready');
+ break;
+
+ case 'start': {
+ // Start generation
+ setMessages(prev => [...prev, { "role": "assistant", "content": "" }]);
+ }
+ break;
+
+ case 'update': {
+ // Generation update: update the output text.
+ // Parse messages
+ const { output, tps, numTokens } = e.data;
+ setTps(tps);
+ setNumTokens(numTokens)
+ setMessages(prev => {
+ const cloned = [...prev];
+ const last = cloned.at(-1);
+ cloned[cloned.length - 1] = { ...last, content: last.content + output };
+ return cloned;
+ });
+ }
+ break;
+
+ case 'complete':
+ // Generation complete: re-enable the "Generate" button
+ setIsRunning(false);
+ break;
+ }
+ };
+
+ // Attach the callback function as an event listener.
+ worker.current.addEventListener('message', onMessageReceived);
+
+ // Define a cleanup function for when the component is unmounted.
+ return () => {
+ worker.current.removeEventListener('message', onMessageReceived);
+ };
+ }, []);
+
+ // Send the messages to the worker thread whenever the `messages` state changes.
+ useEffect(() => {
+ if (messages.filter(x => x.role === 'user').length === 0) {
+ // No user messages yet: do nothing.
+ return;
+ }
+ if (messages.at(-1).role === 'assistant') {
+ // Do not update if the last message is from the assistant
+ return;
+ }
+ setTps(null);
+ worker.current.postMessage({ type: 'generate', data: messages });
+ }, [messages, isRunning]);
+
+ useEffect(() => {
+ if (!chatContainerRef.current) return;
+ if (isRunning) {
+ const element = chatContainerRef.current;
+ if (element.scrollHeight - element.scrollTop - element.clientHeight < STICKY_SCROLL_THRESHOLD) {
+ element.scrollTop = element.scrollHeight;
+ }
+ }
+ }, [messages, isRunning]);
+
+ return (
+ IS_WEBGPU_AVAILABLE
+ ? (
+
+ {status === null && messages.length === 0 && (
+
+
+
+
Phi-3 WebGPU
+
A private and powerful AI chatbot that runs locally in your browser.
+
+
+
+
+
+ You are about to load Phi-3-mini-4k-instruct ,
+ a 3.82 billion parameter LLM that is optimized for inference on the web. Once downloaded, the model (2.3 GB) will be cached and reused when you revisit the page.
+
+ Everything runs directly in your browser using 🤗 Transformers.js and ONNX Runtime Web, meaning your conversations aren't sent to a server. You can even disconnect from the internet after the model has loaded!
+
+
+
{
+ worker.current.postMessage({ type: 'load' });
+ setStatus('loading');
+ }}
+ disabled={status !== null}
+ >
+ Load model
+
+
+
+ )}
+ {status === 'loading' && (<>
+
+
{loadingMessage}
+ {progressItems.map(({ file, progress, total }, i) => (
+
+ ))}
+
+ >)}
+
+ {status === 'ready' && (
+
+
+ {tps && messages.length > 0 && (<>
+ {!isRunning &&
+ Generated {numTokens} tokens in {(numTokens / tps).toFixed(2)} seconds ( }
+ {<>
+
+ {tps.toFixed(2)}
+
+ tokens/second
+ >}
+ {!isRunning && <>
+ ).
+ {
+ worker.current.postMessage({ type: 'reset' });
+ setMessages([]);
+ }}>Reset
+ >}
+ >)}
+
+
)}
+
+
+
+
+ Disclaimer: Generated content may be inaccurate or false.
+
+
)
+ : (WebGPU is not supported by this browser :(
)
+ )
+}
+
+export default App
diff --git a/examples/webgpu-chat/src/components/Chat.css b/examples/webgpu-chat/src/components/Chat.css
new file mode 100644
index 000000000..f8ab98d4b
--- /dev/null
+++ b/examples/webgpu-chat/src/components/Chat.css
@@ -0,0 +1,112 @@
+@scope (.markdown) {
+
+ /* Code blocks */
+ pre {
+ margin: 0.5rem 0;
+ white-space: break-spaces;
+ }
+
+ code {
+ padding: 0.2em 0.4em;
+ border-radius: 4px;
+ font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
+ font-size: 0.9em;
+ }
+
+ pre,
+ code {
+ background-color: #f2f2f2;
+ }
+
+ @media (prefers-color-scheme: dark) {
+
+ pre,
+ code {
+ background-color: #333;
+ }
+
+ }
+
+ pre:has(code) {
+ padding: 1rem 0.5rem;
+ }
+
+ pre>code {
+ padding: 0;
+ }
+
+ /* Headings */
+ h1,
+ h2,
+ h3,
+ h4,
+ h5,
+ h6 {
+ font-weight: 600;
+ line-height: 1.2;
+ }
+
+ h1 {
+ font-size: 2em;
+ margin: 1rem 0;
+ }
+
+ h2 {
+ font-size: 1.5em;
+ margin: 0.83rem 0;
+ }
+
+ h3 {
+ font-size: 1.25em;
+ margin: 0.67rem 0;
+ }
+
+ h4 {
+ font-size: 1em;
+ margin: 0.5rem 0;
+ }
+
+ h5 {
+ font-size: 0.875em;
+ margin: 0.33rem 0;
+ }
+
+ h6 {
+ font-size: 0.75em;
+ margin: 0.25rem 0;
+ }
+
+ h1,
+ h2,
+ h3,
+ h4,
+ h5,
+ h6:first-child {
+ margin-top: 0;
+ }
+
+ /* Unordered List */
+ ul {
+ list-style-type: disc;
+ margin-left: 1.5rem;
+ }
+
+ /* Ordered List */
+ ol {
+ list-style-type: decimal;
+ margin-left: 1.5rem;
+ }
+
+ /* List Items */
+ li {
+ margin: 0.25rem 0;
+ }
+
+ p:not(:first-child) {
+ margin-top: 0.75rem;
+ }
+
+ p:not(:last-child) {
+ margin-bottom: 0.75rem;
+ }
+}
\ No newline at end of file
diff --git a/examples/webgpu-chat/src/components/Chat.jsx b/examples/webgpu-chat/src/components/Chat.jsx
new file mode 100644
index 000000000..2fe7442bf
--- /dev/null
+++ b/examples/webgpu-chat/src/components/Chat.jsx
@@ -0,0 +1,42 @@
+import { marked } from 'marked';
+import DOMPurify from 'dompurify';
+
+import BotIcon from './icons/BotIcon';
+import UserIcon from './icons/UserIcon';
+
+import './Chat.css';
+
+export default function Chat({ messages }) {
+ const empty = messages.length === 0;
+
+ return (
+ {empty
+ ?
Ready!
+ : messages.map((msg, i) => (
+
+ {msg.role === 'assistant'
+ ? (<>
+
+
+
{
+ msg.content.length > 0
+ ?
+ : (
+
+
+
+ )
+ }
+
+ >
+ ) : (<>
+
+
+ >)
+ }
+
+ ))}
+
)
+}
diff --git a/examples/webgpu-chat/src/components/Progress.jsx b/examples/webgpu-chat/src/components/Progress.jsx
new file mode 100644
index 000000000..9ce024cc8
--- /dev/null
+++ b/examples/webgpu-chat/src/components/Progress.jsx
@@ -0,0 +1,15 @@
+function formatBytes(size) {
+ const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
+ return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
+}
+
+export default function Progress({ text, percentage, total }) {
+ percentage ??= 0;
+ return (
+
+
+ {text} ({percentage.toFixed(2)}%{isNaN(total) ? '' : ` of ${formatBytes(total)}`})
+
+
+ );
+}
diff --git a/examples/webgpu-chat/src/components/icons/ArrowRightIcon.jsx b/examples/webgpu-chat/src/components/icons/ArrowRightIcon.jsx
new file mode 100644
index 000000000..0ca5ed917
--- /dev/null
+++ b/examples/webgpu-chat/src/components/icons/ArrowRightIcon.jsx
@@ -0,0 +1,19 @@
+export default function ArrowRightIcon(props) {
+ return (
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-chat/src/components/icons/BotIcon.jsx b/examples/webgpu-chat/src/components/icons/BotIcon.jsx
new file mode 100644
index 000000000..b8bd0ceae
--- /dev/null
+++ b/examples/webgpu-chat/src/components/icons/BotIcon.jsx
@@ -0,0 +1,23 @@
+export default function BotIcon(props) {
+ return (
+
+
+
+
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-chat/src/components/icons/StopIcon.jsx b/examples/webgpu-chat/src/components/icons/StopIcon.jsx
new file mode 100644
index 000000000..9b97f3723
--- /dev/null
+++ b/examples/webgpu-chat/src/components/icons/StopIcon.jsx
@@ -0,0 +1,19 @@
+export default function StopIcon(props) {
+ return (
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-chat/src/components/icons/UserIcon.jsx b/examples/webgpu-chat/src/components/icons/UserIcon.jsx
new file mode 100644
index 000000000..cb09e7574
--- /dev/null
+++ b/examples/webgpu-chat/src/components/icons/UserIcon.jsx
@@ -0,0 +1,19 @@
+export default function UserIcon(props) {
+ return (
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-chat/src/index.css b/examples/webgpu-chat/src/index.css
new file mode 100644
index 000000000..8848bbd6d
--- /dev/null
+++ b/examples/webgpu-chat/src/index.css
@@ -0,0 +1,32 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
+
+@layer utilities {
+ .scrollbar-thin::-webkit-scrollbar {
+ @apply w-2;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-track {
+ @apply rounded-full bg-gray-100 dark:bg-gray-700;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb {
+ @apply rounded-full bg-gray-300 dark:bg-gray-600;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb:hover {
+ @apply bg-gray-500;
+ }
+
+ .animation-delay-200 {
+ animation-delay: 200ms;
+ }
+ .animation-delay-400 {
+ animation-delay: 400ms;
+ }
+
+ .overflow-wrap-anywhere {
+ overflow-wrap: anywhere;
+ }
+}
diff --git a/examples/webgpu-chat/src/main.jsx b/examples/webgpu-chat/src/main.jsx
new file mode 100644
index 000000000..54b39dd1d
--- /dev/null
+++ b/examples/webgpu-chat/src/main.jsx
@@ -0,0 +1,10 @@
+import React from 'react'
+import ReactDOM from 'react-dom/client'
+import App from './App.jsx'
+import './index.css'
+
+ReactDOM.createRoot(document.getElementById('root')).render(
+
+
+ ,
+)
diff --git a/examples/webgpu-chat/src/worker.js b/examples/webgpu-chat/src/worker.js
new file mode 100644
index 000000000..65d679670
--- /dev/null
+++ b/examples/webgpu-chat/src/worker.js
@@ -0,0 +1,174 @@
+
+import {
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ TextStreamer,
+ StoppingCriteria,
+} from '@xenova/transformers';
+
+
+class CallbackTextStreamer extends TextStreamer {
+ constructor(tokenizer, cb) {
+ super(tokenizer, {
+ skip_prompt: true,
+ skip_special_tokens: true,
+ });
+ this.cb = cb;
+ }
+
+ on_finalized_text(text) {
+ this.cb(text);
+ }
+}
+
+class InterruptableStoppingCriteria extends StoppingCriteria {
+ constructor() {
+ super();
+ this.interrupted = false;
+ }
+
+ interrupt() {
+ this.interrupted = true;
+ }
+
+ reset() {
+ this.interrupted = false;
+ }
+
+ _call(input_ids, scores) {
+ return new Array(input_ids.length).fill(this.interrupted);
+ }
+}
+
+const stopping_criteria = new InterruptableStoppingCriteria();
+
+async function hasFp16() {
+ try {
+ const adapter = await navigator.gpu.requestAdapter();
+ return adapter.features.has('shader-f16');
+ } catch (e) {
+ return false;
+ }
+}
+
+/**
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
+ */
+class TextGenerationPipeline {
+ static model_id = null;
+ static model = null;
+ static tokenizer = null;
+ static streamer = null;
+
+ static async getInstance(progress_callback = null) {
+ // Choose the model based on whether fp16 is available
+ this.model_id ??= (await hasFp16())
+ ? 'Xenova/Phi-3-mini-4k-instruct_fp16'
+ : 'Xenova/Phi-3-mini-4k-instruct';
+
+ this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
+ legacy: true,
+ progress_callback,
+ });
+
+ this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
+ dtype: 'q4',
+ device: 'webgpu',
+ use_external_data_format: true,
+ progress_callback,
+ });
+
+ return Promise.all([this.tokenizer, this.model]);
+ }
+}
+
+async function generate(messages) {
+ // Retrieve the text-generation pipeline.
+ const [tokenizer, model] = await TextGenerationPipeline.getInstance();
+
+ const inputs = tokenizer.apply_chat_template(messages, {
+ add_generation_prompt: true,
+ return_dict: true,
+ });
+
+ let startTime;
+ let numTokens = 0;
+ const cb = (output) => {
+ startTime ??= performance.now();
+
+ let tps;
+ if (numTokens++ > 0) {
+ tps = numTokens / (performance.now() - startTime) * 1000;
+ }
+ self.postMessage({
+ status: 'update',
+ output, tps, numTokens,
+ });
+ }
+
+ const streamer = new CallbackTextStreamer(tokenizer, cb);
+
+ // Tell the main thread we are starting
+ self.postMessage({ status: 'start' });
+
+ const outputs = await model.generate({
+ ...inputs,
+ max_new_tokens: 512,
+ streamer,
+ stopping_criteria,
+ });
+ const outputText = tokenizer.batch_decode(outputs, { skip_special_tokens: false });
+
+ // Send the output back to the main thread
+ self.postMessage({
+ status: 'complete',
+ output: outputText,
+ });
+}
+
+async function load() {
+ self.postMessage({
+ status: 'loading',
+ data: 'Loading model...'
+ });
+
+ // Load the pipeline and save it for future use.
+ const [tokenizer, model] = await TextGenerationPipeline.getInstance(x => {
+ // We also add a progress callback to the pipeline so that we can
+ // track model loading.
+ self.postMessage(x);
+ });
+
+ self.postMessage({
+ status: 'loading',
+ data: 'Compiling shaders and warming up model...'
+ });
+
+ // Run model with dummy input to compile shaders
+ const inputs = tokenizer('a');
+ await model.generate({ ...inputs, max_new_tokens: 1 });
+ self.postMessage({ status: 'ready' });
+}
+// Listen for messages from the main thread
+self.addEventListener('message', async (e) => {
+ const { type, data } = e.data;
+
+ switch (type) {
+ case 'load':
+ load();
+ break;
+
+ case 'generate':
+ stopping_criteria.reset();
+ generate(data);
+ break;
+
+ case 'interrupt':
+ stopping_criteria.interrupt();
+ break;
+
+ case 'reset':
+ stopping_criteria.reset();
+ break;
+ }
+});
diff --git a/examples/webgpu-chat/tailwind.config.js b/examples/webgpu-chat/tailwind.config.js
new file mode 100644
index 000000000..d37737fc0
--- /dev/null
+++ b/examples/webgpu-chat/tailwind.config.js
@@ -0,0 +1,12 @@
+/** @type {import('tailwindcss').Config} */
+export default {
+ content: [
+ "./index.html",
+ "./src/**/*.{js,ts,jsx,tsx}",
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/examples/webgpu-chat/vite.config.js b/examples/webgpu-chat/vite.config.js
new file mode 100644
index 000000000..5a33944a9
--- /dev/null
+++ b/examples/webgpu-chat/vite.config.js
@@ -0,0 +1,7 @@
+import { defineConfig } from 'vite'
+import react from '@vitejs/plugin-react'
+
+// https://vitejs.dev/config/
+export default defineConfig({
+ plugins: [react()],
+})
diff --git a/examples/webgpu-clip/.gitignore b/examples/webgpu-clip/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/webgpu-clip/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/webgpu-clip/index.html b/examples/webgpu-clip/index.html
new file mode 100644
index 000000000..4a87dacb4
--- /dev/null
+++ b/examples/webgpu-clip/index.html
@@ -0,0 +1,39 @@
+
+
+
+
+
+
+ Transformers.js | real-time CLIP
+
+
+
+
+ Real-time zero-shot image classification (WebGPU)
+
+
+ Runs locally in your browser w/
+ 🤗 Transformers.js
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/webgpu-clip/main.js b/examples/webgpu-clip/main.js
new file mode 100644
index 000000000..35f17512a
--- /dev/null
+++ b/examples/webgpu-clip/main.js
@@ -0,0 +1,169 @@
+
+import {
+ AutoTokenizer,
+ CLIPTextModelWithProjection,
+ AutoProcessor,
+ CLIPVisionModelWithProjection,
+ RawImage,
+ dot,
+ softmax,
+} from '@xenova/transformers';
+
+import './style.css';
+
+// Reference the elements that we will need
+const status = document.getElementById('status');
+const container = document.getElementById('container');
+const video = document.getElementById('video');
+const labelsInput = document.getElementById('labels');
+const templateInput = document.getElementById('template');
+const overlay = document.getElementById('overlay');
+
+status.textContent = 'Loading model (300MB)...';
+
+// Use fp16 if available, otherwise use fp32
+async function hasFp16() {
+ try {
+ const adapter = await navigator.gpu.requestAdapter();
+ return adapter.features.has('shader-f16');
+ } catch (e) {
+ return false;
+ }
+}
+const dtype = (await hasFp16()) ? 'fp16' : 'fp32';
+
+// Load object detection pipeline
+const model_id = 'Xenova/clip-vit-base-patch16';
+let tokenizer, text_model, processor, vision_model;
+try {
+ // Load tokenizer and text model
+ tokenizer = await AutoTokenizer.from_pretrained(model_id);
+ text_model = await CLIPTextModelWithProjection.from_pretrained(model_id, {
+ device: 'webgpu',
+ dtype,
+ });
+
+ // Load processor and vision model
+ processor = await AutoProcessor.from_pretrained(model_id);
+ vision_model = await CLIPVisionModelWithProjection.from_pretrained(model_id, {
+ device: 'webgpu',
+ dtype,
+ });
+
+} catch (err) {
+ status.textContent = err.message;
+ alert(err.message)
+ throw err;
+}
+
+labelsInput.disabled = false;
+templateInput.disabled = false;
+
+status.textContent = 'Ready';
+
+// See `model.logit_scale` parameter of original model
+const exp_logit_scale = Math.exp(4.6052);
+
+const IMAGE_SIZE = 224;
+const canvas = document.createElement('canvas');
+canvas.width = canvas.height = IMAGE_SIZE;
+const context = canvas.getContext('2d', { willReadFrequently: true });
+
+let isProcessing = false;
+let previousTime;
+let textEmbeddings;
+let prevTextInputs;
+let prevTemplate;
+let labels;
+
+function onFrameUpdate() {
+ if (!isProcessing) {
+ isProcessing = true;
+ (async function () {
+
+ // If text inputs have changed, update the embeddings
+ if (prevTextInputs !== labelsInput.value || prevTemplate !== templateInput.value) {
+ textEmbeddings = null;
+ prevTextInputs = labelsInput.value;
+ prevTemplate = templateInput.value;
+ labels = prevTextInputs.split(/\s*,\s*/).filter(x => x);
+
+ if (labels.length > 0) {
+ const texts = labels.map(x => templateInput.value.replaceAll('{}', x));
+
+ const text_inputs = tokenizer(texts, { padding: true, truncation: true });
+
+ // Compute embeddings
+ const { text_embeds } = await text_model(text_inputs);
+ textEmbeddings = text_embeds.normalize().tolist();
+ } else {
+ overlay.innerHTML = '';
+ }
+ }
+
+ if (textEmbeddings) {
+ // Read the current frame from the video
+ context.drawImage(video, 0, 0, IMAGE_SIZE, IMAGE_SIZE);
+ const pixelData = context.getImageData(0, 0, IMAGE_SIZE, IMAGE_SIZE).data;
+ const image = new RawImage(pixelData, IMAGE_SIZE, IMAGE_SIZE, 4);
+
+ const image_inputs = await processor(image);
+
+ // Compute embeddings
+ const { image_embeds } = await vision_model(image_inputs);
+ const imageEmbedding = image_embeds.normalize().tolist()[0];
+
+ // Compute similarity
+ const similarities = textEmbeddings.map(
+ x => dot(x, imageEmbedding) * exp_logit_scale
+ );
+
+ const sortedIndices = softmax(similarities)
+ .map((x, i) => [x, i])
+ .sort((a, b) => b[0] - a[0]);
+
+ // Update UI
+ overlay.innerHTML = '';
+ for (const [score, index] of sortedIndices) {
+ overlay.appendChild(document.createTextNode(`${labels[index]}: ${score.toFixed(2)}`));
+ overlay.appendChild(document.createElement('br'));
+ }
+ }
+
+ if (previousTime !== undefined) {
+ const fps = 1000 / (performance.now() - previousTime);
+ status.textContent = `FPS: ${fps.toFixed(2)}`;
+ }
+ previousTime = performance.now();
+ isProcessing = false;
+ })();
+ }
+
+ window.requestAnimationFrame(onFrameUpdate);
+}
+
+// Start the video stream
+navigator.mediaDevices.getUserMedia(
+ { video: true }, // Ask for video
+).then((stream) => {
+ // Set up the video and canvas elements.
+ video.srcObject = stream;
+ video.play();
+
+ const videoTrack = stream.getVideoTracks()[0];
+ const { width, height } = videoTrack.getSettings();
+
+ video.width = width;
+ video.height = height;
+
+ // Set container width and height depending on the image aspect ratio
+ const ar = width / height;
+ const [cw, ch] = (ar > 720 / 405) ? [720, 720 / ar] : [405 * ar, 405];
+ container.style.width = `${cw}px`;
+ container.style.height = `${ch}px`;
+
+ // Start the animation loop
+ window.requestAnimationFrame(onFrameUpdate);
+}).catch((error) => {
+ alert(error);
+});
diff --git a/examples/webgpu-clip/package.json b/examples/webgpu-clip/package.json
new file mode 100644
index 000000000..44888248a
--- /dev/null
+++ b/examples/webgpu-clip/package.json
@@ -0,0 +1,17 @@
+{
+ "name": "webgpu-clip",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "preview": "vite preview"
+ },
+ "devDependencies": {
+ "vite": "^5.2.10"
+ },
+ "dependencies": {
+ "@xenova/transformers": "^3.0.0"
+ }
+}
diff --git a/examples/webgpu-clip/style.css b/examples/webgpu-clip/style.css
new file mode 100644
index 000000000..e08c41d1a
--- /dev/null
+++ b/examples/webgpu-clip/style.css
@@ -0,0 +1,91 @@
+* {
+ box-sizing: border-box;
+ padding: 0;
+ margin: 0;
+ font-family: sans-serif;
+}
+
+html,
+body {
+ height: 100%;
+}
+
+body {
+ padding: 16px 32px;
+}
+
+body,
+#container {
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ align-items: center;
+}
+
+#controls {
+ display: flex;
+ padding: 1rem;
+ gap: 1rem;
+}
+
+#controls>div {
+ text-align: center;
+}
+
+h1,
+h3 {
+ text-align: center;
+}
+
+h3 {
+ margin-top: 0.5rem;
+}
+
+#container {
+ position: relative;
+ width: 720px;
+ height: 405px;
+ max-width: 100%;
+ max-height: 100%;
+ border: 2px dashed #D1D5DB;
+ border-radius: 0.75rem;
+ overflow: hidden;
+ margin-top: 1rem;
+ background-size: 100% 100%;
+ background-position: center;
+ background-repeat: no-repeat;
+}
+
+#status {
+ min-height: 16px;
+ margin: 8px 0;
+}
+
+video {
+ width: 100%;
+ height: 100%;
+}
+
+input[type="text"] {
+ padding: 0.25rem 0.5rem;
+ border: 1px solid #D1D5DB;
+ border-radius: 0.25rem;
+ margin-top: 2px;
+}
+
+input[type="range"] {
+ margin-top: 6px;
+}
+
+#overlay {
+ position: absolute;
+ top: 0;
+ left: 0;
+ background-color: rgba(255, 255, 255, 0.9);
+ font-size: 1.25rem;
+ border-radius: 2px;
+}
+
+#overlay:not(:empty) {
+ padding: 0.5rem;
+}
\ No newline at end of file
diff --git a/examples/webgpu-clip/vite.config.js b/examples/webgpu-clip/vite.config.js
new file mode 100644
index 000000000..6c32f52df
--- /dev/null
+++ b/examples/webgpu-clip/vite.config.js
@@ -0,0 +1,6 @@
+import { defineConfig } from 'vite';
+export default defineConfig({
+ build: {
+ target: 'esnext'
+ }
+});
diff --git a/examples/webgpu-embedding-benchmark/.gitignore b/examples/webgpu-embedding-benchmark/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/webgpu-embedding-benchmark/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/webgpu-embedding-benchmark/index.html b/examples/webgpu-embedding-benchmark/index.html
new file mode 100644
index 000000000..8b4a9d361
--- /dev/null
+++ b/examples/webgpu-embedding-benchmark/index.html
@@ -0,0 +1,64 @@
+
+
+
+
+
+
+ Transformers.js | WebGPU Benchmark
+
+
+
+
+
+ This benchmark measures the execution time of BERT-based embedding models
+ using the WASM and WebGPU execution providers across different batch sizes.
+
+
+
+
+
+ Start Benchmark
+ Stop Benchmark
+
+
+
+ Options
+
+ WASM (int8)
+ WASM (fp16)
+ WASM (fp32)
+
+ WebGPU (fp16)
+ WebGPU (fp32)
+
+
+
+ Model ID
+
+
+
+ Batch sizes
+
+
+
+ Sequence length
+
+
+
+
+ Log scale (x)
+ Log scale (y)
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/webgpu-embedding-benchmark/main.js b/examples/webgpu-embedding-benchmark/main.js
new file mode 100644
index 000000000..bdf731395
--- /dev/null
+++ b/examples/webgpu-embedding-benchmark/main.js
@@ -0,0 +1,305 @@
+import './style.css';
+import { env, AutoModel, ones } from '@xenova/transformers';
+import Chart from 'chart.js/auto';
+
+// Throw an error if WebGPU is not supported
+if (!navigator.gpu) {
+ const err = 'WebGPU is not supported by this browser.';
+ alert(err)
+ throw Error(err);
+}
+
+env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/';
+env.backends.onnx.wasm.numThreads = 1;
+
+// Reference the elements that we will need
+const ctx = document.getElementById('chart');
+const batchSizes = document.getElementById('batch-sizes');
+const xscale = document.getElementById('x-scale');
+const yscale = document.getElementById('y-scale');
+const sequenceLength = document.getElementById('sequence-length');
+const modelID = document.getElementById('model-id');
+const status = document.getElementById('status');
+const start = document.getElementById('start');
+const stop = document.getElementById('stop');
+const tests = document.getElementsByClassName('tests');
+
+// Benchmark settings
+const NUM_WARMUP_STEPS = 3;
+const MODEL_CACHE = new Map();
+
+// Chart configuration
+const initChart = () => {
+ const config = {
+ type: 'line',
+ data: {
+ labels: [],
+ datasets: [],
+ },
+ options: {
+ responsive: true,
+ maintainAspectRatio: false,
+ plugins: {
+ legend: {
+ position: 'top',
+ },
+ },
+ scales: {
+ x: {
+ title: {
+ display: true,
+ text: 'Batch size',
+ },
+ min: 1,
+ },
+ y: {
+ title: {
+ display: true,
+ text: 'Time (ms)',
+ },
+ }
+ }
+ },
+ };
+ const chart = new Chart(ctx, config);
+ return chart;
+}
+let chart = initChart();
+const toggleScale = (axis, enabled) => {
+ chart.options.scales[axis].type = enabled ? 'logarithmic' : 'linear';
+ chart.update();
+}
+
+const getSelectedTests = () => {
+ return [...tests].filter(x => x.checked);
+}
+
+const updateDatasets = () => {
+ chart.data.datasets = getSelectedTests().map(test => {
+ const color = test.getAttribute('data-color');
+ return {
+ label: test.value,
+ data: [],
+ borderColor: `rgba(${color}, 1)`,
+ backgroundColor: `rgba(${color}, 0.5)`,
+ }
+ })
+ chart.update();
+}
+updateDatasets();
+[...tests].forEach(test => test.addEventListener('change', updateDatasets));
+
+xscale.addEventListener('change', () => toggleScale('x', xscale.checked));
+yscale.addEventListener('change', () => toggleScale('y', yscale.checked));
+
+const generateDummyInputs = (batch_size, seqLength) => {
+ const inputs = ones([batch_size, seqLength]);
+
+ const model_inputs = {
+ input_ids: inputs,
+ attention_mask: inputs,
+ }
+ return model_inputs;
+}
+
+let adapterInfo;
+let gpuHasFp16 = false;
+try {
+ // Shouldn't fail since the WebGPU model has loaded successfully
+ const adapter = await navigator.gpu.requestAdapter();
+ adapterInfo = await adapter.requestAdapterInfo();
+ gpuHasFp16 = adapter.features.has('shader-f16')
+} catch (err) {
+ adapterInfo = {};
+}
+if (!gpuHasFp16) {
+ const element = document.querySelector('.tests[data-device="webgpu"][data-dtype="fp16"]');
+ element.setAttribute('unsupported', true);
+ element.disabled = true;
+ element.title = 'This device does not support fp16 on WebGPU';
+}
+
+status.textContent = 'Ready';
+
+let interrupted = false;
+start.addEventListener('click', async () => {
+ const validTests = [...tests].filter(test => !test.getAttribute('unsupported'))
+ // Update UI
+ start.disabled = true;
+ stop.disabled = false;
+ batchSizes.disabled = true;
+ sequenceLength.disabled = true;
+ modelID.disabled = true;
+ validTests.forEach(test => test.disabled = true);
+ interrupted = false;
+
+ // Get parameters
+ const model_id = modelID.value;
+ const batch_sizes = batchSizes.value.split(',').map(x => parseInt(x)).filter(x => x);
+ const seqLength = parseInt(sequenceLength.value);
+ const selectedTests = getSelectedTests().map(x => ({
+ label: x.value,
+ dtype: x.getAttribute('data-dtype'),
+ device: x.getAttribute('data-device'),
+ }));
+
+ // Reset
+ chart.destroy();
+ chart = initChart();
+ updateDatasets();
+
+ // NOTE: Models must be loaded sequentially (otherwise it will fail due to multiple calls to initWasm())
+ const testsToRun = new Map();
+ for (const test of selectedTests) {
+ const { label, dtype, device, quantized } = test;
+
+ const key = `${model_id}///${label}`;
+
+ const cached = MODEL_CACHE.get(key);
+ if (cached) {
+ testsToRun.set(label, cached);
+ continue;
+ }
+ status.textContent = 'Loading model(s)...';
+
+ try {
+ const model = await AutoModel.from_pretrained(model_id, {
+ quantized,
+ device,
+ dtype,
+ });
+ MODEL_CACHE.set(key, model);
+ testsToRun.set(label, model);
+ } catch (err) {
+ status.textContent = err.message;
+ alert(err.message)
+ throw err;
+ }
+ }
+
+ status.textContent = 'Warming up...';
+
+ // Warm up: This is important for the WebGPU execution provider, which compiles the shaders on first load
+ for (let i = 0; i < NUM_WARMUP_STEPS; ++i) {
+ const model_inputs = generateDummyInputs(1, seqLength);
+ for (const [label, model] of testsToRun) {
+ await model(model_inputs);
+ }
+ }
+
+ status.textContent = 'Running benchmark...';
+
+ for (const batch_size of batch_sizes) {
+ if (interrupted) break;
+
+ const model_inputs = generateDummyInputs(batch_size, seqLength);
+
+ const times = []
+
+ for (const [label, model] of testsToRun) {
+ const start = performance.now();
+ await model(model_inputs);
+ const end = performance.now();
+ times.push(end - start);
+ }
+
+ chart.data.labels.push(batch_size);
+ for (let i = 0; i < times.length; ++i) {
+ chart.data.datasets[i].data.push(times[i]);
+ }
+ chart.update();
+ }
+
+ // Calculate max speedup:
+ if (chart.data.labels.length === 0) return;
+
+ const testNames = [...testsToRun.keys()];
+ const table = generateResultsTable(model_id, testNames, chart.data, seqLength);
+
+
+ // Calculate slowest and fastest times
+ let minMaxTimes = [Infinity, 0];
+ let minMaxIndices = [0, 0];
+ for (let i = 0; i < chart.data.datasets.length; i++) {
+ const lastTime = chart.data.datasets[i].data.at(-1);
+ if (lastTime < minMaxTimes[0]) {
+ minMaxTimes[0] = lastTime;
+ minMaxIndices[0] = i;
+ }
+ if (lastTime > minMaxTimes[1]) {
+ minMaxTimes[1] = lastTime;
+ minMaxIndices[1] = i;
+ }
+ }
+
+ const speedup = minMaxTimes[1] / minMaxTimes[0];
+ const roundedSpeedup = speedup.toFixed(2);
+ const params = new URLSearchParams({
+ title: `⚡ WebGPU Benchmark Results (${roundedSpeedup}x speedup)`,
+ description: table.outerHTML,
+ });
+
+ const paramsStr = params.toString();
+ status.innerHTML = `⚡ Done! ${testNames.at(minMaxIndices[0])} is
${roundedSpeedup}x faster than ${testNames.at(minMaxIndices[1])}! ⚡
Share results `;
+ start.disabled = false;
+ stop.disabled = true;
+ batchSizes.disabled = false;
+ sequenceLength.disabled = false;
+ modelID.disabled = false;
+ validTests.forEach(test => test.disabled = false);
+});
+
+start.disabled = false;
+
+stop.addEventListener('click', () => {
+ status.textContent = 'Stopping...';
+ interrupted = true;
+ stop.disabled = true;
+});
+
+function generateResultsTable(model_id, testNames, data, sequence_length) {
+
+ const datasets = data.datasets.map(d => d.data);
+ const batch_sizes = data.labels;
+
+ const container = document.createElement('div');
+
+ const table = document.createElement('table');
+ const thead = table.createTHead();
+ const tbody = table.createTBody();
+
+ // Add header row
+ const headerRow = thead.insertRow();
+ headerRow.insertCell().textContent = 'Batch Size';
+ testNames.forEach(model => {
+ headerRow.insertCell().textContent = model;
+ });
+
+ // Add data rows
+ batch_sizes.forEach((batchSize, rowIndex) => {
+ const row = tbody.insertRow();
+ row.insertCell().textContent = batchSize;
+ datasets.forEach(dataset => {
+ row.insertCell().textContent = dataset[rowIndex].toFixed(2);
+ });
+ });
+
+ container.appendChild(table);
+
+ const createBulletPoint = (text) => {
+ const li = document.createElement('li');
+ li.textContent = text;
+ return li;
+ }
+
+ // Add other information
+ const info = document.createElement('ul');
+ info.appendChild(createBulletPoint(`Model: ${model_id}`));
+ info.appendChild(createBulletPoint(`Tests run: ${testNames.join(', ')}`));
+ info.appendChild(createBulletPoint(`Sequence length: ${sequence_length}`));
+ info.appendChild(createBulletPoint(`Browser: ${navigator.userAgent}`));
+ info.appendChild(createBulletPoint(`GPU: vendor=${adapterInfo.vendor}, architecture=${adapterInfo.architecture}, device=${adapterInfo.device}, description=${adapterInfo.description}`));
+ container.appendChild(info);
+
+ return container;
+}
diff --git a/examples/webgpu-embedding-benchmark/package.json b/examples/webgpu-embedding-benchmark/package.json
new file mode 100644
index 000000000..d90288d7a
--- /dev/null
+++ b/examples/webgpu-embedding-benchmark/package.json
@@ -0,0 +1,18 @@
+{
+ "name": "webgpu-embedding-benchmark",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "preview": "vite preview"
+ },
+ "devDependencies": {
+ "vite": "^5.0.12"
+ },
+ "dependencies": {
+ "@xenova/transformers": "^3.0.0",
+ "chart.js": "^4.4.2"
+ }
+}
diff --git a/examples/webgpu-embedding-benchmark/style.css b/examples/webgpu-embedding-benchmark/style.css
new file mode 100644
index 000000000..9253d75e3
--- /dev/null
+++ b/examples/webgpu-embedding-benchmark/style.css
@@ -0,0 +1,87 @@
+* {
+ box-sizing: border-box;
+ padding: 0;
+ margin: 0;
+ font-family: sans-serif;
+}
+
+html,
+body {
+ height: 100%;
+}
+
+body {
+ padding: 16px 32px;
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ align-items: center;
+}
+
+h1 {
+ text-align: center;
+}
+
+#status {
+ min-height: 16px;
+ margin: 8px 0;
+ text-align: center;
+}
+
+button {
+ transition: all .25s;
+ background: rgba(40, 44, 52, 0.05);
+ border: 1px solid transparent;
+ border-radius: 6px;
+ color: #3080d0;
+ text-decoration: none !important;
+ display: inline-block;
+ font-size: 14px;
+ font-weight: 500;
+ padding: 8px 16px;
+ cursor: pointer;
+ -webkit-user-select: none;
+ -moz-user-select: none;
+ user-select: none;
+}
+
+button:disabled {
+ background: rgba(40, 44, 52, 0.1);
+ color: #a0a0a0;
+ cursor: not-allowed;
+}
+
+button:hover {
+ background: rgba(40, 44, 52, 0.1);
+}
+
+p {
+ text-align: center;
+ font-size: 12px;
+ max-width: 600px;
+ padding: 8px;
+}
+
+#chart-container {
+ position: relative;
+ height: 60vh;
+ width: min(90vw, 800px);
+ padding-right: 50px;
+ margin-bottom: 10px;
+}
+
+details {
+ position: fixed;
+ background-color: white;
+ right: 0;
+ top: 0;
+ padding: 16px;
+}
+
+summary {
+ text-align: right;
+}
+
+hr {
+ margin: 8px 0;
+}
diff --git a/examples/webgpu-embedding-benchmark/vite.config.js b/examples/webgpu-embedding-benchmark/vite.config.js
new file mode 100644
index 000000000..6c32f52df
--- /dev/null
+++ b/examples/webgpu-embedding-benchmark/vite.config.js
@@ -0,0 +1,6 @@
+import { defineConfig } from 'vite';
+export default defineConfig({
+ build: {
+ target: 'esnext'
+ }
+});
diff --git a/examples/webgpu-video-background-removal/.gitignore b/examples/webgpu-video-background-removal/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/webgpu-video-background-removal/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/webgpu-video-background-removal/index.html b/examples/webgpu-video-background-removal/index.html
new file mode 100644
index 000000000..8e71df5a9
--- /dev/null
+++ b/examples/webgpu-video-background-removal/index.html
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
Transformers.js | Real-time background removal
+
+
+
+
+ Real-time background removal w/
+ 🤗 Transformers.js
+
+
+ Runs locally in your browser, powered by
+ MODNet
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/webgpu-video-background-removal/main.js b/examples/webgpu-video-background-removal/main.js
new file mode 100644
index 000000000..620f21afb
--- /dev/null
+++ b/examples/webgpu-video-background-removal/main.js
@@ -0,0 +1,128 @@
+import './style.css';
+
+import { env, AutoModel, AutoProcessor, RawImage } from '@xenova/transformers';
+
+env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/';
+env.backends.onnx.wasm.numThreads = 1;
+
+// Reference the elements that we will need
+const status = document.getElementById('status');
+const container = document.getElementById('container');
+const canvas = document.getElementById('canvas');
+const outputCanvas = document.getElementById('output-canvas');
+const video = document.getElementById('video');
+const sizeSlider = document.getElementById('size');
+const sizeLabel = document.getElementById('size-value');
+const scaleSlider = document.getElementById('scale');
+const scaleLabel = document.getElementById('scale-value');
+
+function setStreamSize(width, height) {
+ video.width = outputCanvas.width = canvas.width = Math.round(width);
+ video.height = outputCanvas.height = canvas.height = Math.round(height);
+}
+
+status.textContent = 'Loading model...';
+
+// Load model and processor
+const model_id = 'Xenova/modnet';
+let model;
+try {
+ model = await AutoModel.from_pretrained(model_id, {
+ device: 'webgpu',
+ dtype: 'fp32', // TODO: add fp16 support
+ });
+} catch (err) {
+ status.textContent = err.message;
+ alert(err.message)
+ throw err;
+}
+
+const processor = await AutoProcessor.from_pretrained(model_id);
+
+// Set up controls
+let size = 256;
+processor.feature_extractor.size = { shortest_edge: size };
+sizeSlider.addEventListener('input', () => {
+ size = Number(sizeSlider.value);
+ processor.feature_extractor.size = { shortest_edge: size };
+ sizeLabel.textContent = size;
+});
+sizeSlider.disabled = false;
+
+let scale = 0.5;
+scaleSlider.addEventListener('input', () => {
+ scale = Number(scaleSlider.value);
+ setStreamSize(video.videoWidth * scale, video.videoHeight * scale);
+ scaleLabel.textContent = scale;
+});
+scaleSlider.disabled = false;
+
+status.textContent = 'Ready';
+
+let isProcessing = false;
+let previousTime;
+const context = canvas.getContext('2d', { willReadFrequently: true });
+const outputContext = outputCanvas.getContext('2d', { willReadFrequently: true });
+function updateCanvas() {
+ const { width, height } = canvas;
+
+ if (!isProcessing) {
+ isProcessing = true;
+ (async function () {
+ // Read the current frame from the video
+ context.drawImage(video, 0, 0, width, height);
+ const currentFrame = context.getImageData(0, 0, width, height);
+ const image = new RawImage(currentFrame.data, width, height, 4);
+
+ // Pre-process image
+ const inputs = await processor(image);
+
+ // Predict alpha matte
+ const { output } = await model({ input: inputs.pixel_values });
+
+ const mask = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(width, height);
+
+ // Update alpha channel
+ const outPixelData = currentFrame;
+ for (let i = 0; i < mask.data.length; ++i) {
+ outPixelData.data[4 * i + 3] = mask.data[i];
+ }
+ outputContext.putImageData(outPixelData, 0, 0);
+
+ if (previousTime !== undefined) {
+ const fps = 1000 / (performance.now() - previousTime);
+ status.textContent = `FPS: ${fps.toFixed(2)}`;
+ }
+ previousTime = performance.now();
+
+ isProcessing = false;
+ })();
+ }
+
+ window.requestAnimationFrame(updateCanvas);
+}
+
+// Start the video stream
+navigator.mediaDevices.getUserMedia(
+ { video: true }, // Ask for video
+).then((stream) => {
+ // Set up the video and canvas elements.
+ video.srcObject = stream;
+ video.play();
+
+ const videoTrack = stream.getVideoTracks()[0];
+ const { width, height } = videoTrack.getSettings();
+
+ setStreamSize(width * scale, height * scale);
+
+ // Set container width and height depending on the image aspect ratio
+ const ar = width / height;
+ const [cw, ch] = (ar > 720 / 405) ? [720, 720 / ar] : [405 * ar, 405];
+ container.style.width = `${cw}px`;
+ container.style.height = `${ch}px`;
+
+ // Start the animation loop
+ setTimeout(updateCanvas, 50);
+}).catch((error) => {
+ alert(error);
+});
diff --git a/examples/webgpu-video-background-removal/package.json b/examples/webgpu-video-background-removal/package.json
new file mode 100644
index 000000000..9ebe47afe
--- /dev/null
+++ b/examples/webgpu-video-background-removal/package.json
@@ -0,0 +1,17 @@
+{
+ "name": "webgpu-video-background-removal",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "preview": "vite preview"
+ },
+ "devDependencies": {
+ "vite": "^5.0.12"
+ },
+ "dependencies": {
+ "@xenova/transformers": "^3.0.0"
+ }
+}
diff --git a/examples/webgpu-video-background-removal/style.css b/examples/webgpu-video-background-removal/style.css
new file mode 100644
index 000000000..a86729e1c
--- /dev/null
+++ b/examples/webgpu-video-background-removal/style.css
@@ -0,0 +1,87 @@
+* {
+ box-sizing: border-box;
+ padding: 0;
+ margin: 0;
+ font-family: sans-serif;
+}
+
+html,
+body {
+ height: 100%;
+}
+
+body {
+ padding: 16px 32px;
+}
+
+body,
+#container {
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ align-items: center;
+}
+
+#controls {
+ display: flex;
+ padding: 1rem;
+ gap: 1rem;
+}
+
+#controls>div {
+ text-align: center;
+}
+
+h1,
+h4 {
+ text-align: center;
+}
+
+h4 {
+ margin-top: 0.5rem;
+}
+
+#container {
+ position: relative;
+ width: 720px;
+ height: 405px;
+ max-width: 100%;
+ max-height: 100%;
+ border: 2px dashed #D1D5DB;
+ border-radius: 0.75rem;
+ overflow: hidden;
+ margin-top: 1rem;
+ background-size: 100% 100%;
+ background-position: center;
+ background-repeat: no-repeat;
+}
+
+#overlay,
+canvas {
+ position: absolute;
+ width: 100%;
+ height: 100%;
+}
+
+#status {
+ min-height: 16px;
+ margin: 8px 0;
+}
+
+.bounding-box {
+ position: absolute;
+ box-sizing: border-box;
+ border: solid 2px;
+}
+
+.bounding-box-label {
+ color: white;
+ position: absolute;
+ font-size: 12px;
+ margin: -16px 0 0 -2px;
+ padding: 1px;
+}
+
+#video, #canvas {
+ display: none;
+}
diff --git a/examples/webgpu-video-background-removal/vite.config.js b/examples/webgpu-video-background-removal/vite.config.js
new file mode 100644
index 000000000..6c32f52df
--- /dev/null
+++ b/examples/webgpu-video-background-removal/vite.config.js
@@ -0,0 +1,6 @@
+import { defineConfig } from 'vite';
+export default defineConfig({
+ build: {
+ target: 'esnext'
+ }
+});
diff --git a/examples/webgpu-video-depth-estimation/.gitignore b/examples/webgpu-video-depth-estimation/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/webgpu-video-depth-estimation/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/webgpu-video-depth-estimation/index.html b/examples/webgpu-video-depth-estimation/index.html
new file mode 100644
index 000000000..c05574f67
--- /dev/null
+++ b/examples/webgpu-video-depth-estimation/index.html
@@ -0,0 +1,42 @@
+
+
+
+
+
+
+
Transformers.js | Real-time depth estimation
+
+
+
+
+ Real-time depth estimation w/
+ Depth Anything V2
+
+
+ Runs locally in your browser, powered by
+ 🤗 Transformers.js
+
+
+
+
+
+
+
Loading model...
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/webgpu-video-depth-estimation/main.js b/examples/webgpu-video-depth-estimation/main.js
new file mode 100644
index 000000000..a745da774
--- /dev/null
+++ b/examples/webgpu-video-depth-estimation/main.js
@@ -0,0 +1,145 @@
+import './style.css';
+
+import { AutoModel, AutoProcessor, RawImage } from '@xenova/transformers';
+
+async function hasFp16() {
+ try {
+ const adapter = await navigator.gpu.requestAdapter()
+ return adapter.features.has('shader-f16')
+ } catch (e) {
+ return false
+ }
+}
+
+// Reference the elements that we will need
+const status = document.getElementById('status');
+const canvas = document.createElement('canvas');
+const outputCanvas = document.getElementById('output-canvas');
+const video = document.getElementById('video');
+const sizeSlider = document.getElementById('size');
+const sizeLabel = document.getElementById('size-value');
+const scaleSlider = document.getElementById('scale');
+const scaleLabel = document.getElementById('scale-value');
+
+function setStreamSize(width, height) {
+ video.width = outputCanvas.width = canvas.width = Math.round(width);
+ video.height = outputCanvas.height = canvas.height = Math.round(height);
+}
+
+status.textContent = 'Loading model...';
+
+// Load model and processor
+const model_id = 'onnx-community/depth-anything-v2-small';
+
+let model;
+try {
+ model = await AutoModel.from_pretrained(model_id, {
+ device: 'webgpu',
+ // Use fp16 if available, otherwise use fp32
+ dtype: (await hasFp16()) ? 'fp16' : 'fp32',
+ });
+} catch (err) {
+ status.textContent = err.message;
+ alert(err.message)
+ throw err;
+}
+
+const processor = await AutoProcessor.from_pretrained(model_id);
+
+// Set up controls
+let size = 504;
+processor.feature_extractor.size = { width: size, height: size };
+sizeSlider.addEventListener('input', () => {
+ size = Number(sizeSlider.value);
+ processor.feature_extractor.size = { width: size, height: size };
+ sizeLabel.textContent = size;
+});
+sizeSlider.disabled = false;
+
+let scale = 0.4;
+scaleSlider.addEventListener('input', () => {
+ scale = Number(scaleSlider.value);
+ setStreamSize(video.videoWidth * scale, video.videoHeight * scale);
+ scaleLabel.textContent = scale;
+});
+scaleSlider.disabled = false;
+
+status.textContent = 'Ready';
+
+let isProcessing = false;
+let previousTime;
+const context = canvas.getContext('2d', { willReadFrequently: true });
+const outputContext = outputCanvas.getContext('2d', { willReadFrequently: true });
+function updateCanvas() {
+ const { width, height } = canvas;
+
+ if (!isProcessing) {
+ isProcessing = true;
+ (async function () {
+ // Read the current frame from the video
+ context.drawImage(video, 0, 0, width, height);
+ const currentFrame = context.getImageData(0, 0, width, height);
+ const image = new RawImage(currentFrame.data, width, height, 4);
+
+ // Pre-process image
+ const inputs = await processor(image);
+
+ // Predict depth map
+ const { predicted_depth } = await model(inputs);
+ const data = predicted_depth.data;
+ const [bs, oh, ow] = predicted_depth.dims;
+
+ // Normalize the depth map
+ let min = Infinity;
+ let max = -Infinity;
+ outputCanvas.width = ow;
+ outputCanvas.height = oh;
+ for (let i = 0; i < data.length; ++i) {
+ const v = data[i];
+ if (v < min) min = v;
+ if (v > max) max = v;
+ }
+ const range = max - min;
+
+ const imageData = new Uint8ClampedArray(4 * data.length);
+ for (let i = 0; i < data.length; ++i) {
+ const offset = 4 * i;
+ imageData[offset] = 255; // Set base color to red
+
+ // Set alpha to normalized depth value
+ imageData[offset + 3] = 255 * (1 - (data[i] - min) / range);
+ }
+ const outPixelData = new ImageData(imageData, ow, oh);
+ outputContext.putImageData(outPixelData, 0, 0);
+
+ if (previousTime !== undefined) {
+ const fps = 1000 / (performance.now() - previousTime);
+ status.textContent = `FPS: ${fps.toFixed(2)}`;
+ }
+ previousTime = performance.now();
+
+ isProcessing = false;
+ })();
+ }
+
+ window.requestAnimationFrame(updateCanvas);
+}
+
+// Start the video stream
+navigator.mediaDevices.getUserMedia(
+ { video: { width: 720, height: 720 } }, // Ask for square video
+).then((stream) => {
+ // Set up the video and canvas elements.
+ video.srcObject = stream;
+ video.play();
+
+ const videoTrack = stream.getVideoTracks()[0];
+ const { width, height } = videoTrack.getSettings();
+
+ setStreamSize(width * scale, height * scale);
+
+ // Start the animation loop
+ setTimeout(updateCanvas, 50);
+}).catch((error) => {
+ alert(error);
+});
diff --git a/examples/webgpu-video-depth-estimation/package.json b/examples/webgpu-video-depth-estimation/package.json
new file mode 100644
index 000000000..041dd86e0
--- /dev/null
+++ b/examples/webgpu-video-depth-estimation/package.json
@@ -0,0 +1,17 @@
+{
+ "name": "webgpu-video-depth-estimation",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "preview": "vite preview"
+ },
+ "devDependencies": {
+ "vite": "^5.2.0"
+ },
+ "dependencies": {
+ "@xenova/transformers": "github:xenova/transformers.js#v3"
+ }
+}
diff --git a/examples/webgpu-video-depth-estimation/style.css b/examples/webgpu-video-depth-estimation/style.css
new file mode 100644
index 000000000..bd4796b95
--- /dev/null
+++ b/examples/webgpu-video-depth-estimation/style.css
@@ -0,0 +1,71 @@
+* {
+ box-sizing: border-box;
+ padding: 0;
+ margin: 0;
+ font-family: sans-serif;
+}
+
+html,
+body {
+ height: 100%;
+}
+
+body {
+ padding: 16px 32px;
+}
+
+body,
+#container {
+ display: flex;
+ flex-direction: column;
+ justify-content: center;
+ align-items: center;
+}
+
+#controls {
+ display: flex;
+ padding: 1rem;
+ gap: 1rem;
+}
+
+#controls>div {
+ text-align: center;
+}
+
+h1,
+h3 {
+ text-align: center;
+}
+
+h3 {
+ margin-top: 0.5rem;
+}
+
+#container {
+ display: flex;
+ flex-direction: row;
+ position: relative;
+ max-width: 100%;
+ max-height: 100%;
+ border: 2px dashed #D1D5DB;
+ border-radius: 0.75rem;
+ overflow: hidden;
+ margin-top: 1rem;
+ background-size: 100% 100%;
+ background-position: center;
+ background-repeat: no-repeat;
+}
+#video, #output-canvas {
+ width: 504px;
+ height: 504px;
+}
+
+canvas {
+ width: 100%;
+ height: 100%;
+}
+
+#status {
+ min-height: 16px;
+ margin: 8px 0;
+}
diff --git a/examples/webgpu-video-depth-estimation/vite.config.js b/examples/webgpu-video-depth-estimation/vite.config.js
new file mode 100644
index 000000000..6c32f52df
--- /dev/null
+++ b/examples/webgpu-video-depth-estimation/vite.config.js
@@ -0,0 +1,6 @@
+import { defineConfig } from 'vite';
+export default defineConfig({
+ build: {
+ target: 'esnext'
+ }
+});
diff --git a/examples/webgpu-vlm/.eslintrc.cjs b/examples/webgpu-vlm/.eslintrc.cjs
new file mode 100644
index 000000000..ce8fffe57
--- /dev/null
+++ b/examples/webgpu-vlm/.eslintrc.cjs
@@ -0,0 +1,21 @@
+module.exports = {
+ root: true,
+ env: { browser: true, es2020: true },
+ extends: [
+ 'eslint:recommended',
+ 'plugin:react/recommended',
+ 'plugin:react/jsx-runtime',
+ 'plugin:react-hooks/recommended',
+ ],
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
+ parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
+ settings: { react: { version: '18.2' } },
+ plugins: ['react-refresh'],
+ rules: {
+ 'react-refresh/only-export-components': [
+ 'warn',
+ { allowConstantExport: true },
+ ],
+ 'react/prop-types': 'off'
+ },
+}
diff --git a/examples/webgpu-vlm/.gitignore b/examples/webgpu-vlm/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/webgpu-vlm/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/webgpu-vlm/README.md b/examples/webgpu-vlm/README.md
new file mode 100644
index 000000000..f768e33fc
--- /dev/null
+++ b/examples/webgpu-vlm/README.md
@@ -0,0 +1,8 @@
+# React + Vite
+
+This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
+
+Currently, two official plugins are available:
+
+- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
+- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
diff --git a/examples/webgpu-vlm/index.html b/examples/webgpu-vlm/index.html
new file mode 100644
index 000000000..4ed94aa49
--- /dev/null
+++ b/examples/webgpu-vlm/index.html
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
Moondream WebGPU
+
+
+
+
+
+
diff --git a/examples/webgpu-vlm/package.json b/examples/webgpu-vlm/package.json
new file mode 100644
index 000000000..34e6e95e6
--- /dev/null
+++ b/examples/webgpu-vlm/package.json
@@ -0,0 +1,32 @@
+{
+ "name": "webgpu-chat",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@xenova/transformers": "github:xenova/transformers.js#v3",
+ "dompurify": "^3.1.2",
+ "marked": "^12.0.2",
+ "react": "^18.2.0",
+ "react-dom": "^18.2.0"
+ },
+ "devDependencies": {
+ "@types/react": "^18.2.43",
+ "@types/react-dom": "^18.2.17",
+ "@vitejs/plugin-react": "^4.2.1",
+ "autoprefixer": "^10.4.19",
+ "eslint": "^8.55.0",
+ "eslint-plugin-react": "^7.33.2",
+ "eslint-plugin-react-hooks": "^4.6.0",
+ "eslint-plugin-react-refresh": "^0.4.5",
+ "postcss": "^8.4.38",
+ "tailwindcss": "^3.4.3",
+ "vite": "^5.2.11"
+ }
+}
diff --git a/examples/webgpu-vlm/postcss.config.js b/examples/webgpu-vlm/postcss.config.js
new file mode 100644
index 000000000..2e7af2b7f
--- /dev/null
+++ b/examples/webgpu-vlm/postcss.config.js
@@ -0,0 +1,6 @@
+export default {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/examples/webgpu-vlm/public/logo.png b/examples/webgpu-vlm/public/logo.png
new file mode 100644
index 000000000..ee800de7f
Binary files /dev/null and b/examples/webgpu-vlm/public/logo.png differ
diff --git a/examples/webgpu-vlm/src/App.jsx b/examples/webgpu-vlm/src/App.jsx
new file mode 100644
index 000000000..08cf1797b
--- /dev/null
+++ b/examples/webgpu-vlm/src/App.jsx
@@ -0,0 +1,318 @@
+import { useEffect, useState, useRef } from 'react';
+
+import Chat from './components/Chat';
+import ArrowRightIcon from './components/icons/ArrowRightIcon';
+import StopIcon from './components/icons/StopIcon';
+import Progress from './components/Progress';
+import ImageIcon from './components/icons/ImageIcon';
+import ImagePreview from './components/ImagePreview';
+
+const IS_WEBGPU_AVAILABLE = !!navigator.gpu;
+const STICKY_SCROLL_THRESHOLD = 120;
+
+function App() {
+
+ // Create a reference to the worker object.
+ const worker = useRef(null);
+
+ const textareaRef = useRef(null);
+ const chatContainerRef = useRef(null);
+ const imageRef = useRef(null);
+ const imageUploadRef = useRef(null);
+
+ // Model loading and progress
+ const [status, setStatus] = useState(null);
+ const [loadingMessage, setLoadingMessage] = useState('');
+ const [progressItems, setProgressItems] = useState([]);
+ const [isRunning, setIsRunning] = useState(false);
+
+ // Inputs and outputs
+ const [input, setInput] = useState('');
+ const [image, setImage] = useState(null);
+ const [messages, setMessages] = useState([]);
+ const [tps, setTps] = useState(null);
+ const [numTokens, setNumTokens] = useState(null);
+
+ function onEnter(message, image = null) {
+ setMessages(prev => [
+ ...prev,
+ { role: "user", content: message, image },
+ ]);
+ setTps(null);
+ setIsRunning(true);
+ setInput('');
+ setImage(null);
+ }
+
+ useEffect(() => {
+ resizeInput();
+ }, [input]);
+
+ function onInterrupt() {
+ // NOTE: We do not set isRunning to false here because the worker
+ // will send a 'complete' message when it is done.
+ worker.current.postMessage({ type: 'interrupt' });
+ }
+
+ function resizeInput() {
+ if (!textareaRef.current) return;
+
+ const target = textareaRef.current;
+ target.style.height = 'auto';
+ const newHeight = Math.min(Math.max(target.scrollHeight, 24), 200);
+ target.style.height = `${newHeight}px`;
+ }
+
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
+ useEffect(() => {
+ if (!worker.current) {
+ // Create the worker if it does not yet exist.
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
+ type: 'module'
+ });
+ }
+
+ // Create a callback function for messages from the worker thread.
+ const onMessageReceived = (e) => {
+ switch (e.data.status) {
+ case 'loading':
+ // Model file start load: add a new progress item to the list.
+ setStatus('loading');
+ setLoadingMessage(e.data.data);
+ break;
+
+ case 'initiate':
+ setProgressItems(prev => [...prev, e.data]);
+ break;
+
+ case 'progress':
+ // Model file progress: update one of the progress items.
+ setProgressItems(
+ prev => prev.map(item => {
+ if (item.file === e.data.file) {
+ return { ...item, ...e.data }
+ }
+ return item;
+ })
+ );
+ break;
+
+ case 'done':
+ // Model file loaded: remove the progress item from the list.
+ setProgressItems(
+ prev => prev.filter(item => item.file !== e.data.file)
+ );
+ break;
+
+ case 'ready':
+ // Pipeline ready: the worker is ready to accept messages.
+ setStatus('ready');
+ break;
+
+ case 'start': {
+ // Start generation
+ setMessages(prev => [...prev, { "role": "assistant", "content": "" }]);
+ }
+ break;
+
+ case 'update': {
+ // Generation update: update the output text.
+ // Parse messages
+ const { output, tps, numTokens } = e.data;
+ setTps(tps);
+ setNumTokens(numTokens)
+ setMessages(prev => {
+ const cloned = [...prev];
+ const last = cloned.at(-1);
+ cloned[cloned.length - 1] = { ...last, content: last.content + output };
+ return cloned;
+ });
+ }
+ break;
+
+ case 'complete':
+ // Generation complete: re-enable the "Generate" button
+ setIsRunning(false);
+ break;
+ }
+ };
+
+ // Attach the callback function as an event listener.
+ worker.current.addEventListener('message', onMessageReceived);
+
+ // Define a cleanup function for when the component is unmounted.
+ return () => {
+ worker.current.removeEventListener('message', onMessageReceived);
+ };
+ }, []);
+
+ // Send the messages to the worker thread whenever the `messages` state changes.
+ useEffect(() => {
+ if (messages.filter(x => x.role === 'user').length === 0) {
+ // No user messages yet: do nothing.
+ return;
+ }
+ if (messages.at(-1).role === 'assistant') {
+ // Do not update if the last message is from the assistant
+ return;
+ }
+ setTps(null);
+ worker.current.postMessage({ type: 'generate', data: messages });
+ }, [messages, isRunning]);
+
+ useEffect(() => {
+ if (!chatContainerRef.current) return;
+ if (isRunning) {
+ const element = chatContainerRef.current;
+ if (element.scrollHeight - element.scrollTop - element.clientHeight < STICKY_SCROLL_THRESHOLD) {
+ element.scrollTop = element.scrollHeight;
+ }
+ }
+ }, [messages, isRunning]);
+
+ return (
+ IS_WEBGPU_AVAILABLE
+ ? (
+
+ {status === null && messages.length === 0 && (
+
+
+
+
Moondream WebGPU
+
A private and powerful multimodal AI chatbot that runs locally in your browser.
+
+
+
+
+
+ You are about to load moondream2 ,
+ a 1.86 billion parameter VLM (Vision-Language Model) that is optimized for inference on the web. Once downloaded, the model (1.8 GB) will be cached and reused when you revisit the page.
+
+ Everything runs directly in your browser using 🤗 Transformers.js and ONNX Runtime Web, meaning your conversations aren't sent to a server. You can even disconnect from the internet after the model has loaded!
+
+
+
{
+ worker.current.postMessage({ type: 'load' });
+ setStatus('loading');
+ }}
+ disabled={status !== null}
+ >
+ Load model
+
+
+
+ )}
+ {status === 'loading' && (<>
+
+
{loadingMessage}
+ {progressItems.map(({ file, progress, total }, i) => (
+
+ ))}
+
+ >)}
+
+ {status === 'ready' && (
+
+
+ {tps && messages.length > 0 && (<>
+ {!isRunning &&
+ Generated {numTokens} tokens in {(numTokens / tps).toFixed(2)} seconds ( }
+ {<>
+
+ {tps.toFixed(2)}
+
+ tokens/second
+ >}
+ {!isRunning && <>
+ ).
+ setMessages([])}>Reset
+ >}
+ >)}
+
+
)}
+
+
+
+
+ {
+ const file = e.target.files[0];
+ if (!file) {
+ return;
+ }
+
+ const reader = new FileReader();
+
+ // Set up a callback when the file is loaded
+ reader.onload = e2 => {
+ setImage(e2.target.result);
+ e.target.value = '';
+ };
+
+ reader.readAsDataURL(file);
+ }}>
+
+
+ {image && (
+ {
+ setImage(null);
+ }} src={image} className="w-20 h-20 min-w-20 min-h-20 relative p-2" />
+ )}
+
+
+
+ {isRunning
+ ? (
+
+
)
+ : input.length > 0
+ ? (
onEnter(input, image)}>
+
+
)
+ : (
)
+ }
+
+
+
+ Disclaimer: Generated content may be inaccurate or false.
+
+
)
+ : (
WebGPU is not supported by this browser :(
)
+ )
+}
+
+export default App
diff --git a/examples/webgpu-vlm/src/components/Chat.css b/examples/webgpu-vlm/src/components/Chat.css
new file mode 100644
index 000000000..f8ab98d4b
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/Chat.css
@@ -0,0 +1,112 @@
+@scope (.markdown) {
+
+ /* Code blocks */
+ pre {
+ margin: 0.5rem 0;
+ white-space: break-spaces;
+ }
+
+ code {
+ padding: 0.2em 0.4em;
+ border-radius: 4px;
+ font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
+ font-size: 0.9em;
+ }
+
+ pre,
+ code {
+ background-color: #f2f2f2;
+ }
+
+ @media (prefers-color-scheme: dark) {
+
+ pre,
+ code {
+ background-color: #333;
+ }
+
+ }
+
+ pre:has(code) {
+ padding: 1rem 0.5rem;
+ }
+
+ pre>code {
+ padding: 0;
+ }
+
+ /* Headings */
+ h1,
+ h2,
+ h3,
+ h4,
+ h5,
+ h6 {
+ font-weight: 600;
+ line-height: 1.2;
+ }
+
+ h1 {
+ font-size: 2em;
+ margin: 1rem 0;
+ }
+
+ h2 {
+ font-size: 1.5em;
+ margin: 0.83rem 0;
+ }
+
+ h3 {
+ font-size: 1.25em;
+ margin: 0.67rem 0;
+ }
+
+ h4 {
+ font-size: 1em;
+ margin: 0.5rem 0;
+ }
+
+ h5 {
+ font-size: 0.875em;
+ margin: 0.33rem 0;
+ }
+
+ h6 {
+ font-size: 0.75em;
+ margin: 0.25rem 0;
+ }
+
+ h1,
+ h2,
+ h3,
+ h4,
+ h5,
+ h6:first-child {
+ margin-top: 0;
+ }
+
+ /* Unordered List */
+ ul {
+ list-style-type: disc;
+ margin-left: 1.5rem;
+ }
+
+ /* Ordered List */
+ ol {
+ list-style-type: decimal;
+ margin-left: 1.5rem;
+ }
+
+ /* List Items */
+ li {
+ margin: 0.25rem 0;
+ }
+
+ p:not(:first-child) {
+ margin-top: 0.75rem;
+ }
+
+ p:not(:last-child) {
+ margin-bottom: 0.75rem;
+ }
+}
\ No newline at end of file
diff --git a/examples/webgpu-vlm/src/components/Chat.jsx b/examples/webgpu-vlm/src/components/Chat.jsx
new file mode 100644
index 000000000..49516896c
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/Chat.jsx
@@ -0,0 +1,43 @@
+import { marked } from 'marked';
+import DOMPurify from 'dompurify';
+
+import BotIcon from './icons/BotIcon';
+import UserIcon from './icons/UserIcon';
+
+import './Chat.css';
+
+export default function Chat({ messages }) {
+ const empty = messages.length === 0;
+
+ return (
+ {empty
+ ?
Ready!
+ : messages.map((msg, i) => (
+
+ {msg.role === 'assistant'
+ ? (<>
+
+
+
{
+ msg.content.length > 0
+ ?
+ : (
+
+
+
+ )
+ }
+
+ >
+ ) : (<>
+
+
+ {msg.image &&
}
+
{msg.content}
+
+ >)
+ }
+
+ ))}
+
)
+}
diff --git a/examples/webgpu-vlm/src/components/ImagePreview.jsx b/examples/webgpu-vlm/src/components/ImagePreview.jsx
new file mode 100644
index 000000000..9e5ccc0c9
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/ImagePreview.jsx
@@ -0,0 +1,16 @@
+import { useState } from "react";
+import CrossIcon from "./icons/CrossIcon"
+
+export default function ImagePreview({ src, onRemove, ...props }) {
+ const [hover, setHover] = useState(false);
+
+ return (
+
setHover(true)}
+ onMouseLeave={() => setHover(false)}
+ >
+
+
+
)
+}
diff --git a/examples/webgpu-vlm/src/components/Progress.jsx b/examples/webgpu-vlm/src/components/Progress.jsx
new file mode 100644
index 000000000..9ce024cc8
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/Progress.jsx
@@ -0,0 +1,15 @@
+function formatBytes(size) {
+ const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
+ return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
+}
+
+export default function Progress({ text, percentage, total }) {
+ percentage ??= 0;
+ return (
+
+
+ {text} ({percentage.toFixed(2)}%{isNaN(total) ? '' : ` of ${formatBytes(total)}`})
+
+
+ );
+}
diff --git a/examples/webgpu-vlm/src/components/icons/ArrowRightIcon.jsx b/examples/webgpu-vlm/src/components/icons/ArrowRightIcon.jsx
new file mode 100644
index 000000000..0ca5ed917
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/icons/ArrowRightIcon.jsx
@@ -0,0 +1,19 @@
+export default function ArrowRightIcon(props) {
+ return (
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-vlm/src/components/icons/BotIcon.jsx b/examples/webgpu-vlm/src/components/icons/BotIcon.jsx
new file mode 100644
index 000000000..b8bd0ceae
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/icons/BotIcon.jsx
@@ -0,0 +1,23 @@
+export default function BotIcon(props) {
+ return (
+
+
+
+
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-vlm/src/components/icons/CrossIcon.jsx b/examples/webgpu-vlm/src/components/icons/CrossIcon.jsx
new file mode 100644
index 000000000..d2e03d480
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/icons/CrossIcon.jsx
@@ -0,0 +1,18 @@
+export default function CrossIcon(props) {
+ return (
+
+
+
+ )
+}
diff --git a/examples/webgpu-vlm/src/components/icons/ImageIcon.jsx b/examples/webgpu-vlm/src/components/icons/ImageIcon.jsx
new file mode 100644
index 000000000..93409108f
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/icons/ImageIcon.jsx
@@ -0,0 +1,19 @@
+export default function ImageIcon(props) {
+ return (
+
+
+
+ )
+}
+
diff --git a/examples/webgpu-vlm/src/components/icons/StopIcon.jsx b/examples/webgpu-vlm/src/components/icons/StopIcon.jsx
new file mode 100644
index 000000000..9b97f3723
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/icons/StopIcon.jsx
@@ -0,0 +1,19 @@
+export default function StopIcon(props) {
+ return (
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-vlm/src/components/icons/UserIcon.jsx b/examples/webgpu-vlm/src/components/icons/UserIcon.jsx
new file mode 100644
index 000000000..cb09e7574
--- /dev/null
+++ b/examples/webgpu-vlm/src/components/icons/UserIcon.jsx
@@ -0,0 +1,19 @@
+export default function UserIcon(props) {
+ return (
+
+
+
+
+ )
+}
\ No newline at end of file
diff --git a/examples/webgpu-vlm/src/index.css b/examples/webgpu-vlm/src/index.css
new file mode 100644
index 000000000..8848bbd6d
--- /dev/null
+++ b/examples/webgpu-vlm/src/index.css
@@ -0,0 +1,32 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
+
+@layer utilities {
+ .scrollbar-thin::-webkit-scrollbar {
+ @apply w-2;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-track {
+ @apply rounded-full bg-gray-100 dark:bg-gray-700;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb {
+ @apply rounded-full bg-gray-300 dark:bg-gray-600;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb:hover {
+ @apply bg-gray-500;
+ }
+
+ .animation-delay-200 {
+ animation-delay: 200ms;
+ }
+ .animation-delay-400 {
+ animation-delay: 400ms;
+ }
+
+ .overflow-wrap-anywhere {
+ overflow-wrap: anywhere;
+ }
+}
diff --git a/examples/webgpu-vlm/src/main.jsx b/examples/webgpu-vlm/src/main.jsx
new file mode 100644
index 000000000..54b39dd1d
--- /dev/null
+++ b/examples/webgpu-vlm/src/main.jsx
@@ -0,0 +1,10 @@
+import React from 'react'
+import ReactDOM from 'react-dom/client'
+import App from './App.jsx'
+import './index.css'
+
+ReactDOM.createRoot(document.getElementById('root')).render(
+
+
+ ,
+)
diff --git a/examples/webgpu-vlm/src/worker.js b/examples/webgpu-vlm/src/worker.js
new file mode 100644
index 000000000..d145b17a3
--- /dev/null
+++ b/examples/webgpu-vlm/src/worker.js
@@ -0,0 +1,209 @@
+
+import {
+ env,
+ AutoTokenizer,
+ Moondream1ForConditionalGeneration,
+ TextStreamer,
+ StoppingCriteria,
+ RawImage,
+ AutoProcessor,
+ Tensor,
+ full,
+} from '@xenova/transformers';
+
+const DEVICE = 'webgpu';
+const MAX_NEW_TOKENS = 256;
+
+env.backends.onnx.wasm.proxy = DEVICE !== 'webgpu';
+
+async function hasFp16() {
+ try {
+ const adapter = await navigator.gpu.requestAdapter();
+ return adapter.features.has('shader-f16');
+ } catch (e) {
+ return false;
+ }
+}
+/**
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
+ */
+class TextGenerationPipeline {
+ static model_id = 'Xenova/moondream2';
+ static tokenizer = null;
+ static processor = null;
+ static model = null;
+ static supportsFp16 = null;
+
+ static async getInstance(progress_callback = null) {
+
+ this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
+ progress_callback,
+ });
+
+ this.processor ??= AutoProcessor.from_pretrained(this.model_id);
+
+ // Choose the model based on whether fp16 is available
+ this.supportsFp16 ??= await hasFp16();
+ this.model ??= Moondream1ForConditionalGeneration.from_pretrained(this.model_id, {
+ dtype: {
+ embed_tokens: this.supportsFp16 ? 'fp16' : 'fp32', // or 'fp32'
+ vision_encoder: this.supportsFp16 ? 'fp16' : 'fp32', // or 'q8'
+ decoder_model_merged: 'q4', // or 'q4f16' or 'q8'
+ },
+ device: DEVICE,
+ progress_callback,
+ });
+
+ return Promise.all([this.tokenizer, this.processor, this.model]);
+ }
+}
+
+
+class CallbackTextStreamer extends TextStreamer {
+ constructor(tokenizer, cb) {
+ super(tokenizer, {
+ skip_prompt: true,
+ skip_special_tokens: true,
+ });
+ this.cb = cb;
+ }
+
+ on_finalized_text(text) {
+ this.cb(text);
+ }
+}
+
+class InterruptableStoppingCriteria extends StoppingCriteria {
+ constructor() {
+ super();
+ this.interrupted = false;
+ }
+
+ interrupt() {
+ this.interrupted = true;
+ }
+
+ reset() {
+ this.interrupted = false;
+ }
+
+ _call(input_ids, scores) {
+ return new Array(input_ids.length).fill(this.interrupted);
+ }
+}
+
+const stopping_criteria = new InterruptableStoppingCriteria();
+
+async function generate(messages) {
+
+ // Only support a single image for now
+ const images = messages.filter(x => x.image).map(x => x.image);
+ if (images.length > 1) {
+ self.postMessage({
+ status: 'error',
+ error: 'Currently, at most one image is supported.',
+ });
+ return;
+ }
+
+ // Retrieve the text-generation pipeline.
+ const [tokenizer, processor, model] = await TextGenerationPipeline.getInstance();
+
+ // Construct and tokenize prompt
+ const prompt = messages.map(x => `${x.image ? '
\n\n' : ''}${x.role === 'user' ? 'Question: ' : 'Answer: '}${x.content.trim()}`).join('\n\n') + '\n\nAnswer:'
+ let inputs = tokenizer(prompt);
+
+ if (images.length > 0) {
+ const image = await RawImage.fromURL(images[0]);
+ const vision_inputs = await processor(image);
+
+ inputs = { ...inputs, ...vision_inputs };
+ }
+
+ let startTime;
+ let numTokens = 0;
+ const cb = (output) => {
+ startTime ??= performance.now();
+
+ let tps;
+ if (numTokens++ > 0) {
+ tps = numTokens / (performance.now() - startTime) * 1000;
+ }
+ self.postMessage({
+ status: 'update',
+ output, tps, numTokens,
+ });
+ }
+
+ const streamer = new CallbackTextStreamer(tokenizer, cb);
+
+ // Tell the main thread we are starting
+ self.postMessage({ status: 'start' });
+
+ const outputs = await model.generate({
+ ...inputs,
+ max_new_tokens: MAX_NEW_TOKENS,
+ streamer,
+ stopping_criteria,
+ });
+ const outputText = tokenizer.batch_decode(outputs, { skip_special_tokens: false });
+
+ // Send the output back to the main thread
+ self.postMessage({
+ status: 'complete',
+ output: outputText,
+ });
+}
+
+async function load() {
+ self.postMessage({
+ status: 'loading',
+ data: 'Loading model...'
+ });
+
+ // Load the pipeline and save it for future use.
+ const [tokenizer, processor, model] = await TextGenerationPipeline.getInstance(x => {
+ // We also add a progress callback to the pipeline so that we can
+ // track model loading.
+ self.postMessage(x);
+ });
+
+ self.postMessage({
+ status: 'loading',
+ data: 'Compiling shaders and warming up model...'
+ });
+
+ // Run model with dummy input to compile shaders
+ const text_inputs = tokenizer('a');
+
+ const vision_inputs = {
+ pixel_values: full([1, 3, 378, 378], 0.0)
+ }
+
+ const inputs = { ...text_inputs, ...vision_inputs };
+ await model.generate({ ...inputs, max_new_tokens: 1 });
+ self.postMessage({ status: 'ready' });
+}
+// Listen for messages from the main thread
+self.addEventListener('message', async (e) => {
+ const { type, data } = e.data;
+
+ switch (type) {
+ case 'load':
+ load();
+ break;
+
+ case 'generate':
+ stopping_criteria.reset();
+ generate(data);
+ break;
+
+ case 'interrupt':
+ stopping_criteria.interrupt();
+ break;
+
+ case 'reset':
+ stopping_criteria.reset();
+ break;
+ }
+});
diff --git a/examples/webgpu-vlm/tailwind.config.js b/examples/webgpu-vlm/tailwind.config.js
new file mode 100644
index 000000000..d37737fc0
--- /dev/null
+++ b/examples/webgpu-vlm/tailwind.config.js
@@ -0,0 +1,12 @@
+/** @type {import('tailwindcss').Config} */
+export default {
+ content: [
+ "./index.html",
+ "./src/**/*.{js,ts,jsx,tsx}",
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/examples/webgpu-vlm/vite.config.js b/examples/webgpu-vlm/vite.config.js
new file mode 100644
index 000000000..5a33944a9
--- /dev/null
+++ b/examples/webgpu-vlm/vite.config.js
@@ -0,0 +1,7 @@
+import { defineConfig } from 'vite'
+import react from '@vitejs/plugin-react'
+
+// https://vitejs.dev/config/
+export default defineConfig({
+ plugins: [react()],
+})
diff --git a/examples/webgpu-whisper/.eslintrc.cjs b/examples/webgpu-whisper/.eslintrc.cjs
new file mode 100644
index 000000000..ce8fffe57
--- /dev/null
+++ b/examples/webgpu-whisper/.eslintrc.cjs
@@ -0,0 +1,21 @@
+module.exports = {
+ root: true,
+ env: { browser: true, es2020: true },
+ extends: [
+ 'eslint:recommended',
+ 'plugin:react/recommended',
+ 'plugin:react/jsx-runtime',
+ 'plugin:react-hooks/recommended',
+ ],
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
+ parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
+ settings: { react: { version: '18.2' } },
+ plugins: ['react-refresh'],
+ rules: {
+ 'react-refresh/only-export-components': [
+ 'warn',
+ { allowConstantExport: true },
+ ],
+ 'react/prop-types': 'off'
+ },
+}
diff --git a/examples/webgpu-whisper/.gitignore b/examples/webgpu-whisper/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/webgpu-whisper/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/webgpu-whisper/README.md b/examples/webgpu-whisper/README.md
new file mode 100644
index 000000000..f768e33fc
--- /dev/null
+++ b/examples/webgpu-whisper/README.md
@@ -0,0 +1,8 @@
+# React + Vite
+
+This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
+
+Currently, two official plugins are available:
+
+- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
+- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
diff --git a/examples/webgpu-whisper/index.html b/examples/webgpu-whisper/index.html
new file mode 100644
index 000000000..da24b23cb
--- /dev/null
+++ b/examples/webgpu-whisper/index.html
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+ Whisper WebGPU
+
+
+
+
+
+
diff --git a/examples/webgpu-whisper/package.json b/examples/webgpu-whisper/package.json
new file mode 100644
index 000000000..325990590
--- /dev/null
+++ b/examples/webgpu-whisper/package.json
@@ -0,0 +1,30 @@
+{
+ "name": "webgpu-whisper",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@huggingface/transformers": "^3.0.0-alpha.18",
+ "react": "^18.2.0",
+ "react-dom": "^18.2.0"
+ },
+ "devDependencies": {
+ "@types/react": "^18.2.43",
+ "@types/react-dom": "^18.2.17",
+ "@vitejs/plugin-react": "^4.2.1",
+ "autoprefixer": "^10.4.19",
+ "eslint": "^8.55.0",
+ "eslint-plugin-react": "^7.33.2",
+ "eslint-plugin-react-hooks": "^4.6.0",
+ "eslint-plugin-react-refresh": "^0.4.5",
+ "postcss": "^8.4.38",
+ "tailwindcss": "^3.4.3",
+ "vite": "^5.2.11"
+ }
+}
diff --git a/examples/webgpu-whisper/postcss.config.js b/examples/webgpu-whisper/postcss.config.js
new file mode 100644
index 000000000..2e7af2b7f
--- /dev/null
+++ b/examples/webgpu-whisper/postcss.config.js
@@ -0,0 +1,6 @@
+export default {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/examples/webgpu-whisper/public/banner.png b/examples/webgpu-whisper/public/banner.png
new file mode 100644
index 000000000..b9b0e75f4
Binary files /dev/null and b/examples/webgpu-whisper/public/banner.png differ
diff --git a/examples/webgpu-whisper/public/logo.png b/examples/webgpu-whisper/public/logo.png
new file mode 100644
index 000000000..fc3b13f6b
Binary files /dev/null and b/examples/webgpu-whisper/public/logo.png differ
diff --git a/examples/webgpu-whisper/src/App.jsx b/examples/webgpu-whisper/src/App.jsx
new file mode 100644
index 000000000..5f74ecba8
--- /dev/null
+++ b/examples/webgpu-whisper/src/App.jsx
@@ -0,0 +1,257 @@
+import { useEffect, useState, useRef } from 'react';
+
+import { AudioVisualizer } from './components/AudioVisualizer';
+import Progress from './components/Progress';
+import { LanguageSelector } from './components/LanguageSelector';
+
+const IS_WEBGPU_AVAILABLE = !!navigator.gpu;
+
+const WHISPER_SAMPLING_RATE = 16_000;
+const MAX_AUDIO_LENGTH = 30; // seconds
+const MAX_SAMPLES = WHISPER_SAMPLING_RATE * MAX_AUDIO_LENGTH;
+
+function App() {
+
+ // Create a reference to the worker object.
+ const worker = useRef(null);
+
+ const recorderRef = useRef(null);
+
+ // Model loading and progress
+ const [status, setStatus] = useState(null);
+ const [loadingMessage, setLoadingMessage] = useState('');
+ const [progressItems, setProgressItems] = useState([]);
+
+ // Inputs and outputs
+ const [text, setText] = useState('');
+ const [tps, setTps] = useState(null);
+ const [language, setLanguage] = useState('en');
+
+ // Processing
+ const [recording, setRecording] = useState(false);
+ const [isProcessing, setIsProcessing] = useState(false);
+ const [chunks, setChunks] = useState([]);
+ const [stream, setStream] = useState(null);
+ const audioContextRef = useRef(null);
+
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
+ useEffect(() => {
+ if (!worker.current) {
+ // Create the worker if it does not yet exist.
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
+ type: 'module'
+ });
+ }
+
+ // Create a callback function for messages from the worker thread.
+ const onMessageReceived = (e) => {
+ switch (e.data.status) {
+ case 'loading':
+ // Model file start load: add a new progress item to the list.
+ setStatus('loading');
+ setLoadingMessage(e.data.data);
+ break;
+
+ case 'initiate':
+ setProgressItems(prev => [...prev, e.data]);
+ break;
+
+ case 'progress':
+ // Model file progress: update one of the progress items.
+ setProgressItems(
+ prev => prev.map(item => {
+ if (item.file === e.data.file) {
+ return { ...item, ...e.data }
+ }
+ return item;
+ })
+ );
+ break;
+
+ case 'done':
+ // Model file loaded: remove the progress item from the list.
+ setProgressItems(
+ prev => prev.filter(item => item.file !== e.data.file)
+ );
+ break;
+
+ case 'ready':
+ // Pipeline ready: the worker is ready to accept messages.
+ setStatus('ready');
+ recorderRef.current?.start();
+ break;
+
+ case 'start': {
+ // Start generation
+ setIsProcessing(true);
+
+ // Request new data from the recorder
+ recorderRef.current?.requestData();
+ }
+ break;
+
+ case 'update': {
+ // Generation update: update the output text.
+ const { tps } = e.data;
+ setTps(tps);
+ }
+ break;
+
+ case 'complete':
+ // Generation complete: re-enable the "Generate" button
+ setIsProcessing(false);
+ setText(e.data.output);
+ break;
+ }
+ };
+
+ // Attach the callback function as an event listener.
+ worker.current.addEventListener('message', onMessageReceived);
+
+ // Define a cleanup function for when the component is unmounted.
+ return () => {
+ worker.current.removeEventListener('message', onMessageReceived);
+ };
+ }, []);
+
+ useEffect(() => {
+ if (recorderRef.current) return; // Already set
+
+ if (navigator.mediaDevices.getUserMedia) {
+ navigator.mediaDevices.getUserMedia({ audio: true })
+ .then(stream => {
+ setStream(stream);
+
+ recorderRef.current = new MediaRecorder(stream);
+ audioContextRef.current = new AudioContext({ sampleRate: WHISPER_SAMPLING_RATE });
+
+ recorderRef.current.onstart = () => {
+ setRecording(true);
+ setChunks([]);
+ }
+ recorderRef.current.ondataavailable = (e) => {
+ if (e.data.size > 0) {
+ setChunks((prev) => [...prev, e.data]);
+ } else {
+ // Empty chunk received, so we request new data after a short timeout
+ setTimeout(() => {
+ recorderRef.current.requestData();
+ }, 25);
+ }
+ };
+
+ recorderRef.current.onstop = () => {
+ setRecording(false);
+ };
+
+ })
+ .catch(err => console.error("The following error occurred: ", err));
+ } else {
+ console.error("getUserMedia not supported on your browser!");
+ }
+
+ return () => {
+ recorderRef.current?.stop();
+ recorderRef.current = null;
+ };
+ }, []);
+
+ useEffect(() => {
+ if (!recorderRef.current) return;
+ if (!recording) return;
+ if (isProcessing) return;
+ if (status !== 'ready') return;
+
+ if (chunks.length > 0) {
+ // Generate from data
+ const blob = new Blob(chunks, { type: recorderRef.current.mimeType });
+
+ const fileReader = new FileReader();
+
+ fileReader.onloadend = async () => {
+ const arrayBuffer = fileReader.result;
+ const decoded = await audioContextRef.current.decodeAudioData(arrayBuffer);
+ let audio = decoded.getChannelData(0);
+ if (audio.length > MAX_SAMPLES) { // Get last MAX_SAMPLES
+ audio = audio.slice(-MAX_SAMPLES);
+ }
+
+ worker.current.postMessage({ type: 'generate', data: { audio, language } });
+ }
+ fileReader.readAsArrayBuffer(blob);
+ } else {
+ recorderRef.current?.requestData();
+ }
+ }, [status, recording, isProcessing, chunks, language]);
+
+ return (
+ IS_WEBGPU_AVAILABLE
+ ? (
+ {(
+
+
+
+
Whisper WebGPU
+
Real-time in-browser speech recognition
+
+
+
+ {status === null && (<>
+
+
+ You are about to load whisper-base ,
+ a 73 million parameter speech recognition model that is optimized for inference on the web. Once downloaded, the model (~200 MB) will be cached and reused when you revisit the page.
+
+ Everything runs directly in your browser using 🤗 Transformers.js and ONNX Runtime Web,
+ meaning no data is sent to a server. You can even disconnect from the internet after the model has loaded!
+
+
+
{
+ worker.current.postMessage({ type: 'load' });
+ setStatus('loading');
+ }}
+ disabled={status !== null}
+ >
+ Load model
+
+ >)}
+
+
+
+ {status === 'ready' &&
+
{text}
+ {tps &&
{tps.toFixed(2)} tok/s }
+
}
+
+
+ {status === 'ready' &&
+ {
+ recorderRef.current?.stop();
+ setLanguage(e);
+ recorderRef.current?.start();
+ }} />
+ {
+ recorderRef.current?.stop();
+ recorderRef.current?.start();
+ }}>Reset
+
+ }
+ {status === 'loading' && (
+
+
{loadingMessage}
+ {progressItems.map(({ file, progress, total }, i) => (
+
+ ))}
+
+ )}
+
+
+ )}
+
)
+ : (WebGPU is not supported by this browser :(
)
+ )
+}
+
+export default App
diff --git a/examples/webgpu-whisper/src/components/AudioVisualizer.jsx b/examples/webgpu-whisper/src/components/AudioVisualizer.jsx
new file mode 100644
index 000000000..5935a3e76
--- /dev/null
+++ b/examples/webgpu-whisper/src/components/AudioVisualizer.jsx
@@ -0,0 +1,58 @@
+import { useRef, useCallback, useEffect } from "react";
+
+export function AudioVisualizer({ stream, ...props }) {
+ const canvasRef = useRef(null);
+
+ const visualize = useCallback((stream) => {
+ const audioContext = new (window.AudioContext || window.webkitAudioContext)();
+ const source = audioContext.createMediaStreamSource(stream);
+ const analyser = audioContext.createAnalyser();
+ analyser.fftSize = 2048;
+ source.connect(analyser);
+
+ const canvas = canvasRef.current;
+ const canvasCtx = canvas.getContext('2d');
+ const bufferLength = analyser.frequencyBinCount;
+ const dataArray = new Uint8Array(bufferLength);
+
+ const drawVisual = () => {
+ requestAnimationFrame(drawVisual);
+ analyser.getByteTimeDomainData(dataArray);
+
+ canvasCtx.fillStyle = 'rgb(255, 255, 255)';
+ canvasCtx.fillRect(0, 0, canvas.width, canvas.height);
+
+ canvasCtx.lineWidth = 2;
+ canvasCtx.strokeStyle = 'rgb(0, 0, 0)';
+ canvasCtx.beginPath();
+
+ const sliceWidth = canvas.width * 1.0 / bufferLength;
+
+ let x = 0;
+ for (let i = 0; i < bufferLength; ++i) {
+ const v = dataArray[i] / 128.0;
+ const y = v * canvas.height / 2;
+
+ if (i === 0) {
+ canvasCtx.moveTo(x, y);
+ } else {
+ canvasCtx.lineTo(x, y);
+ }
+
+ x += sliceWidth;
+ }
+
+ canvasCtx.lineTo(canvas.width, canvas.height / 2);
+ canvasCtx.stroke();
+ };
+
+ drawVisual();
+ }, []);
+
+ useEffect(() => {
+ stream && visualize(stream);
+ }, [visualize, stream]);
+ return (
+
+ )
+}
diff --git a/examples/webgpu-whisper/src/components/LanguageSelector.jsx b/examples/webgpu-whisper/src/components/LanguageSelector.jsx
new file mode 100644
index 000000000..9383d640e
--- /dev/null
+++ b/examples/webgpu-whisper/src/components/LanguageSelector.jsx
@@ -0,0 +1,133 @@
+
+function titleCase(str) {
+ str = str.toLowerCase();
+ return (str.match(/\w+.?/g) || [])
+ .map((word) => {
+ return word.charAt(0).toUpperCase() + word.slice(1);
+ })
+ .join("");
+}
+
+// List of supported languages:
+// https://help.openai.com/en/articles/7031512-whisper-api-faq
+// https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L79
+const LANGUAGES = {
+ en: "english",
+ zh: "chinese",
+ de: "german",
+ es: "spanish/castilian",
+ ru: "russian",
+ ko: "korean",
+ fr: "french",
+ ja: "japanese",
+ pt: "portuguese",
+ tr: "turkish",
+ pl: "polish",
+ ca: "catalan/valencian",
+ nl: "dutch/flemish",
+ ar: "arabic",
+ sv: "swedish",
+ it: "italian",
+ id: "indonesian",
+ hi: "hindi",
+ fi: "finnish",
+ vi: "vietnamese",
+ he: "hebrew",
+ uk: "ukrainian",
+ el: "greek",
+ ms: "malay",
+ cs: "czech",
+ ro: "romanian/moldavian/moldovan",
+ da: "danish",
+ hu: "hungarian",
+ ta: "tamil",
+ no: "norwegian",
+ th: "thai",
+ ur: "urdu",
+ hr: "croatian",
+ bg: "bulgarian",
+ lt: "lithuanian",
+ la: "latin",
+ mi: "maori",
+ ml: "malayalam",
+ cy: "welsh",
+ sk: "slovak",
+ te: "telugu",
+ fa: "persian",
+ lv: "latvian",
+ bn: "bengali",
+ sr: "serbian",
+ az: "azerbaijani",
+ sl: "slovenian",
+ kn: "kannada",
+ et: "estonian",
+ mk: "macedonian",
+ br: "breton",
+ eu: "basque",
+ is: "icelandic",
+ hy: "armenian",
+ ne: "nepali",
+ mn: "mongolian",
+ bs: "bosnian",
+ kk: "kazakh",
+ sq: "albanian",
+ sw: "swahili",
+ gl: "galician",
+ mr: "marathi",
+ pa: "punjabi/panjabi",
+ si: "sinhala/sinhalese",
+ km: "khmer",
+ sn: "shona",
+ yo: "yoruba",
+ so: "somali",
+ af: "afrikaans",
+ oc: "occitan",
+ ka: "georgian",
+ be: "belarusian",
+ tg: "tajik",
+ sd: "sindhi",
+ gu: "gujarati",
+ am: "amharic",
+ yi: "yiddish",
+ lo: "lao",
+ uz: "uzbek",
+ fo: "faroese",
+ ht: "haitian creole/haitian",
+ ps: "pashto/pushto",
+ tk: "turkmen",
+ nn: "nynorsk",
+ mt: "maltese",
+ sa: "sanskrit",
+ lb: "luxembourgish/letzeburgesch",
+ my: "myanmar/burmese",
+ bo: "tibetan",
+ tl: "tagalog",
+ mg: "malagasy",
+ as: "assamese",
+ tt: "tatar",
+ haw: "hawaiian",
+ ln: "lingala",
+ ha: "hausa",
+ ba: "bashkir",
+ jw: "javanese",
+ su: "sundanese",
+};
+export function LanguageSelector({ language, setLanguage }) {
+ const handleLanguageChange = (event) => {
+ setLanguage(event.target.value);
+ };
+
+ const names = Object.values(LANGUAGES).map(titleCase);
+
+ return (
+
+ {Object.keys(LANGUAGES).map((key, i) => (
+
+ {names[i]}
+
+ ))}
+
+ );
+}
\ No newline at end of file
diff --git a/examples/webgpu-whisper/src/components/Progress.jsx b/examples/webgpu-whisper/src/components/Progress.jsx
new file mode 100644
index 000000000..9ce024cc8
--- /dev/null
+++ b/examples/webgpu-whisper/src/components/Progress.jsx
@@ -0,0 +1,15 @@
+function formatBytes(size) {
+ const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
+ return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
+}
+
+export default function Progress({ text, percentage, total }) {
+ percentage ??= 0;
+ return (
+
+
+ {text} ({percentage.toFixed(2)}%{isNaN(total) ? '' : ` of ${formatBytes(total)}`})
+
+
+ );
+}
diff --git a/examples/webgpu-whisper/src/index.css b/examples/webgpu-whisper/src/index.css
new file mode 100644
index 000000000..8848bbd6d
--- /dev/null
+++ b/examples/webgpu-whisper/src/index.css
@@ -0,0 +1,32 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
+
+@layer utilities {
+ .scrollbar-thin::-webkit-scrollbar {
+ @apply w-2;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-track {
+ @apply rounded-full bg-gray-100 dark:bg-gray-700;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb {
+ @apply rounded-full bg-gray-300 dark:bg-gray-600;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb:hover {
+ @apply bg-gray-500;
+ }
+
+ .animation-delay-200 {
+ animation-delay: 200ms;
+ }
+ .animation-delay-400 {
+ animation-delay: 400ms;
+ }
+
+ .overflow-wrap-anywhere {
+ overflow-wrap: anywhere;
+ }
+}
diff --git a/examples/webgpu-whisper/src/main.jsx b/examples/webgpu-whisper/src/main.jsx
new file mode 100644
index 000000000..54b39dd1d
--- /dev/null
+++ b/examples/webgpu-whisper/src/main.jsx
@@ -0,0 +1,10 @@
+import React from 'react'
+import ReactDOM from 'react-dom/client'
+import App from './App.jsx'
+import './index.css'
+
+ReactDOM.createRoot(document.getElementById('root')).render(
+
+
+ ,
+)
diff --git a/examples/webgpu-whisper/src/worker.js b/examples/webgpu-whisper/src/worker.js
new file mode 100644
index 000000000..12735b1cb
--- /dev/null
+++ b/examples/webgpu-whisper/src/worker.js
@@ -0,0 +1,134 @@
+
+import {
+ AutoTokenizer,
+ AutoProcessor,
+ WhisperForConditionalGeneration,
+ TextStreamer,
+ full,
+} from '@huggingface/transformers';
+
+
+const MAX_NEW_TOKENS = 64;
+
+/**
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
+ */
+class AutomaticSpeechRecognitionPipeline {
+ static model_id = null;
+ static tokenizer = null;
+ static processor = null;
+ static model = null;
+
+ static async getInstance(progress_callback = null) {
+ this.model_id = 'onnx-community/whisper-base';
+
+ this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
+ progress_callback,
+ });
+ this.processor ??= AutoProcessor.from_pretrained(this.model_id, {
+ progress_callback,
+ });
+
+ this.model ??= WhisperForConditionalGeneration.from_pretrained(this.model_id, {
+ dtype: {
+ encoder_model: 'fp32', // 'fp16' works too
+ decoder_model_merged: 'q4', // or 'fp32' ('fp16' is broken)
+ },
+ device: 'webgpu',
+ progress_callback,
+ });
+
+ return Promise.all([this.tokenizer, this.processor, this.model]);
+ }
+}
+
+let processing = false;
+async function generate({ audio, language }) {
+ if (processing) return;
+ processing = true;
+
+ // Tell the main thread we are starting
+ self.postMessage({ status: 'start' });
+
+ // Retrieve the text-generation pipeline.
+ const [tokenizer, processor, model] = await AutomaticSpeechRecognitionPipeline.getInstance();
+
+ let startTime;
+ let numTokens = 0;
+ const callback_function = (output) => {
+ startTime ??= performance.now();
+
+ let tps;
+ if (numTokens++ > 0) {
+ tps = numTokens / (performance.now() - startTime) * 1000;
+ }
+ self.postMessage({
+ status: 'update',
+ output, tps, numTokens,
+ });
+ }
+
+ const streamer = new TextStreamer(tokenizer, {
+ skip_prompt: true,
+ skip_special_tokens: true,
+ callback_function,
+ });
+
+ const inputs = await processor(audio);
+
+ const outputs = await model.generate({
+ ...inputs,
+ max_new_tokens: MAX_NEW_TOKENS,
+ language,
+ streamer,
+ });
+
+ const outputText = tokenizer.batch_decode(outputs, { skip_special_tokens: true });
+
+ // Send the output back to the main thread
+ self.postMessage({
+ status: 'complete',
+ output: outputText,
+ });
+ processing = false;
+}
+
+async function load() {
+ self.postMessage({
+ status: 'loading',
+ data: 'Loading model...'
+ });
+
+ // Load the pipeline and save it for future use.
+ const [tokenizer, processor, model] = await AutomaticSpeechRecognitionPipeline.getInstance(x => {
+ // We also add a progress callback to the pipeline so that we can
+ // track model loading.
+ self.postMessage(x);
+ });
+
+ self.postMessage({
+ status: 'loading',
+ data: 'Compiling shaders and warming up model...'
+ });
+
+ // Run model with dummy input to compile shaders
+ await model.generate({
+ input_features: full([1, 80, 3000], 0.0),
+ max_new_tokens: 1,
+ });
+ self.postMessage({ status: 'ready' });
+}
+// Listen for messages from the main thread
+self.addEventListener('message', async (e) => {
+ const { type, data } = e.data;
+
+ switch (type) {
+ case 'load':
+ load();
+ break;
+
+ case 'generate':
+ generate(data);
+ break;
+ }
+});
diff --git a/examples/webgpu-whisper/tailwind.config.js b/examples/webgpu-whisper/tailwind.config.js
new file mode 100644
index 000000000..d37737fc0
--- /dev/null
+++ b/examples/webgpu-whisper/tailwind.config.js
@@ -0,0 +1,12 @@
+/** @type {import('tailwindcss').Config} */
+export default {
+ content: [
+ "./index.html",
+ "./src/**/*.{js,ts,jsx,tsx}",
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/examples/webgpu-whisper/vite.config.js b/examples/webgpu-whisper/vite.config.js
new file mode 100644
index 000000000..5a33944a9
--- /dev/null
+++ b/examples/webgpu-whisper/vite.config.js
@@ -0,0 +1,7 @@
+import { defineConfig } from 'vite'
+import react from '@vitejs/plugin-react'
+
+// https://vitejs.dev/config/
+export default defineConfig({
+ plugins: [react()],
+})
diff --git a/examples/whisper-word-timestamps/.eslintrc.cjs b/examples/whisper-word-timestamps/.eslintrc.cjs
new file mode 100644
index 000000000..3e212e1d4
--- /dev/null
+++ b/examples/whisper-word-timestamps/.eslintrc.cjs
@@ -0,0 +1,21 @@
+module.exports = {
+ root: true,
+ env: { browser: true, es2020: true },
+ extends: [
+ 'eslint:recommended',
+ 'plugin:react/recommended',
+ 'plugin:react/jsx-runtime',
+ 'plugin:react-hooks/recommended',
+ ],
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
+ parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
+ settings: { react: { version: '18.2' } },
+ plugins: ['react-refresh'],
+ rules: {
+ 'react/jsx-no-target-blank': 'off',
+ 'react-refresh/only-export-components': [
+ 'warn',
+ { allowConstantExport: true },
+ ],
+ },
+}
diff --git a/examples/whisper-word-timestamps/.gitignore b/examples/whisper-word-timestamps/.gitignore
new file mode 100644
index 000000000..a547bf36d
--- /dev/null
+++ b/examples/whisper-word-timestamps/.gitignore
@@ -0,0 +1,24 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+lerna-debug.log*
+
+node_modules
+dist
+dist-ssr
+*.local
+
+# Editor directories and files
+.vscode/*
+!.vscode/extensions.json
+.idea
+.DS_Store
+*.suo
+*.ntvs*
+*.njsproj
+*.sln
+*.sw?
diff --git a/examples/whisper-word-timestamps/README.md b/examples/whisper-word-timestamps/README.md
new file mode 100644
index 000000000..f768e33fc
--- /dev/null
+++ b/examples/whisper-word-timestamps/README.md
@@ -0,0 +1,8 @@
+# React + Vite
+
+This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
+
+Currently, two official plugins are available:
+
+- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh
+- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
diff --git a/examples/whisper-word-timestamps/index.html b/examples/whisper-word-timestamps/index.html
new file mode 100644
index 000000000..5f620b00e
--- /dev/null
+++ b/examples/whisper-word-timestamps/index.html
@@ -0,0 +1,12 @@
+
+
+
+
+
+ Whisper Timestamped
+
+
+
+
+
+
diff --git a/examples/whisper-word-timestamps/package.json b/examples/whisper-word-timestamps/package.json
new file mode 100644
index 000000000..3af99d9ef
--- /dev/null
+++ b/examples/whisper-word-timestamps/package.json
@@ -0,0 +1,30 @@
+{
+ "name": "whisper-word-timestamps",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "vite",
+ "build": "vite build",
+ "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0",
+ "preview": "vite preview"
+ },
+ "dependencies": {
+ "@xenova/transformers": "github:xenova/transformers.js#v3",
+ "react": "^18.3.1",
+ "react-dom": "^18.3.1"
+ },
+ "devDependencies": {
+ "@types/react": "^18.3.3",
+ "@types/react-dom": "^18.3.0",
+ "@vitejs/plugin-react": "^4.3.1",
+ "autoprefixer": "^10.4.19",
+ "eslint": "^8.57.0",
+ "eslint-plugin-react": "^7.34.2",
+ "eslint-plugin-react-hooks": "^4.6.2",
+ "eslint-plugin-react-refresh": "^0.4.7",
+ "postcss": "^8.4.38",
+ "tailwindcss": "^3.4.4",
+ "vite": "^5.3.1"
+ }
+}
diff --git a/examples/whisper-word-timestamps/postcss.config.js b/examples/whisper-word-timestamps/postcss.config.js
new file mode 100644
index 000000000..2e7af2b7f
--- /dev/null
+++ b/examples/whisper-word-timestamps/postcss.config.js
@@ -0,0 +1,6 @@
+export default {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/examples/whisper-word-timestamps/src/App.jsx b/examples/whisper-word-timestamps/src/App.jsx
new file mode 100644
index 000000000..c7b6e89fc
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/App.jsx
@@ -0,0 +1,217 @@
+import { useEffect, useState, useRef, useCallback } from 'react';
+
+import Progress from './components/Progress';
+import MediaInput from './components/MediaInput';
+import Transcript from './components/Transcript';
+import LanguageSelector from './components/LanguageSelector';
+
+
+async function hasWebGPU() {
+ if (!navigator.gpu) {
+ return false;
+ }
+ try {
+ const adapter = await navigator.gpu.requestAdapter();
+ return !!adapter;
+ } catch (e) {
+ return false;
+ }
+}
+
+function App() {
+
+ // Create a reference to the worker object.
+ const worker = useRef(null);
+
+ // Model loading and progress
+ const [status, setStatus] = useState(null);
+ const [loadingMessage, setLoadingMessage] = useState('');
+ const [progressItems, setProgressItems] = useState([]);
+
+ const mediaInputRef = useRef(null);
+ const [audio, setAudio] = useState(null);
+ const [language, setLanguage] = useState('en');
+
+ const [result, setResult] = useState(null);
+ const [time, setTime] = useState(null);
+ const [currentTime, setCurrentTime] = useState(0);
+
+ const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
+ const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
+ useEffect(() => {
+ hasWebGPU().then((result) => {
+ setModelSize(result ? 196 : 77);
+ setDevice(result ? 'webgpu' : 'wasm');
+ });
+ }, []);
+
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
+ useEffect(() => {
+ if (!worker.current) {
+ // Create the worker if it does not yet exist.
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
+ type: 'module'
+ });
+ }
+
+ // Create a callback function for messages from the worker thread.
+ const onMessageReceived = (e) => {
+ switch (e.data.status) {
+ case 'loading':
+ // Model file start load: add a new progress item to the list.
+ setStatus('loading');
+ setLoadingMessage(e.data.data);
+ break;
+
+ case 'initiate':
+ setProgressItems(prev => [...prev, e.data]);
+ break;
+
+ case 'progress':
+ // Model file progress: update one of the progress items.
+ setProgressItems(
+ prev => prev.map(item => {
+ if (item.file === e.data.file) {
+ return { ...item, ...e.data }
+ }
+ return item;
+ })
+ );
+ break;
+
+ case 'done':
+ // Model file loaded: remove the progress item from the list.
+ setProgressItems(
+ prev => prev.filter(item => item.file !== e.data.file)
+ );
+ break;
+
+ case 'ready':
+ // Pipeline ready: the worker is ready to accept messages.
+ setStatus('ready');
+ break;
+
+ case 'complete':
+ setResult(e.data.result);
+ setTime(e.data.time);
+ setStatus('ready');
+ break;
+ }
+ };
+
+ // Attach the callback function as an event listener.
+ worker.current.addEventListener('message', onMessageReceived);
+
+ // Define a cleanup function for when the component is unmounted.
+ return () => {
+ worker.current.removeEventListener('message', onMessageReceived);
+ };
+ }, []);
+
+ const handleClick = useCallback(() => {
+ setResult(null);
+ setTime(null);
+ if (status === null) {
+ setStatus('loading');
+ worker.current.postMessage({ type: 'load', data: { device } });
+ } else {
+ setStatus('running');
+ worker.current.postMessage({
+ type: 'run', data: { audio, language }
+ });
+ }
+ }, [status, audio, language, device]);
+
+ return (
+
+
+ {status === 'loading' && (
+
+
+
{loadingMessage}
+ {progressItems.map(({ file, progress, total }, i) => (
+
+ ))}
+
+
+ )}
+
+
+
Whisper Timestamped
+ In-browser speech recognition w/ word-level timestamps
+
+
+
+ {
+ !audio && (
+
+ You are about to download whisper-base (timestamped) ,
+ a 73 million parameter speech recognition model with the ability to generate word-level timestamps across 100 different languages.
+ Once loaded, the model ({modelSize} MB) will be cached and reused when you revisit the page.
+
+ Everything runs locally in your browser using 🤗 Transformers.js and ONNX Runtime Web,
+ meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
+
+ )
+ }
+
+
+ Input audio/video
+ setAudio(result)}
+ onTimeUpdate={(time) => setCurrentTime(time)}
+ />
+
+
+
+
+ {status === null ? 'Load model' :
+ status === 'running'
+ ? 'Running...'
+ : 'Run model'
+ }
+
+
+ {status !== null &&
+
+ Language:
+
+
+
+ }
+
+
+ {
+ result && time && (
+ <>
+
+ {
+ setCurrentTime(time);
+ mediaInputRef.current.setMediaTime(time);
+ }}
+ />
+
+
Generation time: {time.toFixed(2)}ms
+ >
+ )
+
+
+ }
+
+
+
+
+ )
+}
+
+export default App
diff --git a/examples/whisper-word-timestamps/src/components/LanguageSelector.jsx b/examples/whisper-word-timestamps/src/components/LanguageSelector.jsx
new file mode 100644
index 000000000..74c02a62a
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/components/LanguageSelector.jsx
@@ -0,0 +1,134 @@
+
+function titleCase(str) {
+ str = str.toLowerCase();
+ return (str.match(/\w+.?/g) || [])
+ .map((word) => {
+ return word.charAt(0).toUpperCase() + word.slice(1);
+ })
+ .join("");
+}
+
+// List of supported languages:
+// https://help.openai.com/en/articles/7031512-whisper-api-faq
+// https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L79
+const LANGUAGES = {
+ en: "english",
+ zh: "chinese",
+ de: "german",
+ es: "spanish/castilian",
+ ru: "russian",
+ ko: "korean",
+ fr: "french",
+ ja: "japanese",
+ pt: "portuguese",
+ tr: "turkish",
+ pl: "polish",
+ ca: "catalan/valencian",
+ nl: "dutch/flemish",
+ ar: "arabic",
+ sv: "swedish",
+ it: "italian",
+ id: "indonesian",
+ hi: "hindi",
+ fi: "finnish",
+ vi: "vietnamese",
+ he: "hebrew",
+ uk: "ukrainian",
+ el: "greek",
+ ms: "malay",
+ cs: "czech",
+ ro: "romanian/moldavian/moldovan",
+ da: "danish",
+ hu: "hungarian",
+ ta: "tamil",
+ no: "norwegian",
+ th: "thai",
+ ur: "urdu",
+ hr: "croatian",
+ bg: "bulgarian",
+ lt: "lithuanian",
+ la: "latin",
+ mi: "maori",
+ ml: "malayalam",
+ cy: "welsh",
+ sk: "slovak",
+ te: "telugu",
+ fa: "persian",
+ lv: "latvian",
+ bn: "bengali",
+ sr: "serbian",
+ az: "azerbaijani",
+ sl: "slovenian",
+ kn: "kannada",
+ et: "estonian",
+ mk: "macedonian",
+ br: "breton",
+ eu: "basque",
+ is: "icelandic",
+ hy: "armenian",
+ ne: "nepali",
+ mn: "mongolian",
+ bs: "bosnian",
+ kk: "kazakh",
+ sq: "albanian",
+ sw: "swahili",
+ gl: "galician",
+ mr: "marathi",
+ pa: "punjabi/panjabi",
+ si: "sinhala/sinhalese",
+ km: "khmer",
+ sn: "shona",
+ yo: "yoruba",
+ so: "somali",
+ af: "afrikaans",
+ oc: "occitan",
+ ka: "georgian",
+ be: "belarusian",
+ tg: "tajik",
+ sd: "sindhi",
+ gu: "gujarati",
+ am: "amharic",
+ yi: "yiddish",
+ lo: "lao",
+ uz: "uzbek",
+ fo: "faroese",
+ ht: "haitian creole/haitian",
+ ps: "pashto/pushto",
+ tk: "turkmen",
+ nn: "nynorsk",
+ mt: "maltese",
+ sa: "sanskrit",
+ lb: "luxembourgish/letzeburgesch",
+ my: "myanmar/burmese",
+ bo: "tibetan",
+ tl: "tagalog",
+ mg: "malagasy",
+ as: "assamese",
+ tt: "tatar",
+ haw: "hawaiian",
+ ln: "lingala",
+ ha: "hausa",
+ ba: "bashkir",
+ jw: "javanese",
+ su: "sundanese",
+};
+function LanguageSelector({ language, setLanguage, ...props }) {
+ const handleLanguageChange = (event) => {
+ setLanguage(event.target.value);
+ };
+
+ const names = Object.values(LANGUAGES).map(titleCase);
+
+ return (
+
+ {Object.keys(LANGUAGES).map((key, i) => (
+
+ {names[i]}
+
+ ))}
+
+ );
+}
+export default LanguageSelector
diff --git a/examples/whisper-word-timestamps/src/components/MediaInput.jsx b/examples/whisper-word-timestamps/src/components/MediaInput.jsx
new file mode 100644
index 000000000..4bf7afcb6
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/components/MediaInput.jsx
@@ -0,0 +1,194 @@
+import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
+
+const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/whisper-timestamps-demo.mp4';
+
+const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) => {
+ // UI states
+ const [dragging, setDragging] = useState(false);
+ const fileInputRef = useRef(null);
+
+ // Create a reference to the audio and video elements
+ const audioElement = useRef(null);
+ const videoElement = useRef(null);
+
+ const currentTimeRef = useRef(0);
+ useImperativeHandle(ref, () => ({
+ setMediaTime(time) {
+ if (audioElement.current?.src) {
+ audioElement.current.currentTime = time;
+ } else if (videoElement.current?.src) {
+ videoElement.current.currentTime = time;
+ }
+ currentTimeRef.current = time;
+ }
+ }));
+
+ const onBufferLoad = (arrayBuffer, type) => {
+ const blob = new Blob([arrayBuffer.slice(0)], { type: type });
+ const url = URL.createObjectURL(blob);
+ processFile(arrayBuffer);
+
+ // Create a URL for the Blob
+ if (type.startsWith('audio/')) {
+ // Dispose the previous source
+ videoElement.current.pause();
+ videoElement.current.removeAttribute('src');
+ videoElement.current.load();
+
+ audioElement.current.src = url;
+ } else if (type.startsWith('video/')) {
+ // Dispose the previous source
+ audioElement.current.pause();
+ audioElement.current.removeAttribute('src');
+ audioElement.current.load();
+
+ videoElement.current.src = url;
+ } else {
+ alert(`Unsupported file type: ${type}`);
+ }
+ }
+
+ const readFile = (file) => {
+ if (!file) return;
+
+ // file.type
+ const reader = new FileReader();
+ reader.onload = (e) => {
+ onBufferLoad(e.target.result, file.type);
+ }
+ reader.readAsArrayBuffer(file);
+ }
+
+ const handleInputChange = (event) => {
+ readFile(event.target.files[0]);
+ };
+
+ const handleDragOver = (event) => {
+ event.preventDefault();
+ };
+
+ const handleDrop = (event) => {
+ event.preventDefault();
+ setDragging(false);
+ readFile(event.dataTransfer.files[0]);
+ };
+
+ const handleClick = (e) => {
+ if (e.target.tagName === 'VIDEO' || e.target.tagName === 'AUDIO') {
+ e.preventDefault();
+ fileInputRef.current.click();
+ } else if (e.target.tagName === 'INPUT') {
+ e.stopPropagation();
+ } else {
+ fileInputRef.current.click();
+ e.stopPropagation();
+ }
+ };
+
+ const processFile = async (buffer) => {
+ const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
+
+ try {
+ const audioBuffer = await audioContext.decodeAudioData(buffer);
+ let audio;
+ if (audioBuffer.numberOfChannels === 2) {
+ // Merge channels
+ const SCALING_FACTOR = Math.sqrt(2);
+ const left = audioBuffer.getChannelData(0);
+ const right = audioBuffer.getChannelData(1);
+ audio = new Float32Array(left.length);
+ for (let i = 0; i < audioBuffer.length; ++i) {
+ audio[i] = SCALING_FACTOR * (left[i] + right[i]) / 2;
+ }
+ } else {
+ audio = audioBuffer.getChannelData(0);
+ }
+ onInputChange(audio);
+
+ } catch (e) {
+ alert(e);
+ }
+ };
+
+ const requestRef = useRef();
+
+ const updateTime = useCallback(() => {
+ let elem;
+ if (audioElement.current?.src) {
+ elem = audioElement.current;
+
+ } else if (videoElement.current?.src) {
+ elem = videoElement.current;
+ }
+
+ if (elem && currentTimeRef.current !== elem.currentTime) {
+ currentTimeRef.current = elem.currentTime;
+ onTimeUpdate(elem.currentTime);
+ }
+
+ // Request the next frame
+ requestRef.current = requestAnimationFrame(updateTime);
+ }, [onTimeUpdate]);
+
+ useEffect(() => {
+ // Start the animation
+ requestRef.current = requestAnimationFrame(updateTime);
+
+ return () => {
+ // Cleanup on component unmount
+ cancelAnimationFrame(requestRef.current);
+ };
+ }, [updateTime]);
+ return (
+ setDragging(true)}
+ onDragLeave={(e) => setDragging(false)}
+ >
+
+ {
+
+ }
+ {
+
+ }
+ {
+ !audioElement.current?.src && !videoElement.current?.src && (
+
+ Drag & drop or click to select media
+ {
+ e.stopPropagation();
+ const buffer = await fetch(EXAMPLE_URL).then((r) => r.arrayBuffer());
+ videoElement.current.src = URL.createObjectURL(new Blob([buffer], { type: 'video/mp4' }));
+ onBufferLoad(buffer, 'video/mp4');
+ }}>(or try an example )
+
+ )
+ }
+
+ );
+});
+MediaInput.displayName = 'MediaInput';
+
+export default MediaInput;
diff --git a/examples/whisper-word-timestamps/src/components/Progress.jsx b/examples/whisper-word-timestamps/src/components/Progress.jsx
new file mode 100644
index 000000000..9ce024cc8
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/components/Progress.jsx
@@ -0,0 +1,15 @@
+function formatBytes(size) {
+ const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
+ return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
+}
+
+export default function Progress({ text, percentage, total }) {
+ percentage ??= 0;
+ return (
+
+
+ {text} ({percentage.toFixed(2)}%{isNaN(total) ? '' : ` of ${formatBytes(total)}`})
+
+
+ );
+}
diff --git a/examples/whisper-word-timestamps/src/components/Transcript.jsx b/examples/whisper-word-timestamps/src/components/Transcript.jsx
new file mode 100644
index 000000000..542014323
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/components/Transcript.jsx
@@ -0,0 +1,68 @@
+import { useMemo } from "react";
+
+const Chunk = ({ chunk, currentTime, onClick, ...props }) => {
+ const { text, timestamp } = chunk;
+ const [start, end] = timestamp;
+
+ const bolded = start <= currentTime && currentTime < end;
+
+ return (
+
+ {text.startsWith(' ') ? " " : ""}
+ x.toFixed(2)).join(' → ')}
+ style={{
+ textDecoration: bolded ? 'underline' : 'none',
+ textShadow: bolded ? '0 0 1px #000' : 'none',
+ }}
+ >{text.trim()}
+
+ )
+}
+
+const Transcript = ({ transcript, currentTime, setCurrentTime, ...props }) => {
+
+
+ const jsonTranscript = useMemo(() => {
+ return JSON.stringify(transcript, null, 2)
+ // post-process the JSON to make it more readable
+ .replace(/( {4}"timestamp": )\[\s+(\S+)\s+(\S+)\s+\]/gm, "$1[$2 $3]");
+ }, [transcript]);
+
+ const downloadTranscript = () => {
+ const blob = new Blob([jsonTranscript], { type: 'application/json' });
+ const url = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ a.href = url;
+ a.download = 'transcript.json';
+ a.click();
+ URL.revokeObjectURL(url);
+ }
+
+ return (<>
+
+ {
+ transcript.chunks.map((chunk, i) => {
+ setCurrentTime(chunk.timestamp[0]) // Set to start of chunk
+ }} />)
+ }
+
+
+
+
+
+
+
+ Download transcript
+
+
+
+
+ >)
+};
+export default Transcript;
diff --git a/examples/whisper-word-timestamps/src/index.css b/examples/whisper-word-timestamps/src/index.css
new file mode 100644
index 000000000..87bbb9dac
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/index.css
@@ -0,0 +1,25 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
+
+@layer utilities {
+ .scrollbar-thin::-webkit-scrollbar {
+ @apply w-2;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-track {
+ @apply rounded-full bg-gray-100 dark:bg-gray-700;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb {
+ @apply rounded-full bg-gray-300 dark:bg-gray-600;
+ }
+
+ .scrollbar-thin::-webkit-scrollbar-thumb:hover {
+ @apply bg-gray-500;
+ }
+}
+
+html {
+ @apply scrollbar-thin;
+}
\ No newline at end of file
diff --git a/examples/whisper-word-timestamps/src/main.jsx b/examples/whisper-word-timestamps/src/main.jsx
new file mode 100644
index 000000000..54b39dd1d
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/main.jsx
@@ -0,0 +1,10 @@
+import React from 'react'
+import ReactDOM from 'react-dom/client'
+import App from './App.jsx'
+import './index.css'
+
+ReactDOM.createRoot(document.getElementById('root')).render(
+
+
+ ,
+)
diff --git a/examples/whisper-word-timestamps/src/worker.js b/examples/whisper-word-timestamps/src/worker.js
new file mode 100644
index 000000000..efa029a97
--- /dev/null
+++ b/examples/whisper-word-timestamps/src/worker.js
@@ -0,0 +1,94 @@
+
+import { pipeline } from '@xenova/transformers';
+
+const PER_DEVICE_CONFIG = {
+ webgpu: {
+ dtype: {
+ encoder_model: 'fp32',
+ decoder_model_merged: 'q4',
+ },
+ device: 'webgpu',
+ },
+ wasm: {
+ dtype: 'q8',
+ device: 'wasm',
+ },
+};
+
+/**
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
+ */
+class PipelineSingeton {
+ static model_id = 'onnx-community/whisper-base_timestamped';
+ static instance = null;
+
+ static async getInstance(progress_callback = null, device = 'webgpu') {
+
+ if (!this.instance) {
+ this.instance = pipeline('automatic-speech-recognition', this.model_id, {
+ ...PER_DEVICE_CONFIG[device],
+ progress_callback,
+ });
+ }
+ return this.instance;
+ }
+}
+
+async function load({ device }) {
+ self.postMessage({
+ status: 'loading',
+ data: `Loading model (${device})...`
+ });
+
+ // Load the pipeline and save it for future use.
+ const transcriber = await PipelineSingeton.getInstance(x => {
+ // We also add a progress callback to the pipeline so that we can
+ // track model loading.
+ self.postMessage(x);
+ }, device);
+
+ if (device === 'webgpu') {
+ self.postMessage({
+ status: 'loading',
+ data: 'Compiling shaders and warming up model...'
+ });
+
+ await transcriber(new Float32Array(16_000), {
+ language: 'en',
+ });
+ }
+
+ self.postMessage({ status: 'ready' });
+}
+
+async function run({ audio, language }) {
+ const transcriber = await PipelineSingeton.getInstance();
+
+ // Read and preprocess image
+ const start = performance.now();
+
+ const result = await transcriber(audio, {
+ language,
+ return_timestamps: 'word',
+ chunk_length_s: 30,
+ });
+
+ const end = performance.now();
+
+ self.postMessage({ status: 'complete', result, time: end - start });
+}
+
+// Listen for messages from the main thread
+self.addEventListener('message', async (e) => {
+ const { type, data } = e.data;
+
+ switch (type) {
+ case 'load':
+ load(data);
+ break;
+
+ case 'run':
+ run(data);
+ break;
+ }
+});
diff --git a/examples/whisper-word-timestamps/tailwind.config.js b/examples/whisper-word-timestamps/tailwind.config.js
new file mode 100644
index 000000000..d37737fc0
--- /dev/null
+++ b/examples/whisper-word-timestamps/tailwind.config.js
@@ -0,0 +1,12 @@
+/** @type {import('tailwindcss').Config} */
+export default {
+ content: [
+ "./index.html",
+ "./src/**/*.{js,ts,jsx,tsx}",
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/examples/whisper-word-timestamps/vite.config.js b/examples/whisper-word-timestamps/vite.config.js
new file mode 100644
index 000000000..5a33944a9
--- /dev/null
+++ b/examples/whisper-word-timestamps/vite.config.js
@@ -0,0 +1,7 @@
+import { defineConfig } from 'vite'
+import react from '@vitejs/plugin-react'
+
+// https://vitejs.dev/config/
+export default defineConfig({
+ plugins: [react()],
+})
diff --git a/jest.config.mjs b/jest.config.mjs
index b6a5fdb3d..0d15ce842 100644
--- a/jest.config.mjs
+++ b/jest.config.mjs
@@ -23,9 +23,10 @@ export default {
coverageDirectory: "coverage",
// An array of regexp pattern strings used to skip coverage collection
- // coveragePathIgnorePatterns: [
- // "\\\\node_modules\\\\"
- // ],
+ coveragePathIgnorePatterns: [
+ "node_modules",
+ "tests",
+ ],
// Indicates which provider should be used to instrument code for coverage
coverageProvider: "v8",
@@ -121,9 +122,7 @@ export default {
// rootDir: undefined,
// A list of paths to directories that Jest should use to search for files in
- roots: [
- "./tests/"
- ],
+ roots: ["./tests/"],
// Allows you to use a custom runner instead of Jest's default test runner
// runner: "jest-runner",
@@ -170,7 +169,7 @@ export default {
// testRunner: "jest-circus/runner",
// A map from regular expressions to paths to transformers
- transform: {}
+ transform: {},
// An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation
// transformIgnorePatterns: [
diff --git a/jsconfig.json b/jsconfig.json
index 5430d98f2..9af7d54be 100644
--- a/jsconfig.json
+++ b/jsconfig.json
@@ -1,18 +1,14 @@
{
- // Only include files in the src directory
- "include": [
- "src/**/*"
- ],
- "compilerOptions": {
- // Tells the compiler to check JS files
- "checkJs": true,
- "target": "esnext",
- "module": "esnext",
- "moduleResolution": "nodenext",
- },
- "typeAcquisition": {
- "include": [
- "jest"
- ]
- }
-}
\ No newline at end of file
+ // Only include files in the src directory
+ "include": ["src/**/*"],
+ "compilerOptions": {
+ // Tells the compiler to check JS files
+ "checkJs": true,
+ "target": "esnext",
+ "module": "nodenext",
+ "moduleResolution": "nodenext"
+ },
+ "typeAcquisition": {
+ "include": ["jest"]
+ }
+}
diff --git a/package-lock.json b/package-lock.json
index 3fba478d5..17b45d57e 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -1,33 +1,32 @@
{
- "name": "@xenova/transformers",
- "version": "2.17.2",
+ "name": "@huggingface/transformers",
+ "version": "3.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
- "name": "@xenova/transformers",
- "version": "2.17.2",
+ "name": "@huggingface/transformers",
+ "version": "3.0.0",
"license": "Apache-2.0",
"dependencies": {
- "@huggingface/jinja": "^0.2.2",
- "onnxruntime-web": "1.14.0",
- "sharp": "^0.32.0"
+ "@huggingface/jinja": "^0.3.0",
+ "onnxruntime-node": "1.19.2",
+ "onnxruntime-web": "1.20.0-dev.20241016-2b8fc5529b",
+ "sharp": "^0.33.5"
},
"devDependencies": {
"@types/jest": "^29.5.1",
+ "@webgpu/types": "^0.1.44",
"catharsis": "github:xenova/catharsis",
- "copy-webpack-plugin": "^11.0.0",
"jest": "^29.5.0",
"jest-environment-node": "^29.5.0",
"jsdoc-to-markdown": "^8.0.1",
+ "prettier": "3.3.3",
"typescript": "^5.2.2",
"wavefile": "^11.0.0",
"webpack": "^5.80.0",
"webpack-cli": "^5.0.2",
"webpack-dev-server": "^4.13.3"
- },
- "optionalDependencies": {
- "onnxruntime-node": "1.14.0"
}
},
"node_modules/@ampproject/remapping": {
@@ -744,14 +743,465 @@
"node": ">=10.0.0"
}
},
+ "node_modules/@emnapi/runtime": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.2.0.tgz",
+ "integrity": "sha512-bV21/9LQmcQeCPEg3BDFtvwL6cwiTMksYNWQQ4KOxCZikEGalWtenoZ0wCiukJINlGCIi2KXx01g4FoH/LxpzQ==",
+ "optional": true,
+ "dependencies": {
+ "tslib": "^2.4.0"
+ }
+ },
"node_modules/@huggingface/jinja": {
- "version": "0.2.2",
- "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.2.tgz",
- "integrity": "sha512-/KPde26khDUIPkTGU82jdtTW9UAuvUTumCAbFs/7giR0SxsvZC4hru51PBvpijH6BVkHcROcvZM/lpy5h1jRRA==",
+ "version": "0.3.0",
+ "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.3.0.tgz",
+ "integrity": "sha512-GLJzso0M07ZncFkrJMIXVU4os6GFbPocD4g8fMQPMGJubf48FtGOsUORH2rtFdXPIPelz8SLBMn8ZRmOTwXm9Q==",
"engines": {
"node": ">=18"
}
},
+ "node_modules/@img/sharp-darwin-arm64": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-darwin-arm64/-/sharp-darwin-arm64-0.33.5.tgz",
+ "integrity": "sha512-UT4p+iz/2H4twwAoLCqfA9UH5pI6DggwKEGuaPy7nCVQ8ZsiY5PIcrRvD1DzuY3qYL07NtIQcWnBSY/heikIFQ==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-darwin-arm64": "1.0.4"
+ }
+ },
+ "node_modules/@img/sharp-darwin-x64": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-darwin-x64/-/sharp-darwin-x64-0.33.5.tgz",
+ "integrity": "sha512-fyHac4jIc1ANYGRDxtiqelIbdWkIuQaI84Mv45KvGRRxSAa7o7d1ZKAOBaYbnepLC1WqxfpimdeWfvqqSGwR2Q==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-darwin-x64": "1.0.4"
+ }
+ },
+ "node_modules/@img/sharp-libvips-darwin-arm64": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-arm64/-/sharp-libvips-darwin-arm64-1.0.4.tgz",
+ "integrity": "sha512-XblONe153h0O2zuFfTAbQYAX2JhYmDHeWikp1LM9Hul9gVPjFY427k6dFEcOL72O01QxQsWi761svJ/ev9xEDg==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-darwin-x64": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-x64/-/sharp-libvips-darwin-x64-1.0.4.tgz",
+ "integrity": "sha512-xnGR8YuZYfJGmWPvmlunFaWJsb9T/AO2ykoP3Fz/0X5XV2aoYBPkX6xqCQvUTKKiLddarLaxpzNe+b1hjeWHAQ==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-arm": {
+ "version": "1.0.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm/-/sharp-libvips-linux-arm-1.0.5.tgz",
+ "integrity": "sha512-gvcC4ACAOPRNATg/ov8/MnbxFDJqf/pDePbBnuBDcjsI8PssmjoKMAz4LtLaVi+OnSb5FK/yIOamqDwGmXW32g==",
+ "cpu": [
+ "arm"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-arm64": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm64/-/sharp-libvips-linux-arm64-1.0.4.tgz",
+ "integrity": "sha512-9B+taZ8DlyyqzZQnoeIvDVR/2F4EbMepXMc/NdVbkzsJbzkUjhXv/70GQJ7tdLA4YJgNP25zukcxpX2/SueNrA==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-s390x": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-s390x/-/sharp-libvips-linux-s390x-1.0.4.tgz",
+ "integrity": "sha512-u7Wz6ntiSSgGSGcjZ55im6uvTrOxSIS8/dgoVMoiGE9I6JAfU50yH5BoDlYA1tcuGS7g/QNtetJnxA6QEsCVTA==",
+ "cpu": [
+ "s390x"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linux-x64": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-x64/-/sharp-libvips-linux-x64-1.0.4.tgz",
+ "integrity": "sha512-MmWmQ3iPFZr0Iev+BAgVMb3ZyC4KeFc3jFxnNbEPas60e1cIfevbtuyf9nDGIzOaW9PdnDciJm+wFFaTlj5xYw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linuxmusl-arm64": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-arm64/-/sharp-libvips-linuxmusl-arm64-1.0.4.tgz",
+ "integrity": "sha512-9Ti+BbTYDcsbp4wfYib8Ctm1ilkugkA/uscUn6UXK1ldpC1JjiXbLfFZtRlBhjPZ5o1NCLiDbg8fhUPKStHoTA==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-libvips-linuxmusl-x64": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-x64/-/sharp-libvips-linuxmusl-x64-1.0.4.tgz",
+ "integrity": "sha512-viYN1KX9m+/hGkJtvYYp+CCLgnJXwiQB39damAO7WMdKWlIhmYTfHjwSbQeUK/20vY154mwezd9HflVFM1wVSw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-linux-arm": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm/-/sharp-linux-arm-0.33.5.tgz",
+ "integrity": "sha512-JTS1eldqZbJxjvKaAkxhZmBqPRGmxgu+qFKSInv8moZ2AmT5Yib3EQ1c6gp493HvrvV8QgdOXdyaIBrhvFhBMQ==",
+ "cpu": [
+ "arm"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-arm": "1.0.5"
+ }
+ },
+ "node_modules/@img/sharp-linux-arm64": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm64/-/sharp-linux-arm64-0.33.5.tgz",
+ "integrity": "sha512-JMVv+AMRyGOHtO1RFBiJy/MBsgz0x4AWrT6QoEVVTyh1E39TrCUpTRI7mx9VksGX4awWASxqCYLCV4wBZHAYxA==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-arm64": "1.0.4"
+ }
+ },
+ "node_modules/@img/sharp-linux-s390x": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-s390x/-/sharp-linux-s390x-0.33.5.tgz",
+ "integrity": "sha512-y/5PCd+mP4CA/sPDKl2961b+C9d+vPAveS33s6Z3zfASk2j5upL6fXVPZi7ztePZ5CuH+1kW8JtvxgbuXHRa4Q==",
+ "cpu": [
+ "s390x"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-s390x": "1.0.4"
+ }
+ },
+ "node_modules/@img/sharp-linux-x64": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linux-x64/-/sharp-linux-x64-0.33.5.tgz",
+ "integrity": "sha512-opC+Ok5pRNAzuvq1AG0ar+1owsu842/Ab+4qvU879ippJBHvyY5n2mxF1izXqkPYlGuP/M556uh53jRLJmzTWA==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linux-x64": "1.0.4"
+ }
+ },
+ "node_modules/@img/sharp-linuxmusl-arm64": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-arm64/-/sharp-linuxmusl-arm64-0.33.5.tgz",
+ "integrity": "sha512-XrHMZwGQGvJg2V/oRSUfSAfjfPxO+4DkiRh6p2AFjLQztWUuY/o8Mq0eMQVIY7HJ1CDQUJlxGGZRw1a5bqmd1g==",
+ "cpu": [
+ "arm64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linuxmusl-arm64": "1.0.4"
+ }
+ },
+ "node_modules/@img/sharp-linuxmusl-x64": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-x64/-/sharp-linuxmusl-x64-0.33.5.tgz",
+ "integrity": "sha512-WT+d/cgqKkkKySYmqoZ8y3pxx7lx9vVejxW/W4DOFMYVSkErR+w7mf2u8m/y4+xHe7yY9DAXQMWQhpnMuFfScw==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ },
+ "optionalDependencies": {
+ "@img/sharp-libvips-linuxmusl-x64": "1.0.4"
+ }
+ },
+ "node_modules/@img/sharp-wasm32": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-wasm32/-/sharp-wasm32-0.33.5.tgz",
+ "integrity": "sha512-ykUW4LVGaMcU9lu9thv85CbRMAwfeadCJHRsg2GmeRa/cJxsVY9Rbd57JcMxBkKHag5U/x7TSBpScF4U8ElVzg==",
+ "cpu": [
+ "wasm32"
+ ],
+ "optional": true,
+ "dependencies": {
+ "@emnapi/runtime": "^1.2.0"
+ },
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-win32-ia32": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-win32-ia32/-/sharp-win32-ia32-0.33.5.tgz",
+ "integrity": "sha512-T36PblLaTwuVJ/zw/LaH0PdZkRz5rd3SmMHX8GSmR7vtNSP5Z6bQkExdSK7xGWyxLw4sUknBuugTelgw2faBbQ==",
+ "cpu": [
+ "ia32"
+ ],
+ "optional": true,
+ "os": [
+ "win32"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@img/sharp-win32-x64": {
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/@img/sharp-win32-x64/-/sharp-win32-x64-0.33.5.tgz",
+ "integrity": "sha512-MpY/o8/8kj+EcnxwvrP4aTJSWw/aZ7JIGR4aBeZkZw5B7/Jn+tY9/VNwtcoGmdT7GfggGIU4kygOMSbYnOrAbg==",
+ "cpu": [
+ "x64"
+ ],
+ "optional": true,
+ "os": [
+ "win32"
+ ],
+ "engines": {
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
+ "node_modules/@isaacs/cliui": {
+ "version": "8.0.2",
+ "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz",
+ "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==",
+ "dependencies": {
+ "string-width": "^5.1.2",
+ "string-width-cjs": "npm:string-width@^4.2.0",
+ "strip-ansi": "^7.0.1",
+ "strip-ansi-cjs": "npm:strip-ansi@^6.0.1",
+ "wrap-ansi": "^8.1.0",
+ "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0"
+ },
+ "engines": {
+ "node": ">=12"
+ }
+ },
+ "node_modules/@isaacs/cliui/node_modules/ansi-regex": {
+ "version": "6.0.1",
+ "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.0.1.tgz",
+ "integrity": "sha512-n5M855fKb2SsfMIiFFoVrABHJC8QtHwVx+mHWP3QcEqBHYienj5dHSgjbxtC0WEZXYt4wcD6zrQElDPhFuZgfA==",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-regex?sponsor=1"
+ }
+ },
+ "node_modules/@isaacs/cliui/node_modules/ansi-styles": {
+ "version": "6.2.1",
+ "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz",
+ "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-styles?sponsor=1"
+ }
+ },
+ "node_modules/@isaacs/cliui/node_modules/emoji-regex": {
+ "version": "9.2.2",
+ "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz",
+ "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg=="
+ },
+ "node_modules/@isaacs/cliui/node_modules/string-width": {
+ "version": "5.1.2",
+ "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz",
+ "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==",
+ "dependencies": {
+ "eastasianwidth": "^0.2.0",
+ "emoji-regex": "^9.2.2",
+ "strip-ansi": "^7.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/@isaacs/cliui/node_modules/strip-ansi": {
+ "version": "7.1.0",
+ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz",
+ "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==",
+ "dependencies": {
+ "ansi-regex": "^6.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/strip-ansi?sponsor=1"
+ }
+ },
+ "node_modules/@isaacs/cliui/node_modules/wrap-ansi": {
+ "version": "8.1.0",
+ "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz",
+ "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==",
+ "dependencies": {
+ "ansi-styles": "^6.1.0",
+ "string-width": "^5.0.1",
+ "strip-ansi": "^7.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/wrap-ansi?sponsor=1"
+ }
+ },
+ "node_modules/@isaacs/fs-minipass": {
+ "version": "4.0.1",
+ "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz",
+ "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==",
+ "dependencies": {
+ "minipass": "^7.0.4"
+ },
+ "engines": {
+ "node": ">=18.0.0"
+ }
+ },
"node_modules/@istanbuljs/load-nyc-config": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/@istanbuljs/load-nyc-config/-/load-nyc-config-1.1.0.tgz",
@@ -1116,14 +1566,14 @@
}
},
"node_modules/@jridgewell/gen-mapping": {
- "version": "0.3.3",
- "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz",
- "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==",
+ "version": "0.3.5",
+ "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz",
+ "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==",
"dev": true,
"dependencies": {
- "@jridgewell/set-array": "^1.0.1",
+ "@jridgewell/set-array": "^1.2.1",
"@jridgewell/sourcemap-codec": "^1.4.10",
- "@jridgewell/trace-mapping": "^0.3.9"
+ "@jridgewell/trace-mapping": "^0.3.24"
},
"engines": {
"node": ">=6.0.0"
@@ -1139,22 +1589,22 @@
}
},
"node_modules/@jridgewell/set-array": {
- "version": "1.1.2",
- "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz",
- "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==",
+ "version": "1.2.1",
+ "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz",
+ "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==",
"dev": true,
"engines": {
"node": ">=6.0.0"
}
},
"node_modules/@jridgewell/source-map": {
- "version": "0.3.3",
- "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.3.tgz",
- "integrity": "sha512-b+fsZXeLYi9fEULmfBrhxn4IrPlINf8fiNarzTof004v3lFdntdwa9PF7vFJqm3mg7s+ScJMxXaE3Acp1irZcg==",
+ "version": "0.3.6",
+ "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz",
+ "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==",
"dev": true,
"dependencies": {
- "@jridgewell/gen-mapping": "^0.3.0",
- "@jridgewell/trace-mapping": "^0.3.9"
+ "@jridgewell/gen-mapping": "^0.3.5",
+ "@jridgewell/trace-mapping": "^0.3.25"
}
},
"node_modules/@jridgewell/sourcemap-codec": {
@@ -1164,21 +1614,15 @@
"dev": true
},
"node_modules/@jridgewell/trace-mapping": {
- "version": "0.3.18",
- "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.18.tgz",
- "integrity": "sha512-w+niJYzMHdd7USdiH2U6869nqhD2nbfZXND5Yp93qIbEmnDNk7PD48o+YchRVpzMU7M6jVCbenTR7PA1FLQ9pA==",
+ "version": "0.3.25",
+ "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz",
+ "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==",
"dev": true,
"dependencies": {
- "@jridgewell/resolve-uri": "3.1.0",
- "@jridgewell/sourcemap-codec": "1.4.14"
+ "@jridgewell/resolve-uri": "^3.1.0",
+ "@jridgewell/sourcemap-codec": "^1.4.14"
}
},
- "node_modules/@jridgewell/trace-mapping/node_modules/@jridgewell/sourcemap-codec": {
- "version": "1.4.14",
- "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz",
- "integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==",
- "dev": true
- },
"node_modules/@jsdoc/salty": {
"version": "0.2.5",
"resolved": "https://registry.npmjs.org/@jsdoc/salty/-/salty-0.2.5.tgz",
@@ -1206,39 +1650,13 @@
"semver": "bin/semver.js"
}
},
- "node_modules/@nodelib/fs.scandir": {
- "version": "2.1.5",
- "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
- "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==",
- "dev": true,
- "dependencies": {
- "@nodelib/fs.stat": "2.0.5",
- "run-parallel": "^1.1.9"
- },
- "engines": {
- "node": ">= 8"
- }
- },
- "node_modules/@nodelib/fs.stat": {
- "version": "2.0.5",
- "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz",
- "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==",
- "dev": true,
- "engines": {
- "node": ">= 8"
- }
- },
- "node_modules/@nodelib/fs.walk": {
- "version": "1.2.8",
- "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz",
- "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==",
- "dev": true,
- "dependencies": {
- "@nodelib/fs.scandir": "2.1.5",
- "fastq": "^1.6.0"
- },
+ "node_modules/@pkgjs/parseargs": {
+ "version": "0.11.0",
+ "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz",
+ "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==",
+ "optional": true,
"engines": {
- "node": ">= 8"
+ "node": ">=14"
}
},
"node_modules/@protobufjs/aspromise": {
@@ -1398,30 +1816,10 @@
"@types/node": "*"
}
},
- "node_modules/@types/eslint": {
- "version": "8.37.0",
- "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.37.0.tgz",
- "integrity": "sha512-Piet7dG2JBuDIfohBngQ3rCt7MgO9xCO4xIMKxBThCq5PNRB91IjlJ10eJVwfoNtvTErmxLzwBZ7rHZtbOMmFQ==",
- "dev": true,
- "dependencies": {
- "@types/estree": "*",
- "@types/json-schema": "*"
- }
- },
- "node_modules/@types/eslint-scope": {
- "version": "3.7.4",
- "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.4.tgz",
- "integrity": "sha512-9K4zoImiZc3HlIp6AVUDE4CWYx22a+lhSZMYNpbjW04+YF0KWj4pJXnEMjdnFTiQibFFmElcsasJXDbdI/EPhA==",
- "dev": true,
- "dependencies": {
- "@types/eslint": "*",
- "@types/estree": "*"
- }
- },
"node_modules/@types/estree": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.1.tgz",
- "integrity": "sha512-LG4opVs2ANWZ1TJoKc937iMmNstM/d0ae1vNbnBvBhqCSezgVUOzcLCqbI5elV8Vy6WKwKjaqR+zO9VKirBBCA==",
+ "version": "1.0.5",
+ "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz",
+ "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==",
"dev": true
},
"node_modules/@types/express": {
@@ -1621,151 +2019,157 @@
"dev": true
},
"node_modules/@webassemblyjs/ast": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.11.5.tgz",
- "integrity": "sha512-LHY/GSAZZRpsNQH+/oHqhRQ5FT7eoULcBqgfyTB5nQHogFnK3/7QoN7dLnwSE/JkUAF0SrRuclT7ODqMFtWxxQ==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.12.1.tgz",
+ "integrity": "sha512-EKfMUOPRRUTy5UII4qJDGPpqfwjOmZ5jeGFwid9mnoqIFK+e0vqoi1qH56JpmZSzEL53jKnNzScdmftJyG5xWg==",
"dev": true,
"dependencies": {
- "@webassemblyjs/helper-numbers": "1.11.5",
- "@webassemblyjs/helper-wasm-bytecode": "1.11.5"
+ "@webassemblyjs/helper-numbers": "1.11.6",
+ "@webassemblyjs/helper-wasm-bytecode": "1.11.6"
}
},
"node_modules/@webassemblyjs/floating-point-hex-parser": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.5.tgz",
- "integrity": "sha512-1j1zTIC5EZOtCplMBG/IEwLtUojtwFVwdyVMbL/hwWqbzlQoJsWCOavrdnLkemwNoC/EOwtUFch3fuo+cbcXYQ==",
+ "version": "1.11.6",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.6.tgz",
+ "integrity": "sha512-ejAj9hfRJ2XMsNHk/v6Fu2dGS+i4UaXBXGemOfQ/JfQ6mdQg/WXtwleQRLLS4OvfDhv8rYnVwH27YJLMyYsxhw==",
"dev": true
},
"node_modules/@webassemblyjs/helper-api-error": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.5.tgz",
- "integrity": "sha512-L65bDPmfpY0+yFrsgz8b6LhXmbbs38OnwDCf6NpnMUYqa+ENfE5Dq9E42ny0qz/PdR0LJyq/T5YijPnU8AXEpA==",
+ "version": "1.11.6",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.6.tgz",
+ "integrity": "sha512-o0YkoP4pVu4rN8aTJgAyj9hC2Sv5UlkzCHhxqWj8butaLvnpdc2jOwh4ewE6CX0txSfLn/UYaV/pheS2Txg//Q==",
"dev": true
},
"node_modules/@webassemblyjs/helper-buffer": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.5.tgz",
- "integrity": "sha512-fDKo1gstwFFSfacIeH5KfwzjykIE6ldh1iH9Y/8YkAZrhmu4TctqYjSh7t0K2VyDSXOZJ1MLhht/k9IvYGcIxg==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.12.1.tgz",
+ "integrity": "sha512-nzJwQw99DNDKr9BVCOZcLuJJUlqkJh+kVzVl6Fmq/tI5ZtEyWT1KZMyOXltXLZJmDtvLCDgwsyrkohEtopTXCw==",
"dev": true
},
"node_modules/@webassemblyjs/helper-numbers": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.5.tgz",
- "integrity": "sha512-DhykHXM0ZABqfIGYNv93A5KKDw/+ywBFnuWybZZWcuzWHfbp21wUfRkbtz7dMGwGgT4iXjWuhRMA2Mzod6W4WA==",
+ "version": "1.11.6",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.6.tgz",
+ "integrity": "sha512-vUIhZ8LZoIWHBohiEObxVm6hwP034jwmc9kuq5GdHZH0wiLVLIPcMCdpJzG4C11cHoQ25TFIQj9kaVADVX7N3g==",
"dev": true,
"dependencies": {
- "@webassemblyjs/floating-point-hex-parser": "1.11.5",
- "@webassemblyjs/helper-api-error": "1.11.5",
+ "@webassemblyjs/floating-point-hex-parser": "1.11.6",
+ "@webassemblyjs/helper-api-error": "1.11.6",
"@xtuc/long": "4.2.2"
}
},
"node_modules/@webassemblyjs/helper-wasm-bytecode": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.5.tgz",
- "integrity": "sha512-oC4Qa0bNcqnjAowFn7MPCETQgDYytpsfvz4ujZz63Zu/a/v71HeCAAmZsgZ3YVKec3zSPYytG3/PrRCqbtcAvA==",
+ "version": "1.11.6",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.6.tgz",
+ "integrity": "sha512-sFFHKwcmBprO9e7Icf0+gddyWYDViL8bpPjJJl0WHxCdETktXdmtWLGVzoHbqUcY4Be1LkNfwTmXOJUFZYSJdA==",
"dev": true
},
"node_modules/@webassemblyjs/helper-wasm-section": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.5.tgz",
- "integrity": "sha512-uEoThA1LN2NA+K3B9wDo3yKlBfVtC6rh0i4/6hvbz071E8gTNZD/pT0MsBf7MeD6KbApMSkaAK0XeKyOZC7CIA==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.12.1.tgz",
+ "integrity": "sha512-Jif4vfB6FJlUlSbgEMHUyk1j234GTNG9dBJ4XJdOySoj518Xj0oGsNi59cUQF4RRMS9ouBUxDDdyBVfPTypa5g==",
"dev": true,
"dependencies": {
- "@webassemblyjs/ast": "1.11.5",
- "@webassemblyjs/helper-buffer": "1.11.5",
- "@webassemblyjs/helper-wasm-bytecode": "1.11.5",
- "@webassemblyjs/wasm-gen": "1.11.5"
+ "@webassemblyjs/ast": "1.12.1",
+ "@webassemblyjs/helper-buffer": "1.12.1",
+ "@webassemblyjs/helper-wasm-bytecode": "1.11.6",
+ "@webassemblyjs/wasm-gen": "1.12.1"
}
},
"node_modules/@webassemblyjs/ieee754": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.5.tgz",
- "integrity": "sha512-37aGq6qVL8A8oPbPrSGMBcp38YZFXcHfiROflJn9jxSdSMMM5dS5P/9e2/TpaJuhE+wFrbukN2WI6Hw9MH5acg==",
+ "version": "1.11.6",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.6.tgz",
+ "integrity": "sha512-LM4p2csPNvbij6U1f19v6WR56QZ8JcHg3QIJTlSwzFcmx6WSORicYj6I63f9yU1kEUtrpG+kjkiIAkevHpDXrg==",
"dev": true,
"dependencies": {
"@xtuc/ieee754": "^1.2.0"
}
},
"node_modules/@webassemblyjs/leb128": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.5.tgz",
- "integrity": "sha512-ajqrRSXaTJoPW+xmkfYN6l8VIeNnR4vBOTQO9HzR7IygoCcKWkICbKFbVTNMjMgMREqXEr0+2M6zukzM47ZUfQ==",
+ "version": "1.11.6",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.6.tgz",
+ "integrity": "sha512-m7a0FhE67DQXgouf1tbN5XQcdWoNgaAuoULHIfGFIEVKA6tu/edls6XnIlkmS6FrXAquJRPni3ZZKjw6FSPjPQ==",
"dev": true,
"dependencies": {
"@xtuc/long": "4.2.2"
}
},
"node_modules/@webassemblyjs/utf8": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.5.tgz",
- "integrity": "sha512-WiOhulHKTZU5UPlRl53gHR8OxdGsSOxqfpqWeA2FmcwBMaoEdz6b2x2si3IwC9/fSPLfe8pBMRTHVMk5nlwnFQ==",
+ "version": "1.11.6",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.6.tgz",
+ "integrity": "sha512-vtXf2wTQ3+up9Zsg8sa2yWiQpzSsMyXj0qViVP6xKGCUT8p8YJ6HqI7l5eCnWx1T/FYdsv07HQs2wTFbbof/RA==",
"dev": true
},
"node_modules/@webassemblyjs/wasm-edit": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.5.tgz",
- "integrity": "sha512-C0p9D2fAu3Twwqvygvf42iGCQ4av8MFBLiTb+08SZ4cEdwzWx9QeAHDo1E2k+9s/0w1DM40oflJOpkZ8jW4HCQ==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.12.1.tgz",
+ "integrity": "sha512-1DuwbVvADvS5mGnXbE+c9NfA8QRcZ6iKquqjjmR10k6o+zzsRVesil54DKexiowcFCPdr/Q0qaMgB01+SQ1u6g==",
"dev": true,
"dependencies": {
- "@webassemblyjs/ast": "1.11.5",
- "@webassemblyjs/helper-buffer": "1.11.5",
- "@webassemblyjs/helper-wasm-bytecode": "1.11.5",
- "@webassemblyjs/helper-wasm-section": "1.11.5",
- "@webassemblyjs/wasm-gen": "1.11.5",
- "@webassemblyjs/wasm-opt": "1.11.5",
- "@webassemblyjs/wasm-parser": "1.11.5",
- "@webassemblyjs/wast-printer": "1.11.5"
+ "@webassemblyjs/ast": "1.12.1",
+ "@webassemblyjs/helper-buffer": "1.12.1",
+ "@webassemblyjs/helper-wasm-bytecode": "1.11.6",
+ "@webassemblyjs/helper-wasm-section": "1.12.1",
+ "@webassemblyjs/wasm-gen": "1.12.1",
+ "@webassemblyjs/wasm-opt": "1.12.1",
+ "@webassemblyjs/wasm-parser": "1.12.1",
+ "@webassemblyjs/wast-printer": "1.12.1"
}
},
"node_modules/@webassemblyjs/wasm-gen": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.5.tgz",
- "integrity": "sha512-14vteRlRjxLK9eSyYFvw1K8Vv+iPdZU0Aebk3j6oB8TQiQYuO6hj9s4d7qf6f2HJr2khzvNldAFG13CgdkAIfA==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.12.1.tgz",
+ "integrity": "sha512-TDq4Ojh9fcohAw6OIMXqiIcTq5KUXTGRkVxbSo1hQnSy6lAM5GSdfwWeSxpAo0YzgsgF182E/U0mDNhuA0tW7w==",
"dev": true,
"dependencies": {
- "@webassemblyjs/ast": "1.11.5",
- "@webassemblyjs/helper-wasm-bytecode": "1.11.5",
- "@webassemblyjs/ieee754": "1.11.5",
- "@webassemblyjs/leb128": "1.11.5",
- "@webassemblyjs/utf8": "1.11.5"
+ "@webassemblyjs/ast": "1.12.1",
+ "@webassemblyjs/helper-wasm-bytecode": "1.11.6",
+ "@webassemblyjs/ieee754": "1.11.6",
+ "@webassemblyjs/leb128": "1.11.6",
+ "@webassemblyjs/utf8": "1.11.6"
}
},
"node_modules/@webassemblyjs/wasm-opt": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.5.tgz",
- "integrity": "sha512-tcKwlIXstBQgbKy1MlbDMlXaxpucn42eb17H29rawYLxm5+MsEmgPzeCP8B1Cl69hCice8LeKgZpRUAPtqYPgw==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.12.1.tgz",
+ "integrity": "sha512-Jg99j/2gG2iaz3hijw857AVYekZe2SAskcqlWIZXjji5WStnOpVoat3gQfT/Q5tb2djnCjBtMocY/Su1GfxPBg==",
"dev": true,
"dependencies": {
- "@webassemblyjs/ast": "1.11.5",
- "@webassemblyjs/helper-buffer": "1.11.5",
- "@webassemblyjs/wasm-gen": "1.11.5",
- "@webassemblyjs/wasm-parser": "1.11.5"
+ "@webassemblyjs/ast": "1.12.1",
+ "@webassemblyjs/helper-buffer": "1.12.1",
+ "@webassemblyjs/wasm-gen": "1.12.1",
+ "@webassemblyjs/wasm-parser": "1.12.1"
}
},
"node_modules/@webassemblyjs/wasm-parser": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.5.tgz",
- "integrity": "sha512-SVXUIwsLQlc8srSD7jejsfTU83g7pIGr2YYNb9oHdtldSxaOhvA5xwvIiWIfcX8PlSakgqMXsLpLfbbJ4cBYew==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.12.1.tgz",
+ "integrity": "sha512-xikIi7c2FHXysxXe3COrVUPSheuBtpcfhbpFj4gmu7KRLYOzANztwUU0IbsqvMqzuNK2+glRGWCEqZo1WCLyAQ==",
"dev": true,
"dependencies": {
- "@webassemblyjs/ast": "1.11.5",
- "@webassemblyjs/helper-api-error": "1.11.5",
- "@webassemblyjs/helper-wasm-bytecode": "1.11.5",
- "@webassemblyjs/ieee754": "1.11.5",
- "@webassemblyjs/leb128": "1.11.5",
- "@webassemblyjs/utf8": "1.11.5"
+ "@webassemblyjs/ast": "1.12.1",
+ "@webassemblyjs/helper-api-error": "1.11.6",
+ "@webassemblyjs/helper-wasm-bytecode": "1.11.6",
+ "@webassemblyjs/ieee754": "1.11.6",
+ "@webassemblyjs/leb128": "1.11.6",
+ "@webassemblyjs/utf8": "1.11.6"
}
},
"node_modules/@webassemblyjs/wast-printer": {
- "version": "1.11.5",
- "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.11.5.tgz",
- "integrity": "sha512-f7Pq3wvg3GSPUPzR0F6bmI89Hdb+u9WXrSKc4v+N0aV0q6r42WoF92Jp2jEorBEBRoRNXgjp53nBniDXcqZYPA==",
+ "version": "1.12.1",
+ "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.12.1.tgz",
+ "integrity": "sha512-+X4WAlOisVWQMikjbcvY2e0rwPsKQ9F688lksZhBcPycBBuii3O7m8FACbDMWDojpAqvjIncrG8J0XHKyQfVeA==",
"dev": true,
"dependencies": {
- "@webassemblyjs/ast": "1.11.5",
+ "@webassemblyjs/ast": "1.12.1",
"@xtuc/long": "4.2.2"
}
},
+ "node_modules/@webgpu/types": {
+ "version": "0.1.44",
+ "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.44.tgz",
+ "integrity": "sha512-JDpYJN5E/asw84LTYhKyvPpxGnD+bAKPtpW9Ilurf7cZpxaTbxkQcGwOd7jgB9BPBrTYQ+32ufo4HiuomTjHNQ==",
+ "dev": true
+ },
"node_modules/@webpack-cli/configtest": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/@webpack-cli/configtest/-/configtest-2.0.1.tgz",
@@ -1836,9 +2240,9 @@
}
},
"node_modules/acorn": {
- "version": "8.8.2",
- "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.2.tgz",
- "integrity": "sha512-xjIYgE8HBrkpd/sJqOGNspf8uHG+NOHGOw6a/Urj8taM2EXfdNAH2oFcPeIFfsv3+kz/mJrS5VuMqbNLjCa2vw==",
+ "version": "8.12.1",
+ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.12.1.tgz",
+ "integrity": "sha512-tcpGyI9zbizT9JbV6oYE477V6mTlXvvi0T0G3SNIYE2apm/G5huBa1+K89VGeovbg+jycCrfhl3ADxErOuO6Jg==",
"dev": true,
"bin": {
"acorn": "bin/acorn"
@@ -1847,10 +2251,10 @@
"node": ">=0.4.0"
}
},
- "node_modules/acorn-import-assertions": {
- "version": "1.8.0",
- "resolved": "https://registry.npmjs.org/acorn-import-assertions/-/acorn-import-assertions-1.8.0.tgz",
- "integrity": "sha512-m7VZ3jwz4eK6A4Vtt8Ew1/mNbP24u0FhdyfA7fSvnJR6LMdfOYnmuIrrJAgrYfYJ10F/otaHTtrtrtmHdMNzEw==",
+ "node_modules/acorn-import-attributes": {
+ "version": "1.9.5",
+ "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz",
+ "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==",
"dev": true,
"peerDependencies": {
"acorn": "^8"
@@ -1972,7 +2376,6 @@
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz",
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
- "dev": true,
"engines": {
"node": ">=8"
}
@@ -1981,7 +2384,6 @@
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
- "dev": true,
"dependencies": {
"color-convert": "^2.0.1"
},
@@ -2026,11 +2428,6 @@
"integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==",
"dev": true
},
- "node_modules/b4a": {
- "version": "1.6.4",
- "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz",
- "integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw=="
- },
"node_modules/babel-jest": {
"version": "29.6.1",
"resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-29.6.1.tgz",
@@ -2134,27 +2531,7 @@
"node_modules/balanced-match": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
- "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==",
- "dev": true
- },
- "node_modules/base64-js": {
- "version": "1.5.1",
- "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz",
- "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
+ "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw=="
},
"node_modules/batch": {
"version": "0.6.1",
@@ -2171,16 +2548,6 @@
"node": ">=8"
}
},
- "node_modules/bl": {
- "version": "4.1.0",
- "resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz",
- "integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==",
- "dependencies": {
- "buffer": "^5.5.0",
- "inherits": "^2.0.4",
- "readable-stream": "^3.4.0"
- }
- },
"node_modules/bluebird": {
"version": "3.7.2",
"resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.7.2.tgz",
@@ -2188,9 +2555,9 @@
"dev": true
},
"node_modules/body-parser": {
- "version": "1.20.2",
- "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz",
- "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==",
+ "version": "1.20.3",
+ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz",
+ "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==",
"dev": true,
"dependencies": {
"bytes": "3.1.2",
@@ -2201,7 +2568,7 @@
"http-errors": "2.0.0",
"iconv-lite": "0.4.24",
"on-finished": "2.4.1",
- "qs": "6.11.0",
+ "qs": "6.13.0",
"raw-body": "2.5.2",
"type-is": "~1.6.18",
"unpipe": "1.0.0"
@@ -2255,9 +2622,9 @@
}
},
"node_modules/browserslist": {
- "version": "4.21.9",
- "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.9.tgz",
- "integrity": "sha512-M0MFoZzbUrRU4KNfCrDLnvyE7gub+peetoTid3TBIqtunaDJyXlwhakT+/VkvSXcfIzFfK/nkCs4nmyTmxdNSg==",
+ "version": "4.23.3",
+ "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.23.3.tgz",
+ "integrity": "sha512-btwCFJVjI4YWDNfau8RhZ+B1Q/VLoUITrm3RlP6y1tYGWIOa+InuYiRGXUBXo8nA1qKmHMyLB/iVQg5TT4eFoA==",
"dev": true,
"funding": [
{
@@ -2274,10 +2641,10 @@
}
],
"dependencies": {
- "caniuse-lite": "^1.0.30001503",
- "electron-to-chromium": "^1.4.431",
- "node-releases": "^2.0.12",
- "update-browserslist-db": "^1.0.11"
+ "caniuse-lite": "^1.0.30001646",
+ "electron-to-chromium": "^1.5.4",
+ "node-releases": "^2.0.18",
+ "update-browserslist-db": "^1.1.0"
},
"bin": {
"browserslist": "cli.js"
@@ -2295,29 +2662,6 @@
"node-int64": "^0.4.0"
}
},
- "node_modules/buffer": {
- "version": "5.7.1",
- "resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz",
- "integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ],
- "dependencies": {
- "base64-js": "^1.3.1",
- "ieee754": "^1.1.13"
- }
- },
"node_modules/buffer-from": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz",
@@ -2394,9 +2738,9 @@
}
},
"node_modules/caniuse-lite": {
- "version": "1.0.30001513",
- "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001513.tgz",
- "integrity": "sha512-pnjGJo7SOOjAGytZZ203Em95MRM8Cr6jhCXNF/FAXTpCTRTECnqQWLpiTRqrFtdYcth8hf4WECUpkezuYsMVww==",
+ "version": "1.0.30001653",
+ "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001653.tgz",
+ "integrity": "sha512-XGWQVB8wFQ2+9NZwZ10GxTYC5hk0Fa+q8cSkr0tgvMhYhMHP/QC+WTgrePMDBWiWc/pV+1ik82Al20XOK25Gcw==",
"dev": true,
"funding": [
{
@@ -2490,9 +2834,12 @@
}
},
"node_modules/chownr": {
- "version": "1.1.4",
- "resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz",
- "integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg=="
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz",
+ "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==",
+ "engines": {
+ "node": ">=18"
+ }
},
"node_modules/chrome-trace-event": {
"version": "1.0.3",
@@ -2814,119 +3161,30 @@
"integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==",
"dev": true,
"engines": {
- "node": ">= 0.6"
- }
- },
- "node_modules/convert-source-map": {
- "version": "2.0.0",
- "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz",
- "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==",
- "dev": true
- },
- "node_modules/cookie": {
- "version": "0.6.0",
- "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz",
- "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==",
- "dev": true,
- "engines": {
- "node": ">= 0.6"
- }
- },
- "node_modules/cookie-signature": {
- "version": "1.0.6",
- "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz",
- "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==",
- "dev": true
- },
- "node_modules/copy-webpack-plugin": {
- "version": "11.0.0",
- "resolved": "https://registry.npmjs.org/copy-webpack-plugin/-/copy-webpack-plugin-11.0.0.tgz",
- "integrity": "sha512-fX2MWpamkW0hZxMEg0+mYnA40LTosOSa5TqZ9GYIBzyJa9C3QUaMPSE2xAi/buNr8u89SfD9wHSQVBzrRa/SOQ==",
- "dev": true,
- "dependencies": {
- "fast-glob": "^3.2.11",
- "glob-parent": "^6.0.1",
- "globby": "^13.1.1",
- "normalize-path": "^3.0.0",
- "schema-utils": "^4.0.0",
- "serialize-javascript": "^6.0.0"
- },
- "engines": {
- "node": ">= 14.15.0"
- },
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/webpack"
- },
- "peerDependencies": {
- "webpack": "^5.1.0"
- }
- },
- "node_modules/copy-webpack-plugin/node_modules/ajv": {
- "version": "8.12.0",
- "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz",
- "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==",
- "dev": true,
- "dependencies": {
- "fast-deep-equal": "^3.1.1",
- "json-schema-traverse": "^1.0.0",
- "require-from-string": "^2.0.2",
- "uri-js": "^4.2.2"
- },
- "funding": {
- "type": "github",
- "url": "https://github.com/sponsors/epoberezkin"
- }
- },
- "node_modules/copy-webpack-plugin/node_modules/ajv-keywords": {
- "version": "5.1.0",
- "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz",
- "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==",
- "dev": true,
- "dependencies": {
- "fast-deep-equal": "^3.1.3"
- },
- "peerDependencies": {
- "ajv": "^8.8.2"
- }
- },
- "node_modules/copy-webpack-plugin/node_modules/glob-parent": {
- "version": "6.0.2",
- "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz",
- "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==",
- "dev": true,
- "dependencies": {
- "is-glob": "^4.0.3"
- },
- "engines": {
- "node": ">=10.13.0"
+ "node": ">= 0.6"
}
},
- "node_modules/copy-webpack-plugin/node_modules/json-schema-traverse": {
- "version": "1.0.0",
- "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz",
- "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==",
+ "node_modules/convert-source-map": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz",
+ "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==",
"dev": true
},
- "node_modules/copy-webpack-plugin/node_modules/schema-utils": {
- "version": "4.0.1",
- "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.0.1.tgz",
- "integrity": "sha512-lELhBAAly9NowEsX0yZBlw9ahZG+sK/1RJ21EpzdYHKEs13Vku3LJ+MIPhh4sMs0oCCeufZQEQbMekiA4vuVIQ==",
+ "node_modules/cookie": {
+ "version": "0.7.1",
+ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz",
+ "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==",
"dev": true,
- "dependencies": {
- "@types/json-schema": "^7.0.9",
- "ajv": "^8.9.0",
- "ajv-formats": "^2.1.1",
- "ajv-keywords": "^5.1.0"
- },
"engines": {
- "node": ">= 12.13.0"
- },
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/webpack"
+ "node": ">= 0.6"
}
},
+ "node_modules/cookie-signature": {
+ "version": "1.0.6",
+ "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz",
+ "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==",
+ "dev": true
+ },
"node_modules/core-util-is": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz",
@@ -2937,7 +3195,6 @@
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
"integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==",
- "dev": true,
"dependencies": {
"path-key": "^3.1.0",
"shebang-command": "^2.0.0",
@@ -2956,20 +3213,6 @@
"ms": "2.0.0"
}
},
- "node_modules/decompress-response": {
- "version": "6.0.0",
- "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz",
- "integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==",
- "dependencies": {
- "mimic-response": "^3.1.0"
- },
- "engines": {
- "node": ">=10"
- },
- "funding": {
- "url": "https://github.com/sponsors/sindresorhus"
- }
- },
"node_modules/dedent": {
"version": "0.7.0",
"resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz",
@@ -2980,6 +3223,7 @@
"version": "0.6.0",
"resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz",
"integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==",
+ "dev": true,
"engines": {
"node": ">=4.0.0"
}
@@ -3051,9 +3295,9 @@
}
},
"node_modules/detect-libc": {
- "version": "2.0.2",
- "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz",
- "integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==",
+ "version": "2.0.3",
+ "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.3.tgz",
+ "integrity": "sha512-bwy0MGW55bG41VqxxypOsdSdGqLwXPI/focwgTYCFMbdUiBAxLg9CFzG08sz2aqzknwiX7Hkl0bQENjg8iLByw==",
"engines": {
"node": ">=8"
}
@@ -3082,18 +3326,6 @@
"node": "^14.15.0 || ^16.10.0 || >=18.0.0"
}
},
- "node_modules/dir-glob": {
- "version": "3.0.1",
- "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz",
- "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==",
- "dev": true,
- "dependencies": {
- "path-type": "^4.0.0"
- },
- "engines": {
- "node": ">=8"
- }
- },
"node_modules/dmd": {
"version": "6.2.0",
"resolved": "https://registry.npmjs.org/dmd/-/dmd-6.2.0.tgz",
@@ -3135,6 +3367,11 @@
"node": ">=6"
}
},
+ "node_modules/eastasianwidth": {
+ "version": "0.2.0",
+ "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz",
+ "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA=="
+ },
"node_modules/ee-first": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz",
@@ -3142,9 +3379,9 @@
"dev": true
},
"node_modules/electron-to-chromium": {
- "version": "1.4.454",
- "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.454.tgz",
- "integrity": "sha512-pmf1rbAStw8UEQ0sr2cdJtWl48ZMuPD9Sto8HVQOq9vx9j2WgDEN6lYoaqFvqEHYOmGA9oRGn7LqWI9ta0YugQ==",
+ "version": "1.5.13",
+ "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.13.tgz",
+ "integrity": "sha512-lbBcvtIJ4J6sS4tb5TLp1b4LyfCdMkwStzXPyAgVgTRAsep4bvrAGaBOP7ZJtQMNJpSQ9SqG4brWOroNaQtm7Q==",
"dev": true
},
"node_modules/emittery": {
@@ -3162,30 +3399,21 @@
"node_modules/emoji-regex": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
- "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==",
- "dev": true
+ "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A=="
},
"node_modules/encodeurl": {
- "version": "1.0.2",
- "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz",
- "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==",
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz",
+ "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==",
"dev": true,
"engines": {
"node": ">= 0.8"
}
},
- "node_modules/end-of-stream": {
- "version": "1.4.4",
- "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz",
- "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==",
- "dependencies": {
- "once": "^1.4.0"
- }
- },
"node_modules/enhanced-resolve": {
- "version": "5.13.0",
- "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.13.0.tgz",
- "integrity": "sha512-eyV8f0y1+bzyfh8xAwW/WTSZpLbjhqc4ne9eGSH4Zo2ejdyiNG9pU6mf9DG8a7+Auk6MFTlNOT4Y2y/9k8GKVg==",
+ "version": "5.17.1",
+ "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.17.1.tgz",
+ "integrity": "sha512-LMHl3dXhTcfv8gM4kEzIUeTQ+7fpdA0l2tUf34BddXPkz2A5xJ5L/Pchd5BL6rdccM9QGvu0sWZzK1Z1t4wwyg==",
"dev": true,
"dependencies": {
"graceful-fs": "^4.2.4",
@@ -3259,9 +3487,9 @@
"dev": true
},
"node_modules/escalade": {
- "version": "3.1.1",
- "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz",
- "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==",
+ "version": "3.1.2",
+ "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.2.tgz",
+ "integrity": "sha512-ErCHMCae19vR8vQGe50xIsVomy19rg6gFu3+r3jkEO46suLMWBksvVyoGgQV+jOfl84ZSOSlmv6Gxa89PmTGmA==",
"dev": true,
"engines": {
"node": ">=6"
@@ -3394,14 +3622,6 @@
"node": ">= 0.8.0"
}
},
- "node_modules/expand-template": {
- "version": "2.0.3",
- "resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz",
- "integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==",
- "engines": {
- "node": ">=6"
- }
- },
"node_modules/expect": {
"version": "29.6.1",
"resolved": "https://registry.npmjs.org/expect/-/expect-29.6.1.tgz",
@@ -3420,37 +3640,37 @@
}
},
"node_modules/express": {
- "version": "4.19.2",
- "resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz",
- "integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==",
+ "version": "4.21.1",
+ "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz",
+ "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==",
"dev": true,
"dependencies": {
"accepts": "~1.3.8",
"array-flatten": "1.1.1",
- "body-parser": "1.20.2",
+ "body-parser": "1.20.3",
"content-disposition": "0.5.4",
"content-type": "~1.0.4",
- "cookie": "0.6.0",
+ "cookie": "0.7.1",
"cookie-signature": "1.0.6",
"debug": "2.6.9",
"depd": "2.0.0",
- "encodeurl": "~1.0.2",
+ "encodeurl": "~2.0.0",
"escape-html": "~1.0.3",
"etag": "~1.8.1",
- "finalhandler": "1.2.0",
+ "finalhandler": "1.3.1",
"fresh": "0.5.2",
"http-errors": "2.0.0",
- "merge-descriptors": "1.0.1",
+ "merge-descriptors": "1.0.3",
"methods": "~1.1.2",
"on-finished": "2.4.1",
"parseurl": "~1.3.3",
- "path-to-regexp": "0.1.7",
+ "path-to-regexp": "0.1.10",
"proxy-addr": "~2.0.7",
- "qs": "6.11.0",
+ "qs": "6.13.0",
"range-parser": "~1.2.1",
"safe-buffer": "5.2.1",
- "send": "0.18.0",
- "serve-static": "1.15.0",
+ "send": "0.19.0",
+ "serve-static": "1.16.2",
"setprototypeof": "1.2.0",
"statuses": "2.0.1",
"type-is": "~1.6.18",
@@ -3473,27 +3693,6 @@
"integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==",
"dev": true
},
- "node_modules/fast-fifo": {
- "version": "1.3.2",
- "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz",
- "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ=="
- },
- "node_modules/fast-glob": {
- "version": "3.2.12",
- "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz",
- "integrity": "sha512-DVj4CQIYYow0BlaelwK1pHl5n5cRSJfM60UA0zK891sVInoPri2Ekj7+e1CT3/3qxXenpI+nBBmQAcJPJgaj4w==",
- "dev": true,
- "dependencies": {
- "@nodelib/fs.stat": "^2.0.2",
- "@nodelib/fs.walk": "^1.2.3",
- "glob-parent": "^5.1.2",
- "merge2": "^1.3.0",
- "micromatch": "^4.0.4"
- },
- "engines": {
- "node": ">=8.6.0"
- }
- },
"node_modules/fast-json-stable-stringify": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz",
@@ -3509,15 +3708,6 @@
"node": ">= 4.9.1"
}
},
- "node_modules/fastq": {
- "version": "1.15.0",
- "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz",
- "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==",
- "dev": true,
- "dependencies": {
- "reusify": "^1.0.4"
- }
- },
"node_modules/faye-websocket": {
"version": "0.11.4",
"resolved": "https://registry.npmjs.org/faye-websocket/-/faye-websocket-0.11.4.tgz",
@@ -3574,13 +3764,13 @@
}
},
"node_modules/finalhandler": {
- "version": "1.2.0",
- "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.2.0.tgz",
- "integrity": "sha512-5uXcUVftlQMFnWC9qu/svkWv3GTd2PfUhK/3PLkYNAe7FbqJMt3515HaxE6eRL74GdsriiwujiawdaB1BpEISg==",
+ "version": "1.3.1",
+ "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.3.1.tgz",
+ "integrity": "sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==",
"dev": true,
"dependencies": {
"debug": "2.6.9",
- "encodeurl": "~1.0.2",
+ "encodeurl": "~2.0.0",
"escape-html": "~1.0.3",
"on-finished": "2.4.1",
"parseurl": "~1.3.3",
@@ -3650,6 +3840,32 @@
}
}
},
+ "node_modules/foreground-child": {
+ "version": "3.1.1",
+ "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.1.1.tgz",
+ "integrity": "sha512-TMKDUnIte6bfb5nWv7V/caI169OHgvwjb7V4WkeUvbQQdjr5rWKqHFiKWb/fcOwB+CzBT+qbWjvj+DVwRskpIg==",
+ "dependencies": {
+ "cross-spawn": "^7.0.0",
+ "signal-exit": "^4.0.1"
+ },
+ "engines": {
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/foreground-child/node_modules/signal-exit": {
+ "version": "4.1.0",
+ "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz",
+ "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==",
+ "engines": {
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
"node_modules/forwarded": {
"version": "0.2.0",
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
@@ -3668,11 +3884,6 @@
"node": ">= 0.6"
}
},
- "node_modules/fs-constants": {
- "version": "1.0.0",
- "resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz",
- "integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow=="
- },
"node_modules/fs-monkey": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.3.tgz",
@@ -3695,9 +3906,9 @@
"dev": true
},
"node_modules/fsevents": {
- "version": "2.3.2",
- "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
- "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==",
+ "version": "2.3.3",
+ "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
+ "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
"dev": true,
"hasInstallScript": true,
"optional": true,
@@ -3775,11 +3986,6 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
- "node_modules/github-from-package": {
- "version": "0.0.0",
- "resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz",
- "integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw=="
- },
"node_modules/glob": {
"version": "7.2.3",
"resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz",
@@ -3827,25 +4033,6 @@
"node": ">=4"
}
},
- "node_modules/globby": {
- "version": "13.1.4",
- "resolved": "https://registry.npmjs.org/globby/-/globby-13.1.4.tgz",
- "integrity": "sha512-iui/IiiW+QrJ1X1hKH5qwlMQyv34wJAYwH1vrf8b9kBA4sNiif3gKsMHa+BrdnOpEudWjpotfa7LrTzB1ERS/g==",
- "dev": true,
- "dependencies": {
- "dir-glob": "^3.0.1",
- "fast-glob": "^3.2.11",
- "ignore": "^5.2.0",
- "merge2": "^1.4.1",
- "slash": "^4.0.0"
- },
- "engines": {
- "node": "^12.20.0 || ^14.13.1 || >=16.0.0"
- },
- "funding": {
- "url": "https://github.com/sponsors/sindresorhus"
- }
- },
"node_modules/gopd": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz",
@@ -3859,9 +4046,9 @@
}
},
"node_modules/graceful-fs": {
- "version": "4.2.10",
- "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.10.tgz",
- "integrity": "sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==",
+ "version": "4.2.11",
+ "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz",
+ "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==",
"dev": true
},
"node_modules/guid-typescript": {
@@ -4106,34 +4293,6 @@
"node": ">=0.10.0"
}
},
- "node_modules/ieee754": {
- "version": "1.2.1",
- "resolved": "https://registry.npmjs.org/ieee754/-/ieee754-1.2.1.tgz",
- "integrity": "sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
- },
- "node_modules/ignore": {
- "version": "5.2.4",
- "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz",
- "integrity": "sha512-MAb38BcSbH0eHNBxn7ql2NH/kX33OkB3lZ1BNdh7ENeRChHTYsTvWrMubiIAMNS2llXEEgZ1MUOBtXChP3kaFQ==",
- "dev": true,
- "engines": {
- "node": ">= 4"
- }
- },
"node_modules/import-local": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/import-local/-/import-local-3.1.0.tgz",
@@ -4175,12 +4334,8 @@
"node_modules/inherits": {
"version": "2.0.4",
"resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz",
- "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
- },
- "node_modules/ini": {
- "version": "1.3.8",
- "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
- "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
+ "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==",
+ "dev": true
},
"node_modules/interpret": {
"version": "3.1.1",
@@ -4257,7 +4412,6 @@
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz",
"integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==",
- "dev": true,
"engines": {
"node": ">=8"
}
@@ -4349,8 +4503,7 @@
"node_modules/isexe": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz",
- "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==",
- "dev": true
+ "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw=="
},
"node_modules/isobject": {
"version": "3.0.1",
@@ -4386,6 +4539,15 @@
"node": ">=8"
}
},
+ "node_modules/istanbul-lib-instrument/node_modules/semver": {
+ "version": "6.3.1",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz",
+ "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==",
+ "dev": true,
+ "bin": {
+ "semver": "bin/semver.js"
+ }
+ },
"node_modules/istanbul-lib-report": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/istanbul-lib-report/-/istanbul-lib-report-3.0.0.tgz",
@@ -4462,6 +4624,23 @@
"node": ">=8"
}
},
+ "node_modules/jackspeak": {
+ "version": "3.1.2",
+ "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.1.2.tgz",
+ "integrity": "sha512-kWmLKn2tRtfYMF/BakihVVRzBKOxz4gJMiL2Rj91WnAB5TPZumSH99R/Yf1qE1u4uRimvCSJfm6hnxohXeEXjQ==",
+ "dependencies": {
+ "@isaacs/cliui": "^8.0.2"
+ },
+ "engines": {
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ },
+ "optionalDependencies": {
+ "@pkgjs/parseargs": "^0.11.0"
+ }
+ },
"node_modules/jest": {
"version": "29.6.1",
"resolved": "https://registry.npmjs.org/jest/-/jest-29.6.1.tgz",
@@ -5422,20 +5601,9 @@
"dev": true
},
"node_modules/long": {
- "version": "4.0.0",
- "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz",
- "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA=="
- },
- "node_modules/lru-cache": {
- "version": "6.0.0",
- "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz",
- "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==",
- "dependencies": {
- "yallist": "^4.0.0"
- },
- "engines": {
- "node": ">=10"
- }
+ "version": "5.2.3",
+ "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz",
+ "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q=="
},
"node_modules/make-dir": {
"version": "3.1.0",
@@ -5452,6 +5620,15 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
+ "node_modules/make-dir/node_modules/semver": {
+ "version": "6.3.1",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz",
+ "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==",
+ "dev": true,
+ "bin": {
+ "semver": "bin/semver.js"
+ }
+ },
"node_modules/makeerror": {
"version": "1.0.12",
"resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz",
@@ -5527,10 +5704,13 @@
}
},
"node_modules/merge-descriptors": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz",
- "integrity": "sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==",
- "dev": true
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.3.tgz",
+ "integrity": "sha512-gaNvAS7TZ897/rVaZ0nMtAyxNyi/pdbjbAwUpFQpN70GqnVfOiXpeUUMKRBmzXaSQ8DdTX4/0ms62r2K+hE6mQ==",
+ "dev": true,
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
},
"node_modules/merge-stream": {
"version": "2.0.0",
@@ -5538,15 +5718,6 @@
"integrity": "sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==",
"dev": true
},
- "node_modules/merge2": {
- "version": "1.4.1",
- "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz",
- "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==",
- "dev": true,
- "engines": {
- "node": ">= 8"
- }
- },
"node_modules/methods": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz",
@@ -5557,12 +5728,12 @@
}
},
"node_modules/micromatch": {
- "version": "4.0.5",
- "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz",
- "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==",
+ "version": "4.0.8",
+ "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz",
+ "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==",
"dev": true,
"dependencies": {
- "braces": "^3.0.2",
+ "braces": "^3.0.3",
"picomatch": "^2.3.1"
},
"engines": {
@@ -5611,17 +5782,6 @@
"node": ">=6"
}
},
- "node_modules/mimic-response": {
- "version": "3.1.0",
- "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz",
- "integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==",
- "engines": {
- "node": ">=10"
- },
- "funding": {
- "url": "https://github.com/sponsors/sindresorhus"
- }
- },
"node_modules/minimalistic-assert": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz",
@@ -5644,10 +5804,91 @@
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz",
"integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==",
+ "dev": true,
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/minipass": {
+ "version": "7.1.2",
+ "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz",
+ "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==",
+ "engines": {
+ "node": ">=16 || 14 >=14.17"
+ }
+ },
+ "node_modules/minizlib": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.0.1.tgz",
+ "integrity": "sha512-umcy022ILvb5/3Djuu8LWeqUa8D68JaBzlttKeMWen48SjabqS3iY5w/vzeMzMUNhLDifyhbOwKDSznB1vvrwg==",
+ "dependencies": {
+ "minipass": "^7.0.4",
+ "rimraf": "^5.0.5"
+ },
+ "engines": {
+ "node": ">= 18"
+ }
+ },
+ "node_modules/minizlib/node_modules/brace-expansion": {
+ "version": "2.0.1",
+ "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz",
+ "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==",
+ "dependencies": {
+ "balanced-match": "^1.0.0"
+ }
+ },
+ "node_modules/minizlib/node_modules/glob": {
+ "version": "10.4.1",
+ "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.1.tgz",
+ "integrity": "sha512-2jelhlq3E4ho74ZyVLN03oKdAZVUa6UDZzFLVH1H7dnoax+y9qyaq8zBkfDIggjniU19z0wU18y16jMB2eyVIw==",
+ "dependencies": {
+ "foreground-child": "^3.1.0",
+ "jackspeak": "^3.1.2",
+ "minimatch": "^9.0.4",
+ "minipass": "^7.1.2",
+ "path-scurry": "^1.11.1"
+ },
+ "bin": {
+ "glob": "dist/esm/bin.mjs"
+ },
+ "engines": {
+ "node": ">=16 || 14 >=14.18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/minizlib/node_modules/minimatch": {
+ "version": "9.0.4",
+ "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.4.tgz",
+ "integrity": "sha512-KqWh+VchfxcMNRAJjj2tnsSJdNbHsVgnkBhTNrW7AjVo6OvLtxw8zfT9oLw1JSohlFzJ8jCoTgaoXvJ+kHt6fw==",
+ "dependencies": {
+ "brace-expansion": "^2.0.1"
+ },
+ "engines": {
+ "node": ">=16 || 14 >=14.17"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/minizlib/node_modules/rimraf": {
+ "version": "5.0.7",
+ "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-5.0.7.tgz",
+ "integrity": "sha512-nV6YcJo5wbLW77m+8KjH8aB/7/rxQy9SZ0HY5shnwULfS+9nmTtVXAJET5NdZmCzA4fPI/Hm1wo/Po/4mopOdg==",
+ "dependencies": {
+ "glob": "^10.3.7"
+ },
+ "bin": {
+ "rimraf": "dist/esm/bin.mjs"
+ },
+ "engines": {
+ "node": ">=14.18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
"node_modules/mkdirp": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz",
@@ -5660,11 +5901,6 @@
"node": ">=10"
}
},
- "node_modules/mkdirp-classic": {
- "version": "0.5.3",
- "resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz",
- "integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A=="
- },
"node_modules/mkdirp2": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/mkdirp2/-/mkdirp2-1.0.5.tgz",
@@ -5690,11 +5926,6 @@
"multicast-dns": "cli.js"
}
},
- "node_modules/napi-build-utils": {
- "version": "1.0.2",
- "resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-1.0.2.tgz",
- "integrity": "sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg=="
- },
"node_modules/natural-compare": {
"version": "1.4.0",
"resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz",
@@ -5711,26 +5942,10 @@
}
},
"node_modules/neo-async": {
- "version": "2.6.2",
- "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz",
- "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==",
- "dev": true
- },
- "node_modules/node-abi": {
- "version": "3.35.0",
- "resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.35.0.tgz",
- "integrity": "sha512-jAlSOFR1Bls963NmFwxeQkNTzqjUF0NThm8Le7eRIRGzFUVJuMOFZDLv5Y30W/Oaw+KEebEJLAigwO9gQHoEmw==",
- "dependencies": {
- "semver": "^7.3.5"
- },
- "engines": {
- "node": ">=10"
- }
- },
- "node_modules/node-addon-api": {
- "version": "6.1.0",
- "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz",
- "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA=="
+ "version": "2.6.2",
+ "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz",
+ "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==",
+ "dev": true
},
"node_modules/node-forge": {
"version": "1.3.1",
@@ -5748,9 +5963,9 @@
"dev": true
},
"node_modules/node-releases": {
- "version": "2.0.13",
- "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz",
- "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==",
+ "version": "2.0.18",
+ "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.18.tgz",
+ "integrity": "sha512-d9VeXT4SJ7ZeOqGX6R5EM022wpL+eWPooLI+5UpWn2jCT1aosUQEhQP214x33Wkwx3JQMvIm+tIoVOdodFS40g==",
"dev": true
},
"node_modules/normalize-path": {
@@ -5781,10 +5996,13 @@
"dev": true
},
"node_modules/object-inspect": {
- "version": "1.13.1",
- "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz",
- "integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==",
+ "version": "1.13.2",
+ "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz",
+ "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==",
"dev": true,
+ "engines": {
+ "node": ">= 0.4"
+ },
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
@@ -5829,6 +6047,7 @@
"version": "1.4.0",
"resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz",
"integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==",
+ "dev": true,
"dependencies": {
"wrappy": "1"
}
@@ -5848,46 +6067,44 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
- "node_modules/onnx-proto": {
- "version": "4.0.4",
- "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-4.0.4.tgz",
- "integrity": "sha512-aldMOB3HRoo6q/phyB6QRQxSt895HNNw82BNyZ2CMh4bjeKv7g/c+VpAFtJuEMVfYLMbRx61hbuqnKceLeDcDA==",
- "dependencies": {
- "protobufjs": "^6.8.8"
- }
- },
"node_modules/onnxruntime-common": {
- "version": "1.14.0",
- "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz",
- "integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew=="
+ "version": "1.19.2",
+ "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.19.2.tgz",
+ "integrity": "sha512-a4R7wYEVFbZBlp0BfhpbFWqe4opCor3KM+5Wm22Az3NGDcQMiU2hfG/0MfnBs+1ZrlSGmlgWeMcXQkDk1UFb8Q=="
},
"node_modules/onnxruntime-node": {
- "version": "1.14.0",
- "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz",
- "integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==",
- "optional": true,
+ "version": "1.19.2",
+ "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.19.2.tgz",
+ "integrity": "sha512-9eHMP/HKbbeUcqte1JYzaaRC8JPn7ojWeCeoyShO86TOR97OCyIyAIOGX3V95ErjslVhJRXY8Em/caIUc0hm1Q==",
+ "hasInstallScript": true,
"os": [
"win32",
"darwin",
"linux"
],
"dependencies": {
- "onnxruntime-common": "~1.14.0"
+ "onnxruntime-common": "1.19.2",
+ "tar": "^7.0.1"
}
},
"node_modules/onnxruntime-web": {
- "version": "1.14.0",
- "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz",
- "integrity": "sha512-Kcqf43UMfW8mCydVGcX9OMXI2VN17c0p6XvR7IPSZzBf/6lteBzXHvcEVWDPmCKuGombl997HgLqj91F11DzXw==",
+ "version": "1.20.0-dev.20241016-2b8fc5529b",
+ "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.20.0-dev.20241016-2b8fc5529b.tgz",
+ "integrity": "sha512-1XovqtgqeEFtupuyzdDQo7Tqj4GRyNHzOoXjapCEo4rfH3JrXok5VtqucWfRXHPsOI5qoNxMQ9VE+drDIp6woQ==",
"dependencies": {
"flatbuffers": "^1.12.0",
"guid-typescript": "^1.0.9",
- "long": "^4.0.0",
- "onnx-proto": "^4.0.4",
- "onnxruntime-common": "~1.14.0",
- "platform": "^1.3.6"
+ "long": "^5.2.3",
+ "onnxruntime-common": "1.20.0-dev.20241016-2b8fc5529b",
+ "platform": "^1.3.6",
+ "protobufjs": "^7.2.4"
}
},
+ "node_modules/onnxruntime-web/node_modules/onnxruntime-common": {
+ "version": "1.20.0-dev.20241016-2b8fc5529b",
+ "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.20.0-dev.20241016-2b8fc5529b.tgz",
+ "integrity": "sha512-KZK8b6zCYGZFjd4ANze0pqBnqnFTS3GIVeclQpa2qseDpXrCQJfkWBixRcrZShNhm3LpFOZ8qJYFC5/qsJK9WQ=="
+ },
"node_modules/open": {
"version": "8.4.2",
"resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz",
@@ -6003,7 +6220,6 @@
"version": "3.1.1",
"resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz",
"integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==",
- "dev": true,
"engines": {
"node": ">=8"
}
@@ -6014,25 +6230,39 @@
"integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==",
"dev": true
},
- "node_modules/path-to-regexp": {
- "version": "0.1.7",
- "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz",
- "integrity": "sha512-5DFkuoqlv1uYQKxy8omFBeJPQcdoE07Kv2sferDCrAq1ohOU+MSDswDIbnx3YAM60qIOnYa53wBhXW0EbMonrQ==",
- "dev": true
+ "node_modules/path-scurry": {
+ "version": "1.11.1",
+ "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz",
+ "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==",
+ "dependencies": {
+ "lru-cache": "^10.2.0",
+ "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0"
+ },
+ "engines": {
+ "node": ">=16 || 14 >=14.18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
},
- "node_modules/path-type": {
- "version": "4.0.0",
- "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz",
- "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==",
- "dev": true,
+ "node_modules/path-scurry/node_modules/lru-cache": {
+ "version": "10.2.2",
+ "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.2.2.tgz",
+ "integrity": "sha512-9hp3Vp2/hFQUiIwKo8XCeFVnrg8Pk3TYNPIR7tJADKi5YfcF7vEaK7avFHTlSy3kOKYaJQaalfEo6YuXdceBOQ==",
"engines": {
- "node": ">=8"
+ "node": "14 || >=16.14"
}
},
+ "node_modules/path-to-regexp": {
+ "version": "0.1.10",
+ "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz",
+ "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==",
+ "dev": true
+ },
"node_modules/picocolors": {
- "version": "1.0.0",
- "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz",
- "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==",
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.1.tgz",
+ "integrity": "sha512-anP1Z8qwhkbmu7MFP5iTt+wQKXgwzf7zTyGlcdzabySa9vd0Xt392U0rVmz9poOaBj0uHJKyyo9/upk0HrEQew==",
"dev": true
},
"node_modules/picomatch": {
@@ -6073,29 +6303,19 @@
"resolved": "https://registry.npmjs.org/platform/-/platform-1.3.6.tgz",
"integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg=="
},
- "node_modules/prebuild-install": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz",
- "integrity": "sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==",
- "dependencies": {
- "detect-libc": "^2.0.0",
- "expand-template": "^2.0.3",
- "github-from-package": "0.0.0",
- "minimist": "^1.2.3",
- "mkdirp-classic": "^0.5.3",
- "napi-build-utils": "^1.0.1",
- "node-abi": "^3.3.0",
- "pump": "^3.0.0",
- "rc": "^1.2.7",
- "simple-get": "^4.0.0",
- "tar-fs": "^2.0.0",
- "tunnel-agent": "^0.6.0"
- },
+ "node_modules/prettier": {
+ "version": "3.3.3",
+ "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz",
+ "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==",
+ "dev": true,
"bin": {
- "prebuild-install": "bin.js"
+ "prettier": "bin/prettier.cjs"
},
"engines": {
- "node": ">=10"
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/prettier/prettier?sponsor=1"
}
},
"node_modules/pretty-format": {
@@ -6166,11 +6386,6 @@
"node": ">=12.0.0"
}
},
- "node_modules/protobufjs/node_modules/long": {
- "version": "5.2.3",
- "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz",
- "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q=="
- },
"node_modules/proxy-addr": {
"version": "2.0.7",
"resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz",
@@ -6193,15 +6408,6 @@
"node": ">= 0.10"
}
},
- "node_modules/pump": {
- "version": "3.0.0",
- "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz",
- "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==",
- "dependencies": {
- "end-of-stream": "^1.1.0",
- "once": "^1.3.1"
- }
- },
"node_modules/punycode": {
"version": "2.3.0",
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz",
@@ -6228,12 +6434,12 @@
]
},
"node_modules/qs": {
- "version": "6.11.0",
- "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz",
- "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==",
+ "version": "6.13.0",
+ "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz",
+ "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==",
"dev": true,
"dependencies": {
- "side-channel": "^1.0.4"
+ "side-channel": "^1.0.6"
},
"engines": {
"node": ">=0.6"
@@ -6242,31 +6448,6 @@
"url": "https://github.com/sponsors/ljharb"
}
},
- "node_modules/queue-microtask": {
- "version": "1.2.3",
- "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
- "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==",
- "dev": true,
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
- },
- "node_modules/queue-tick": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz",
- "integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag=="
- },
"node_modules/randombytes": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/randombytes/-/randombytes-2.1.0.tgz",
@@ -6309,20 +6490,6 @@
"node": ">= 0.8"
}
},
- "node_modules/rc": {
- "version": "1.2.8",
- "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz",
- "integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==",
- "dependencies": {
- "deep-extend": "^0.6.0",
- "ini": "~1.3.0",
- "minimist": "^1.2.0",
- "strip-json-comments": "~2.0.1"
- },
- "bin": {
- "rc": "cli.js"
- }
- },
"node_modules/react-is": {
"version": "18.2.0",
"resolved": "https://registry.npmjs.org/react-is/-/react-is-18.2.0.tgz",
@@ -6333,6 +6500,7 @@
"version": "3.6.1",
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz",
"integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==",
+ "dev": true,
"dependencies": {
"inherits": "^2.0.3",
"string_decoder": "^1.1.1",
@@ -6547,16 +6715,6 @@
"node": ">= 4"
}
},
- "node_modules/reusify": {
- "version": "1.0.4",
- "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz",
- "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==",
- "dev": true,
- "engines": {
- "iojs": ">=1.0.0",
- "node": ">=0.10.0"
- }
- },
"node_modules/rimraf": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz",
@@ -6572,33 +6730,11 @@
"url": "https://github.com/sponsors/isaacs"
}
},
- "node_modules/run-parallel": {
- "version": "1.2.0",
- "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz",
- "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==",
- "dev": true,
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ],
- "dependencies": {
- "queue-microtask": "^1.2.2"
- }
- },
"node_modules/safe-buffer": {
"version": "5.2.1",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",
"integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==",
+ "dev": true,
"funding": [
{
"type": "github",
@@ -6621,9 +6757,9 @@
"dev": true
},
"node_modules/schema-utils": {
- "version": "3.1.2",
- "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.1.2.tgz",
- "integrity": "sha512-pvjEHOgWc9OWA/f/DE3ohBWTD6EleVLf7iFUkoSwAxttdBhB9QUebQgxER2kWueOvRJXPHNnyrvvh9eZINB8Eg==",
+ "version": "3.3.0",
+ "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz",
+ "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==",
"dev": true,
"dependencies": {
"@types/json-schema": "^7.0.8",
@@ -6657,12 +6793,10 @@
}
},
"node_modules/semver": {
- "version": "7.5.4",
- "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
- "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
- "dependencies": {
- "lru-cache": "^6.0.0"
- },
+ "version": "7.6.3",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-7.6.3.tgz",
+ "integrity": "sha512-oVekP1cKtI+CTDvHWYFUcMtsK/00wmAEfyqKfNdARm8u1wNVhSgaX7A8d4UuIlUI5e84iEwOhs7ZPYRmzU9U6A==",
+ "license": "ISC",
"bin": {
"semver": "bin/semver.js"
},
@@ -6671,9 +6805,9 @@
}
},
"node_modules/send": {
- "version": "0.18.0",
- "resolved": "https://registry.npmjs.org/send/-/send-0.18.0.tgz",
- "integrity": "sha512-qqWzuOjSFOuqPjFe4NOsMLafToQQwBSOEpS+FwEt3A2V3vKubTquT3vmLTQpFgMXp8AlFWFuP1qKaJZOtPpVXg==",
+ "version": "0.19.0",
+ "resolved": "https://registry.npmjs.org/send/-/send-0.19.0.tgz",
+ "integrity": "sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==",
"dev": true,
"dependencies": {
"debug": "2.6.9",
@@ -6694,6 +6828,15 @@
"node": ">= 0.8.0"
}
},
+ "node_modules/send/node_modules/encodeurl": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz",
+ "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==",
+ "dev": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
"node_modules/send/node_modules/ms": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz",
@@ -6701,9 +6844,9 @@
"dev": true
},
"node_modules/serialize-javascript": {
- "version": "6.0.1",
- "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.1.tgz",
- "integrity": "sha512-owoXEFjWRllis8/M1Q+Cw5k8ZH40e3zhp/ovX+Xr/vi1qj6QesbyXXViFbpNvWvPNAD62SutwEXavefrLJWj7w==",
+ "version": "6.0.2",
+ "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz",
+ "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==",
"dev": true,
"dependencies": {
"randombytes": "^2.1.0"
@@ -6773,15 +6916,15 @@
}
},
"node_modules/serve-static": {
- "version": "1.15.0",
- "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.15.0.tgz",
- "integrity": "sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==",
+ "version": "1.16.2",
+ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.16.2.tgz",
+ "integrity": "sha512-VqpjJZKadQB/PEbEwvFdO43Ax5dFBZ2UECszz8bQ7pi7wt//PWe1P6MN7eCnjsatYtBT6EuiClbjSWP2WrIoTw==",
"dev": true,
"dependencies": {
- "encodeurl": "~1.0.2",
+ "encodeurl": "~2.0.0",
"escape-html": "~1.0.3",
"parseurl": "~1.3.3",
- "send": "0.18.0"
+ "send": "0.19.0"
},
"engines": {
"node": ">= 0.8.0"
@@ -6823,52 +6966,47 @@
}
},
"node_modules/sharp": {
- "version": "0.32.6",
- "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz",
- "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==",
+ "version": "0.33.5",
+ "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.33.5.tgz",
+ "integrity": "sha512-haPVm1EkS9pgvHrQ/F3Xy+hgcuMV0Wm9vfIBSiwZ05k+xgb0PkBQpGsAA/oWdDobNaZTH5ppvHtzCFbnSEwHVw==",
"hasInstallScript": true,
"dependencies": {
"color": "^4.2.3",
- "detect-libc": "^2.0.2",
- "node-addon-api": "^6.1.0",
- "prebuild-install": "^7.1.1",
- "semver": "^7.5.4",
- "simple-get": "^4.0.1",
- "tar-fs": "^3.0.4",
- "tunnel-agent": "^0.6.0"
+ "detect-libc": "^2.0.3",
+ "semver": "^7.6.3"
},
"engines": {
- "node": ">=14.15.0"
+ "node": "^18.17.0 || ^20.3.0 || >=21.0.0"
},
"funding": {
"url": "https://opencollective.com/libvips"
- }
- },
- "node_modules/sharp/node_modules/tar-fs": {
- "version": "3.0.4",
- "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz",
- "integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==",
- "dependencies": {
- "mkdirp-classic": "^0.5.2",
- "pump": "^3.0.0",
- "tar-stream": "^3.1.5"
- }
- },
- "node_modules/sharp/node_modules/tar-stream": {
- "version": "3.1.6",
- "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz",
- "integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==",
- "dependencies": {
- "b4a": "^1.6.4",
- "fast-fifo": "^1.2.0",
- "streamx": "^2.15.0"
+ },
+ "optionalDependencies": {
+ "@img/sharp-darwin-arm64": "0.33.5",
+ "@img/sharp-darwin-x64": "0.33.5",
+ "@img/sharp-libvips-darwin-arm64": "1.0.4",
+ "@img/sharp-libvips-darwin-x64": "1.0.4",
+ "@img/sharp-libvips-linux-arm": "1.0.5",
+ "@img/sharp-libvips-linux-arm64": "1.0.4",
+ "@img/sharp-libvips-linux-s390x": "1.0.4",
+ "@img/sharp-libvips-linux-x64": "1.0.4",
+ "@img/sharp-libvips-linuxmusl-arm64": "1.0.4",
+ "@img/sharp-libvips-linuxmusl-x64": "1.0.4",
+ "@img/sharp-linux-arm": "0.33.5",
+ "@img/sharp-linux-arm64": "0.33.5",
+ "@img/sharp-linux-s390x": "0.33.5",
+ "@img/sharp-linux-x64": "0.33.5",
+ "@img/sharp-linuxmusl-arm64": "0.33.5",
+ "@img/sharp-linuxmusl-x64": "0.33.5",
+ "@img/sharp-wasm32": "0.33.5",
+ "@img/sharp-win32-ia32": "0.33.5",
+ "@img/sharp-win32-x64": "0.33.5"
}
},
"node_modules/shebang-command": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz",
"integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==",
- "dev": true,
"dependencies": {
"shebang-regex": "^3.0.0"
},
@@ -6880,7 +7018,6 @@
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz",
"integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==",
- "dev": true,
"engines": {
"node": ">=8"
}
@@ -6918,49 +7055,6 @@
"integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==",
"dev": true
},
- "node_modules/simple-concat": {
- "version": "1.0.1",
- "resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz",
- "integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ]
- },
- "node_modules/simple-get": {
- "version": "4.0.1",
- "resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz",
- "integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==",
- "funding": [
- {
- "type": "github",
- "url": "https://github.com/sponsors/feross"
- },
- {
- "type": "patreon",
- "url": "https://www.patreon.com/feross"
- },
- {
- "type": "consulting",
- "url": "https://feross.org/support"
- }
- ],
- "dependencies": {
- "decompress-response": "^6.0.0",
- "once": "^1.3.1",
- "simple-concat": "^1.0.0"
- }
- },
"node_modules/simple-swizzle": {
"version": "0.2.2",
"resolved": "https://registry.npmjs.org/simple-swizzle/-/simple-swizzle-0.2.2.tgz",
@@ -6975,18 +7069,6 @@
"integrity": "sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==",
"dev": true
},
- "node_modules/slash": {
- "version": "4.0.0",
- "resolved": "https://registry.npmjs.org/slash/-/slash-4.0.0.tgz",
- "integrity": "sha512-3dOsAHXXUkQTpOYcoAxLIorMTp4gIQr5IW3iVb7A7lFIp0VHhnynm9izx6TssdrIcVIESAlVjtnO2K8bg+Coew==",
- "dev": true,
- "engines": {
- "node": ">=12"
- },
- "funding": {
- "url": "https://github.com/sponsors/sindresorhus"
- }
- },
"node_modules/sockjs": {
"version": "0.3.24",
"resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz",
@@ -7184,19 +7266,11 @@
"node": ">=0.10.0"
}
},
- "node_modules/streamx": {
- "version": "2.15.5",
- "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.5.tgz",
- "integrity": "sha512-9thPGMkKC2GctCzyCUjME3yR03x2xNo0GPKGkRw2UMYN+gqWa9uqpyNWhmsNCutU5zHmkUum0LsCRQTXUgUCAg==",
- "dependencies": {
- "fast-fifo": "^1.1.0",
- "queue-tick": "^1.0.1"
- }
- },
"node_modules/string_decoder": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz",
"integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==",
+ "dev": true,
"dependencies": {
"safe-buffer": "~5.2.0"
}
@@ -7218,7 +7292,20 @@
"version": "4.2.3",
"resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
"integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
- "dev": true,
+ "dependencies": {
+ "emoji-regex": "^8.0.0",
+ "is-fullwidth-code-point": "^3.0.0",
+ "strip-ansi": "^6.0.1"
+ },
+ "engines": {
+ "node": ">=8"
+ }
+ },
+ "node_modules/string-width-cjs": {
+ "name": "string-width",
+ "version": "4.2.3",
+ "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
+ "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
"dependencies": {
"emoji-regex": "^8.0.0",
"is-fullwidth-code-point": "^3.0.0",
@@ -7232,7 +7319,18 @@
"version": "6.0.1",
"resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
"integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
- "dev": true,
+ "dependencies": {
+ "ansi-regex": "^5.0.1"
+ },
+ "engines": {
+ "node": ">=8"
+ }
+ },
+ "node_modules/strip-ansi-cjs": {
+ "name": "strip-ansi",
+ "version": "6.0.1",
+ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
+ "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
"dependencies": {
"ansi-regex": "^5.0.1"
},
@@ -7258,14 +7356,6 @@
"node": ">=6"
}
},
- "node_modules/strip-json-comments": {
- "version": "2.0.1",
- "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz",
- "integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==",
- "engines": {
- "node": ">=0.10.0"
- }
- },
"node_modules/supports-color": {
"version": "8.1.1",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz",
@@ -7330,30 +7420,42 @@
"node": ">=6"
}
},
- "node_modules/tar-fs": {
- "version": "2.1.1",
- "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz",
- "integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==",
+ "node_modules/tar": {
+ "version": "7.2.0",
+ "resolved": "https://registry.npmjs.org/tar/-/tar-7.2.0.tgz",
+ "integrity": "sha512-hctwP0Nb4AB60bj8WQgRYaMOuJYRAPMGiQUAotms5igN8ppfQM+IvjQ5HcKu1MaZh2Wy2KWVTe563Yj8dfc14w==",
"dependencies": {
- "chownr": "^1.1.1",
- "mkdirp-classic": "^0.5.2",
- "pump": "^3.0.0",
- "tar-stream": "^2.1.4"
+ "@isaacs/fs-minipass": "^4.0.0",
+ "chownr": "^3.0.0",
+ "minipass": "^7.1.0",
+ "minizlib": "^3.0.1",
+ "mkdirp": "^3.0.1",
+ "yallist": "^5.0.0"
+ },
+ "engines": {
+ "node": ">=18"
}
},
- "node_modules/tar-stream": {
- "version": "2.2.0",
- "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz",
- "integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==",
- "dependencies": {
- "bl": "^4.0.3",
- "end-of-stream": "^1.4.1",
- "fs-constants": "^1.0.0",
- "inherits": "^2.0.3",
- "readable-stream": "^3.1.1"
+ "node_modules/tar/node_modules/mkdirp": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-3.0.1.tgz",
+ "integrity": "sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==",
+ "bin": {
+ "mkdirp": "dist/cjs/src/bin.js"
},
"engines": {
- "node": ">=6"
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/tar/node_modules/yallist": {
+ "version": "5.0.0",
+ "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz",
+ "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==",
+ "engines": {
+ "node": ">=18"
}
},
"node_modules/temp-path": {
@@ -7363,13 +7465,13 @@
"dev": true
},
"node_modules/terser": {
- "version": "5.17.1",
- "resolved": "https://registry.npmjs.org/terser/-/terser-5.17.1.tgz",
- "integrity": "sha512-hVl35zClmpisy6oaoKALOpS0rDYLxRFLHhRuDlEGTKey9qHjS1w9GMORjuwIMt70Wan4lwsLYyWDVnWgF+KUEw==",
+ "version": "5.31.6",
+ "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz",
+ "integrity": "sha512-PQ4DAriWzKj+qgehQ7LK5bQqCFNMmlhjR2PFFLuqGCpuCAauxemVBWwWOxo3UIwWQx8+Pr61Df++r76wDmkQBg==",
"dev": true,
"dependencies": {
- "@jridgewell/source-map": "^0.3.2",
- "acorn": "^8.5.0",
+ "@jridgewell/source-map": "^0.3.3",
+ "acorn": "^8.8.2",
"commander": "^2.20.0",
"source-map-support": "~0.5.20"
},
@@ -7381,16 +7483,16 @@
}
},
"node_modules/terser-webpack-plugin": {
- "version": "5.3.7",
- "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.7.tgz",
- "integrity": "sha512-AfKwIktyP7Cu50xNjXF/6Qb5lBNzYaWpU6YfoX3uZicTx0zTy0stDDCsvjDapKsSDvOeWo5MEq4TmdBy2cNoHw==",
+ "version": "5.3.10",
+ "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.10.tgz",
+ "integrity": "sha512-BKFPWlPDndPs+NGGCr1U59t0XScL5317Y0UReNrHaw9/FwhPENlq6bfgs+4yPfyP51vqC1bQ4rp1EfXW5ZSH9w==",
"dev": true,
"dependencies": {
- "@jridgewell/trace-mapping": "^0.3.17",
+ "@jridgewell/trace-mapping": "^0.3.20",
"jest-worker": "^27.4.5",
"schema-utils": "^3.1.1",
"serialize-javascript": "^6.0.1",
- "terser": "^5.16.5"
+ "terser": "^5.26.0"
},
"engines": {
"node": ">= 10.13.0"
@@ -7495,16 +7597,11 @@
"node": ">=0.6"
}
},
- "node_modules/tunnel-agent": {
- "version": "0.6.0",
- "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz",
- "integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==",
- "dependencies": {
- "safe-buffer": "^5.0.1"
- },
- "engines": {
- "node": "*"
- }
+ "node_modules/tslib": {
+ "version": "2.6.3",
+ "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.3.tgz",
+ "integrity": "sha512-xNvxJEOUiWPGhUuUdQgAJPKOOJfGnIyKySOc09XkKsgdUV/3E2zvwZYdejjmRgPCgcym1juLH3226yA7sEFJKQ==",
+ "optional": true
},
"node_modules/type-detect": {
"version": "4.0.8",
@@ -7594,9 +7691,9 @@
}
},
"node_modules/update-browserslist-db": {
- "version": "1.0.11",
- "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.11.tgz",
- "integrity": "sha512-dCwEFf0/oT85M1fHBg4F0jtLwJrutGoHSQXCh7u4o2t1drG+c0a9Flnqww6XUKSfQMPpJBRjU8d4RXB09qtvaA==",
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.1.0.tgz",
+ "integrity": "sha512-EdRAaAyk2cUE1wOf2DkEhzxqOQvFOoRJFNS6NeyJ01Gp2beMRpBAINjM2iDXE3KCuKhwnvHIQCJm6ThL2Z+HzQ==",
"dev": true,
"funding": [
{
@@ -7613,8 +7710,8 @@
}
],
"dependencies": {
- "escalade": "^3.1.1",
- "picocolors": "^1.0.0"
+ "escalade": "^3.1.2",
+ "picocolors": "^1.0.1"
},
"bin": {
"update-browserslist-db": "cli.js"
@@ -7635,7 +7732,8 @@
"node_modules/util-deprecate": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz",
- "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw=="
+ "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==",
+ "dev": true
},
"node_modules/utils-merge": {
"version": "1.0.1",
@@ -7703,9 +7801,9 @@
}
},
"node_modules/watchpack": {
- "version": "2.4.0",
- "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz",
- "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==",
+ "version": "2.4.2",
+ "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.2.tgz",
+ "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==",
"dev": true,
"dependencies": {
"glob-to-regexp": "^0.4.1",
@@ -7737,34 +7835,33 @@
}
},
"node_modules/webpack": {
- "version": "5.80.0",
- "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.80.0.tgz",
- "integrity": "sha512-OIMiq37XK1rWO8mH9ssfFKZsXg4n6klTEDL7S8/HqbAOBBaiy8ABvXvz0dDCXeEF9gqwxSvVk611zFPjS8hJxA==",
+ "version": "5.94.0",
+ "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.94.0.tgz",
+ "integrity": "sha512-KcsGn50VT+06JH/iunZJedYGUJS5FGjow8wb9c0v5n1Om8O1g4L6LjtfxwlXIATopoQu+vOXXa7gYisWxCoPyg==",
"dev": true,
"dependencies": {
- "@types/eslint-scope": "^3.7.3",
- "@types/estree": "^1.0.0",
- "@webassemblyjs/ast": "^1.11.5",
- "@webassemblyjs/wasm-edit": "^1.11.5",
- "@webassemblyjs/wasm-parser": "^1.11.5",
+ "@types/estree": "^1.0.5",
+ "@webassemblyjs/ast": "^1.12.1",
+ "@webassemblyjs/wasm-edit": "^1.12.1",
+ "@webassemblyjs/wasm-parser": "^1.12.1",
"acorn": "^8.7.1",
- "acorn-import-assertions": "^1.7.6",
- "browserslist": "^4.14.5",
+ "acorn-import-attributes": "^1.9.5",
+ "browserslist": "^4.21.10",
"chrome-trace-event": "^1.0.2",
- "enhanced-resolve": "^5.13.0",
+ "enhanced-resolve": "^5.17.1",
"es-module-lexer": "^1.2.1",
"eslint-scope": "5.1.1",
"events": "^3.2.0",
"glob-to-regexp": "^0.4.1",
- "graceful-fs": "^4.2.9",
+ "graceful-fs": "^4.2.11",
"json-parse-even-better-errors": "^2.3.1",
"loader-runner": "^4.2.0",
"mime-types": "^2.1.27",
"neo-async": "^2.6.2",
- "schema-utils": "^3.1.2",
+ "schema-utils": "^3.2.0",
"tapable": "^2.1.1",
- "terser-webpack-plugin": "^5.3.7",
- "watchpack": "^2.4.0",
+ "terser-webpack-plugin": "^5.3.10",
+ "watchpack": "^2.4.1",
"webpack-sources": "^3.2.3"
},
"bin": {
@@ -8074,7 +8171,6 @@
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
"integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==",
- "dev": true,
"dependencies": {
"isexe": "^2.0.0"
},
@@ -8136,10 +8232,28 @@
"url": "https://github.com/chalk/wrap-ansi?sponsor=1"
}
},
+ "node_modules/wrap-ansi-cjs": {
+ "name": "wrap-ansi",
+ "version": "7.0.0",
+ "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz",
+ "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==",
+ "dependencies": {
+ "ansi-styles": "^4.0.0",
+ "string-width": "^4.1.0",
+ "strip-ansi": "^6.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/wrap-ansi?sponsor=1"
+ }
+ },
"node_modules/wrappy": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz",
- "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="
+ "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==",
+ "dev": true
},
"node_modules/write-file-atomic": {
"version": "4.0.2",
@@ -8190,11 +8304,6 @@
"node": ">=10"
}
},
- "node_modules/yallist": {
- "version": "4.0.0",
- "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",
- "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A=="
- },
"node_modules/yargs": {
"version": "17.7.2",
"resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz",
diff --git a/package.json b/package.json
index 224682fb9..cb64c59e7 100644
--- a/package.json
+++ b/package.json
@@ -1,24 +1,47 @@
{
- "name": "@xenova/transformers",
- "version": "2.17.2",
+ "name": "@huggingface/transformers",
+ "version": "3.0.0",
"description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
"main": "./src/transformers.js",
"types": "./types/transformers.d.ts",
"type": "module",
+ "exports": {
+ "node": {
+ "import": {
+ "types": "./types/transformers.d.ts",
+ "default": "./dist/transformers.mjs"
+ },
+ "require": {
+ "types": "./types/transformers.d.ts",
+ "default": "./dist/transformers.cjs"
+ }
+ },
+ "default": {
+ "types": "./types/transformers.d.ts",
+ "default": "./dist/transformers.js"
+ }
+ },
+ "imports": {
+ "#onnxruntime-webgpu": {
+ "node": "onnxruntime-web",
+ "default": "onnxruntime-web/webgpu"
+ }
+ },
"scripts": {
+ "format": "prettier --write .",
+ "format:check": "prettier --check .",
"typegen": "tsc ./src/transformers.js --allowJs --declaration --emitDeclarationOnly --declarationMap --outDir types",
"dev": "webpack serve --no-client-overlay",
"build": "webpack && npm run typegen",
- "generate-tests": "python -m tests.generate_tests",
- "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1",
+ "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose",
"readme": "python ./docs/scripts/build_readme.py",
"docs-api": "node ./docs/scripts/generate.js",
"docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
- "docs-build": "doc-builder build transformers.js ./docs/source/ --not_python_module --build_dir ./docs/build/ --repo_owner xenova"
+ "docs-build": "doc-builder build transformers.js ./docs/source/ --not_python_module --build_dir ./docs/build/"
},
"repository": {
"type": "git",
- "url": "git+https://github.com/xenova/transformers.js.git"
+ "url": "git+https://github.com/huggingface/transformers.js.git"
},
"keywords": [
"transformers",
@@ -31,37 +54,32 @@
"AI",
"ML"
],
- "author": "Xenova",
+ "author": "Hugging Face",
"license": "Apache-2.0",
"bugs": {
- "url": "https://github.com/xenova/transformers.js/issues"
+ "url": "https://github.com/huggingface/transformers.js/issues"
},
- "homepage": "https://github.com/xenova/transformers.js#readme",
+ "homepage": "https://github.com/huggingface/transformers.js#readme",
"dependencies": {
- "onnxruntime-web": "1.14.0",
- "sharp": "^0.32.0",
- "@huggingface/jinja": "^0.2.2"
- },
- "optionalDependencies": {
- "onnxruntime-node": "1.14.0"
+ "@huggingface/jinja": "^0.3.0",
+ "onnxruntime-node": "1.19.2",
+ "onnxruntime-web": "1.20.0-dev.20241016-2b8fc5529b",
+ "sharp": "^0.33.5"
},
"devDependencies": {
"@types/jest": "^29.5.1",
+ "@webgpu/types": "^0.1.44",
"catharsis": "github:xenova/catharsis",
- "copy-webpack-plugin": "^11.0.0",
"jest": "^29.5.0",
"jest-environment-node": "^29.5.0",
"jsdoc-to-markdown": "^8.0.1",
+ "prettier": "3.3.3",
"typescript": "^5.2.2",
"wavefile": "^11.0.0",
"webpack": "^5.80.0",
"webpack-cli": "^5.0.2",
"webpack-dev-server": "^4.13.3"
},
- "overrides": {
- "semver": "^7.5.4",
- "protobufjs": "^7.2.6"
- },
"files": [
"src",
"dist",
diff --git a/scripts/convert.py b/scripts/convert.py
index 3a2b223e8..bf9265e48 100644
--- a/scripts/convert.py
+++ b/scripts/convert.py
@@ -2,9 +2,9 @@
import json
import os
import shutil
-from dataclasses import dataclass, field
-from typing import Optional, Set
-from tqdm import tqdm
+from dataclasses import dataclass, field, asdict
+from typing import Optional
+from enum import Enum
from transformers import (
AutoConfig,
@@ -12,117 +12,46 @@
HfArgumentParser
)
-import onnx
+import onnxslim
from optimum.exporters.onnx import main_export, export_models
+from optimum.onnx.graph_transformations import check_and_save_model
from optimum.exporters.tasks import TasksManager
-from onnxruntime.quantization import (
- quantize_dynamic,
- QuantType
-)
-DEFAULT_QUANTIZE_PARAMS = {
- 'per_channel': True,
- 'reduce_range': True,
-}
+from .quantize import QuantizationArguments, quantize
-MODEL_SPECIFIC_QUANTIZE_PARAMS = {
+NO_PER_CHANNEL_REDUCE_RANGE_MODELS = {
# Decoder-only models
- 'codegen': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'gpt2': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'gpt_bigcode': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'gptj': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'gpt-neo': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'gpt-neox': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'mpt': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'bloom': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'llama': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'opt': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'mistral': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'falcon': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'phi': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'qwen2': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'stablelm': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'starcoder2': {
- 'per_channel': False,
- 'reduce_range': False,
- },
+ 'codegen',
+ 'gpt2',
+ 'gpt_bigcode',
+ 'gptj',
+ 'gpt-neo',
+ 'gpt-neox',
+ 'mpt',
+ 'bloom',
+ 'llama',
+ 'gemma',
+ 'opt',
+ 'mistral',
+ 'falcon',
+ 'phi',
+ 'phi3',
+ 'qwen2',
+ 'stablelm',
+ 'starcoder2',
+ 'openelm',
+ 'gemma',
# Encoder-decoder models
- 'whisper': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'vision-encoder-decoder': {
- 'per_channel': False,
- 'reduce_range': False,
- },
+ 'whisper',
+ 'vision-encoder-decoder',
# Encoder-only models
- 'owlv2': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'wavlm': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'wav2vec2': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'unispeech': {
- 'per_channel': False,
- 'reduce_range': False,
- },
- 'unispeech-sat': {
- 'per_channel': False,
- 'reduce_range': False,
- },
+ 'owlv2',
+ 'wavlm',
+ 'wav2vec2',
+ 'unispeech',
+ 'unispeech-sat',
}
MODELS_WITHOUT_TOKENIZERS = [
@@ -135,6 +64,16 @@
]
+class QuantMode(Enum):
+ # F32 = 'fp32'
+ FP16 = 'fp16'
+ Q8 = 'q8'
+ QI8 = 'int8'
+ QU8 = 'uint8'
+ Q4 = 'q4'
+ BNB4 = 'bnb4'
+
+
@dataclass
class ConversionArguments:
"""
@@ -174,7 +113,22 @@ class ConversionArguments:
)
}
)
+ library_name: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "The library name to use for the export. If not specified, the library name will be auto-inferred based on the model."
+ )
+ }
+ )
+
+ variant: Optional[str] = field(
+ default='default',
+ metadata={
+ "help": "The variant of the ONNX export to use."
+ }
+ )
opset: int = field(
default=None,
metadata={
@@ -197,19 +151,6 @@ class ConversionArguments:
}
)
- per_channel: bool = field(
- default=None,
- metadata={
- "help": "Whether to quantize weights per channel"
- }
- )
- reduce_range: bool = field(
- default=None,
- metadata={
- "help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode"
- }
- )
-
output_attentions: bool = field(
default=False,
metadata={
@@ -239,90 +180,19 @@ class ConversionArguments:
"that desire a finer-grained control on the export."
}
)
-
-
-def get_operators(model: onnx.ModelProto) -> Set[str]:
- operators = set()
-
- def traverse_graph(graph):
- for node in graph.node:
- operators.add(node.op_type)
- for attr in node.attribute:
- if attr.type == onnx.AttributeProto.GRAPH:
- subgraph = attr.g
- traverse_graph(subgraph)
-
- traverse_graph(model.graph)
- return operators
-
-
-def quantize(model_names_or_paths, **quantize_kwargs):
- """
- Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
-
- Uses unsigned ints for activation values, signed ints for weights, per
- https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
- it is faster on most CPU architectures
- Args:
- onnx_model_path: Path to location the exported ONNX model is stored
- Returns: The Path generated for the quantized
- """
-
- quantize_config = dict(
- **quantize_kwargs,
- per_model_config={}
+ skip_onnxslim: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether or not to skip onnxslim."
+ }
)
- for model in tqdm(model_names_or_paths, desc='Quantizing'):
- directory_path = os.path.dirname(model)
- file_name_without_extension = os.path.splitext(
- os.path.basename(model))[0]
-
- # NOTE:
- # As of 2023/04/20, the current latest version of onnxruntime-web is 1.14.0, and does not support INT8 weights for Conv layers.
- # For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
- #
- # As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
- # For more information, see:
- # - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
- # - https://github.com/microsoft/onnxruntime/issues/2339
-
- loaded_model = onnx.load_model(model)
- op_types = get_operators(loaded_model)
- weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8
-
- quantize_dynamic(
- model_input=model,
- model_output=os.path.join(
- directory_path, f'{file_name_without_extension}_quantized.onnx'),
-
- weight_type=weight_type,
- optimize_model=False,
-
- # TODO allow user to specify these
- # op_types_to_quantize=['MatMul', 'Add', 'Conv'],
- extra_options=dict(
- EnableSubgraph=True
- ),
- **quantize_kwargs
- )
-
- quantize_config['per_model_config'][file_name_without_extension] = dict(
- op_types=list(op_types),
- weight_type=str(weight_type),
- )
-
- # Save quantization config
- with open(os.path.join(directory_path, 'quantize_config.json'), 'w') as fp:
- json.dump(quantize_config, fp, indent=4)
-
-
def main():
parser = HfArgumentParser(
- (ConversionArguments, )
+ (ConversionArguments, QuantizationArguments)
)
- conv_args, = parser.parse_args_into_dataclasses()
+ conv_args, quantization_args = parser.parse_args_into_dataclasses()
model_id = conv_args.model_id
tokenizer_id = conv_args.tokenizer_id or model_id
@@ -339,30 +209,38 @@ def main():
# Saving the model config
config = AutoConfig.from_pretrained(model_id, **from_pretrained_kwargs)
- custom_kwargs={}
+ custom_kwargs = {}
if conv_args.custom_onnx_configs is not None:
if conv_args.task == 'auto':
- raise Exception('`--task` must be set when exporting with `--custom_onnx_configs`')
+ raise Exception(
+ '`--task` must be set when exporting with `--custom_onnx_configs`')
custom_onnx_configs = json.loads(conv_args.custom_onnx_configs)
for key in custom_onnx_configs:
onnx_configs = TasksManager._SUPPORTED_MODEL_TYPE[custom_onnx_configs[key]]['onnx']
mapping = onnx_configs[conv_args.task]
- custom_onnx_configs[key] = mapping.func(config, **mapping.keywords)
+ new_kwargs = {}
+ if conv_args.task.startswith('text-generation'):
+ new_kwargs['use_past_in_inputs'] = True
+
+ custom_onnx_configs[key] = mapping.func(
+ config, **mapping.keywords, **new_kwargs)
custom_kwargs['custom_onnx_configs'] = custom_onnx_configs
tokenizer = None
try:
# Load tokenizer
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, **from_pretrained_kwargs)
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_id, **from_pretrained_kwargs)
# To avoid inserting all chat templates into tokenizers.js, we save the chat template
# to the tokenizer_config.json file, and load it when the tokenizer is loaded.
if getattr(tokenizer, 'chat_template', None) is None and \
- getattr(tokenizer, 'use_default_system_prompt', False):
+ getattr(tokenizer, 'use_default_system_prompt', False):
# No chat template specified, and we use the default
- setattr(tokenizer, 'chat_template', tokenizer.default_chat_template)
+ setattr(tokenizer, 'chat_template',
+ tokenizer.default_chat_template)
except KeyError:
pass # No Tokenizer
@@ -383,7 +261,8 @@ def main():
output=output_model_folder,
task=conv_args.task,
do_validation=not conv_args.skip_validation,
- library_name='transformers',
+ _variant=conv_args.variant,
+ library_name=conv_args.library_name,
**core_export_kwargs,
)
@@ -398,7 +277,8 @@ def main():
elif config.model_type == 'esm':
from .extra.esm import generate_fast_tokenizer
fast_tokenizer = generate_fast_tokenizer(tokenizer)
- fast_tokenizer.save(os.path.join(output_model_folder, 'tokenizer.json'))
+ fast_tokenizer.save(os.path.join(
+ output_model_folder, 'tokenizer.json'))
elif config.model_type == 'whisper':
if conv_args.output_attentions:
@@ -408,14 +288,14 @@ def main():
**get_main_export_kwargs(config, "automatic-speech-recognition")
)
- elif config.model_type in ('wav2vec2', 'wav2vec2-bert', 'hubert', 'unispeech' , 'unispeech-sat'):
+ elif config.model_type in ('wav2vec2', 'wav2vec2-bert', 'hubert', 'unispeech', 'unispeech-sat'):
if tokenizer is not None:
from .extra.wav2vec2 import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
-
+
elif config.model_type == 'vits':
if tokenizer is not None:
from .extra.vits import generate_tokenizer_json
@@ -423,10 +303,11 @@ def main():
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
-
+
elif config.model_type == 'speecht5':
# TODO allow user to specify vocoder path
- export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"}
+ export_kwargs["model_kwargs"] = {
+ "vocoder": "microsoft/speecht5_hifigan"}
if tokenizer is not None:
from .extra.speecht5 import generate_tokenizer_json
@@ -440,6 +321,26 @@ def main():
# For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
export_kwargs['batch_size'] = 1
+ elif config.model_type == 'openelm':
+ from .extra.openelm import OpenElmOnnxConfig
+
+ config = AutoConfig.from_pretrained(
+ model_id, trust_remote_code=conv_args.trust_remote_code)
+
+ onnx_config = OpenElmOnnxConfig(
+ config=config,
+ task="text-generation",
+ use_past=True,
+ use_past_in_inputs=True,
+ )
+
+ custom_onnx_configs = {
+ "model": onnx_config,
+ }
+
+ export_kwargs['task'] = "text-generation-with-past"
+ export_kwargs['custom_onnx_configs'] = custom_onnx_configs
+
else:
pass # TODO
@@ -457,8 +358,10 @@ def main():
from .extra.clip import CLIPTextModelWithProjectionOnnxConfig, CLIPVisionModelWithProjectionOnnxConfig
from transformers.models.clip import CLIPTextModelWithProjection, CLIPVisionModelWithProjection
- text_model = CLIPTextModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)
- vision_model = CLIPVisionModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)
+ text_model = CLIPTextModelWithProjection.from_pretrained(
+ model_id, **from_pretrained_kwargs)
+ vision_model = CLIPVisionModelWithProjection.from_pretrained(
+ model_id, **from_pretrained_kwargs)
export_models(
models_and_onnx_configs={
@@ -473,8 +376,10 @@ def main():
from .extra.siglip import SiglipTextModelOnnxConfig, SiglipVisionModelOnnxConfig
from transformers.models.siglip import SiglipTextModel, SiglipVisionModel
- text_model = SiglipTextModel.from_pretrained(model_id, **from_pretrained_kwargs)
- vision_model = SiglipVisionModel.from_pretrained(model_id, **from_pretrained_kwargs)
+ text_model = SiglipTextModel.from_pretrained(
+ model_id, **from_pretrained_kwargs)
+ vision_model = SiglipVisionModel.from_pretrained(
+ model_id, **from_pretrained_kwargs)
export_models(
models_and_onnx_configs={
@@ -500,32 +405,43 @@ def main():
# },
# **custom_export_kwargs,
# )
-
else:
- raise Exception(f'Unable to export {config.model_type} model with `--split_modalities`.')
+ raise Exception(
+ f'Unable to export {config.model_type} model with `--split_modalities`.')
+
+ os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)
+ if not conv_args.skip_onnxslim:
+ onnx_models = [os.path.join(output_model_folder, x)
+ for x in os.listdir(output_model_folder) if x.endswith('.onnx')]
+
+ for model in onnx_models:
+ try:
+ slimmed_model = onnxslim.slim(model)
+ check_and_save_model(slimmed_model, model)
+ except Exception as e:
+ print(f"Failed to slim {model}: {e}")
# Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
if conv_args.quantize:
- # Update quantize config with model specific defaults
- quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get(
- config.model_type, DEFAULT_QUANTIZE_PARAMS)
- # Update if user specified values
- if conv_args.per_channel is not None:
- quantize_config['per_channel'] = conv_args.per_channel
+ # Possibly update quantize config with model specific defaults
+ use_per_channel_reduce_range = config.model_type not in NO_PER_CHANNEL_REDUCE_RANGE_MODELS
- if conv_args.reduce_range is not None:
- quantize_config['reduce_range'] = conv_args.reduce_range
+ if quantization_args.per_channel is None:
+ quantization_args.per_channel = use_per_channel_reduce_range
+ if quantization_args.reduce_range is None:
+ quantization_args.reduce_range = use_per_channel_reduce_range
- quantize([
- os.path.join(output_model_folder, x)
- for x in os.listdir(output_model_folder)
- if x.endswith('.onnx') and not x.endswith('_quantized.onnx')
- ], **quantize_config)
+ quantize(
+ output_model_folder,
+ os.path.join(output_model_folder, 'onnx'),
+ quantization_args,
+ )
+ with open(os.path.join(output_model_folder, 'quantize_config.json'), 'w') as fp:
+ json.dump(asdict(quantization_args), fp, indent=4)
# Step 3. Move .onnx files to the 'onnx' subfolder
- os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)
for file in os.listdir(output_model_folder):
if file.endswith(('.onnx', '.onnx_data')):
shutil.move(os.path.join(output_model_folder, file),
@@ -536,7 +452,8 @@ def main():
from transformers import GenerationConfig
from .extra.whisper import get_alignment_heads
- generation_config = GenerationConfig.from_pretrained(model_id, **from_pretrained_kwargs)
+ generation_config = GenerationConfig.from_pretrained(
+ model_id, **from_pretrained_kwargs)
generation_config.alignment_heads = get_alignment_heads(config)
generation_config.save_pretrained(output_model_folder)
diff --git a/scripts/extra/marian.py b/scripts/extra/marian.py
index e5f370021..ef9bd279d 100644
--- a/scripts/extra/marian.py
+++ b/scripts/extra/marian.py
@@ -1,61 +1,6 @@
import json
from transformers.utils import cached_file
-# NOTE: In total, there are 1440 models available on the HuggingFace hub (https://huggingface.co/Helsinki-NLP).
-# We have converted some of these (listed below). If you don't see your model here, feel free to convert it yourself
-# and make a pull request to this repo.
-
-SUPPORTED_HELSINKI_NLP_MODELS = [
- 'en-es', 'es-en', # English <-> Spanish
- 'en-fr', 'fr-en', # English <-> French
- 'en-hi', 'hi-en', # English <-> Hindi
- 'en-de', 'de-en', # English <-> German
- 'en-ru', 'ru-en', # English <-> Russian
- 'en-it', 'it-en', # English <-> Italian
- 'en-ar', 'ar-en', # English <-> Arabic
- 'en-zh', 'zh-en', # English <-> Chinese
- 'en-sv', 'sv-en', # English <-> Swedish
- 'en-mul', 'mul-en', # English <-> Multilingual
- 'en-nl', 'nl-en', # English <-> Dutch
- 'en-fi', 'fi-en', # English <-> Finnish
- 'en-jap', 'jap-en', # English <-> Japanese
- 'en-cs', 'cs-en', # English <-> Czech
- 'en-vi', 'vi-en', # English <-> Vietnamese
- 'en-xh', 'xh-en', # English <-> Xhosa
- 'en-hu', 'hu-en', # English <-> Hungarian
- 'en-da', 'da-en', # English <-> Danish
- 'en-id', 'id-en', # English <-> Indonesia
- 'en-uk', 'uk-en', # English <-> Ukranian
- 'en-af', 'af-en', # English <-> Afrikaans
- 'en-ROMANCE', 'ROMANCE-en', # English <-> ROMANCE
- 'de-es', 'es-de', # German <-> Spanish
- 'fr-es', 'es-fr', # French <-> Spanish
- 'fr-de', 'de-fr', # French <-> German
- 'es-it', 'it-es', # Spanish <-> Italian
- 'es-ru', 'ru-es', # Spanish <-> Russian
- 'fr-ru', 'ru-fr', # French <-> Russian
- 'fr-ro', 'ro-fr', # French <-> Romanian
- 'uk-ru', 'ru-uk', # Ukranian <-> Russian
-
- 'it-fr', # Italian --> French
- 'en-ro', # English --> Romanian
- 'pl-en', # Poland --> English
- 'tr-en', # Turkey --> English
- 'ko-en', # Korean --> English
- 'bat-en', # Baltic --> English
- 'et-en', # Estonian --> English
- 'fi-de', # Finnish --> German
- 'gem-gem', # Germanic <-> Germanic
- 'gmw-gmw', # West Germanic <-> West Germanic
- 'da-de', # Danish <-> German
- 'ja-en', # Japanese --> English
- 'nl-fr', # Netherlands --> French
- 'no-de', # Norwegian --> German
- 'tc-big-tr-en', # Turkish --> English
- 'th-en', # Thai --> English
- 'en-cs', # English --> Czech
-]
-
def generate_tokenizer_json(model_path, tokenizer):
# Marian models use two separate tokenizers for source and target languages.
diff --git a/scripts/extra/openelm.py b/scripts/extra/openelm.py
new file mode 100644
index 000000000..28ce793ee
--- /dev/null
+++ b/scripts/extra/openelm.py
@@ -0,0 +1,64 @@
+import random
+from typing import Optional, Tuple
+
+from optimum.exporters.onnx.config import TextDecoderOnnxConfig
+from optimum.utils import NormalizedTextConfig, DummyInputGenerator, DEFAULT_DUMMY_SHAPES, DummyTextInputGenerator, NormalizedConfig
+
+class OpenElmDummyPastKeyValuesGenerator(DummyInputGenerator):
+
+ SUPPORTED_INPUT_NAMES = ("past_key_values", )
+
+ def __init__(
+ self,
+ task: str,
+ normalized_config: NormalizedTextConfig,
+ batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
+ sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
+ random_batch_size_range: Optional[Tuple[int, int]] = None,
+ random_sequence_length_range: Optional[Tuple[int, int]] = None,
+ **kwargs,
+ ):
+ self.num_layers = normalized_config.num_layers
+ self.num_kv_heads = normalized_config.num_kv_heads
+ self.num_query_heads = normalized_config.num_query_heads
+ self.head_dim = normalized_config.head_dim
+
+ self.hidden_size = normalized_config.model_dim
+ if random_batch_size_range:
+ low, high = random_batch_size_range
+ self.batch_size = random.randint(low, high)
+ else:
+ self.batch_size = batch_size
+ if random_sequence_length_range:
+ low, high = random_sequence_length_range
+ self.sequence_length = random.randint(low, high)
+ else:
+ self.sequence_length = sequence_length
+
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
+ data = []
+ for i in range(self.num_layers):
+ kv_shape = (
+ self.batch_size,
+ self.num_kv_heads[i],
+ self.sequence_length,
+ self.head_dim,
+ )
+ data.append((
+ self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype),
+ self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype),
+ ))
+ return data
+
+
+class OpenElmOnnxConfig(TextDecoderOnnxConfig):
+ DEFAULT_ONNX_OPSET = 14
+
+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, OpenElmDummyPastKeyValuesGenerator)
+ DUMMY_PKV_GENERATOR_CLASS = OpenElmDummyPastKeyValuesGenerator
+ NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
+ num_kv_heads="num_kv_heads",
+ num_query_heads="num_query_heads",
+ num_layers="num_transformer_layers",
+ allow_new=True,
+ )
diff --git a/scripts/extra/whisper.py b/scripts/extra/whisper.py
index 1a1a70aab..2a9937c96 100644
--- a/scripts/extra/whisper.py
+++ b/scripts/extra/whisper.py
@@ -14,44 +14,30 @@
'whisper-small': [[5, 3], [5, 9], [8, 0], [8, 4], [8, 7], [8, 8], [9, 0], [9, 7], [9, 9], [10, 5]],
'whisper-medium.en': [[11, 4], [14, 1], [14, 12], [14, 14], [15, 4], [16, 0], [16, 4], [16, 9], [17, 12], [17, 14], [18, 7], [18, 10], [18, 15], [20, 0], [20, 3], [20, 9], [20, 14], [21, 12]],
'whisper-medium': [[13, 15], [15, 4], [15, 15], [16, 1], [20, 0], [23, 4]],
+ 'whisper-large-v3-turbo': [[2, 4], [2, 11], [3, 3], [3, 6], [3, 11], [3, 14]],
'whisper-large-v2': [[10, 12], [13, 17], [16, 11], [16, 12], [16, 13], [17, 15], [17, 16], [18, 4], [18, 11], [18, 19], [19, 11], [21, 2], [21, 3], [22, 3], [22, 9], [22, 12], [23, 5], [23, 7], [23, 13], [25, 5], [26, 1], [26, 12], [27, 15]],
'whisper-large': [[9, 19], [11, 2], [11, 4], [11, 17], [22, 7], [22, 11], [22, 17], [23, 2], [23, 15]],
}
class CustomWhisperOnnxConfig(WhisperOnnxConfig):
+ """
+ Custom ONNX config for Whisper models to output cross attentions.
+ Needed to compute token-level timestamps.
+ """
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
- if self._behavior is ConfigBehavior.ENCODER:
- for i in range(self._config.encoder_layers):
- common_outputs[f"encoder_attentions.{i}"] = {0: "batch_size"}
- elif self._behavior is ConfigBehavior.DECODER:
- for i in range(self._config.decoder_layers):
- common_outputs[f"decoder_attentions.{i}"] = {
- 0: "batch_size",
- 2: "decoder_sequence_length",
- 3: "past_decoder_sequence_length + 1"
- }
+ if self._behavior is ConfigBehavior.DECODER:
for i in range(self._config.decoder_layers):
common_outputs[f"cross_attentions.{i}"] = {
0: "batch_size",
2: "decoder_sequence_length",
3: "encoder_sequence_length_out"
}
-
return common_outputs
- @property
- def torch_to_onnx_output_map(self):
- if self._behavior is ConfigBehavior.ENCODER:
- # The encoder export uses WhisperEncoder that returns the key "attentions"
- return {"attentions": "encoder_attentions"}
- else:
- return {}
-
-
def get_main_export_kwargs(config, task):
# See https://github.com/huggingface/optimum/blob/a39b1f5637af9725c0c788b86ca1fdf71ad3dcc2/docs/source/exporters/onnx/usage_guides/export_a_model.mdx#L264
@@ -59,9 +45,8 @@ def get_main_export_kwargs(config, task):
custom_onnx_configs = dict(
encoder_model=custom_config.with_behavior("encoder"),
- decoder_model=custom_config.with_behavior("decoder", use_past=False),
- decoder_with_past_model=custom_config.with_behavior(
- "decoder", use_past=True),
+ decoder_model=custom_config.with_behavior("decoder", use_past=True, use_past_in_inputs=False),
+ decoder_with_past_model=custom_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True),
)
return dict(
diff --git a/scripts/quantize.py b/scripts/quantize.py
new file mode 100644
index 000000000..1ace2d353
--- /dev/null
+++ b/scripts/quantize.py
@@ -0,0 +1,345 @@
+from enum import Enum
+
+from tqdm import tqdm
+from typing import Set
+import onnx
+import os
+
+from dataclasses import dataclass, field
+
+from transformers import HfArgumentParser
+from optimum.onnx.graph_transformations import check_and_save_model
+
+from onnxruntime.quantization import QuantType, QuantizationMode
+from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
+from onnxruntime.quantization.registry import IntegerOpsRegistry
+from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
+from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
+from onnxconverter_common import float16
+import onnx_graphsurgeon as gs
+
+
+class QuantMode(Enum):
+ # F32 = 'fp32'
+ FP16 = "fp16"
+ Q8 = "q8"
+ QI8 = "int8"
+ QU8 = "uint8"
+ Q4 = "q4"
+ Q4F16 = "q4f16"
+ BNB4 = "bnb4"
+
+
+QUANTIZE_SUFFIX_MAPPING = {
+ QuantMode.Q8: "quantized",
+}
+
+QUANTIZE_OPTIONS = tuple(x.value for x in QuantMode)
+
+
+@dataclass
+class IOArguments:
+ """
+ Arguments to specify input and output folders
+ """
+ input_folder: str = field(
+ metadata={
+ "help": "Path of the input folder containing the .onnx models to quantize"
+ }
+ )
+ output_folder: str = field(
+ metadata={
+ "help": "Path of the output folder where the quantized .onnx models will be saved"
+ }
+ )
+
+@dataclass
+class QuantizationArguments:
+ """
+ Arguments for quantizing ONNX models
+ """
+
+ modes: QuantMode = field(
+ default=QUANTIZE_OPTIONS,
+ metadata={
+ "help": "Quantization mode to use.",
+ "choices": QUANTIZE_OPTIONS,
+ "nargs": "+",
+ },
+ )
+
+ # 8-bit quantization
+ per_channel: bool = field(
+ default=None, metadata={"help": "Whether to quantize weights per channel"}
+ )
+ reduce_range: bool = field(
+ default=None,
+ metadata={
+ "help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode"
+ },
+ )
+
+ # 4-bit quantization
+ block_size: int = field(
+ default=None,
+ metadata={
+ "help": "Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64"
+ },
+ )
+
+ # MatMul4BitsQuantizer
+ is_symmetric: bool = field(
+ default=True,
+ metadata={"help": "Indicate whether to quantize the model symmetrically"},
+ )
+ accuracy_level: int = field(
+ default=None,
+ metadata={
+ "help": "Accuracy level of the 4-bit quantized MatMul computation. "
+ "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
+ "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)."
+ },
+ )
+
+ # MatMulBnb4Quantizer
+ quant_type: int = field(
+ default=MatMulBnb4Quantizer.NF4,
+ metadata={
+ "help": "Quantization data type. 0: FP4, 1: NF4",
+ "choices": [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
+ },
+ )
+
+
+def get_operators(model: onnx.ModelProto) -> Set[str]:
+ operators = set()
+
+ def traverse_graph(graph):
+ for node in graph.node:
+ operators.add(node.op_type)
+ for attr in node.attribute:
+ if attr.type == onnx.AttributeProto.GRAPH:
+ traverse_graph(attr.g)
+
+ traverse_graph(model.graph)
+ return operators
+
+
+def quantize_q8(
+ model: onnx.ModelProto,
+ save_path: str,
+ per_channel: bool,
+ reduce_range: bool,
+ weight_type: QuantType,
+):
+ """
+ Quantize the weights of the model from float32 to int8/uint8
+
+ Uses unsigned ints for activation values, signed ints for weights, per
+ https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
+ it is faster on most CPU architectures
+ """
+
+ quantizer = ONNXQuantizer(
+ model,
+ per_channel,
+ reduce_range,
+ mode=QuantizationMode.IntegerOps,
+ static=False,
+ weight_qType=weight_type,
+ activation_qType=QuantType.QUInt8, # dynamic activation only supports uint8
+ tensors_range=None,
+ nodes_to_quantize=[],
+ nodes_to_exclude=[],
+ op_types_to_quantize=list(IntegerOpsRegistry.keys()),
+ extra_options=dict(
+ EnableSubgraph=True,
+ MatMulConstBOnly=True,
+ ),
+ )
+
+ quantizer.quantize_model()
+ check_and_save_model(quantizer.model.model, save_path)
+
+
+def quantize_fp16(
+ model: onnx.ModelProto,
+ save_path: str,
+):
+ """
+ Quantize the weights of the model from float32 to float16
+ """
+
+ # Check whether we should disable shape infer:
+ # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
+ disable_shape_infer = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF
+
+ model_fp16 = float16.convert_float_to_float16(
+ model,
+ keep_io_types=True,
+ disable_shape_infer=disable_shape_infer,
+ )
+ graph = gs.import_onnx(model_fp16)
+ graph.toposort()
+ model_fp16 = gs.export_onnx(graph)
+ check_and_save_model(model_fp16, save_path)
+
+
+def quantize_q4(
+ model: onnx.ModelProto,
+ save_path: str | None,
+ block_size: int,
+ is_symmetric: bool,
+ accuracy_level: int,
+):
+ """
+ Quantize the weights of the model from float32 to 4-bit int
+ """
+
+ quantizer = MatMul4BitsQuantizer(
+ model=model,
+ block_size=block_size,
+ is_symmetric=is_symmetric,
+ accuracy_level=accuracy_level,
+ )
+ quantizer.process()
+ if save_path:
+ check_and_save_model(quantizer.model.model, save_path)
+ return quantizer.model.model
+
+
+def quantize_bnb4(
+ model: onnx.ModelProto,
+ save_path: str,
+ block_size: int,
+ quant_type: int,
+):
+ """
+ Quantize the weights of the model from float32 to 4-bit int using MatMulBnb4Quantizer
+ """
+
+ quantizer = MatMulBnb4Quantizer(
+ model=model,
+ block_size=block_size,
+ quant_type=quant_type,
+ )
+ quantizer.process()
+ check_and_save_model(quantizer.model.model, save_path)
+ return quantizer.model.model
+
+
+def quantize(input_folder, output_folder, quantization_args: QuantizationArguments):
+
+ # (Step 1) Validate the arguments
+ if not quantization_args.modes:
+ raise ValueError("At least one quantization mode must be specified")
+
+ if not os.path.exists(input_folder):
+ raise ValueError(f"Input folder {input_folder} does not exist")
+
+ model_names_or_paths = [
+ os.path.join(input_folder, file)
+ for file in os.listdir(input_folder)
+ if file.endswith(".onnx")
+ ]
+ if not model_names_or_paths:
+ raise ValueError(f"No .onnx models found in {input_folder}")
+
+ os.makedirs(output_folder, exist_ok=True)
+
+ # (Step 2) Quantize the models
+ for model_path in (progress_models := tqdm(model_names_or_paths)):
+ progress_models.set_description(f"Processing {model_path}")
+
+ file_name_without_extension = os.path.splitext(os.path.basename(model_path))[0]
+
+ for mode in (progress := tqdm(quantization_args.modes)):
+ progress.set_description(f" - Quantizing to {mode}")
+ mode = QuantMode(mode)
+ suffix = QUANTIZE_SUFFIX_MAPPING.get(mode, mode.value)
+ save_path = os.path.join(
+ output_folder,
+ f"{file_name_without_extension}_{suffix}.onnx",
+ )
+
+ # NOTE: Unfortunately, we need to reload the model for each quantization mode,
+ # which is memory inefficient. This is because the quantization functions
+ # modify the model in-place, and we need to keep the original model for each mode.
+ model = onnx.load_model(model_path)
+
+ if mode == QuantMode.FP16:
+ quantize_fp16(
+ model,
+ save_path,
+ )
+
+ elif mode in (QuantMode.Q4, QuantMode.Q4F16):
+ block_size = quantization_args.block_size or 32
+
+ q4_model = quantize_q4(
+ model,
+ save_path=None if mode == QuantMode.Q4F16 else save_path,
+ block_size=block_size,
+ is_symmetric=quantization_args.is_symmetric,
+ accuracy_level=quantization_args.accuracy_level,
+ )
+ if mode == QuantMode.Q4F16:
+ quantize_fp16(
+ q4_model,
+ save_path,
+ )
+
+ elif mode == QuantMode.BNB4:
+ quantize_bnb4(
+ model,
+ save_path,
+ block_size=quantization_args.block_size or 64,
+ quant_type=(
+ quantization_args.quant_type
+ if quantization_args.quant_type is not None
+ else MatMulBnb4Quantizer.NF4
+ ),
+ )
+
+ elif mode in (QuantMode.Q8, QuantMode.QI8, QuantMode.QU8):
+ if mode == QuantMode.Q8:
+ # NOTE:
+ # As of 2024/06/28, the current latest version of onnxruntime-web is 1.18.0, and does not support INT8 weights for Conv layers.
+ # If you attempt to run a model with INT8 weights for Conv layers, you will get an error like:
+ # `Can't create a session. ERROR_CODE: 9, ERROR_MESSAGE: Could not find an implementation for ConvInteger(10) node with name '/.../Conv_quant'`
+ #
+ # For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
+ #
+ # As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
+ # For more information, see:
+ # - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
+ # - https://github.com/microsoft/onnxruntime/issues/2339
+ op_types = get_operators(model)
+ weight_type = (
+ QuantType.QUInt8 if "Conv" in op_types else QuantType.QInt8
+ )
+
+ elif mode == QuantMode.QI8:
+ weight_type = QuantType.QInt8
+
+ else: # mode == QuantMode.QU8:
+ weight_type = QuantType.QUInt8
+
+ quantize_q8(
+ model,
+ save_path,
+ per_channel=quantization_args.per_channel,
+ reduce_range=quantization_args.reduce_range,
+ weight_type=weight_type,
+ )
+
+
+def main():
+ parser = HfArgumentParser((IOArguments, QuantizationArguments))
+ io_args, quantization_args = parser.parse_args_into_dataclasses()
+ input_folder = io_args.input_folder
+ output_folder = io_args.output_folder
+ quantize(input_folder, output_folder, quantization_args)
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/requirements.txt b/scripts/requirements.txt
index f0b3867ae..9773d04e7 100644
--- a/scripts/requirements.txt
+++ b/scripts/requirements.txt
@@ -1,5 +1,9 @@
-transformers[torch]==4.33.2
-onnxruntime<1.16.0
-optimum==1.13.2
-tqdm
-onnx==1.13.1
+transformers[torch]==4.43.4
+onnxruntime==1.19.2
+optimum==1.21.3
+onnx==1.16.2
+onnxconverter-common==1.14.0
+tqdm==4.66.5
+onnxslim==0.1.31
+--extra-index-url https://pypi.ngc.nvidia.com
+onnx_graphsurgeon==0.3.27
diff --git a/scripts/supported_models.py b/scripts/supported_models.py
deleted file mode 100644
index dc044167e..000000000
--- a/scripts/supported_models.py
+++ /dev/null
@@ -1,1206 +0,0 @@
-from .extra.marian import SUPPORTED_HELSINKI_NLP_MODELS
-
-
-SUPPORTED_MODELS = {
- # NOTE: keys of `SUPPORTED_MODELS` are subsets of https://github.com/huggingface/optimum/blob/7f8e606689365931300ef5e6d3b20cb88771cb08/optimum/exporters/tasks.py#L281-L965
- 'albert': {
- # Masked language modelling
- 'fill-mask': [
- 'albert-base-v2',
- 'albert-large-v2',
- ],
-
- # Feature extraction
- 'feature-extraction': [
- 'sentence-transformers/paraphrase-albert-small-v2',
- 'sentence-transformers/paraphrase-albert-base-v2',
- ],
- },
- 'audio-spectrogram-transformer': {
- # Audio classification
- 'audio-classification': {
- 'MIT/ast-finetuned-audioset-10-10-0.4593',
- 'MIT/ast-finetuned-audioset-16-16-0.442',
- 'MIT/ast-finetuned-speech-commands-v2',
- 'mtg-upf/discogs-maest-30s-pw-73e-ts',
- }
- },
- 'bart': {
- # Summarization
- 'summarization': [
- 'sshleifer/distilbart-xsum-12-1',
- 'sshleifer/distilbart-xsum-6-6',
- 'sshleifer/distilbart-xsum-12-3',
- 'sshleifer/distilbart-xsum-9-6',
- 'sshleifer/distilbart-xsum-12-6',
- 'sshleifer/distilbart-cnn-12-3',
- 'sshleifer/distilbart-cnn-12-6',
- 'sshleifer/distilbart-cnn-6-6',
- 'facebook/bart-large-cnn',
- 'facebook/bart-large-xsum',
- ],
- # Zero-shot classification
- 'zero-shot-classification': {
- 'facebook/bart-large-mnli',
- },
- },
- 'beit': {
- # Image classification
- 'image-classification': [
- 'microsoft/beit-base-patch16-224',
- 'microsoft/beit-base-patch16-224-pt22k',
- 'microsoft/beit-base-patch16-384',
- 'microsoft/beit-base-patch16-224-pt22k-ft22k',
- 'microsoft/beit-large-patch16-224',
- 'microsoft/beit-large-patch16-224-pt22k',
- 'microsoft/beit-large-patch16-512',
- 'microsoft/beit-large-patch16-224-pt22k-ft22k',
- 'microsoft/beit-large-patch16-384',
- 'microsoft/dit-base-finetuned-rvlcdip',
- 'microsoft/dit-large-finetuned-rvlcdip',
- ],
- },
- 'bert': {
- # Feature extraction
- 'feature-extraction': [
- 'sentence-transformers/all-MiniLM-L6-v2',
- 'sentence-transformers/all-MiniLM-L12-v2',
- 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
- 'sentence-transformers/paraphrase-MiniLM-L6-v2',
- 'sentence-transformers/paraphrase-MiniLM-L3-v2',
- 'sentence-transformers/bert-base-nli-mean-tokens',
- 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1',
- 'sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens',
- 'sentence-transformers/LaBSE',
- 'deepset/sentence_bert',
- 'intfloat/e5-small',
- 'intfloat/e5-small-v2',
- 'intfloat/e5-base',
- 'intfloat/e5-base-v2',
- 'intfloat/e5-large',
- 'intfloat/e5-large-v2',
- 'intfloat/multilingual-e5-base',
- 'thenlper/gte-small',
- 'thenlper/gte-base',
- 'thenlper/gte-large',
- 'BAAI/bge-small-en',
- 'BAAI/bge-base-en',
- 'BAAI/bge-large-en',
- 'BAAI/bge-large-en-v1.5',
- 'BAAI/bge-base-en-v1.5',
- 'BAAI/bge-small-en-v1.5',
- 'BAAI/bge-large-zh-v1.5',
- 'BAAI/bge-base-zh-v1.5',
- 'BAAI/bge-small-zh-v1.5',
- 'allenai/scibert_scivocab_uncased',
- 'SpanBERT/spanbert-large-cased',
- 'SpanBERT/spanbert-base-cased',
- 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext',
- 'indobenchmark/indobert-base-p1',
- 'GanjinZero/UMLSBert_ENG',
- 'DeepPavlov/rubert-base-cased',
- 'monologg/kobert',
- ],
-
- # Text classification
- 'text-classification': [
- 'nlptown/bert-base-multilingual-uncased-sentiment',
- 'ProsusAI/finbert',
- 'unitary/toxic-bert',
- 'BAAI/bge-reranker-large',
- 'BAAI/bge-reranker-base',
- 'cross-encoder/ms-marco-TinyBERT-L-2-v2',
- 'cross-encoder/ms-marco-MiniLM-L-2-v2',
- 'cross-encoder/ms-marco-MiniLM-L-4-v2',
- 'cross-encoder/ms-marco-MiniLM-L-6-v2',
- 'cross-encoder/ms-marco-MiniLM-L-12-v2',
- ],
-
- # Token classification
- 'token-classification': [
- 'Davlan/bert-base-multilingual-cased-ner-hrl',
- 'ckiplab/bert-base-chinese-ner',
- 'ckiplab/bert-base-chinese-ws',
- 'ckiplab/bert-base-chinese-pos',
- 'dslim/bert-base-NER',
- 'dslim/bert-base-NER-uncased',
- ],
-
- # Masked language modelling
- 'fill-mask': [
- 'bert-base-uncased',
- 'bert-base-cased',
- 'bert-base-multilingual-uncased',
- 'bert-base-multilingual-cased',
- 'bert-base-chinese',
- 'emilyalsentzer/Bio_ClinicalBERT',
- ],
- },
- 'blenderbot': {
- # Text-to-text (TODO add conversational)
- 'text2text-generation': [
- 'facebook/blenderbot-400M-distill',
- # 'facebook/blenderbot-1B-distill',
- ],
- },
- 'blenderbot-small': {
- # Text-to-text (TODO add conversational)
- 'text2text-generation': [
- # 'facebook/blenderbot-90M', # DEPRECATED
- 'facebook/blenderbot_small-90M',
- ],
- },
- 'bloom': {
- # Text generation
- 'text-generation': [
- 'bigscience/bloom-560m',
- 'bigscience/bloomz-560m',
- ],
- },
-
- 'camembert': {
- # Feature extraction
- 'feature-extraction': [
- 'dangvantuan/sentence-camembert-large',
- ],
-
- # Token classification
- 'token-classification': [
- 'Jean-Baptiste/camembert-ner',
- 'Jean-Baptiste/camembert-ner-with-dates',
- 'pythainlp/thainer-corpus-v2-base-model',
- 'gilf/french-camembert-postag-model',
- ],
-
- # Masked language modelling
- 'fill-mask': [
- 'camembert-base',
- 'airesearch/wangchanberta-base-att-spm-uncased',
- ],
- },
- 'clap': {
- # Zero-shot audio classification and feature extraction
- # (with and without `--split_modalities`)
- 'zero-shot-audio-classification': {
- 'laion/clap-htsat-unfused',
- # TODO add 'laion/clap-htsat-fused',
- 'laion/larger_clap_general',
- 'laion/larger_clap_music_and_speech',
- # 'Xenova/tiny-random-ClapModel',
- }
- },
- 'chinese_clip': {
- # Zero-shot image classification
- # TODO: Add `--split_modalities` option
- 'zero-shot-image-classification': [
- 'OFA-Sys/chinese-clip-vit-base-patch16',
- 'OFA-Sys/chinese-clip-vit-large-patch14',
- 'OFA-Sys/chinese-clip-vit-large-patch14-336px',
- # 'OFA-Sys/chinese-clip-vit-huge-patch14', # TODO add
- ],
- },
- 'clip': {
- # Zero-shot image classification (and feature extraction)
- # (with and without `--split_modalities`)
- 'zero-shot-image-classification': [
- 'openai/clip-vit-base-patch16',
- 'openai/clip-vit-base-patch32',
- 'openai/clip-vit-large-patch14',
- 'openai/clip-vit-large-patch14-336',
- ],
- },
- 'clipseg': {
- # Image segmentation
- 'image-segmentation': [
- 'CIDAS/clipseg-rd64-refined',
- 'CIDAS/clipseg-rd64',
- 'CIDAS/clipseg-rd16',
- ],
- },
- 'codegen': {
- # Text generation
- 'text-generation': [
- 'Salesforce/codegen-350M-mono',
- 'Salesforce/codegen-350M-multi',
- 'Salesforce/codegen-350M-nl',
- ],
- },
- 'convbert': {
- # Feature extraction
- 'feature-extraction': [
- 'YituTech/conv-bert-small',
- 'YituTech/conv-bert-medium-small',
- 'YituTech/conv-bert-base',
- ],
- },
- 'convnext': {
- # Image classification
- 'image-classification': [
- 'facebook/convnext-tiny-224',
- 'facebook/convnext-small-224',
- 'facebook/convnext-base-224',
- 'facebook/convnext-base-224-22k',
- 'facebook/convnext-base-224-22k-1k',
- 'facebook/convnext-base-384',
- 'facebook/convnext-base-384-22k-1k',
- 'facebook/convnext-large-224',
- 'facebook/convnext-large-224-22k',
- 'facebook/convnext-large-224-22k-1k',
- 'facebook/convnext-large-384',
- 'facebook/convnext-large-384-22k-1k',
- 'facebook/convnext-xlarge-224-22k',
- 'facebook/convnext-xlarge-224-22k-1k',
- 'facebook/convnext-xlarge-384-22k-1k',
- ],
- },
- 'convnextv2': {
- # Image classification
- 'image-classification': [
- 'facebook/convnextv2-atto-1k-224',
- 'facebook/convnextv2-femto-1k-224',
- 'facebook/convnextv2-pico-1k-224',
- 'facebook/convnextv2-tiny-1k-224',
- 'facebook/convnextv2-tiny-22k-384',
- 'facebook/convnextv2-tiny-22k-224',
- 'facebook/convnextv2-nano-1k-224',
- 'facebook/convnextv2-nano-22k-384',
- 'facebook/convnextv2-base-22k-224',
- 'facebook/convnextv2-base-1k-224',
- 'facebook/convnextv2-base-22k-384',
- 'facebook/convnextv2-large-22k-224',
- 'facebook/convnextv2-large-1k-224',
- 'facebook/convnextv2-large-22k-384',
- # 'facebook/convnextv2-huge-22k-512',
- # 'facebook/convnextv2-huge-1k-224',
- # 'facebook/convnextv2-huge-22k-384',
- # 'facebook/convnextv2-nano-22k-224',
- ],
- },
- 'deberta': {
- # Zero-shot classification
- 'zero-shot-classification': [
- 'cross-encoder/nli-deberta-base',
- 'Narsil/deberta-large-mnli-zero-cls',
- ],
- },
- 'deberta-v2': {
- # Zero-shot classification
- 'zero-shot-classification': [
- 'cross-encoder/nli-deberta-v3-xsmall',
- 'cross-encoder/nli-deberta-v3-small',
- 'cross-encoder/nli-deberta-v3-base',
- 'cross-encoder/nli-deberta-v3-large',
- 'MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary',
- 'MoritzLaurer/DeBERTa-v3-base-mnli',
- 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli',
- 'MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli',
- 'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7',
- 'sileod/deberta-v3-base-tasksource-nli',
- 'sileod/deberta-v3-large-tasksource-nli',
- ],
- },
- # TODO: Add back in v3
- # 'decision-transformer': {
- # # Reinforcement learning
- # 'reinforcement-learning': [
- # 'edbeeching/decision-transformer-gym-hopper-expert',
- # 'edbeeching/decision-transformer-gym-hopper-medium',
- # 'edbeeching/decision-transformer-gym-hopper-medium-replay',
- # 'edbeeching/decision-transformer-gym-hopper-expert-new',
- # 'edbeeching/decision-transformer-gym-halfcheetah-expert',
- # 'edbeeching/decision-transformer-gym-halfcheetah-medium',
- # 'edbeeching/decision-transformer-gym-halfcheetah-medium-replay',
- # 'edbeeching/decision-transformer-gym-walker2d-expert',
- # 'edbeeching/decision-transformer-gym-walker2d-medium',
- # 'edbeeching/decision-transformer-gym-walker2d-medium-replay',
- # ],
- # },
- 'deit': {
- # Image classification
- 'image-classification': [
- 'facebook/deit-tiny-distilled-patch16-224',
- 'facebook/deit-small-distilled-patch16-224',
- 'facebook/deit-base-distilled-patch16-224',
- 'facebook/deit-base-distilled-patch16-384',
- ],
- },
- 'detr': {
- # Object detection
- 'object-detection': [
- 'facebook/detr-resnet-50',
- 'facebook/detr-resnet-101',
- ],
-
- # Image segmentation
- 'image-segmentation': [
- 'facebook/detr-resnet-50-panoptic',
- ],
- },
- 'dinov2': {
- # Feature extraction
- 'feature-extraction': [
- 'facebook/dinov2-small',
- 'facebook/dinov2-base',
- 'facebook/dinov2-large',
- # 'facebook/dinov2-giant', # TODO add
- ],
-
- # Image classification
- 'image-classification': [
- 'facebook/dinov2-small-imagenet1k-1-layer',
- 'facebook/dinov2-base-imagenet1k-1-layer',
- 'facebook/dinov2-large-imagenet1k-1-layer',
- # 'facebook/dinov2-giant-imagenet1k-1-layer', # TODO add
- ],
- },
- 'distilbert': {
- # Feature extraction
- 'feature-extraction': [
- 'sentence-transformers/multi-qa-distilbert-cos-v1',
- 'sentence-transformers/distiluse-base-multilingual-cased-v1',
- 'sentence-transformers/distiluse-base-multilingual-cased-v2',
- 'sentence-transformers/distilbert-base-nli-mean-tokens',
- 'sentence-transformers/distilbert-base-nli-stsb-mean-tokens',
- 'sentence-transformers/msmarco-distilbert-base-v4',
- ],
-
- # Text classification
- 'text-classification': [
- 'distilbert-base-uncased-finetuned-sst-2-english',
- ],
-
- # Question answering
- 'question-answering': [
- 'distilbert-base-uncased-distilled-squad',
- 'distilbert-base-cased-distilled-squad',
- ],
-
- # Zero-shot classification
- 'zero-shot-classification': [
- 'typeform/distilbert-base-uncased-mnli',
- ],
-
- # Token classification
- 'token-classification': [
- 'Davlan/distilbert-base-multilingual-cased-ner-hrl',
- ],
-
- # Masked language modelling
- 'fill-mask': [
- 'distilbert-base-uncased',
- 'distilbert-base-cased',
- ],
- },
- 'dit': { # NOTE: DiT has the same architecture as BEiT.
- # Feature extraction
- # NOTE: requires --task feature-extraction
- 'feature-extraction': [
- 'microsoft/dit-base',
- 'microsoft/dit-large',
- ],
-
- # Image classification
- 'image-classification': [
- 'microsoft/dit-base-finetuned-rvlcdip',
- 'microsoft/dit-large-finetuned-rvlcdip',
- ],
- },
- 'donut': { # NOTE: also a `vision-encoder-decoder`
- # Image-to-text
- 'image-to-text': [
- 'naver-clova-ix/donut-base-finetuned-cord-v2',
- 'naver-clova-ix/donut-base-finetuned-zhtrainticket',
- ],
-
- # Document Question Answering
- 'document-question-answering': [
- 'naver-clova-ix/donut-base-finetuned-docvqa',
- ],
- },
- 'dpt': {
- # Depth estimation
- 'depth-estimation': [
- 'Intel/dpt-hybrid-midas',
- 'Intel/dpt-large',
- ],
- },
- 'depth_anything': {
- # Depth estimation
- # NOTE: requires --task depth-estimation
- 'depth-estimation': [
- 'LiheYoung/depth-anything-small-hf',
- 'LiheYoung/depth-anything-base-hf',
- 'LiheYoung/depth-anything-large-hf',
- ],
- },
- 'electra': {
- # Feature extraction
- 'feature-extraction': [
- # NOTE: requires --task feature-extraction
- 'google/electra-small-discriminator',
- 'google/electra-base-discriminator',
- ],
- },
- 'esm': {
- # Masked language modelling
- 'fill-mask': [
- # with and without --task feature-extraction
- 'InstaDeepAI/nucleotide-transformer-500m-human-ref',
- 'InstaDeepAI/nucleotide-transformer-500m-1000g',
-
- # NOTE: requires --opset 12
- 'facebook/esm2_t6_8M_UR50D',
- 'facebook/esm2_t12_35M_UR50D',
- 'facebook/esm2_t30_150M_UR50D',
- 'facebook/esm2_t33_650M_UR50D',
- ],
-
- # Token classification
- 'token-classification': [
- 'AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor',
- ],
-
- # Zero-shot classification
- 'zero-shot-classification': [
- 'AmelieSchreiber/esm2_t6_8M_UR50D_sequence_classifier_v1',
- ],
- },
- 'falcon': {
- # Text generation
- 'text-generation': [
- 'Rocketknight1/tiny-random-falcon-7b',
- 'fxmarty/really-tiny-falcon-testing',
- ],
- },
- 'fastvit': {
- # Image classification
- 'image-classification': [
- # NOTE: Supported by timm, but not by transformers
- # 'timm/fastvit_t8.apple_in1k',
- # 'timm/fastvit_t8.apple_dist_in1k',
- # 'timm/fastvit_t12.apple_in1k',
- # 'timm/fastvit_t12.apple_dist_in1k',
- # 'timm/fastvit_s12.apple_in1k',
- # 'timm/fastvit_s12.apple_dist_in1k',
- # 'timm/fastvit_sa12.apple_in1k',
- # 'timm/fastvit_sa12.apple_dist_in1k',
- # 'timm/fastvit_sa24.apple_in1k',
- # 'timm/fastvit_sa24.apple_dist_in1k',
- # 'timm/fastvit_sa36.apple_in1k',
- # 'timm/fastvit_sa36.apple_dist_in1k',
- # 'timm/fastvit_ma36.apple_in1k',
- # 'timm/fastvit_ma36.apple_dist_in1k',
- ],
- },
- 'glpn': {
- # Depth estimation
- 'depth-estimation': [
- 'vinvino02/glpn-kitti',
- 'vinvino02/glpn-nyu',
- ],
- },
- 'gpt_neo': {
- # Text generation
- 'text-generation': [
- 'EleutherAI/gpt-neo-125M',
- 'MBZUAI/LaMini-Neo-125M',
- # 'MBZUAI/LaMini-Neo-1.3B', # TODO add
- 'iliemihai/gpt-neo-romanian-125m',
- ],
- },
- 'gpt_neox': {
- # Text generation
- 'text-generation': [
- 'EleutherAI/pythia-14m',
- 'EleutherAI/pythia-31m',
- 'EleutherAI/pythia-70m',
- 'EleutherAI/pythia-70m-deduped',
- 'EleutherAI/pythia-160m',
- 'EleutherAI/pythia-160m-deduped',
- 'EleutherAI/pythia-410m',
- 'EleutherAI/pythia-410m-deduped',
- ],
- },
- 'gpt2': {
- # Text generation
- 'text-generation': [
- 'gpt2',
- 'distilgpt2',
- 'MBZUAI/LaMini-Cerebras-111M',
- 'MBZUAI/LaMini-Cerebras-256M',
- 'MBZUAI/LaMini-Cerebras-590M',
- # 'MBZUAI/LaMini-Cerebras-1.3B', # TODO add
- 'MBZUAI/LaMini-GPT-124M',
- 'MBZUAI/LaMini-GPT-774M',
- # 'MBZUAI/LaMini-GPT-1.5B', # TODO add
- 'aisquared/dlite-v2-774m',
- 'Locutusque/gpt2-large-conversational',
- ],
- },
- 'gpt_bigcode': {
- # Text generation
- 'text-generation': [
- 'bigcode/tiny_starcoder_py',
- 'abacaj/starcoderbase-1b-sft',
- # 'bigcode/starcoderbase-1b', # NOTE: This model is gated, so we ignore it when testing
- ],
- },
- 'gptj': {
- # Text generation
- 'text-generation': [
- 'TabbyML/J-350M',
- 'Milos/slovak-gpt-j-405M',
- 'heegyu/kogpt-j-350m',
- ],
- },
- 'herbert': {
- # Feature extraction
- 'feature-extraction': [
- 'allegro/herbert-base-cased',
- 'allegro/herbert-large-cased',
- ],
- },
- 'hubert': {
- # Feature extraction
- 'feature-extraction': [
- 'facebook/hubert-base-ls960',
- ],
-
- # Audio classification
- 'audio-classification': [
- 'superb/hubert-base-superb-ks',
- ],
-
- # Automatic speech recognition
- 'automatic-speech-recognition': [
- 'facebook/hubert-large-ls960-ft',
- ],
- },
- 'llama': {
- # Text generation
- 'text-generation': [
- 'Xenova/llama2.c-stories15M',
- 'Xenova/llama2.c-stories42M',
- 'Xenova/llama2.c-stories110M',
- 'RajuKandasamy/tamillama_tiny_30m',
- 'JackFram/llama-68m',
- 'JackFram/llama-160m',
- ],
- },
- 'longt5': {
- # Text-to-text
- 'text2text-generation': [
- 'google/long-t5-local-base',
- 'google/long-t5-tglobal-base',
- # 'google/long-t5-tglobal-xl', # too large
- # 'google/long-t5-tglobal-large', # too large
- # 'google/long-t5-local-large', # too large
- ],
-
- # Summarization
- 'summarization': [
- 'pszemraj/long-t5-tglobal-base-16384-book-summary',
- ],
-
- # Feature extraction
- 'feature-extraction': [
- # NOTE: requires --task feature-extraction
- 'voidful/long-t5-encodec-tglobal-base',
- ],
- },
- 'm2m_100': {
- # Translation
- 'translation': [
- 'facebook/nllb-200-distilled-600M',
- 'facebook/m2m100_418M',
- ],
- },
- 'marian': {
- # Translation
- 'translation': [
- f'Helsinki-NLP/opus-mt-{x}'
- for x in SUPPORTED_HELSINKI_NLP_MODELS
- ],
- },
- 'mbart': {
- # Translation
- 'translation': [
- 'facebook/mbart-large-50-many-to-many-mmt',
- 'facebook/mbart-large-50-many-to-one-mmt',
- 'facebook/mbart-large-50',
- ],
- },
- 'mistral': {
- # Text generation
- 'text-generation': [
- 'echarlaix/tiny-random-mistral',
- ],
- },
- 'mobilebert': {
- # Zero-shot classification
- 'zero-shot-classification': [
- 'typeform/mobilebert-uncased-mnli',
-
- # TODO:
- # https://github.com/huggingface/optimum/issues/1027
- # 'google/mobilebert-uncased',
- ],
- },
- 'mobilevit': {
- # Image classification
- 'image-classification': [
- 'apple/mobilevit-small',
- 'apple/mobilevit-x-small',
- 'apple/mobilevit-xx-small',
- ],
-
- # TODO: Image segmentation
- # 'image-segmentation': [
- # 'apple/deeplabv3-mobilevit-small',
- # 'apple/deeplabv3-mobilevit-x-small',
- # 'apple/deeplabv3-mobilevit-xx-small',
- # ],
- },
- 'mobilevitv2': {
- # Image classification
- 'image-classification': [
- 'apple/mobilevitv2-1.0-imagenet1k-256',
- ],
-
- # TODO: Image segmentation
- # 'image-segmentation': [
- # 'apple/mobilevitv2-1.0-voc-deeplabv3',
- # ],
- },
- 'mpt': {
- # Text generation
- 'text-generation': [
- 'efederici/ipt-350m',
- ],
- },
- 'mpnet': {
- # Feature extraction
- 'feature-extraction': [
- 'sentence-transformers/all-mpnet-base-v2',
- 'sentence-transformers/nli-mpnet-base-v2',
- 'sentence-transformers/paraphrase-mpnet-base-v2',
- 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
- 'sentence-transformers/multi-qa-mpnet-base-cos-v1',
- 'sentence-transformers/multi-qa-mpnet-base-dot-v1',
- ],
- },
- 'mt5': {
- # Text-to-text
- 'text2text-generation': [
- 'google/mt5-small',
- 'google/mt5-base',
- ],
- },
- 'nougat': {
- # Image-to-text
- 'image-to-text': [
- 'facebook/nougat-small',
- 'facebook/nougat-base',
- ],
- },
- 'opt': {
- # Text generation
- 'text-generation': [
- # Text generation
- 'facebook/opt-125m',
- 'facebook/opt-350m',
- # (TODO conversational)
- 'PygmalionAI/pygmalion-350m',
- ],
- },
- 'owlv2': {
- # Object detection (Zero-shot object detection)
- # NOTE: Exported with --batch_size 1
- 'zero-shot-object-detection': [
- 'google/owlv2-base-patch16',
- 'google/owlv2-base-patch16-finetuned',
- 'google/owlv2-base-patch16-ensemble',
- # TODO: add
- # 'google/owlv2-large-patch14',
- # 'google/owlv2-large-patch14-finetuned',
- # 'google/owlv2-large-patch14-ensemble',
- ],
- },
- 'owlvit': {
- # Object detection (Zero-shot object detection)
- # NOTE: Exported with --batch_size 1
- 'zero-shot-object-detection': [
- 'google/owlvit-base-patch32',
- 'google/owlvit-base-patch16',
- 'google/owlvit-large-patch14',
- ],
- },
- 'resnet': {
- # Image classification
- 'image-classification': [
- 'microsoft/resnet-18',
- 'microsoft/resnet-26',
- 'microsoft/resnet-34',
- 'microsoft/resnet-50',
- 'microsoft/resnet-101',
- 'microsoft/resnet-152',
- ],
- },
- 'roformer': {
- # Feature extraction
- 'feature-extraction': [
- 'hf-tiny-model-private/tiny-random-RoFormerModel',
- ],
-
- # Text classification
- 'text-classification': [
- 'hf-tiny-model-private/tiny-random-RoFormerForSequenceClassification',
- ],
-
- # Token classification
- 'token-classification': [
- 'hf-tiny-model-private/tiny-random-RoFormerForTokenClassification',
- ],
-
- # TODO
- # # Text generation
- # 'text-generation': [
- # 'hf-tiny-model-private/tiny-random-RoFormerForCausalLM',
- # ],
-
- # Masked language modelling
- 'fill-mask': [
- 'alchemab/antiberta2',
- 'hf-tiny-model-private/tiny-random-RoFormerForMaskedLM',
- ],
-
- # Question answering
- 'question-answering': [
- 'hf-tiny-model-private/tiny-random-RoFormerForQuestionAnswering',
- ],
-
- # Multiple choice
- 'multiple-choice': [
- 'hf-tiny-model-private/tiny-random-RoFormerForMultipleChoice',
- ],
- },
- 'phi': {
- # Text generation
- 'text-generation': [
- 'hf-internal-testing/tiny-random-PhiForCausalLM',
- 'susnato/phi-1_5_dev',
- ],
- },
- 'qwen2': {
- # Text generation
- 'text-generation': [
- 'Qwen/Qwen1.5-0.5B',
- 'Qwen/Qwen1.5-0.5B-Chat',
- 'Qwen/Qwen1.5-1.8B',
- 'Qwen/Qwen1.5-1.8B-Chat',
- ],
- },
- 'roberta': {
- # Feature extraction
- 'feature-extraction': [
- 'sentence-transformers/all-distilroberta-v1',
- 'sentence-transformers/all-roberta-large-v1',
- ],
-
- # Text classification
- 'text-classification': [
- 'roberta-large-mnli',
- ],
-
- # Token classification
- 'token-classification': [
- 'julien-c/EsperBERTo-small-pos',
- ],
-
- # Masked language modelling
- 'fill-mask': [
- 'roberta-base',
- 'distilroberta-base',
- ],
- },
- 'sam': {
- # Mask generation
- 'mask-generation': [
- # SAM
- 'facebook/sam-vit-base',
- 'facebook/sam-vit-large',
- 'facebook/sam-vit-huge',
- 'wanglab/medsam-vit-base',
-
- # SlimSAM
- 'nielsr/slimsam-50-uniform',
- 'nielsr/slimsam-77-uniform',
- ],
- },
- 'segformer': {
- # Image segmentation
- 'image-segmentation': [
- 'mattmdjaga/segformer_b0_clothes',
- 'mattmdjaga/segformer_b2_clothes',
- 'jonathandinu/face-parsing',
-
- 'nvidia/segformer-b0-finetuned-cityscapes-768-768',
- 'nvidia/segformer-b0-finetuned-cityscapes-512-1024',
- 'nvidia/segformer-b0-finetuned-cityscapes-640-1280',
- 'nvidia/segformer-b0-finetuned-cityscapes-1024-1024',
- 'nvidia/segformer-b1-finetuned-cityscapes-1024-1024',
- 'nvidia/segformer-b2-finetuned-cityscapes-1024-1024',
- 'nvidia/segformer-b3-finetuned-cityscapes-1024-1024',
- 'nvidia/segformer-b4-finetuned-cityscapes-1024-1024',
- 'nvidia/segformer-b5-finetuned-cityscapes-1024-1024',
- 'nvidia/segformer-b0-finetuned-ade-512-512',
- 'nvidia/segformer-b1-finetuned-ade-512-512',
- 'nvidia/segformer-b2-finetuned-ade-512-512',
- 'nvidia/segformer-b3-finetuned-ade-512-512',
- 'nvidia/segformer-b4-finetuned-ade-512-512',
- 'nvidia/segformer-b5-finetuned-ade-640-640',
- ],
-
- # Image classification
- 'image-classification': [
- 'nvidia/mit-b0',
- 'nvidia/mit-b1',
- 'nvidia/mit-b2',
- 'nvidia/mit-b3',
- 'nvidia/mit-b4',
- 'nvidia/mit-b5',
- ],
- },
- 'siglip': {
- # Zero-shot image classification and feature extraction
- # (with and without `--split_modalities`)
- # NOTE: requires --opset 13
- 'zero-shot-image-classification': [
- 'nielsr/siglip-base-patch16-224',
- ],
- },
- 'speecht5': {
- # Text-to-audio/Text-to-speech
- 'text-to-audio': [
- 'microsoft/speecht5_tts',
- ],
- },
- 'stablelm': {
- # Text generation
- 'text-generation': [
- 'hf-internal-testing/tiny-random-StableLmForCausalLM',
- 'stabilityai/stablelm-2-1_6b',
- 'stabilityai/stablelm-2-zephyr-1_6b',
- ],
- },
- 'squeezebert': {
- # Feature extraction
- 'feature-extraction': [
- 'squeezebert/squeezebert-uncased',
- 'squeezebert/squeezebert-mnli',
- ],
- },
- 'starcoder2': {
- # Text generation
- 'text-generation': [
- 'hf-internal-testing/tiny-random-Starcoder2ForCausalLM',
- ],
- },
- 'swin': {
- # Image classification
- 'image-classification': [
- 'microsoft/swin-tiny-patch4-window7-224',
- 'microsoft/swin-base-patch4-window7-224',
- 'microsoft/swin-large-patch4-window12-384-in22k',
- 'microsoft/swin-base-patch4-window7-224-in22k',
- 'microsoft/swin-base-patch4-window12-384-in22k',
- 'microsoft/swin-base-patch4-window12-384',
- 'microsoft/swin-large-patch4-window7-224',
- 'microsoft/swin-small-patch4-window7-224',
- 'microsoft/swin-large-patch4-window7-224-in22k',
- 'microsoft/swin-large-patch4-window12-384',
- ],
- },
- 'swin2sr': {
- # Image-to-image (Super-resolution)
- 'image-to-image': [
- 'caidas/swin2SR-classical-sr-x2-64',
- 'caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr',
- 'caidas/swin2SR-classical-sr-x4-64',
- 'caidas/swin2SR-compressed-sr-x4-48',
- 'caidas/swin2SR-lightweight-x2-64',
- ],
-
- # Feature extraction
- 'feature-extraction': [
- 'hf-tiny-model-private/tiny-random-Swin2SRModel',
- ],
- },
- 't5': {
- # Translation/Summarization
- ('translation', 'summarization'): [
- 't5-small',
- 't5-base',
- 'google/t5-v1_1-small',
- 'google/t5-v1_1-base',
- 'google/flan-t5-small',
- 'google/flan-t5-base',
- ],
-
- # Text-to-text
- 'text2text-generation': [
- 'MBZUAI/LaMini-Flan-T5-77M',
- 'MBZUAI/LaMini-Flan-T5-248M',
- 'MBZUAI/LaMini-Flan-T5-783M',
- 'MBZUAI/LaMini-T5-61M',
- 'MBZUAI/LaMini-T5-223M',
- 'MBZUAI/LaMini-T5-738M',
- 'declare-lab/flan-alpaca-base',
- 'declare-lab/flan-alpaca-large',
- ],
-
- # Feature extraction
- 'feature-extraction': [
- 'sentence-transformers/sentence-t5-large',
- 'hkunlp/instructor-base',
- 'hkunlp/instructor-large',
- ],
- },
- 'table-transformer': {
- # Object detection
- 'object-detection': [
- 'microsoft/table-transformer-detection',
- 'microsoft/table-transformer-structure-recognition',
- 'microsoft/table-transformer-structure-recognition-v1.1-all',
- 'microsoft/table-transformer-structure-recognition-v1.1-fin',
- 'microsoft/table-transformer-structure-recognition-v1.1-pub',
- ],
- },
- 'trocr': { # NOTE: also a `vision-encoder-decoder`
- # Text-to-image
- 'text-to-image': [
- 'microsoft/trocr-small-printed',
- 'microsoft/trocr-base-printed',
- 'microsoft/trocr-small-handwritten',
- 'microsoft/trocr-base-handwritten',
- ],
- },
- 'unispeech': {
- # Feature extraction
- 'feature-extraction': [
- # Requires --task feature-extraction
- 'microsoft/unispeech-large-1500h-cv',
- ],
- # TODO: add support for
- # # Automatic speech recognition
- # 'automatic-speech-recognition': [
- # 'microsoft/unispeech-1350-en-353-fr-ft-1h',
- # 'microsoft/unispeech-1350-en-17h-ky-ft-1h',
- # 'microsoft/unispeech-1350-en-90-it-ft-1h',
- # 'microsoft/unispeech-1350-en-168-es-ft-1h',
- # ],
- },
- 'unispeech-sat': {
- # Feature extraction
- 'feature-extraction': [
- # Requires --task feature-extraction
- 'microsoft/unispeech-sat-base',
- ],
-
- # Audio XVector (e.g., for speaker verification)
- 'audio-xvector': [
- 'microsoft/unispeech-sat-base-plus-sv',
- 'microsoft/unispeech-sat-base-sv',
- 'microsoft/unispeech-sat-large-sv',
- ],
-
- # Audio frame classification
- 'audio-frame-classification': [
- 'microsoft/unispeech-sat-base-plus-sd',
- ],
-
- # Automatic speech recognition
- 'automatic-speech-recognition': [
- 'microsoft/unispeech-sat-base-100h-libri-ft',
- ],
- },
- 'vision-encoder-decoder': {
- # Image-to-text
- 'image-to-text': [
- 'nlpconnect/vit-gpt2-image-captioning',
- ],
- },
- 'vit': {
- # Feature extraction
- 'feature-extraction': [
- 'google/vit-base-patch16-224-in21k',
- 'facebook/dino-vitb16',
- 'facebook/dino-vits8',
- 'facebook/dino-vitb8',
- 'facebook/dino-vits16',
- ],
- # Image classification
- 'image-classification': [
- 'google/vit-base-patch16-224',
- ],
- },
- 'vitmatte': {
- # Image matting
- 'image-matting': [
- 'hustvl/vitmatte-small-distinctions-646',
- 'hustvl/vitmatte-base-distinctions-646',
- 'hustvl/vitmatte-small-composition-1k',
- 'hustvl/vitmatte-base-composition-1k',
- ],
- },
- 'vits': {
- # Text-to-audio/Text-to-speech/Text-to-waveform
- 'text-to-waveform': {
- # NOTE: requires --task text-to-waveform --skip_validation
- 'echarlaix/tiny-random-vits',
- 'facebook/mms-tts-eng',
- 'facebook/mms-tts-rus',
- 'facebook/mms-tts-hin',
- 'facebook/mms-tts-yor',
- 'facebook/mms-tts-spa',
- 'facebook/mms-tts-fra',
- 'facebook/mms-tts-ara',
- 'facebook/mms-tts-ron',
- 'facebook/mms-tts-vie',
- 'facebook/mms-tts-deu',
- 'facebook/mms-tts-kor',
- 'facebook/mms-tts-por',
- # TODO add more checkpoints from
- # https://huggingface.co/models?other=vits&sort=trending&search=facebook-tts
- }
- },
- 'wav2vec2': {
- # Feature extraction # NOTE: requires --task feature-extraction
- 'feature-extraction': [
- 'facebook/mms-300m',
- 'facebook/mms-1b',
- ],
-
- # Audio classification
- 'audio-classification': [
- 'alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech',
- 'superb/wav2vec2-base-superb-ks',
- 'facebook/mms-lid-126',
- 'facebook/mms-lid-256',
- 'facebook/mms-lid-512',
- 'facebook/mms-lid-1024',
- 'facebook/mms-lid-2048',
- 'facebook/mms-lid-4017',
- ],
-
- # Audio frame classification
- 'audio-frame-classification': [
- 'anton-l/wav2vec2-base-superb-sd',
- ],
-
- # Automatic speech recognition
- 'automatic-speech-recognition': [
- 'jonatasgrosman/wav2vec2-large-xlsr-53-english',
- 'facebook/wav2vec2-base-960h',
- 'facebook/mms-1b-l1107',
- 'facebook/mms-1b-all',
- 'facebook/mms-1b-fl102',
- ],
- },
- 'wav2vec2-bert': {
- 'feature-extraction': [
- 'facebook/w2v-bert-2.0',
- ],
-
- # Automatic speech recognition
- 'automatic-speech-recognition': [
- 'hf-audio/wav2vec2-bert-CV16-en',
- ],
- },
- 'wavlm': {
- # Feature extraction
- 'feature-extraction': [
- 'microsoft/wavlm-base',
- 'microsoft/wavlm-base-plus',
- 'microsoft/wavlm-large',
- ],
-
- # Audio frame classification
- 'audio-frame-classification': [
- 'anton-l/wav2vec2-base-superb-sd',
- 'microsoft/wavlm-base-plus-sd',
- ],
-
- # Audio XVector (e.g., for speaker verification)
- 'audio-xvector': [
- 'microsoft/wavlm-base-plus-sv',
- 'microsoft/wavlm-base-sv',
- ],
- },
- 'whisper': {
- # Automatic speech recognition
- 'automatic-speech-recognition': [
- 'openai/whisper-tiny',
- 'openai/whisper-tiny.en',
- 'openai/whisper-base',
- 'openai/whisper-base.en',
- 'openai/whisper-small',
- 'openai/whisper-small.en',
- 'openai/whisper-medium',
- 'openai/whisper-medium.en',
- 'openai/whisper-large',
- 'openai/whisper-large-v2',
- 'NbAiLab/nb-whisper-tiny-beta',
- 'NbAiLab/nb-whisper-base-beta',
- 'NbAiLab/nb-whisper-small-beta',
- 'NbAiLab/nb-whisper-medium-beta',
- 'NbAiLab/nb-whisper-large-beta',
- ],
- },
- 'xlm': {
- # Masked language modelling
- 'fill-mask': [
- 'xlm-clm-ende-1024',
- 'xlm-mlm-ende-1024',
- 'xlm-clm-enfr-1024',
- 'xlm-mlm-enfr-1024',
- 'xlm-mlm-17-1280',
- 'xlm-mlm-100-1280',
- 'xlm-mlm-en-2048',
- 'xlm-mlm-enro-1024',
- 'xlm-mlm-tlm-xnli15-1024',
- 'xlm-mlm-xnli15-1024',
- ],
- },
- 'xlm-roberta': {
- # Masked language modelling
- 'fill-mask': [
- 'xlm-roberta-base'
- ],
- },
- 'yolos': {
- # Object detection
- 'object-detection': [
- # Object detection
- 'hustvl/yolos-tiny',
- 'hustvl/yolos-small',
- 'hustvl/yolos-base',
- 'hustvl/yolos-small-dwr',
- 'hustvl/yolos-small-300',
- ],
- },
-}
-
-
-def main():
- for model_type, tasks in SUPPORTED_MODELS.items():
- for task, model_ids in tasks.items():
- print(f'# {model_type:=^80}')
- for model_id in model_ids:
- print(
- f'python -m scripts.convert --quantize --model_id {model_id}')
- print()
-
-
-if __name__ == '__main__':
- main()
diff --git a/src/backends/onnx.js b/src/backends/onnx.js
index 0bee3dce7..de89da037 100644
--- a/src/backends/onnx.js
+++ b/src/backends/onnx.js
@@ -9,42 +9,208 @@
*
* This module is not directly exported, but can be accessed through the environment variables:
* ```javascript
- * import { env } from '@xenova/transformers';
+ * import { env } from '@huggingface/transformers';
* console.log(env.backends.onnx);
* ```
*
* @module backends/onnx
*/
+import { env, apis } from '../env.js';
+
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
-import * as ONNX_WEB from 'onnxruntime-web';
-/** @type {import('onnxruntime-web')} The ONNX runtime module. */
-export let ONNX;
+// Use subpath-imports to ensure Node.js and browser interoperability.
+// See package.json and https://nodejs.org/api/packages.html#subpath-imports
+// for more information.
+// @ts-ignore
+import * as ONNX_WEB from '#onnxruntime-webgpu';
+
+export { Tensor } from 'onnxruntime-common';
+
+/**
+ * @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders
+ */
+
+/** @type {Record} */
+const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
+ auto: null, // Auto-detect based on device and environment
+ gpu: null, // Auto-detect GPU
+ cpu: 'cpu', // CPU
+ wasm: 'wasm', // WebAssembly
+ webgpu: 'webgpu', // WebGPU
+ cuda: 'cuda', // CUDA
+ dml: 'dml', // DirectML
+
+ webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default)
+ 'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU
+ 'webnn-gpu': { name: 'webnn', deviceType: 'gpu' }, // WebNN GPU
+ 'webnn-cpu': { name: 'webnn', deviceType: 'cpu' }, // WebNN CPU
+});
-export const executionProviders = [
- // 'webgpu',
- 'wasm'
-];
+/**
+ * The list of supported devices, sorted by priority/performance.
+ * @type {import("../utils/devices.js").DeviceType[]}
+ */
+const supportedDevices = [];
+
+/** @type {ONNXExecutionProviders[]} */
+let defaultDevices;
+let ONNX;
+const ORT_SYMBOL = Symbol.for('onnxruntime');
-if (typeof process !== 'undefined' && process?.release?.name === 'node') {
- // Running in a node-like environment.
+if (ORT_SYMBOL in globalThis) {
+ // If the JS runtime exposes their own ONNX runtime, use it
+ ONNX = globalThis[ORT_SYMBOL];
+
+} else if (apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
- // Add `cpu` execution provider, with higher precedence that `wasm`.
- executionProviders.unshift('cpu');
+ // Updated as of ONNX Runtime 1.18.0
+ // The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
+ // | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
+ // | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- |
+ // | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
+ // | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
+ // | CUDA | ❌ | ❌ | ✔️ (CUDA v11.8) | ❌ | ❌ | ❌ |
+ switch (process.platform) {
+ case 'win32': // Windows x64 and Windows arm64
+ supportedDevices.push('dml');
+ break;
+ case 'linux': // Linux x64 and Linux arm64
+ if (process.arch === 'x64') {
+ supportedDevices.push('cuda');
+ }
+ break;
+ case 'darwin': // MacOS x64 and MacOS arm64
+ break;
+ }
+ supportedDevices.push('cpu');
+ defaultDevices = ['cpu'];
} else {
- // Running in a browser-environment
- ONNX = ONNX_WEB.default ?? ONNX_WEB;
-
- // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
- // As a temporary fix, we disable it for now.
- // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
- const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent);
- if (isIOS) {
- ONNX.env.wasm.simd = false;
+ ONNX = ONNX_WEB;
+
+ if (apis.IS_WEBNN_AVAILABLE) {
+ // TODO: Only push supported providers (depending on available hardware)
+ supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn');
+ }
+
+ if (apis.IS_WEBGPU_AVAILABLE) {
+ supportedDevices.push('webgpu');
}
+
+ supportedDevices.push('wasm');
+ defaultDevices = ['wasm'];
}
+
+// @ts-ignore
+const InferenceSession = ONNX.InferenceSession;
+
+/**
+ * Map a device to the execution providers to use for the given device.
+ * @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
+ * @returns {ONNXExecutionProviders[]} The execution providers to use for the given device.
+ */
+export function deviceToExecutionProviders(device = null) {
+ // Use the default execution providers if the user hasn't specified anything
+ if (!device) return defaultDevices;
+
+ // Handle overloaded cases
+ switch (device) {
+ case "auto":
+ return supportedDevices;
+ case "gpu":
+ return supportedDevices.filter(x =>
+ ["webgpu", "cuda", "dml", "webnn-gpu"].includes(x),
+ );
+ }
+
+ if (supportedDevices.includes(device)) {
+ return [DEVICE_TO_EXECUTION_PROVIDER_MAPPING[device] ?? device];
+ }
+
+ throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedDevices.join(', ')}.`)
+}
+
+
+/**
+ * To prevent multiple calls to `initWasm()`, we store the first call in a Promise
+ * that is resolved when the first InferenceSession is created. Subsequent calls
+ * will wait for this Promise to resolve before creating their own InferenceSession.
+ * @type {Promise|null}
+ */
+let wasmInitPromise = null;
+
+/**
+ * Create an ONNX inference session.
+ * @param {Uint8Array} buffer The ONNX model buffer.
+ * @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
+ * @param {Object} session_config ONNX inference session configuration.
+ * @returns {Promise} The ONNX inference session.
+ */
+export async function createInferenceSession(buffer, session_options, session_config) {
+ if (wasmInitPromise) {
+ // A previous session has already initialized the WASM runtime
+ // so we wait for it to resolve before creating this new session.
+ await wasmInitPromise;
+ }
+
+ const sessionPromise = InferenceSession.create(buffer, session_options);
+ wasmInitPromise ??= sessionPromise;
+ const session = await sessionPromise;
+ session.config = session_config;
+ return session;
+}
+
+/**
+ * Check if an object is an ONNX tensor.
+ * @param {any} x The object to check
+ * @returns {boolean} Whether the object is an ONNX tensor.
+ */
+export function isONNXTensor(x) {
+ return x instanceof ONNX.Tensor;
+}
+
+/** @type {import('onnxruntime-common').Env} */
+// @ts-ignore
+const ONNX_ENV = ONNX?.env;
+if (ONNX_ENV?.wasm) {
+ // Initialize wasm backend with suitable default settings.
+
+ // (Optional) Set path to wasm files. This is needed when running in a web worker.
+ // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
+ // We use remote wasm files by default to make it easier for newer users.
+ // In practice, users should probably self-host the necessary .wasm files.
+ ONNX_ENV.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/`;
+
+ // TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
+ // https://github.com/microsoft/onnxruntime/pull/21534
+
+ // Users may wish to proxy the WASM backend to prevent the UI from freezing,
+ // However, this is not necessary when using WebGPU, so we default to false.
+ ONNX_ENV.wasm.proxy = false;
+
+ // https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated
+ if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) {
+ ONNX_ENV.wasm.numThreads = 1;
+ }
+}
+
+if (ONNX_ENV?.webgpu) {
+ ONNX_ENV.webgpu.powerPreference = 'high-performance';
+}
+
+/**
+ * Check if ONNX's WASM backend is being proxied.
+ * @returns {boolean} Whether ONNX's WASM backend is being proxied.
+ */
+export function isONNXProxy() {
+ // TODO: Update this when allowing non-WASM backends.
+ return ONNX_ENV?.wasm?.proxy;
+}
+
+// Expose ONNX environment variables to `env.backends.onnx`
+env.backends.onnx = ONNX_ENV;
diff --git a/src/configs.js b/src/configs.js
index 4506d2d9c..4bc95cf80 100644
--- a/src/configs.js
+++ b/src/configs.js
@@ -6,8 +6,8 @@
* **Example:** Load an `AutoConfig`.
*
* ```javascript
- * import { AutoConfig } from '@xenova/transformers';
- * let config = await AutoConfig.from_pretrained('bert-base-uncased');
+ * import { AutoConfig } from '@huggingface/transformers';
+ * const config = await AutoConfig.from_pretrained('bert-base-uncased');
* console.log(config);
* // PretrainedConfig {
* // "model_type": "bert",
@@ -27,6 +27,7 @@
* @module configs
*/
+import { pick } from './utils/core.js';
import {
getModelJSON,
} from './utils/hub.js';
@@ -40,13 +41,255 @@ import {
* Loads a config from the specified path.
* @param {string} pretrained_model_name_or_path The path to the config directory.
* @param {PretrainedOptions} options Additional options for loading the config.
- * @returns {Promise} A promise that resolves with information about the loaded config.
+ * @returns {Promise} A promise that resolves with information about the loaded config.
*/
async function loadConfig(pretrained_model_name_or_path, options) {
- let info = await getModelJSON(pretrained_model_name_or_path, 'config.json', true, options);
- return info;
+ return await getModelJSON(pretrained_model_name_or_path, 'config.json', true, options);
}
+/**
+ *
+ * @param {PretrainedConfig} config
+ * @returns {Object} The normalized configuration.
+ */
+function getNormalizedConfig(config) {
+ const mapping = {};
+
+ let init_normalized_config = {};
+ switch (config.model_type) {
+ // Sub-configs
+ case 'llava':
+ case 'paligemma':
+ case 'florence2':
+ init_normalized_config = getNormalizedConfig(config.text_config);
+ break;
+ case 'moondream1':
+ init_normalized_config = getNormalizedConfig(config.phi_config);
+ break;
+ case 'musicgen':
+ init_normalized_config = getNormalizedConfig(config.decoder);
+ break;
+
+ // Decoder-only models
+ case 'gpt2':
+ case 'gptj':
+ case 'jais':
+ case 'codegen':
+ case 'gpt_bigcode':
+ mapping['num_heads'] = 'n_head';
+ mapping['num_layers'] = 'n_layer';
+ mapping['hidden_size'] = 'n_embd';
+ break;
+ case 'gpt_neox':
+ case 'stablelm':
+ case 'opt':
+ case 'phi':
+ case 'phi3':
+ case 'falcon':
+ mapping['num_heads'] = 'num_attention_heads';
+ mapping['num_layers'] = 'num_hidden_layers';
+ mapping['hidden_size'] = 'hidden_size';
+ break;
+ case 'llama':
+ case 'granite':
+ case 'cohere':
+ case 'mistral':
+ case 'starcoder2':
+ case 'qwen2':
+ mapping['num_heads'] = 'num_key_value_heads';
+ mapping['num_layers'] = 'num_hidden_layers';
+ mapping['hidden_size'] = 'hidden_size';
+ mapping['num_attention_heads'] = 'num_attention_heads';
+ break;
+ case 'gemma':
+ case 'gemma2':
+ mapping['num_heads'] = 'num_key_value_heads';
+ mapping['num_layers'] = 'num_hidden_layers';
+ mapping['dim_kv'] = 'head_dim';
+ break;
+ case 'openelm':
+ mapping['num_heads'] = 'num_kv_heads';
+ mapping['num_layers'] = 'num_transformer_layers';
+ mapping['dim_kv'] = 'head_dim';
+ break;
+ case 'gpt_neo':
+ case 'donut-swin':
+ mapping['num_heads'] = 'num_heads';
+ mapping['num_layers'] = 'num_layers';
+ mapping['hidden_size'] = 'hidden_size';
+ break;
+ case 'bloom':
+ mapping['num_heads'] = 'n_head';
+ mapping['num_layers'] = 'n_layer';
+ mapping['hidden_size'] = 'hidden_size';
+ break;
+ case 'mpt':
+ mapping['num_heads'] = 'n_heads';
+ mapping['num_layers'] = 'n_layers';
+ mapping['hidden_size'] = 'd_model';
+ break;
+
+ // Encoder-decoder models
+ case 't5':
+ case 'mt5':
+ case 'longt5':
+ mapping['num_decoder_layers'] = 'num_decoder_layers';
+ mapping['num_decoder_heads'] = 'num_heads';
+ mapping['decoder_dim_kv'] = 'd_kv';
+ mapping['num_encoder_layers'] = 'num_layers';
+ mapping['num_encoder_heads'] = 'num_heads';
+ mapping['encoder_dim_kv'] = 'd_kv';
+ break;
+ case 'bart':
+ case 'mbart':
+ case 'marian':
+ case 'whisper':
+ case 'm2m_100':
+ case 'blenderbot':
+ case 'blenderbot-small':
+ case 'florence2_language':
+ mapping['num_decoder_layers'] = 'decoder_layers';
+ mapping['num_decoder_heads'] = 'decoder_attention_heads';
+ mapping['decoder_hidden_size'] = 'd_model';
+ mapping['num_encoder_layers'] = 'encoder_layers';
+ mapping['num_encoder_heads'] = 'encoder_attention_heads';
+ mapping['encoder_hidden_size'] = 'd_model';
+ break;
+ case 'speecht5':
+ mapping['num_decoder_layers'] = 'decoder_layers';
+ mapping['num_decoder_heads'] = 'decoder_attention_heads';
+ mapping['decoder_hidden_size'] = 'hidden_size';
+ mapping['num_encoder_layers'] = 'encoder_layers';
+ mapping['num_encoder_heads'] = 'encoder_attention_heads';
+ mapping['encoder_hidden_size'] = 'hidden_size';
+ break;
+ case 'trocr':
+ mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'decoder_layers';
+ mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'decoder_attention_heads';
+ mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'd_model';
+ break;
+ case 'musicgen_decoder':
+ mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'num_hidden_layers';
+ mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'num_attention_heads';
+ mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size';
+ break;
+
+ case 'vision-encoder-decoder':
+ const decoderConfig = getNormalizedConfig(config.decoder);
+
+ const add_encoder_pkv = 'num_decoder_layers' in decoderConfig;
+ const result = pick(config, ['model_type', 'is_encoder_decoder']);
+ if (add_encoder_pkv) {
+ // Decoder is part of an encoder-decoder model
+ result.num_decoder_layers = decoderConfig.num_decoder_layers;
+ result.num_decoder_heads = decoderConfig.num_decoder_heads;
+ result.decoder_hidden_size = decoderConfig.decoder_hidden_size;
+
+ result.num_encoder_layers = decoderConfig.num_encoder_layers;
+ result.num_encoder_heads = decoderConfig.num_encoder_heads;
+ result.encoder_hidden_size = decoderConfig.encoder_hidden_size;
+ } else {
+ // Decoder is a decoder-only model
+ result.num_layers = decoderConfig.num_layers;
+ result.num_heads = decoderConfig.num_heads;
+ result.hidden_size = decoderConfig.hidden_size;
+ }
+ return result;
+
+ }
+
+ // NOTE: If `num_attention_heads` is not set, it is assumed to be equal to `num_heads`
+ const normalized_config = {
+ ...init_normalized_config,
+ ...pick(config, ['model_type', 'multi_query', 'is_encoder_decoder']),
+ };
+ for (const key in mapping) {
+ normalized_config[key] = config[mapping[key]];
+ }
+ return normalized_config;
+}
+
+/**
+ *
+ * @param {PretrainedConfig} config
+ * @returns {Record}
+ */
+export function getKeyValueShapes(config, {
+ prefix = 'past_key_values',
+} = {}) {
+ /** @type {Record} */
+ const decoderFeeds = {};
+ const normalized_config = config.normalized_config;
+
+ // TODO support batches (i.e., batch_size > 1)
+ const batch_size = 1;
+
+ if (normalized_config.is_encoder_decoder && (
+ 'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config
+ )) {
+ const encoder_dim_kv = normalized_config.encoder_dim_kv ?? (
+ normalized_config.encoder_hidden_size / normalized_config.num_encoder_heads
+ );
+ const decoder_dim_kv = normalized_config.decoder_dim_kv ?? (
+ normalized_config.decoder_hidden_size / normalized_config.num_decoder_heads
+ );
+
+ const encoder_dims = [batch_size, normalized_config.num_encoder_heads, 0, encoder_dim_kv];
+ const decoder_dims = [batch_size, normalized_config.num_decoder_heads, 0, decoder_dim_kv];
+ for (let i = 0; i < normalized_config.num_decoder_layers; ++i) {
+ decoderFeeds[`${prefix}.${i}.encoder.key`] = encoder_dims;
+ decoderFeeds[`${prefix}.${i}.encoder.value`] = encoder_dims;
+ decoderFeeds[`${prefix}.${i}.decoder.key`] = decoder_dims;
+ decoderFeeds[`${prefix}.${i}.decoder.value`] = decoder_dims;
+ }
+ } else { // Decoders
+ const num_heads = normalized_config.num_heads;
+ const num_layers = normalized_config.num_layers;
+ const dim_kv = normalized_config.dim_kv ?? (
+ normalized_config.hidden_size /
+ (normalized_config.num_attention_heads ?? num_heads)
+ );
+
+ if (normalized_config.model_type === 'falcon') {
+ // NOTE: Custom implementation for Falcon
+ const dims = [batch_size * num_heads, 0, dim_kv]
+ for (let i = 0; i < num_layers; ++i) {
+ decoderFeeds[`${prefix}.${i}.key`] = dims;
+ decoderFeeds[`${prefix}.${i}.value`] = dims;
+ }
+ } else if (normalized_config.multi_query) { // e.g., for `gpt_bigcode`
+ const dims = [batch_size * num_heads, 0, 2 * dim_kv]
+
+ for (let i = 0; i < num_layers; ++i) {
+ decoderFeeds[`${prefix}.${i}.key_value`] = dims;
+ }
+ } else if (normalized_config.model_type === 'bloom') {
+ // NOTE: Custom implementation for Bloom
+
+ const keyDims = [batch_size * num_heads, dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
+ const valueDims = [batch_size * num_heads, 0, dim_kv] // [batch_size x num_heads,past_sequence_length,64]
+ for (let i = 0; i < num_layers; ++i) {
+ decoderFeeds[`${prefix}.${i}.key`] = keyDims;
+ decoderFeeds[`${prefix}.${i}.value`] = valueDims;
+ }
+ } else if (normalized_config.model_type === 'openelm') {
+ for (let i = 0; i < num_layers; ++i) {
+ const dims = [batch_size, num_heads[i], 0, dim_kv]
+
+ decoderFeeds[`${prefix}.${i}.key`] = dims;
+ decoderFeeds[`${prefix}.${i}.value`] = dims;
+ }
+ } else { // Decoder-only
+ const dims = [batch_size, num_heads, 0, dim_kv]
+ for (let i = 0; i < num_layers; ++i) {
+ decoderFeeds[`${prefix}.${i}.key`] = dims;
+ decoderFeeds[`${prefix}.${i}.value`] = dims;
+ }
+ }
+ }
+
+ return decoderFeeds;
+}
/**
* Base class for all configuration classes. For more information, see the corresponding
* [Python documentation](https://huggingface.co/docs/transformers/main/en/main_classes/configuration#transformers.PretrainedConfig).
@@ -54,15 +297,25 @@ async function loadConfig(pretrained_model_name_or_path, options) {
export class PretrainedConfig {
// NOTE: Typo in original
+ /** @type {string|null} */
+ model_type = null;
+
+ /** @type {boolean} */
+ is_encoder_decoder = false;
+
+ /** @type {number} */
+ max_position_embeddings;
+
+ /** @type {TransformersJSConfig} */
+ 'transformers.js_config';
+
/**
* Create a new PreTrainedTokenizer instance.
* @param {Object} configJSON The JSON of the config.
*/
constructor(configJSON) {
- this.model_type = null;
- this.is_encoder_decoder = false;
-
Object.assign(this, configJSON);
+ this.normalized_config = getNormalizedConfig(this);
}
/**
@@ -81,8 +334,11 @@ export class PretrainedConfig {
local_files_only = false,
revision = 'main',
} = {}) {
+ if (config && !(config instanceof PretrainedConfig)) {
+ config = new PretrainedConfig(config);
+ }
- let data = config ?? await loadConfig(pretrained_model_name_or_path, {
+ const data = config ?? await loadConfig(pretrained_model_name_or_path, {
progress_callback,
config,
cache_dir,
@@ -97,11 +353,23 @@ export class PretrainedConfig {
* Helper class which is used to instantiate pretrained configs with the `from_pretrained` function.
*
* @example
- * let config = await AutoConfig.from_pretrained('bert-base-uncased');
+ * const config = await AutoConfig.from_pretrained('Xenova/bert-base-uncased');
*/
export class AutoConfig {
- /** @type {PretrainedConfig.from_pretrained} */
+ /** @type {typeof PretrainedConfig.from_pretrained} */
static async from_pretrained(...args) {
return PretrainedConfig.from_pretrained(...args);
}
}
+
+/**
+ * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
+ * @typedef {Object} TransformersJSConfig
+ * @property {import('./utils/tensor.js').DataType|Record} [kv_cache_dtype] The data type of the key-value cache.
+ * @property {Record} [free_dimension_overrides] Override the free dimensions of the model.
+ * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
+ * for more information.
+ * @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model.
+ * @property {import('./utils/dtypes.js').DataType} [dtype] The default data type to use for the model.
+ * @property {boolean|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
+ */
diff --git a/src/env.js b/src/env.js
index 2ed670021..5e604efdb 100644
--- a/src/env.js
+++ b/src/env.js
@@ -3,19 +3,19 @@
*
* **Example:** Disable remote models.
* ```javascript
- * import { env } from '@xenova/transformers';
+ * import { env } from '@huggingface/transformers';
* env.allowRemoteModels = false;
* ```
*
* **Example:** Set local model path.
* ```javascript
- * import { env } from '@xenova/transformers';
+ * import { env } from '@huggingface/transformers';
* env.localModelPath = '/path/to/local/models/';
* ```
*
* **Example:** Set cache directory.
* ```javascript
- * import { env } from '@xenova/transformers';
+ * import { env } from '@huggingface/transformers';
* env.cacheDir = '/path/to/cache/directory/';
* ```
*
@@ -26,19 +26,53 @@ import fs from 'fs';
import path from 'path';
import url from 'url';
-import { ONNX } from './backends/onnx.js';
-const { env: onnx_env } = ONNX;
-
-const VERSION = '2.17.2';
+const VERSION = '3.0.0';
// Check if various APIs are available (depends on environment)
-const WEB_CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self;
-const FS_AVAILABLE = !isEmpty(fs); // check if file system is available
-const PATH_AVAILABLE = !isEmpty(path); // check if path is available
+const IS_BROWSER_ENV = typeof self !== 'undefined';
+const IS_WEBWORKER_ENV = IS_BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope';
+const IS_WEB_CACHE_AVAILABLE = IS_BROWSER_ENV && 'caches' in self;
+const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
+const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;
+
+const IS_PROCESS_AVAILABLE = typeof process !== 'undefined';
+const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node';
+const IS_FS_AVAILABLE = !isEmpty(fs);
+const IS_PATH_AVAILABLE = !isEmpty(path);
+
+/**
+ * A read-only object containing information about the APIs available in the current environment.
+ */
+export const apis = Object.freeze({
+ /** Whether we are running in a browser environment */
+ IS_BROWSER_ENV,
+
+ /** Whether we are running in a web worker environment */
+ IS_WEBWORKER_ENV,
+
+ /** Whether the Cache API is available */
+ IS_WEB_CACHE_AVAILABLE,
+
+ /** Whether the WebGPU API is available */
+ IS_WEBGPU_AVAILABLE,
+
+ /** Whether the WebNN API is available */
+ IS_WEBNN_AVAILABLE,
+
+ /** Whether the Node.js process API is available */
+ IS_PROCESS_AVAILABLE,
+
+ /** Whether we are running in a Node.js environment */
+ IS_NODE_ENV,
-const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE;
+ /** Whether the filesystem API is available */
+ IS_FS_AVAILABLE,
-// __dirname is reserved so we use dirname__ instead.
+ /** Whether the path API is available */
+ IS_PATH_AVAILABLE,
+});
+
+const RUNNING_LOCALLY = IS_FS_AVAILABLE && IS_PATH_AVAILABLE;
const dirname__ = RUNNING_LOCALLY
? path.dirname(path.dirname(url.fileURLToPath(import.meta.url)))
: './';
@@ -54,27 +88,17 @@ const localModelPath = RUNNING_LOCALLY
? path.join(dirname__, DEFAULT_LOCAL_MODEL_PATH)
: DEFAULT_LOCAL_MODEL_PATH;
-if (onnx_env?.wasm) {
- // Set path to wasm files. This is needed when running in a web worker.
- // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
- // We use remote wasm files by default to make it easier for newer users.
- // In practice, users should probably self-host the necessary .wasm files.
- onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
- ? path.join(dirname__, '/dist/')
- : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${VERSION}/dist/`;
-}
-
/**
- * Global variable used to control execution. This provides users a simple way to configure Transformers.js.
- * @property {Object} backends Expose environment variables of different backends,
- * allowing users to set these variables if they want to.
- * @property {string} __dirname Directory name of module. Useful for resolving local paths.
+ * Global variable given visible to users to control execution. This provides users a simple way to configure Transformers.js.
+ * @typedef {Object} TransformersEnvironment
* @property {string} version This version of Transformers.js.
+ * @property {{onnx: Partial}} backends Expose environment variables of different backends,
+ * allowing users to set these variables if they want to.
* @property {boolean} allowRemoteModels Whether to allow loading of remote files, defaults to `true`.
* If set to `false`, it will have the same effect as setting `local_files_only=true` when loading pipelines, models, tokenizers, processors, etc.
* @property {string} remoteHost Host URL to load models from. Defaults to the Hugging Face Hub.
* @property {string} remotePathTemplate Path template to fill in and append to `remoteHost` when loading models.
- * @property {boolean} allowLocalModels Whether to allow loading of local files, defaults to `true`.
+ * @property {boolean} allowLocalModels Whether to allow loading of local files, defaults to `false` if running in-browser, and `true` otherwise.
* If set to `false`, it will skip the local file check and try to load the model from the remote host.
* @property {string} localModelPath Path to load local models from. Defaults to `/models/`.
* @property {boolean} useFS Whether to use the file system to load files. By default, it is `true` if available.
@@ -85,32 +109,31 @@ if (onnx_env?.wasm) {
* @property {Object} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which
* implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache
*/
+
+/** @type {TransformersEnvironment} */
export const env = {
+ version: VERSION,
+
/////////////////// Backends settings ///////////////////
+ // NOTE: These will be populated later by the backends themselves.
backends: {
// onnxruntime-web/onnxruntime-node
- onnx: onnx_env,
-
- // TensorFlow.js
- tfjs: {},
+ onnx: {},
},
- __dirname: dirname__,
- version: VERSION,
-
/////////////////// Model settings ///////////////////
allowRemoteModels: true,
remoteHost: 'https://huggingface.co/',
remotePathTemplate: '{model}/resolve/{revision}/',
- allowLocalModels: true,
+ allowLocalModels: !IS_BROWSER_ENV,
localModelPath: localModelPath,
- useFS: FS_AVAILABLE,
+ useFS: IS_FS_AVAILABLE,
/////////////////// Cache settings ///////////////////
- useBrowserCache: WEB_CACHE_AVAILABLE,
+ useBrowserCache: IS_WEB_CACHE_AVAILABLE,
- useFSCache: FS_AVAILABLE,
+ useFSCache: IS_FS_AVAILABLE,
cacheDir: DEFAULT_CACHE_DIR,
useCustomCache: false,
diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js
new file mode 100644
index 000000000..33a6fbe81
--- /dev/null
+++ b/src/generation/configuration_utils.js
@@ -0,0 +1,381 @@
+
+/**
+ * @module generation/configuration_utils
+ */
+
+import { pick } from "../utils/core.js";
+
+/**
+ * Class that holds a configuration for a generation task.
+ */
+export class GenerationConfig {
+ // Parameters that control the length of the output
+ /**
+ * The maximum length the generated tokens can have.
+ * Corresponds to the length of the input prompt + `max_new_tokens`.
+ * Its effect is overridden by `max_new_tokens`, if also set.
+ * @type {number}
+ * @default 20
+ */
+ max_length = 20;
+
+ /**
+ * The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
+ * @type {number}
+ * @default null
+ */
+ max_new_tokens = null;
+
+ /**
+ * The minimum length of the sequence to be generated.
+ * Corresponds to the length of the input prompt + `min_new_tokens`.
+ * Its effect is overridden by `min_new_tokens`, if also set.
+ * @type {number}
+ * @default 0
+ */
+ min_length = 0;
+
+ /**
+ * The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
+ * @type {number}
+ * @default null
+ */
+ min_new_tokens = null;
+
+ /**
+ * Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
+ * - `true`, where the generation stops as soon as there are `num_beams` complete candidates;
+ * - `false`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
+ * - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm).
+ * @type {boolean|"never"}
+ * @default false
+ */
+ early_stopping = false;
+
+ /**
+ * The maximum amount of time you allow the computation to run for in seconds.
+ * Generation will still finish the current pass after allocated time has been passed.
+ * @type {number}
+ * @default null
+ */
+ max_time = null;
+
+ // Parameters that control the generation strategy used
+ /**
+ * Whether or not to use sampling; use greedy decoding otherwise.
+ * @type {boolean}
+ * @default false
+ */
+ do_sample = false;
+
+ /**
+ * Number of beams for beam search. 1 means no beam search.
+ * @type {number}
+ * @default 1
+ */
+ num_beams = 1;
+
+ /**
+ * Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
+ * See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
+ * @type {number}
+ * @default 1
+ */
+ num_beam_groups = 1;
+
+ /**
+ * The values balance the model confidence and the degeneration penalty in contrastive search decoding.
+ * @type {number}
+ * @default null
+ */
+ penalty_alpha = null;
+
+ /**
+ * Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
+ * @type {boolean}
+ * @default true
+ */
+ use_cache = true;
+
+ // Parameters for manipulation of the model output logits
+ /**
+ * The value used to modulate the next token probabilities.
+ * @type {number}
+ * @default 1.0
+ */
+ temperature = 1.0;
+
+ /**
+ * The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ * @type {number}
+ * @default 50
+ */
+ top_k = 50;
+
+ /**
+ * If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
+ * @type {number}
+ * @default 1.0
+ */
+ top_p = 1.0;
+
+ /**
+ * Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated.
+ * If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation.
+ * See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
+ * @type {number}
+ * @default 1.0
+ */
+ typical_p = 1.0;
+
+ /**
+ * If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled.
+ * In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model.
+ * See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
+ * @type {number}
+ * @default 0.0
+ */
+ epsilon_cutoff = 0.0;
+
+ /**
+ * Eta sampling is a hybrid of locally typical sampling and epsilon sampling.
+ * If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`.
+ * The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
+ * See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
+ * @type {number}
+ * @default 0.0
+ */
+ eta_cutoff = 0.0;
+
+ /**
+ * This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time.
+ * Note that `diversity_penalty` is only effective if `group beam search` is enabled.
+ * @type {number}
+ * @default 0.0
+ */
+ diversity_penalty = 0.0;
+
+ /**
+ * The parameter for repetition penalty. 1.0 means no penalty.
+ * See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ * @type {number}
+ * @default 1.0
+ */
+ repetition_penalty = 1.0;
+
+ /**
+ * The paramater for encoder_repetition_penalty.
+ * An exponential penalty on sequences that are not in the original input.
+ * 1.0 means no penalty.
+ * @type {number}
+ * @default 1.0
+ */
+ encoder_repetition_penalty = 1.0;
+
+ /**
+ * Exponential penalty to the length that is used with beam-based generation.
+ * It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence.
+ * Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
+ * @type {number}
+ * @default 1.0
+ */
+ length_penalty = 1.0;
+
+ /**
+ * If set to int > 0, all ngrams of that size can only occur once.
+ * @type {number}
+ * @default 0
+ */
+ no_repeat_ngram_size = 0;
+
+ /**
+ * List of token ids that are not allowed to be generated.
+ * In order to get the token ids of the words that should not appear in the generated text, use
+ * `tokenizer(bad_words, { add_prefix_space: true, add_special_tokens: false }).input_ids`.
+ * @type {number[][]}
+ * @default null
+ */
+ bad_words_ids = null;
+
+ /**
+ * List of token ids that must be generated.
+ * If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`.
+ * If given `number[][][]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word.
+ * @type {number[][]|number[][][]}
+ * @default null
+ */
+ force_words_ids = null;
+
+ /**
+ * Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones).
+ * It's highly recommended to set this flag to `true` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization.
+ * @type {boolean}
+ * @default false
+ */
+ renormalize_logits = false;
+
+ /**
+ * Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
+ * @type {Object[]}
+ * @default null
+ */
+ constraints = null;
+
+ /**
+ * The id of the token to force as the first generated token after the `decoder_start_token_id`.
+ * Useful for multilingual models like mBART where the first generated token needs to be the target language token.
+ * @type {number}
+ * @default null
+ */
+ forced_bos_token_id = null;
+
+ /**
+ * The id of the token to force as the last generated token when `max_length` is reached.
+ * Optionally, use a list to set multiple *end-of-sequence* tokens.
+ * @type {number|number[]}
+ * @default null
+ */
+ forced_eos_token_id = null;
+
+ /**
+ * Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation.
+ * @type {boolean}
+ */
+ remove_invalid_values = false;
+
+ /**
+ * This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated.
+ * The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay.
+ * @type {[number, number]}
+ * @default null
+ */
+ exponential_decay_length_penalty = null;
+
+ /**
+ * A list of tokens that will be suppressed at generation.
+ * The `SuppressTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
+ * @type {number[]}
+ * @default null
+ */
+ suppress_tokens = null;
+
+ /**
+ * A list of tokens that will be suppressed at the beginning of the generation.
+ * The `SuppressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
+ * @type {number[]}
+ * @default null
+ */
+ begin_suppress_tokens = null;
+
+ /**
+ * A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling.
+ * For example, `[[1, 123]]` means the second generated token will always be a token of index 123.
+ * @type {[number, number][]}
+ * @default null
+ */
+ forced_decoder_ids = null;
+
+ /**
+ * The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
+ * Higher guidance scale encourages the model to generate samples that are more closely linked to the input
+ * prompt, usually at the expense of poorer quality.
+ * @type {number}
+ * @default null
+ */
+ guidance_scale = null;
+
+ // Parameters that define the output variables of `generate`
+ /**
+ * The number of independently computed returned sequences for each element in the batch.
+ * @type {number}
+ * @default 1
+ */
+ num_return_sequences = 1;
+
+ /**
+ * Whether or not to return the attentions tensors of all attention layers.
+ * See `attentions` under returned tensors for more details.
+ * @type {boolean}
+ * @default false
+ */
+ output_attentions = false;
+
+ /**
+ * Whether or not to return the hidden states of all layers.
+ * See `hidden_states` under returned tensors for more details.
+ * @type {boolean}
+ * @default false
+ */
+ output_hidden_states = false;
+
+ /**
+ * Whether or not to return the prediction scores.
+ * See `scores` under returned tensors for more details.
+ * @type {boolean}
+ * @default false
+ */
+ output_scores = false;
+
+ /**
+ * Whether or not to return a `ModelOutput` instead of a plain tuple.
+ * @type {boolean}
+ * @default false
+ */
+ return_dict_in_generate = false;
+
+ // Special tokens that can be used at generation time
+ /**
+ * The id of the *padding* token.
+ * @type {number}
+ * @default null
+ */
+ pad_token_id = null;
+
+ /**
+ * The id of the *beginning-of-sequence* token.
+ * @type {number}
+ * @default null
+ */
+ bos_token_id = null;
+
+ /**
+ * The id of the *end-of-sequence* token.
+ * Optionally, use a list to set multiple *end-of-sequence* tokens.
+ * @type {number|number[]}
+ * @default null
+ */
+ eos_token_id = null;
+
+ // Generation parameters exclusive to encoder-decoder models
+ /**
+ * If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.
+ * @type {number}
+ * @default 0
+ */
+ encoder_no_repeat_ngram_size = 0;
+
+ /**
+ * If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
+ * @type {number}
+ * @default null
+ */
+ decoder_start_token_id = null;
+
+ // Wild card
+ /**
+ * Additional generation kwargs will be forwarded to the `generate` function of the model.
+ * Kwargs that are not present in `generate`'s signature will be used in the model forward pass.
+ * @type {Object}
+ * @default {}
+ */
+ generation_kwargs = {};
+
+ /**
+ *
+ * @param {GenerationConfig|import('../configs.js').PretrainedConfig} config
+ */
+ constructor(config) {
+ Object.assign(this, pick(config, Object.getOwnPropertyNames(this)));
+ }
+}
+
diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js
new file mode 100644
index 000000000..732af4f3f
--- /dev/null
+++ b/src/generation/logits_process.js
@@ -0,0 +1,719 @@
+
+/**
+ * @module generation/logits_process
+ */
+
+import { Callable } from "../utils/generic.js";
+import { Tensor } from "../utils/tensor.js";
+
+import { max, log_softmax } from "../utils/maths.js";
+
+/**
+ * Abstract base class for all logit processors that can be applied during generation.
+ */
+export class LogitsProcessor extends Callable {
+ /**
+ * Apply the processor to the input logits.
+ *
+ * @abstract
+ * @param {bigint[][]} input_ids The input ids.
+ * @param {Tensor} logits The logits to process.
+ * @throws {Error} Throws an error if `_call` is not implemented in the subclass.
+ */
+ _call(input_ids, logits) {
+ throw Error("`_call` should be implemented in a subclass")
+ }
+}
+
+
+/**
+ * Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.
+ */
+export class LogitsWarper extends Callable {
+ /**
+ * Apply the processor to the input logits.
+ *
+ * @abstract
+ * @param {bigint[][]} input_ids The input ids.
+ * @param {Tensor} logits The logits to process.
+ * @throws {Error} Throws an error if `_call` is not implemented in the subclass.
+ */
+ _call(input_ids, logits) {
+ throw Error("`_call` should be implemented in a subclass")
+ }
+}
+
+
+/**
+ * A class representing a list of logits processors. A logits processor is a function that modifies the logits
+ * output of a language model. This class provides methods for adding new processors and applying all processors to a
+ * batch of logits.
+ */
+export class LogitsProcessorList extends Callable {
+ /**
+ * Constructs a new instance of `LogitsProcessorList`.
+ */
+ constructor() {
+ super();
+ this.processors = [];
+ }
+
+ /**
+ * Adds a new logits processor to the list.
+ *
+ * @param {LogitsProcessor} item The logits processor function to add.
+ */
+ push(item) {
+ this.processors.push(item);
+ }
+
+ /**
+ * Adds multiple logits processors to the list.
+ *
+ * @param {LogitsProcessor[]} items The logits processor functions to add.
+ */
+ extend(items) {
+ this.processors.push(...items);
+ }
+
+ /**
+ * Applies all logits processors in the list to a batch of logits, modifying them in-place.
+ *
+ * @param {bigint[][]} input_ids The input IDs for the language model.
+ * @param {Tensor} logits
+ */
+ _call(input_ids, logits) {
+ let toReturn = logits;
+ // NOTE: Most processors modify logits inplace
+ for (const processor of this.processors) {
+ toReturn = processor(input_ids, toReturn);
+ }
+ return toReturn;
+ }
+
+ [Symbol.iterator]() {
+ return this.processors.values();
+ }
+}
+
+// DEPRECATED: https://github.com/huggingface/transformers/pull/29485
+// /**
+// * A logits processor that forces a specific token to be generated by the decoder.
+// */
+// export class ForceTokensLogitsProcessor extends LogitsProcessor {
+// /**
+// * Constructs a new instance of `ForceTokensLogitsProcessor`.
+// *
+// * @param {[number, number][]} forced_decoder_ids The ids of tokens that should be forced.
+// */
+// constructor(forced_decoder_ids) {
+// super();
+// // TODO: convert to `new Map(forced_decoder_ids)`
+// this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []);
+// }
+
+// /**
+// * Apply the processor to the input logits.
+// *
+// * @param {bigint[][]} input_ids The input ids.
+// * @param {Tensor} logits The logits to process.
+// * @returns {Tensor} The processed logits.
+// */
+// _call(input_ids, logits) {
+// console.log('this.force_token_map', this.force_token_map)
+// console.log('call ForceTokensLogitsProcessor', input_ids, logits)
+// console.log('input_ids.length', input_ids.length)
+// let map = this.force_token_map[input_ids.length];
+// if (map) { // There exists a mapping
+// logits.data.fill(-Infinity)
+// logits.data[map] = 0;
+// }
+// console.log('map', map)
+// // throw Error("Not implemented")
+// return logits;
+// }
+// }
+
+/**
+ * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
+ */
+export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a ForcedBOSTokenLogitsProcessor.
+ * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced.
+ */
+ constructor(bos_token_id) {
+ super();
+ this.bos_token_id = bos_token_id;
+ }
+
+ /**
+ * Apply the BOS token forcing to the logits.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The logits with BOS token forcing.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ if (input_ids[i].length === 1) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+ batch_logits_data.fill(-Infinity);
+ batch_logits_data[this.bos_token_id] = 0;
+ }
+ }
+ return logits;
+ }
+}
+
+/**
+ * A logits processor that enforces the specified token as the last generated token when `max_length` is reached.
+ */
+export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a ForcedEOSTokenLogitsProcessor.
+ * @param {number} max_length The maximum length of the sequence to be generated.
+ * @param {number|number[]} eos_token_id The id(s) of the *end-of-sequence* token.
+ */
+ constructor(max_length, eos_token_id) {
+ super();
+ this.max_length = max_length;
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
+ }
+
+ /**
+ * Apply the processor to input_ids and logits.
+ *
+ * @param {bigint[][]} input_ids The input ids.
+ * @param {Tensor} logits The logits tensor.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ if (input_ids[i].length === this.max_length - 1) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+ batch_logits_data.fill(-Infinity);
+ for (const eos_token of this.eos_token_id) {
+ batch_logits_data[eos_token] = 0;
+ }
+ }
+ }
+ return logits;
+ }
+}
+
+/**
+ * A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts
+ * generating using `begin_index` tokens. This should ensure that the tokens defined by
+ * `begin_suppress_tokens` at not sampled at the begining of the generation.
+ */
+export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a SuppressTokensAtBeginLogitsProcessor.
+ * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress.
+ * @param {number} begin_index The number of tokens to generate before suppressing tokens.
+ */
+ constructor(begin_suppress_tokens, begin_index) {
+ super();
+ this.begin_suppress_tokens = begin_suppress_tokens;
+ this.begin_index = begin_index;
+ }
+
+ /**
+ * Apply the BOS token forcing to the logits.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The logits with BOS token forcing.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ if (input_ids[i].length === this.begin_index) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+ for (const token_id of this.begin_suppress_tokens) {
+ batch_logits_data[token_id] = -Infinity;
+ }
+ }
+ }
+ return logits;
+ }
+}
+
+/**
+ * A LogitsProcessor that handles adding timestamps to generated text.
+ */
+export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
+ /**
+ * Constructs a new WhisperTimeStampLogitsProcessor.
+ * @param {import('../models/whisper/generation_whisper.js').WhisperGenerationConfig} generate_config The config object passed to the `generate()` method of a transformer model.
+ * @param {number[]} init_tokens The initial tokens of the input sequence.
+ */
+ constructor(generate_config, init_tokens) {
+ super();
+ this.eos_token_id =
+ Array.isArray(generate_config.eos_token_id)
+ ? generate_config.eos_token_id[0]
+ : generate_config.eos_token_id;
+
+ this.no_timestamps_token_id = generate_config.no_timestamps_token_id;
+ this.timestamp_begin = this.no_timestamps_token_id + 1;
+
+ this.begin_index = init_tokens.length;
+ if (init_tokens.at(-1) === this.no_timestamps_token_id) {
+ this.begin_index -= 1;
+ }
+ this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index;
+ }
+
+ /**
+ * Modify the logits to handle timestamp tokens.
+ * @param {bigint[][]} input_ids The input sequence of tokens.
+ * @param {Tensor} logits The logits output by the model.
+ * @returns {Tensor} The modified logits.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+
+ // suppress <|notimestamps|> which is handled by without_timestamps
+ batch_logits_data[this.no_timestamps_token_id] = -Infinity;
+
+ if (input_ids[i].length === this.begin_index - 1) {
+ batch_logits_data.fill(-Infinity);
+ batch_logits_data[this.timestamp_begin] = 0;
+ continue;
+ }
+
+ // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
+ const seq = input_ids[i].slice(this.begin_index);
+ const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
+ const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
+
+ if (last_was_timestamp) {
+ if (penultimate_was_timestamp) { // has to be non-timestamp
+ batch_logits_data.subarray(this.timestamp_begin).fill(-Infinity);
+ } else { // cannot be normal text tokens
+ batch_logits_data.subarray(0, this.eos_token_id).fill(-Infinity);
+ }
+ }
+
+ // apply the `max_initial_timestamp` option
+ if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
+ const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
+ batch_logits_data.subarray(last_allowed + 1).fill(-Infinity);
+ }
+
+ // if sum of probability over timestamps is above any other token, sample timestamp
+ const logprobs = log_softmax(batch_logits_data);
+ const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
+ const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
+
+ if (timestamp_logprob > max_text_token_logprob) {
+ batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
+ }
+ }
+
+ return logits;
+ }
+}
+
+/**
+ * A logits processor that disallows ngrams of a certain size to be repeated.
+ */
+export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a NoRepeatNGramLogitsProcessor.
+ * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once.
+ */
+ constructor(no_repeat_ngram_size) {
+ super();
+ this.no_repeat_ngram_size = no_repeat_ngram_size;
+ }
+
+ /**
+ * Generate n-grams from a sequence of token ids.
+ * @param {bigint[]} prevInputIds List of previous input ids
+ * @returns {Map} Map of generated n-grams
+ */
+ getNgrams(prevInputIds) {
+ const curLen = prevInputIds.length;
+
+ /**@type {number[][]} */
+ const ngrams = [];
+ for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) {
+ const ngram = [];
+ for (let k = 0; k < this.no_repeat_ngram_size; ++k) {
+ ngram.push(prevInputIds[j + k]);
+ }
+ ngrams.push(ngram.map(Number));
+ }
+
+ /** @type {Map} */
+ const generatedNgram = new Map();
+ for (const ngram of ngrams) {
+ const prevNgram = ngram.slice(0, ngram.length - 1);
+ const prevNgramKey = JSON.stringify(prevNgram);
+ const prevNgramValue = generatedNgram.get(prevNgramKey) ?? [];
+ prevNgramValue.push(ngram[ngram.length - 1]);
+ generatedNgram.set(prevNgramKey, prevNgramValue);
+ }
+ return generatedNgram;
+ }
+
+ /**
+ * Generate n-grams from a sequence of token ids.
+ * @param {Map} bannedNgrams Map of banned n-grams
+ * @param {bigint[]} prevInputIds List of previous input ids
+ * @returns {number[]} Map of generated n-grams
+ */
+ getGeneratedNgrams(bannedNgrams, prevInputIds) {
+ const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length);
+ const banned = bannedNgrams.get(JSON.stringify(ngramIdx.map(Number))) ?? [];
+ return banned;
+ }
+
+ /**
+ * Calculate banned n-gram tokens
+ * @param {bigint[]} prevInputIds List of previous input ids
+ * @returns {number[]} Map of generated n-grams
+ */
+ calcBannedNgramTokens(prevInputIds) {
+ const bannedTokens = [];
+ if (prevInputIds.length + 1 < this.no_repeat_ngram_size) {
+ // return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
+ return bannedTokens;
+
+ } else {
+ const generatedNgrams = this.getNgrams(prevInputIds);
+ const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds);
+ return bannedTokens;
+ }
+ }
+
+ /**
+ * Apply the no-repeat-ngram processor to the logits.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The logits with no-repeat-ngram processing.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+ const bannedTokens = this.calcBannedNgramTokens(input_ids[i]);
+ for (const token of bannedTokens) {
+ batch_logits_data[token] = -Infinity;
+ }
+ }
+ return logits;
+ }
+}
+
+/**
+ * A logits processor that penalises repeated output tokens.
+ */
+export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a RepetitionPenaltyLogitsProcessor.
+ * @param {number} penalty The penalty to apply for repeated tokens.
+ */
+ constructor(penalty) {
+ super();
+ this.penalty = penalty;
+ }
+
+ /**
+ * Apply the repetition penalty to the logits.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The logits with repetition penalty processing.
+ */
+ _call(input_ids, logits) {
+ // Modify the logits corresponding to each element in `input_ids`.
+ // As a consequence, the logits corresponding to tokens that appear
+ // many times in the output will be penalised more.
+
+ for (let i = 0; i < input_ids.length; ++i) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+ for (const input_id of input_ids[i]) {
+ const token = Number(input_id);
+ if (batch_logits_data[token] < 0) {
+ batch_logits_data[token] *= this.penalty;
+ } else {
+ batch_logits_data[token] /= this.penalty;
+ }
+ }
+ }
+
+ return logits
+ }
+}
+
+/**
+ * A logits processor that enforces a minimum number of tokens.
+ */
+export class MinLengthLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a MinLengthLogitsProcessor.
+ * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
+ * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
+ */
+ constructor(min_length, eos_token_id) {
+ super();
+ this.min_length = min_length;
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
+ }
+
+ /**
+ * Apply logit processor.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The processed logits.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ if (input_ids[i].length < this.min_length) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+
+ for (const eos_token of this.eos_token_id) {
+ batch_logits_data[eos_token] = -Infinity;
+ }
+ }
+ }
+
+ return logits
+ }
+}
+
+/**
+ * A logits processor that enforces a minimum number of new tokens.
+ */
+export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a MinNewTokensLengthLogitsProcessor.
+ * @param {number} prompt_length_to_skip The input tokens length.
+ * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
+ * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
+ */
+ constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
+ super();
+ this.prompt_length_to_skip = prompt_length_to_skip;
+ this.min_new_tokens = min_new_tokens;
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
+ }
+
+ /**
+ * Apply logit processor.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The processed logits.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip;
+ if (new_tokens_length < this.min_new_tokens) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+
+ for (const eos_token of this.eos_token_id) {
+ batch_logits_data[eos_token] = -Infinity;
+ }
+ }
+ }
+ return logits
+ }
+}
+
+export class NoBadWordsLogitsProcessor extends LogitsProcessor {
+ /**
+ * Create a `NoBadWordsLogitsProcessor`.
+ * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
+ * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ */
+ constructor(bad_words_ids, eos_token_id) {
+ super();
+ this.bad_words_ids = bad_words_ids;
+ this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
+ }
+
+ /**
+ * Apply logit processor.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The processed logits.
+ */
+ _call(input_ids, logits) {
+ for (let i = 0; i < input_ids.length; ++i) {
+ const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
+ const ids = input_ids[i];
+ for (const bad_word_ids of this.bad_words_ids) {
+ // Whether to modify the logits of the last token in the bad word id sequence
+ let mark = true;
+
+ // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
+ // then we set the logits of the last bad word id to -Infinity.
+ for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {
+
+ // NOTE: We use != instead of !== to compare bigint and number
+ // @ts-ignore
+ if (bad_word_ids.at(-j - 1) != ids.at(-j)) {
+ // We have found a mismatch
+ mark = false;
+ break;
+ }
+ }
+ if (mark) {
+ batch_logits_data[bad_word_ids.at(-1)] = -Infinity;
+ }
+ }
+ }
+ return logits
+ }
+}
+
+/**
+ * [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
+ * where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
+ * correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
+ * weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
+ *
+ * See [the paper](https://arxiv.org/abs/2306.05284) for more information.
+ */
+export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
+
+ /**
+ * Create a `ClassifierFreeGuidanceLogitsProcessor`.
+ * @param {number} guidance_scale The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
+ * Higher guidance scale encourages the model to generate samples that are more closely linked to the input
+ * prompt, usually at the expense of poorer quality.
+ */
+ constructor(guidance_scale) {
+ super();
+ if (guidance_scale <= 1) {
+ throw new Error(
+ `Require guidance scale >1 to use the classifier free guidance processor, got guidance scale ${guidance_scale}.`
+ )
+ }
+ this.guidance_scale = guidance_scale;
+ }
+
+ /**
+ * Apply logit processor.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The processed logits.
+ */
+ _call(input_ids, logits) {
+ if (logits.dims[0] !== 2 * input_ids.length) {
+ throw new Error(
+ `Logits should have twice the batch size of the input ids, the first half of batches corresponding to ` +
+ `the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got ` +
+ `batch size ${logits.dims[0]} for the logits and ${input_ids.length} for the input ids.`
+ )
+ }
+
+ const unguided_bsz = input_ids.length;
+ const cond_logits = logits.slice([0, unguided_bsz], null);
+ const uncond_logits = logits.slice([unguided_bsz, logits.dims[0]], null);
+
+ // Merge into uncond_logits (to save memory). This is equivalent to the following:
+ // scores = uncond_logits + (cond_logits - uncond_logits) * guidance_scale
+ for (let i = 0; i < uncond_logits.data.length; ++i) {
+ uncond_logits.data[i] += (cond_logits.data[i] - uncond_logits.data[i]) * this.guidance_scale;
+ }
+
+ return uncond_logits;
+ }
+}
+
+/**
+ * [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
+ * that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`].
+ */
+export class TemperatureLogitsWarper extends LogitsWarper {
+ /**
+ * Create a `TemperatureLogitsWarper`.
+ * @param {number} temperature Strictly positive float value used to modulate the logits distribution.
+ * A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting
+ * all probability mass to the most likely token.
+ */
+ constructor(temperature) {
+ super();
+
+ if (typeof temperature !== 'number' || temperature <= 0) {
+ let errorMessage =
+ `\`temperature\` (=${temperature}) must be a strictly positive float, otherwise your next token scores will be invalid.`;
+
+ if (temperature === 0) {
+ errorMessage += " If you're looking for greedy decoding strategies, set `do_sample=false`."
+ }
+ }
+ this.temperature = temperature;
+ }
+
+ /**
+ * Apply logit warper.
+ * @param {bigint[][]} input_ids The input IDs.
+ * @param {Tensor} logits The logits.
+ * @returns {Object} The processed logits.
+ */
+ _call(input_ids, logits) {
+ const batch_logits_data = /** @type {Float32Array} */(logits.data);
+ for (let i = 0; i < batch_logits_data.length; ++i) {
+ batch_logits_data[i] /= this.temperature;
+ }
+ return logits;
+ }
+}
+
+/**
+ * [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
+ * Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
+ */
+export class TopPLogitsWarper extends LogitsWarper {
+ /**
+ * Create a `TopPLogitsWarper`.
+ * @param {number} top_p If set to < 1, only the smallest set of most probable tokens with
+ * probabilities that add up to `top_p` or higher are kept for generation.
+ * @param {Object} options Additional options for the top-p sampling.
+ * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
+ * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
+ */
+ constructor(top_p, {
+ filter_value = -Infinity,
+ min_tokens_to_keep = 1,
+ } = {}) {
+ super();
+ if (top_p < 0 || top_p > 1.0) {
+ throw new Error(`\`top_p\` must be a float > 0 and < 1, but is ${top_p}`)
+ }
+ if (!Number.isInteger(min_tokens_to_keep) || min_tokens_to_keep < 1) {
+ throw new Error(`\`min_tokens_to_keep\` must be a positive integer, but is ${min_tokens_to_keep}`)
+ }
+
+ this.top_p = top_p
+ this.filter_value = filter_value
+ this.min_tokens_to_keep = min_tokens_to_keep
+ }
+}
+
+/**
+ * [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
+ * Often used together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
+ */
+export class TopKLogitsWarper extends LogitsWarper {
+ /**
+ * Create a `TopKLogitsWarper`.
+ * @param {number} top_k If set to > 0, only the top `top_k` tokens are kept for generation.
+ * @param {Object} options Additional options for the top-k sampling.
+ * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
+ * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
+ */
+ constructor(top_k, {
+ filter_value = -Infinity,
+ min_tokens_to_keep = 1,
+ } = {}) {
+ super();
+ if (!Number.isInteger(top_k) || top_k < 0) {
+ throw new Error(`\`top_k\` must be a positive integer, but is ${top_k}`)
+ }
+
+ this.top_k = Math.max(top_k, min_tokens_to_keep)
+ this.filter_value = filter_value
+ }
+}
\ No newline at end of file
diff --git a/src/generation/logits_sampler.js b/src/generation/logits_sampler.js
new file mode 100644
index 000000000..46b74e081
--- /dev/null
+++ b/src/generation/logits_sampler.js
@@ -0,0 +1,204 @@
+
+/**
+ * @module generation/logits_sampler
+ */
+
+import { Callable } from "../utils/generic.js";
+import { Tensor, topk } from "../utils/tensor.js";
+
+import {
+ max,
+ softmax,
+} from '../utils/maths.js';
+import { GenerationConfig } from '../generation/configuration_utils.js';
+
+/**
+ * Sampler is a base class for all sampling methods used for text generation.
+ */
+export class LogitsSampler extends Callable {
+ /**
+ * Creates a new Sampler object with the specified generation config.
+ * @param {GenerationConfig} generation_config The generation config.
+ */
+ constructor(generation_config) {
+ super();
+ this.generation_config = generation_config;
+ }
+
+ /**
+ * Executes the sampler, using the specified logits.
+ * @param {Tensor} logits
+ * @returns {Promise<[bigint, number][]>}
+ */
+ async _call(logits) {
+ // Sample from logits, of dims [batch, sequence_length, vocab_size].
+ // If index is specified, sample from [batch, index, vocab_size].
+ return this.sample(logits);
+ }
+
+ /**
+ * Abstract method for sampling the logits.
+ * @param {Tensor} logits
+ * @throws {Error} If not implemented in subclass.
+ * @returns {Promise<[bigint, number][]>}
+ */
+ async sample(logits) {
+ throw Error("sample should be implemented in subclasses.")
+ }
+
+ /**
+ * Returns the specified logits as an array, with temperature applied.
+ * @param {Tensor} logits
+ * @param {number} index
+ * @returns {Float32Array}
+ */
+ getLogits(logits, index) {
+ let vocabSize = logits.dims.at(-1);
+
+ let logs = /** @type {Float32Array} */(logits.data);
+
+ if (index === -1) {
+ logs = logs.slice(-vocabSize);
+ } else {
+ let startIndex = index * vocabSize;
+ logs = logs.slice(startIndex, startIndex + vocabSize);
+ }
+ return logs;
+ }
+
+ /**
+ * Selects an item randomly based on the specified probabilities.
+ * @param {import("../transformers.js").DataArray} probabilities An array of probabilities to use for selection.
+ * @returns {number} The index of the selected item.
+ */
+ randomSelect(probabilities) {
+ // Return index of chosen item
+ let sumProbabilities = 0;
+ for (let i = 0; i < probabilities.length; ++i) {
+ sumProbabilities += probabilities[i];
+ }
+
+ let r = Math.random() * sumProbabilities;
+ for (let i = 0; i < probabilities.length; ++i) {
+ r -= probabilities[i];
+ if (r <= 0) {
+ return i;
+ }
+ }
+ return 0; // return first (most probable) as a fallback
+ }
+
+ /**
+ * Returns a Sampler object based on the specified options.
+ * @param {GenerationConfig} generation_config An object containing options for the sampler.
+ * @returns {LogitsSampler} A Sampler object.
+ */
+ static getSampler(generation_config) {
+ // - *greedy decoding*: `num_beams=1` and `do_sample=False`
+ // - *contrastive search*: `penalty_alpha>0` and `top_k>1`
+ // - *multinomial sampling*: `num_beams=1` and `do_sample=True`
+ // - *beam-search decoding*: `num_beams>1` and `do_sample=False`
+ // - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True`
+ // - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1`
+ // - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None`
+
+ // NOTE: beam search is implemented directly into the generation function
+ if (generation_config.do_sample) {
+ return new MultinomialSampler(generation_config);
+
+ } else if (generation_config.num_beams > 1) {
+ return new BeamSearchSampler(generation_config);
+
+ } else {
+ if (generation_config.num_return_sequences > 1) {
+ throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`)
+ }
+ return new GreedySampler(generation_config);
+ }
+ }
+}
+
+/**
+ * Class representing a Greedy Sampler.
+ */
+class GreedySampler extends LogitsSampler {
+ /**
+ * Sample the maximum probability of a given logits tensor.
+ * @param {Tensor} logits
+ * @returns {Promise<[bigint, number][]>} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search).
+ */
+ async sample(logits) {
+ // NOTE: no need to do log_softmax here since we only take the maximum
+ const argmax = max(logits.data)[1];
+
+ // Note: score is meaningless in this context, since we are performing
+ // greedy search (p = 1 => log(p) = 0)
+ return [
+ [BigInt(argmax), 0]
+ ];
+ }
+}
+
+/**
+ * Class representing a MultinomialSampler.
+ */
+class MultinomialSampler extends LogitsSampler {
+
+ /**
+ * Sample from the logits.
+ * @param {Tensor} logits
+ * @returns {Promise<[bigint, number][]>}
+ */
+ async sample(logits) {
+ let k = logits.dims.at(-1); // defaults to vocab size
+ if (this.generation_config.top_k > 0) {
+ k = Math.min(this.generation_config.top_k, k);
+ }
+
+ // Get top k tokens
+ const [v, i] = await topk(logits, k);
+
+ // Compute softmax over logits
+ const probabilities = softmax(/** @type {Float32Array} */(v.data));
+
+ return Array.from({ length: this.generation_config.num_beams }, () => {
+ const sampledIndex = this.randomSelect(probabilities);
+ return [
+ i.data[sampledIndex], // token id
+ Math.log(probabilities[sampledIndex]), // score
+ ];
+ });
+ }
+}
+
+
+/**
+ * Class representing a BeamSearchSampler.
+ */
+class BeamSearchSampler extends LogitsSampler {
+
+ /**
+ * Sample from the logits.
+ * @param {Tensor} logits
+ * @returns {Promise<[bigint, number][]>}
+ */
+ async sample(logits) {
+ let k = logits.dims.at(-1); // defaults to vocab size
+ if (this.generation_config.top_k > 0) {
+ k = Math.min(this.generation_config.top_k, k);
+ }
+
+ // Get top k tokens
+ const [v, i] = await topk(logits, k);
+
+ // Compute softmax over logits
+ const probabilities = softmax(/** @type {Float32Array} */(v.data));
+
+ return Array.from({ length: this.generation_config.num_beams }, (_, x) => {
+ return [
+ i.data[x], // token id
+ Math.log(probabilities[x]), // score
+ ];
+ });
+ }
+}
diff --git a/src/generation/parameters.js b/src/generation/parameters.js
new file mode 100644
index 000000000..1e2f2def3
--- /dev/null
+++ b/src/generation/parameters.js
@@ -0,0 +1,35 @@
+
+/**
+ * @module generation/parameters
+ */
+
+/**
+ * @typedef {Object} GenerationFunctionParameters
+ * @property {import('../utils/tensor.js').Tensor} [inputs=null] (`Tensor` of varying shape depending on the modality, *optional*):
+ * The sequence used as a prompt for the generation or as model inputs to the encoder. If `null` the
+ * method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ * should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
+ * `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ * @property {import('./configuration_utils.js').GenerationConfig} [generation_config=null] (`GenerationConfig`, *optional*):
+ * The generation configuration to be used as base parametrization for the generation call.
+ * `**kwargs` passed to generate matching the attributes of `generation_config` will override them.
+ * If `generation_config` is not provided, the default will be used, which has the following loading
+ * priority:
+ * - (1) from the `generation_config.json` model file, if it exists;
+ * - (2) from the model configuration. Please note that unspecified parameters will inherit [`GenerationConfig`]'s
+ * default values, whose documentation should be checked to parameterize generation.
+ * @property {import('./logits_process.js').LogitsProcessorList} [logits_processor=null] (`LogitsProcessorList`, *optional*):
+ * Custom logits processors that complement the default logits processors built from arguments and
+ * generation config. If a logit processor is passed that is already created with the arguments or a
+ * generation config an error is thrown. This feature is intended for advanced users.
+ * @property {import('./stopping_criteria.js').StoppingCriteriaList} [stopping_criteria=null] (`StoppingCriteriaList`, *optional*):
+ * Custom stopping criteria that complements the default stopping criteria built from arguments and a
+ * generation config. If a stopping criteria is passed that is already created with the arguments or a
+ * generation config an error is thrown. This feature is intended for advanced users.
+ * @property {import('./streamers.js').BaseStreamer} [streamer=null] (`BaseStreamer`, *optional*):
+ * Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ * through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ * @property {number[]} [decoder_input_ids=null] (`number[]`, *optional*):
+ * If the model is an encoder-decoder model, this argument is used to pass the `decoder_input_ids`.
+ * @param {any} [kwargs] (`Dict[str, any]`, *optional*):
+ */
diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.js
new file mode 100644
index 000000000..08434f2b4
--- /dev/null
+++ b/src/generation/stopping_criteria.js
@@ -0,0 +1,156 @@
+
+/**
+ * @module generation/stopping_criteria
+ */
+
+import { Callable } from "../utils/generic.js";
+
+// NOTE:
+// Stopping Criteria returns a list of `batch_size` booleans, indicating whether each sequence in the batch should be stopped.
+
+/**
+ * Abstract base class for all stopping criteria that can be applied during generation.
+ */
+export class StoppingCriteria extends Callable {
+ /**
+ *
+ * @param {number[][]} input_ids (`number[][]` of shape `(batch_size, sequence_length)`):
+ * Indices of input sequence tokens in the vocabulary.
+ * @param {number[][]} scores scores (`number[][]` of shape `(batch_size, config.vocab_size)`):
+ * Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
+ * or scores for each vocabulary token after SoftMax.
+ * @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped.
+ */
+ _call(input_ids, scores) {
+ throw Error("StoppingCriteria needs to be subclassed");
+ }
+}
+/**
+ */
+export class StoppingCriteriaList extends Callable {
+ /**
+ * Constructs a new instance of `StoppingCriteriaList`.
+ */
+ constructor() {
+ super();
+ this.criteria = [];
+ }
+
+ /**
+ * Adds a new stopping criterion to the list.
+ *
+ * @param {StoppingCriteria} item The stopping criterion to add.
+ */
+ push(item) {
+ this.criteria.push(item);
+ }
+
+ /**
+ * Adds multiple stopping criteria to the list.
+ *
+ * @param {StoppingCriteria|StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add.
+ */
+ extend(items) {
+ if (items instanceof StoppingCriteriaList) {
+ items = items.criteria;
+ } else if (items instanceof StoppingCriteria) {
+ items = [items];
+ }
+ this.criteria.push(...items);
+ }
+
+ _call(input_ids, scores) {
+ const is_done = new Array(input_ids.length).fill(false);
+ for (const criterion of this.criteria) {
+ const criterion_done = criterion(input_ids, scores);
+ for (let i = 0; i < is_done.length; ++i) {
+ is_done[i] ||= criterion_done[i];
+ }
+ }
+ return is_done;
+ }
+
+ [Symbol.iterator]() {
+ return this.criteria.values();
+ }
+}
+
+/**
+ * This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`.
+ * Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
+ */
+export class MaxLengthCriteria extends StoppingCriteria {
+
+ /**
+ *
+ * @param {number} max_length The maximum length that the output sequence can have in number of tokens.
+ * @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
+ */
+ constructor(max_length, max_position_embeddings = null) {
+ super();
+ this.max_length = max_length;
+ this.max_position_embeddings = max_position_embeddings;
+ }
+
+ _call(input_ids) {
+ return input_ids.map(ids => ids.length >= this.max_length);
+ }
+}
+
+// TODO: add MaxTimeCriteria
+
+/**
+ * This class can be used to stop generation whenever the "end-of-sequence" token is generated.
+ * By default, it uses the `model.generation_config.eos_token_id`.
+ */
+export class EosTokenCriteria extends StoppingCriteria {
+
+ /**
+ *
+ * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token.
+ * Optionally, use a list to set multiple *end-of-sequence* tokens.
+ */
+ constructor(eos_token_id) {
+ super();
+ if (!Array.isArray(eos_token_id)) {
+ eos_token_id = [eos_token_id];
+ }
+ this.eos_token_id = eos_token_id;
+ }
+
+ /**
+ *
+ * @param {number[][]} input_ids
+ * @param {number[][]} scores
+ * @returns {boolean[]}
+ */
+ _call(input_ids, scores) {
+ return input_ids.map(ids => {
+ const last = ids.at(-1);
+ // NOTE: We use == instead of === to allow for number/bigint comparison
+ return this.eos_token_id.some(eos_id => last == eos_id);
+ });
+ }
+}
+
+/**
+ * This class can be used to stop generation whenever the user interrupts the process.
+ */
+export class InterruptableStoppingCriteria extends StoppingCriteria {
+ constructor() {
+ super();
+ this.interrupted = false;
+ }
+
+ interrupt() {
+ this.interrupted = true;
+ }
+
+ reset() {
+ this.interrupted = false;
+ }
+
+ _call(input_ids, scores) {
+ return new Array(input_ids.length).fill(this.interrupted);
+ }
+}
diff --git a/src/generation/streamers.js b/src/generation/streamers.js
new file mode 100644
index 000000000..64afc71c7
--- /dev/null
+++ b/src/generation/streamers.js
@@ -0,0 +1,212 @@
+
+/**
+ * @module generation/streamers
+ */
+
+import { mergeArrays } from '../utils/core.js';
+import { is_chinese_char } from '../tokenizers.js';
+import { apis } from '../env.js';
+
+export class BaseStreamer {
+ /**
+ * Function that is called by `.generate()` to push new tokens
+ * @param {bigint[][]} value
+ */
+ put(value) {
+ throw Error('Not implemented');
+ }
+
+ /**
+ * Function that is called by `.generate()` to signal the end of generation
+ */
+ end() {
+ throw Error('Not implemented');
+ }
+}
+
+const stdout_write = apis.IS_PROCESS_AVAILABLE
+ ? x => process.stdout.write(x)
+ : x => console.log(x);
+
+/**
+ * Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
+ */
+export class TextStreamer extends BaseStreamer {
+ /**
+ *
+ * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
+ */
+ constructor(tokenizer, {
+ skip_prompt = false,
+ callback_function = null,
+ token_callback_function = null,
+ decode_kwargs = {},
+ ...kwargs
+ } = {}) {
+ super();
+ this.tokenizer = tokenizer;
+ this.skip_prompt = skip_prompt;
+ this.callback_function = callback_function ?? stdout_write;
+ this.token_callback_function = token_callback_function;
+ this.decode_kwargs = { ...decode_kwargs, ...kwargs };
+
+ // variables used in the streaming process
+ this.token_cache = [];
+ this.print_len = 0;
+ this.next_tokens_are_prompt = true;
+ }
+
+ /**
+ * Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
+ * @param {bigint[][]} value
+ */
+ put(value) {
+ if (value.length > 1) {
+ throw Error('TextStreamer only supports batch size of 1');
+ }
+
+ if (this.skip_prompt && this.next_tokens_are_prompt) {
+ this.next_tokens_are_prompt = false;
+ return;
+ }
+
+ const tokens = value[0];
+ this.token_callback_function?.(tokens)
+
+ // Add the new token to the cache and decodes the entire thing.
+ this.token_cache = mergeArrays(this.token_cache, tokens);
+ const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
+
+ let printable_text;
+ if (text.endsWith('\n')) {
+ // After the symbol for a new line, we flush the cache.
+ printable_text = text.slice(this.print_len);
+ this.token_cache = [];
+ this.print_len = 0;
+ } else if (text.length > 0 && is_chinese_char(text.charCodeAt(text.length - 1))) {
+ // If the last token is a CJK character, we print the characters.
+ printable_text = text.slice(this.print_len);
+ this.print_len += printable_text.length;
+ } else {
+ // Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
+ // which may change with the subsequent token -- there are probably smarter ways to do this!)
+ printable_text = text.slice(this.print_len, text.lastIndexOf(' ') + 1);
+ this.print_len += printable_text.length;
+ }
+
+ this.on_finalized_text(printable_text, false);
+ }
+
+ /**
+ * Flushes any remaining cache and prints a newline to stdout.
+ */
+ end() {
+ let printable_text;
+ if (this.token_cache.length > 0) {
+ const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
+ printable_text = text.slice(this.print_len);
+ this.token_cache = [];
+ this.print_len = 0;
+ } else {
+ printable_text = '';
+ }
+ this.next_tokens_are_prompt = true;
+ this.on_finalized_text(printable_text, true);
+ }
+
+ /**
+ * Prints the new text to stdout. If the stream is ending, also prints a newline.
+ * @param {string} text
+ * @param {boolean} stream_end
+ */
+ on_finalized_text(text, stream_end) {
+ if (text.length > 0) {
+ this.callback_function?.(text);
+ }
+ if (stream_end && this.callback_function === stdout_write && apis.IS_PROCESS_AVAILABLE) {
+ this.callback_function?.('\n');
+ }
+ }
+}
+
+/**
+ * Utility class to handle streaming of tokens generated by whisper speech-to-text models.
+ * Callback functions are invoked when each of the following events occur:
+ * - A new chunk starts (on_chunk_start)
+ * - A new token is generated (callback_function)
+ * - A chunk ends (on_chunk_end)
+ * - The stream is finalized (on_finalize)
+ */
+export class WhisperTextStreamer extends TextStreamer {
+ /**
+ * @param {import('../tokenizers.js').WhisperTokenizer} tokenizer
+ * @param {Object} options
+ * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
+ * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
+ * @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated
+ * @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts
+ * @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends
+ * @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized
+ * @param {number} [options.time_precision=0.02] Precision of the timestamps
+ * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding
+ * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
+ */
+ constructor(tokenizer, {
+ skip_prompt = false,
+ callback_function = null,
+ token_callback_function = null,
+ on_chunk_start = null,
+ on_chunk_end = null,
+ on_finalize = null,
+ time_precision = 0.02,
+ skip_special_tokens = true,
+ decode_kwargs = {},
+ } = {}) {
+ super(tokenizer, {
+ skip_prompt,
+ callback_function,
+ token_callback_function,
+ decode_kwargs: { skip_special_tokens, ...decode_kwargs },
+ });
+ this.timestamp_begin = tokenizer.timestamp_begin;
+
+ this.on_chunk_start = on_chunk_start;
+ this.on_chunk_end = on_chunk_end;
+ this.on_finalize = on_finalize;
+
+ this.time_precision = time_precision;
+
+ this.waiting_for_timestamp = false;
+ }
+
+ /**
+ * @param {bigint[][]} value
+ */
+ put(value) {
+ if (value.length > 1) {
+ throw Error('WhisperTextStreamer only supports batch size of 1');
+ }
+ const tokens = value[0];
+
+ // Check if the token is a timestamp
+ if (tokens.length === 1) {
+ const offset = Number(tokens[0]) - this.timestamp_begin;
+ if (offset >= 0) {
+ const time = offset * this.time_precision;
+ if (this.waiting_for_timestamp) {
+ this.on_chunk_end?.(time);
+ } else {
+ this.on_chunk_start?.(time);
+ }
+ this.waiting_for_timestamp = !this.waiting_for_timestamp; // Toggle
+ value = [[]]; // Skip timestamp
+ }
+ }
+ return super.put(value);
+ }
+
+ end() {
+ super.end();
+ this.on_finalize?.();
+ }
+}
diff --git a/src/models.js b/src/models.js
index b6b1c71b1..b7d2b0ee2 100644
--- a/src/models.js
+++ b/src/models.js
@@ -5,11 +5,11 @@
* **Example:** Load and run an `AutoModel`.
*
* ```javascript
- * import { AutoModel, AutoTokenizer } from '@xenova/transformers';
- *
+ * import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
+ *
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
- *
+ *
* let inputs = await tokenizer('I love transformers!');
* let { logits } = await model(inputs);
* // Tensor {
@@ -24,11 +24,11 @@
*
* **Example:** Load and run an `AutoModelForSeq2SeqLM`.
* ```javascript
- * import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@xenova/transformers';
+ * import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
* let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
- *
+ *
* let { input_ids } = await tokenizer('translate English to German: I love transformers!');
* let outputs = await model.generate(input_ids);
* let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true });
@@ -40,13 +40,30 @@
import {
AutoConfig,
+ getKeyValueShapes,
} from './configs.js';
+import {
+ deviceToExecutionProviders,
+ createInferenceSession,
+ isONNXTensor,
+ isONNXProxy,
+} from './backends/onnx.js';
+import {
+ DATA_TYPES,
+ DEFAULT_DEVICE_DTYPE_MAPPING,
+ DEFAULT_DTYPE_SUFFIX_MAPPING,
+ isWebGpuFp16Supported,
+} from './utils/dtypes.js';
+
import {
Callable,
+} from './utils/generic.js';
+
+import {
isIntegralNumber,
- isTypedArray,
mergeArrays,
+ pick,
} from './utils/core.js';
import {
@@ -54,10 +71,12 @@ import {
getModelJSON,
} from './utils/hub.js';
+import {
+ GITHUB_ISSUE_URL,
+} from './utils/constants.js';
+
import {
LogitsProcessorList,
- GenerationConfig,
- ForceTokensLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
@@ -68,24 +87,35 @@ import {
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
- Sampler,
-} from './utils/generation.js';
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ ClassifierFreeGuidanceLogitsProcessor,
+} from './generation/logits_process.js';
+
+import {
+ GenerationConfig,
+} from './generation/configuration_utils.js';
import {
cat,
- dynamicTimeWarping,
+ full_like,
mean,
+ ones,
ones_like,
stack,
std_mean,
Tensor,
+ zeros_like,
} from './utils/tensor.js';
-import { executionProviders, ONNX } from './backends/onnx.js';
-import { medianFilter } from './transformers.js';
-const { InferenceSession, Tensor: ONNXTensor, env } = ONNX;
+import { dynamic_time_warping, medianFilter } from './utils/maths.js';
+import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
+import { LogitsSampler } from './generation/logits_sampler.js';
+import { apis } from './env.js';
-/** @typedef {import('onnxruntime-web').InferenceSession} InferenceSession */
+import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js';
+import { whisper_language_to_code } from './models/whisper/common_whisper.js';
//////////////////////////////////////////////////
// Model types: used internally
@@ -96,6 +126,8 @@ const MODEL_TYPES = {
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
+ ImageTextToText: 6,
+ Musicgen: 7,
}
//////////////////////////////////////////////////
@@ -113,40 +145,183 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
* Constructs an InferenceSession using a model file located at the specified path.
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
- * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
- * @returns {Promise} A Promise that resolves to an InferenceSession object.
+ * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
+ * @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
-async function constructSession(pretrained_model_name_or_path, fileName, options) {
- // TODO add option for user to force specify their desired execution provider
- let modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`;
- let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
+async function getSession(pretrained_model_name_or_path, fileName, options) {
+ const custom_config = options.config?.['transformers.js_config'] ?? {};
+ let device = options.device ?? custom_config.device;
+ if (device && typeof device !== 'string') {
+ if (device.hasOwnProperty(fileName)) {
+ device = device[fileName];
+ } else {
+ console.warn(`device not specified for "${fileName}". Using the default device.`);
+ device = null;
+ }
+ }
- try {
- return await InferenceSession.create(buffer, {
- executionProviders,
- });
- } catch (err) {
- // If the execution provided was only wasm, throw the error
- if (executionProviders.length === 1 && executionProviders[0] === 'wasm') {
- throw err;
+ // If the device is not specified, we use the default (supported) execution providers.
+ const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
+ device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
+ );
+ const executionProviders = deviceToExecutionProviders(selectedDevice);
+
+ // If options.dtype is specified, we use it to choose the suffix for the model file.
+ // Otherwise, we use the default dtype for the device.
+ let dtype = options.dtype ?? custom_config.dtype;
+ if (typeof dtype !== 'string') {
+ if (dtype && dtype.hasOwnProperty(fileName)) {
+ dtype = dtype[fileName];
+ } else {
+ dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
+ console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`);
}
+ }
+
+ const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
+
+ if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
+ throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
+ } else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) {
+ throw new Error(`The device (${selectedDevice}) does not support fp16.`);
+ }
+
+ // Only valid for models with a decoder
+ const kv_cache_dtype = custom_config.kv_cache_dtype
+ ? (typeof custom_config.kv_cache_dtype === 'string'
+ ? custom_config.kv_cache_dtype
+ : custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
+ : undefined;
+
+ if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
+ throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
+ }
+
+ const session_config = {
+ dtype: selectedDtype,
+ kv_cache_dtype,
+ }
+
+ // Construct the model file name
+ const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
+ const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
+
+ const session_options = { ...options.session_options };
+
+ // Overwrite `executionProviders` if not specified
+ session_options.executionProviders ??= executionProviders;
- console.warn(err);
+ // Overwrite `freeDimensionOverrides` if specified in config and not set in session options
+ const free_dimension_overrides = custom_config.free_dimension_overrides;
+ if (free_dimension_overrides) {
+ session_options.freeDimensionOverrides ??= free_dimension_overrides;
+ } else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
console.warn(
- 'Something went wrong during model construction (most likely a missing operation). ' +
- 'Using `wasm` as a fallback. '
+ 'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
+ 'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
+ );
+ }
+
+ const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
+
+ // handle onnx external data files
+ const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
+ /** @type {Promise<{path: string, data: Uint8Array}>[]} */
+ let externalDataPromises = [];
+ if (use_external_data_format && (
+ use_external_data_format === true ||
+ (
+ typeof use_external_data_format === 'object' &&
+ use_external_data_format.hasOwnProperty(fileName) &&
+ use_external_data_format[fileName] === true
)
- return await InferenceSession.create(buffer, {
- executionProviders: ['wasm']
+ )) {
+ if (apis.IS_NODE_ENV) {
+ throw new Error('External data format is not yet supported in Node.js');
+ }
+ const path = `${fileName}${suffix}.onnx_data`;
+ const fullPath = `${options.subfolder ?? ''}/${path}`;
+ externalDataPromises.push(new Promise(async (resolve, reject) => {
+ const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options);
+ resolve({ path, data })
+ }));
+
+ } else if (session_options.externalData !== undefined) {
+ externalDataPromises = session_options.externalData.map(async (ext) => {
+ // if the external data is a string, fetch the file and replace the string with its content
+ if (typeof ext.data === "string") {
+ const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
+ return { ...ext, data: ext_buffer };
+ }
+ return ext;
});
}
+
+ if (externalDataPromises.length > 0) {
+ session_options.externalData = await Promise.all(externalDataPromises);
+ }
+
+ if (selectedDevice === 'webgpu') {
+ const shapes = getKeyValueShapes(options.config, {
+ prefix: 'present',
+ });
+ if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
+ // Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
+ /** @type {Record} */
+ const preferredOutputLocation = {};
+ for (const key in shapes) {
+ preferredOutputLocation[key] = 'gpu-buffer';
+ }
+ session_options.preferredOutputLocation = preferredOutputLocation;
+ }
+ }
+
+ const buffer = await bufferPromise;
+
+ return { buffer, session_options, session_config };
+}
+
+/**
+ * Helper function to create multiple InferenceSession objects.
+ *
+ * @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
+ * @param {Record} names The names of the model files to load.
+ * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
+ * @returns {Promise>} A Promise that resolves to a dictionary of InferenceSession objects.
+ * @private
+ */
+async function constructSessions(pretrained_model_name_or_path, names, options) {
+ return Object.fromEntries(await Promise.all(
+ Object.keys(names).map(async (name) => {
+ const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
+ const session = await createInferenceSession(buffer, session_options, session_config);
+ return [name, session];
+ })
+ ));
+}
+
+/**
+ * Helper function to load multiple optional configuration files
+ * @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
+ * @param {Record} names The names of the config files to load.
+ * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
+ * @returns {Promise>} A Promise that resolves to a dictionary of configuration objects.
+ * @private
+ */
+async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
+ return Object.fromEntries(await Promise.all(
+ Object.keys(names).map(async (name) => {
+ const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options);
+ return [name, config];
+ })
+ ));
}
/**
* Validate model inputs
- * @param {InferenceSession} session The InferenceSession object that will be run.
- * @param {Record} inputs The inputs to check.
+ * @param {Object} session The InferenceSession object that will be run.
+ * @param {Object} inputs The inputs to check.
* @returns {Record} The checked inputs.
* @throws {Error} If any inputs are missing.
* @private
@@ -170,7 +345,7 @@ function validateInputs(session, inputs) {
// NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker
// boundary, transferring ownership to the worker and invalidating the tensor.
// So, in this case, we simply sacrifice a clone for it.
- checkedInputs[inputName] = env.wasm.proxy ? tensor.clone() : tensor;
+ checkedInputs[inputName] = isONNXProxy() ? tensor.clone() : tensor;
}
if (missingInputs.length > 0) {
throw new Error(
@@ -195,7 +370,7 @@ function validateInputs(session, inputs) {
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
- * @param {InferenceSession} session The InferenceSession object to run.
+ * @param {Object} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @returns {Promise} A Promise that resolves to an object that maps output names to output tensors.
* @private
@@ -203,8 +378,9 @@ function validateInputs(session, inputs) {
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
- // @ts-ignore
- let output = await session.run(checkedInputs);
+ // pass the original ort tensor
+ const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
+ let output = await session.run(ortFeed);
output = replaceTensors(output);
return output;
} catch (e) {
@@ -223,7 +399,7 @@ async function sessionRun(session, inputs) {
*/
function replaceTensors(obj) {
for (let prop in obj) {
- if (obj[prop] instanceof ONNXTensor) {
+ if (isONNXTensor(obj[prop])) {
obj[prop] = new Tensor(obj[prop]);
} else if (typeof obj[prop] === 'object') {
replaceTensors(obj[prop]);
@@ -268,72 +444,6 @@ function toI64Tensor(items) {
}
}
-/**
- * Prepares an attention mask for a sequence of tokens based on configuration options.
- * @param {Object} self The calling object instance.
- * @param {Tensor} tokens The input tokens.
- * @returns {Tensor} The attention mask tensor.
- * @private
- */
-function prepareAttentionMask(self, tokens) {
-
- // Prepare attention mask
- let pad_token_id = self.config.pad_token_id ?? null;
- let eos_token_id = self.config.eos_token_id ?? null;
- if (isIntegralNumber(eos_token_id)) {
- eos_token_id = [eos_token_id];
- }
-
- let is_pad_token_in_inputs = tokens.indexOf(pad_token_id) !== -1;
- let is_pad_token_not_equal_to_eos_token_id = (eos_token_id === null) || !eos_token_id.includes(pad_token_id)
-
- if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
- let data = BigInt64Array.from(
- // Note: != so that int matches bigint
- // @ts-ignore
- tokens.data.map(x => x != pad_token_id)
- )
- return new Tensor('int64', data, tokens.dims)
- } else {
- return ones_like(tokens);
- }
-}
-
-/**
- * Add position IDs to the feeds object.
- * @param {Object} session The inference session.
- * @param {Object} feeds The input to the model.
- * @param {boolean} use_cache_branch Whether to use the cache branch of the model.
- * @returns {void}
- * @private
- */
-function preparePositionIds(session, feeds, use_cache_branch) {
- if (!session.inputNames.includes('position_ids')) return;
-
- const data = new BigInt64Array(feeds.attention_mask.data.length);
-
- // Compute cumulative sum of the attention mask along the sequence length dimension
- for (let i = 0; i < feeds.attention_mask.dims[0]; ++i) {
- let start = i * feeds.attention_mask.dims[1];
- let sum = BigInt(0);
- for (let j = 0; j < feeds.attention_mask.dims[1]; ++j) {
- const index = start + j;
- if (feeds.attention_mask.data[index] === 0n) {
- data[index] = BigInt(1);
- } else { // === 1n
- data[index] = sum;
- sum += feeds.attention_mask.data[index];
- }
- }
- }
-
- feeds.position_ids = new Tensor('int64', data, feeds.attention_mask.dims);
-
- if (use_cache_branch) {
- feeds.position_ids = feeds.position_ids.slice(null, -1).unsqueeze_(-1);
- }
-}
-
/**
* Creates a boolean tensor with a single value.
* @param {boolean} value The value of the tensor.
@@ -353,162 +463,44 @@ function boolTensor(value) {
* @private
*/
async function seq2seqForward(self, model_inputs) {
-
- let { encoder_outputs, past_key_values } = model_inputs;
-
+ let { encoder_outputs, input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs;
+ // Encode if needed
if (!encoder_outputs) {
+ const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames);
// Encoder outputs are not given, so we must compute them.
- encoder_outputs = (await encoderForward(self, model_inputs)).last_hidden_state;
+ encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state;
}
- let decoderFeeds = {
- input_ids: model_inputs.decoder_input_ids,
- encoder_hidden_states: encoder_outputs,
- };
- const use_cache_branch = !!past_key_values;
-
- if (self.decoder_merged_session.inputNames.includes('use_cache_branch')) {
- decoderFeeds.use_cache_branch = boolTensor(use_cache_branch);
- }
-
- if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) {
- decoderFeeds.encoder_attention_mask = model_inputs.attention_mask
- }
-
- preparePositionIds(self.decoder_merged_session, decoderFeeds, use_cache_branch);
- self.addPastKeyValues(decoderFeeds, past_key_values);
-
- const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds);
- let logits = decoderResults.logits;
- past_key_values = self.getPastKeyValues(decoderResults, past_key_values);
-
- // Get cross attention and/or decoder attentions if they are present
- const attns = self.getAttentions(decoderResults);
-
- return new Seq2SeqLMOutput({ logits, past_key_values, encoder_outputs, ...attns });
-}
-
-/**
- * Start the beam search process for the seq2seq model.
- * @param {PreTrainedModel} self The seq2seq model object.
- * @param {Tensor} inputTokenIds Array of input token ids for each input sequence.
- * @param {Object} generation_config The generation config.
- * @param {number} numOutputTokens The maximum number of output tokens for the model.
- * @returns {Object[]} Array of beam search objects.
- * @private
- */
-function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputTokens) {
- let beams = [];
- let beamId = 0;
-
- // @ts-ignore
- const requires_attention_mask = self.requires_attention_mask ?? true;
-
- // decoder_input_ids == output_token_ids
- let decoder_input_ids =
- generation_config.decoder_input_ids
- ?? generation_config.decoder_start_token_id
- ?? generation_config.bos_token_id
- ?? generation_config.eos_token_id;
-
- // Support input as tensor or list
- // TODO support batched decoder_input_ids
- if (decoder_input_ids instanceof Tensor) {
- decoder_input_ids = decoder_input_ids.tolist().flat();
- } else if (!Array.isArray(decoder_input_ids)) {
- decoder_input_ids = [decoder_input_ids];
- }
-
- for (let tokens of inputTokenIds) {
- // TODO: Improve
- // Currently, just add back batch dimension.
- // In future, allow for true parallel execution
- tokens.dims = [1, ...tokens.dims]
-
- // Create beam
- let start = {
- inputs: tokens,
- encoder_outputs: null,
- prev_model_outputs: null,
-
- output_token_ids: decoder_input_ids,
- done: false,
- score: 0,
- id: beamId++ // assign unique id to beams
- }
-
- if (requires_attention_mask) {
- start.attention_mask = prepareAttentionMask(self, tokens);
- }
-
- beams.push(start);
- }
-
- return beams;
-}
-
-/**
- * Run beam search on the seq2seq model for a single beam.
- * @param {PreTrainedModel} self The seq2seq model object.
- * @param {Object} beam The beam search object for which to run the model.
- * @param {Object} options options
- * @param {string} [options.input_name='input_ids'] The name of the input tensor for the encoder.
- * @returns {Promise} Promise that resolves with the output of the seq2seq model for the given beam.
- * @private
- */
-async function seq2seqRunBeam(self, beam) {
- const input_name = self.main_input_name;
- let decoder_input_ids = beam.output_token_ids;
- if (beam.prev_model_outputs) {
- // After the first step, `prev_model_outputs` won't be null.
- // So, we cut decoder_input_ids if past is used
- decoder_input_ids = decoder_input_ids.slice(-1);
- }
+ other_decoder_inputs.input_ids = decoder_input_ids;
+ other_decoder_inputs.encoder_hidden_states = encoder_outputs;
- // 1. Prepare
- let model_inputs = {
- [input_name]: beam.inputs,
- decoder_input_ids: toI64Tensor(decoder_input_ids),
- encoder_outputs: beam.encoder_outputs,
- past_key_values: beam.prev_model_outputs?.past_key_values,
+ if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) {
+ other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask
}
- if (beam.attention_mask) {
- model_inputs.attention_mask = beam.attention_mask
- }
-
- // 2. Run
- let output = await self.forward(model_inputs);
- // 3. Update
- beam.prev_model_outputs = output;
- beam.encoder_outputs = output.encoder_outputs;
+ const decoderResults = await decoderForward(self, other_decoder_inputs, true);
- return output;
-}
-
-/**
- * Update a beam with a new token ID.
- * @param {Object} beam The beam to update.
- * @param {number} newTokenId The new token ID to add to the beam's output.
- * @private
- */
-function seq2seqUpdatebeam(beam, newTokenId) {
- beam.output_token_ids = [...beam.output_token_ids, newTokenId];
+ return decoderResults;
}
/**
* Forward pass of an encoder model.
* @param {Object} self The encoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
- * @returns {Promise} Promise that resolves with an object containing the model's outputs.
+ * @returns {Promise} The model's outputs.
* @private
*/
async function encoderForward(self, model_inputs) {
- const encoderFeeds = Object.create(null);
- for (const key of self.session.inputNames) {
- encoderFeeds[key] = model_inputs[key];
+ const session = self.sessions['model'];
+ const encoderFeeds = pick(model_inputs, session.inputNames);
+
+ if (session.inputNames.includes('inputs_embeds') && !encoderFeeds.inputs_embeds) {
+ if (!model_inputs.input_ids) {
+ throw new Error('Both `input_ids` and `inputs_embeds` are missing in the model inputs.');
+ }
+ encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids });
}
- if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
+ if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
// Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
encoderFeeds.token_type_ids = new Tensor(
@@ -517,136 +509,211 @@ async function encoderForward(self, model_inputs) {
encoderFeeds.input_ids.dims
)
}
- return await sessionRun(self.session, encoderFeeds);
+ return await sessionRun(session, encoderFeeds);
}
-
/**
* Forward pass of a decoder model.
* @param {Object} self The decoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
- * @returns {Promise} Promise that resolves with an object containing the logits and past key values.
+ * @returns {Promise} The logits and past key values.
* @private
*/
-async function decoderForward(self, model_inputs) {
- let { input_ids, past_key_values, attention_mask } = model_inputs;
- let decoderFeeds = {
- input_ids: input_ids,
- attention_mask: attention_mask ?? prepareAttentionMask(self, input_ids),
- }
- const use_cache_branch = !!past_key_values;
-
- if (self.session.inputNames.includes('use_cache_branch')) {
- decoderFeeds.use_cache_branch = boolTensor(use_cache_branch);
- }
+async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
- preparePositionIds(self.session, decoderFeeds, use_cache_branch);
+ const session = self.sessions[
+ is_encoder_decoder ? 'decoder_model_merged' : 'model'
+ ]
- self.addPastKeyValues(decoderFeeds, past_key_values);
+ const { past_key_values, ...new_model_inputs } = model_inputs;
- let decoderResults = await sessionRun(self.session, decoderFeeds);
+ if (session.inputNames.includes('use_cache_branch')) {
+ new_model_inputs.use_cache_branch = boolTensor(!!past_key_values);
+ }
+ if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) {
+ new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values);
+ }
- let logits = decoderResults.logits;
+ // Unpack the `past_key_values` object into model inputs
+ self.addPastKeyValues(new_model_inputs, past_key_values);
- past_key_values = self.getPastKeyValues(decoderResults, past_key_values);
- return { logits, past_key_values };
+ // Select only the inputs that are needed for the current session
+ const fixed = pick(new_model_inputs, session.inputNames);
+ return await sessionRun(session, fixed);
}
+
/**
- * Starts the generation of text by initializing the beams for the given input token IDs.
- * @param {Object} self The text generation model object.
- * @param {Tensor} inputTokenIds An tensor of input token IDs to generate text from.
- * @param {Object} generation_config The generation config.
- * @param {number} numOutputTokens The maximum number of tokens to generate for each beam.
- * @param {Tensor} [inputs_attention_mask] The attention mask tensor for the input token IDs.
- * @returns {Object[]} An array of beams initialized with the given inputs and parameters.
+ * Forward pass of an image-text-to-text model.
+ * @param {Object} self The image-text-to-text model model.
+ * @param {Object} model_inputs The input data to be used for the forward pass.
+ * @param {Tensor} [model_inputs.input_ids=null]
+ * @param {Tensor} [model_inputs.attention_mask=null]
+ * @param {Tensor} [model_inputs.pixel_values=null]
+ * @param {Tensor} [model_inputs.position_ids=null]
+ * @param {Tensor} [model_inputs.inputs_embeds=null]
+ * @param {Tensor} [model_inputs.past_key_values=null]
+ * @param {Object} [model_inputs.generation_config=null]
+ * @param {Object} [model_inputs.logits_processor=null]
+ * @returns {Promise} The model's output tensor
* @private
*/
-function decoderStartBeams(self, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) {
- let beams = [];
-
- let beamId = 0;
- for (let tokens of inputTokenIds) {
- let output_token_ids = tokens.tolist().map(Number);
-
- // TODO: Improve
- // Currently, just add back batch dimension.
- // In future, allow for true parallel execution
- tokens.dims = [1, ...tokens.dims]
+async function imageTextToTextForward(self, {
+ // Produced by the tokenizer/processor:
+ input_ids = null,
+ attention_mask = null,
+ pixel_values = null,
+
+ // Used during generation:
+ position_ids = null,
+ inputs_embeds = null,
+ past_key_values = null,
+
+ // Generic generation parameters
+ generation_config = null,
+ logits_processor = null,
+
+ // TODO: needed?
+ ...kwargs
+}) {
+
+ if (!inputs_embeds) {
+ // 1. Extract the input embeddings
+ inputs_embeds = await self.encode_text({ input_ids });
+
+ // 2. Possibly, merge text and images
+ if (pixel_values && input_ids.dims[1] !== 1) {
+ const image_features = await self.encode_image({ pixel_values });
+
+ ({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
+ image_features,
+ inputs_embeds,
+ input_ids,
+ attention_mask,
+ }));
- let attn_mask;
- if (inputs_attention_mask) {
- attn_mask = inputs_attention_mask[beamId];
- attn_mask.dims = [1, ...attn_mask.dims]
+ } else if (past_key_values && pixel_values && input_ids.dims[1] === 1) {
+ // This is the case when we are generating with cache
+ const target_length = input_ids.dims[1]; // always 1
+ const past_length = Object.values(past_key_values)[0].dims.at(-2);
- } else {
- attn_mask = prepareAttentionMask(self, tokens)
+ attention_mask = cat([
+ ones([input_ids.dims[0], past_length]),
+ attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
+ ], 1);
}
+ }
- let start = {
- input: tokens,
- model_input_ids: tokens,
- attention_mask: attn_mask,
- prev_model_outputs: null,
-
- output_token_ids: output_token_ids,
- num_output_tokens: numOutputTokens,
-
- done: false,
- score: 0,
- id: beamId++ // assign unique id to beams
+ const outputs = await decoderForward(self, {
+ inputs_embeds,
+ past_key_values,
+ attention_mask,
+ position_ids,
+ generation_config,
+ logits_processor,
+ }, true);
+ return outputs;
+}
+
+function createPositionIds(model_inputs, past_key_values = null) {
+ // If the model supports providing position_ids, we create position_ids on the fly for batch generation,
+ // by computing the cumulative sum of the attention mask along the sequence length dimension.
+ //
+ // Equivalent to:
+ // position_ids = attention_mask.long().cumsum(-1) - 1
+ // position_ids.masked_fill_(attention_mask == 0, 1)
+ // if past_key_values:
+ // position_ids = position_ids[:, -input_ids.shape[1] :]
+ const { input_ids, inputs_embeds, attention_mask } = model_inputs;
+ const [bz, seq_len] = attention_mask.dims;
+
+ const data = new BigInt64Array(attention_mask.data.length);
+ for (let i = 0; i < bz; ++i) {
+ const start = i * seq_len;
+ let sum = BigInt(0);
+ for (let j = 0; j < seq_len; ++j) {
+ const index = start + j;
+ if (attention_mask.data[index] === 0n) {
+ data[index] = BigInt(1);
+ } else { // === 1n
+ data[index] = sum;
+ sum += attention_mask.data[index];
+ }
}
+ }
- beams.push(start);
+ let position_ids = new Tensor('int64', data, attention_mask.dims);
+ if (past_key_values) {
+ const offset = -(input_ids ?? inputs_embeds).dims.at(1);
+ position_ids = position_ids.slice(null, [offset, null]);
}
- return beams;
+ return position_ids;
}
-/**
- * Runs a single step of the text generation process for a given beam.
- *
- * @param {Object} self The decoder object.
- * @param {Object} beam The beam to run.
- * @param {Tensor} beam.input The input tensor.
- * @param {Tensor} beam.model_input_ids The input ids to the model.
- * @param {Tensor} beam.attention_mask The attention mask.
- * @param {Object} beam.prev_model_outputs The past key values.
- * @param {number[]} beam.output_token_ids The output token ids.
- * @returns {Promise} The output of the generation step.
- * @private
- */
-async function decoderRunBeam(self, beam) {
- let attnMaskData = new BigInt64Array(beam.output_token_ids.length).fill(1n)
+function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
+ if (model_inputs.past_key_values) {
+ const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2);
+ const { input_ids, attention_mask } = model_inputs;
- // 1. Prepare
- let model_inputs = {
- input_ids: beam.model_input_ids,
- attention_mask: new Tensor(
- 'int64',
- attnMaskData,
- [1, attnMaskData.length]
- ),
- past_key_values: beam.prev_model_outputs?.past_key_values,
+ // Keep only the unprocessed tokens:
+ // 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ // some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ // input)
+ if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) {
+ // NOTE: not needed since we only pass the generated tokens to the next forward pass
+ // const offset = -(attention_mask.dims[1] - past_length);
+ // model_inputs.input_ids = input_ids.slice(null, [offset, null]);
+ }
+ // 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
+ // We can discard input_ids based on the past_length.
+ else if (past_length < input_ids.dims[1]) {
+ // NOTE: Required for phi models.
+ // See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information.
+ model_inputs.input_ids = input_ids.slice(null, [past_length, null]);
+ }
+ // 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+ else {
+ if (
+ // NOTE: Only used by VLMs (!= so that null matches undefined)
+ self.config.image_token_index != null &&
+ // Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
+ input_ids.data.some(x => x == self.config.image_token_index)
+ ) {
+ // TODO: Support multiple image tokens
+ const num_image_tokens = self.config.num_image_tokens;
+ if (!num_image_tokens) {
+ throw new Error('`num_image_tokens` is missing in the model configuration.');
+ }
+
+ const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens);
+ model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]);
+
+ // TODO: The attention mask should be formed from the attention mask passed in model_inputs
+ model_inputs.attention_mask = ones([1, past_length + num_new_tokens]);
+ }
+ }
}
- // 2. Run
- let output = await self.forward(model_inputs);
+ return model_inputs;
+}
- // 3. Update
- beam.prev_model_outputs = output;
+function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
+ if (model_inputs.past_key_values) {
+ input_ids = input_ids.map(x => [x.at(-1)]);
+ }
- return output;
+ return {
+ ...model_inputs,
+ decoder_input_ids: toI64Tensor(input_ids),
+ };
}
-/**
- * Update a beam with a new token ID.
- * @param {Object} beam The beam to update.
- * @param {number} newTokenId The new token ID to add to the beam's output.
- * @private
- */
-function decoderUpdatebeam(beam, newTokenId) {
- beam.output_token_ids = [...beam.output_token_ids, newTokenId];
- beam.model_input_ids = new Tensor('int64', [BigInt(newTokenId)], [1, 1]);
+function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
+ if (self.config.is_encoder_decoder) {
+ return encoder_decoder_prepare_inputs_for_generation(self, ...args);
+ } else {
+ return decoder_prepare_inputs_for_generation(self, ...args);
+ }
}
//////////////////////////////////////////////////
@@ -657,48 +724,63 @@ function decoderUpdatebeam(beam, newTokenId) {
*/
export class PreTrainedModel extends Callable {
main_input_name = 'input_ids';
-
+ forward_params = ['input_ids', 'attention_mask'];
/**
* Creates a new instance of the `PreTrainedModel` class.
- * @param {Object} config The model configuration.
- * @param {any} session session for the model.
+ * @param {import('./configs.js').PretrainedConfig} config The model configuration.
+ * @param {Record} sessions The inference sessions for the model.
+ * @param {Record} configs Additional configuration files (e.g., generation_config.json).
*/
- constructor(config, session) {
+ constructor(config, sessions, configs) {
super();
this.config = config;
- this.session = session;
+ this.sessions = sessions;
+ this.configs = configs;
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
this.can_generate = false;
- this._runBeam = null;
- this._getStartBeams = null;
- this._updateBeam = null;
this._forward = null;
- if (modelType === MODEL_TYPES.DecoderOnly) {
- this.can_generate = true;
- this._runBeam = decoderRunBeam;
- this._getStartBeams = decoderStartBeams;
- this._updateBeam = decoderUpdatebeam;
- this._forward = decoderForward;
+ this._prepare_inputs_for_generation = null;
+ switch (modelType) {
+ case MODEL_TYPES.DecoderOnly:
+ this.can_generate = true;
+ this._forward = decoderForward;
+ this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation;
+ break;
+ case MODEL_TYPES.Seq2Seq:
+ case MODEL_TYPES.Vision2Seq:
+ case MODEL_TYPES.Musicgen:
+ this.can_generate = true;
- } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
- this.can_generate = true;
+ this._forward = seq2seqForward;
+ this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation;
+ break;
- this._runBeam = seq2seqRunBeam;
- this._getStartBeams = seq2seqStartBeams;
- this._updateBeam = seq2seqUpdatebeam;
- this._forward = seq2seqForward;
+ case MODEL_TYPES.EncoderDecoder:
+ this._forward = seq2seqForward;
+ break;
+ case MODEL_TYPES.ImageTextToText:
+ this.can_generate = true;
+ this._forward = imageTextToTextForward;
+ this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
+ break;
- } else if (modelType === MODEL_TYPES.EncoderDecoder) {
- this._forward = encoderForward;
+ default:
+ // should be MODEL_TYPES.EncoderOnly
+ this._forward = encoderForward;
+ break;
+ }
- } else { // should be MODEL_TYPES.EncoderOnly
- this._forward = encoderForward;
+ if (this.can_generate) {
+ this.forward_params.push('past_key_values');
}
+
+ /** @type {import('./configs.js').TransformersJSConfig} */
+ this.custom_config = this.config['transformers.js_config'] ?? {};
}
/**
@@ -708,11 +790,9 @@ export class PreTrainedModel extends Callable {
*/
async dispose() {
const promises = [];
- for (let key of Object.keys(this)) {
- const item = this[key];
- // @ts-ignore
- if (item instanceof InferenceSession) {
- promises.push(item.handler.dispose())
+ for (const session of Object.values(this.sessions)) {
+ if (session?.handler?.dispose) {
+ promises.push(session.handler.dispose())
}
}
return await Promise.all(promises);
@@ -729,75 +809,122 @@ export class PreTrainedModel extends Callable {
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
- * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
+ * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
*
* @returns {Promise} A new instance of the `PreTrainedModel` class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
- quantized = true,
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
+ subfolder = 'onnx',
+ device = null,
+ dtype = null,
+ use_external_data_format = null,
+ session_options = {},
} = {}) {
let options = {
- quantized,
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
+ subfolder,
+ device,
+ dtype,
+ use_external_data_format,
+ session_options,
}
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
+ config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
+
let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
- AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
- constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options),
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
+ constructSessions(pretrained_model_name_or_path, {
+ model: options.model_file_name ?? 'model',
+ }, options),
+ getOptionalConfigs(pretrained_model_name_or_path, {
+ generation_config: 'generation_config.json',
+ }, options),
]);
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
- AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
- constructSession(pretrained_model_name_or_path, 'encoder_model', options),
- constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options),
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
+ constructSessions(pretrained_model_name_or_path, {
+ model: 'encoder_model',
+ decoder_model_merged: 'decoder_model_merged',
+ }, options),
+ getOptionalConfigs(pretrained_model_name_or_path, {
+ generation_config: 'generation_config.json',
+ }, options),
]);
} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
- AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
- constructSession(pretrained_model_name_or_path, 'vision_encoder', options),
- constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options),
+ constructSessions(pretrained_model_name_or_path, {
+ model: 'vision_encoder',
+ prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
+ }, options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
- AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
- constructSession(pretrained_model_name_or_path, 'encoder_model', options),
- constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options),
+ constructSessions(pretrained_model_name_or_path, {
+ model: 'encoder_model',
+ decoder_model_merged: 'decoder_model_merged',
+ }, options),
+ ]);
+
+ } else if (modelType === MODEL_TYPES.ImageTextToText) {
+ const sessions = {
+ embed_tokens: 'embed_tokens',
+ vision_encoder: 'vision_encoder',
+ decoder_model_merged: 'decoder_model_merged',
+ }
+ if (config.is_encoder_decoder) {
+ sessions['model'] = 'encoder_model';
+ }
+ info = await Promise.all([
+ constructSessions(pretrained_model_name_or_path, sessions, options),
+ getOptionalConfigs(pretrained_model_name_or_path, {
+ generation_config: 'generation_config.json',
+ }, options),
+ ]);
+
+ } else if (modelType === MODEL_TYPES.Musicgen) {
+ info = await Promise.all([
+ constructSessions(pretrained_model_name_or_path, {
+ model: 'text_encoder',
+ decoder_model_merged: 'decoder_model_merged',
+ encodec_decode: 'encodec_decode',
+ }, options),
+ getOptionalConfigs(pretrained_model_name_or_path, {
+ generation_config: 'generation_config.json',
+ }, options),
]);
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
- console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`)
+ console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
}
info = await Promise.all([
- AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
- constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'model', options)
+ constructSessions(pretrained_model_name_or_path, {
+ model: options.model_file_name ?? 'model',
+ }, options),
]);
}
// @ts-ignore
- return new this(...info);
+ return new this(config, ...info);
}
/**
@@ -821,7 +948,41 @@ export class PreTrainedModel extends Callable {
}
/**
- * @param {import('./utils/generation.js').GenerationConfigType} generation_config
+ * Get the model's generation config, if it exists.
+ * @returns {GenerationConfig|null} The model's generation config if it exists, otherwise `null`.
+ */
+ get generation_config() {
+ return this.configs?.generation_config ?? null;
+ }
+
+ /**
+ * This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
+ * instances used for multinomial sampling.
+ * @param {GenerationConfig} generation_config The generation config.
+ * @returns {LogitsProcessorList} generation_config
+ */
+ _get_logits_warper(generation_config) {
+
+ // instantiate warpers list
+ const warpers = new LogitsProcessorList();
+
+ if (generation_config.temperature !== null && generation_config.temperature !== 1.0) {
+ warpers.push(new TemperatureLogitsWarper(generation_config.temperature));
+ }
+ if (generation_config.top_k !== null && generation_config.top_k !== 0) {
+ // TODO: add min_tokens_to_keep
+ warpers.push(new TopKLogitsWarper(generation_config.top_k));
+ }
+ if (generation_config.top_p !== null && generation_config.top_p < 1.0) {
+ // TODO: add min_tokens_to_keep
+ warpers.push(new TopPLogitsWarper(generation_config.top_p));
+ }
+
+ return warpers;
+ }
+
+ /**
+ * @param {GenerationConfig} generation_config
* @param {number} input_ids_seq_length The starting sequence length for the input ids.
* @returns {LogitsProcessorList}
* @private
@@ -921,19 +1082,22 @@ export class PreTrainedModel extends Callable {
// }
if (generation_config.begin_suppress_tokens !== null) {
- let begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null)
+ const begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null)
? input_ids_seq_length
: input_ids_seq_length + 1;
- if (generation_config.forced_decoder_ids !== null) {
- // generation starts after the last token that is forced
- begin_index += generation_config.forced_decoder_ids[generation_config.forced_decoder_ids.length - 1][0];
- }
processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index));
}
- if (generation_config.forced_decoder_ids !== null) {
- processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids));
+ // DEPRECATED: https://github.com/huggingface/transformers/pull/29485
+ // if (generation_config.forced_decoder_ids !== null) {
+ // processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids));
+ // }
+
+
+ // 8. prepare batched CFG externally
+ if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
+ processors.push(new ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale));
}
if (logits_processor !== null) {
@@ -951,287 +1115,473 @@ export class PreTrainedModel extends Callable {
/**
* This function merges multiple generation configs together to form a final generation config to be used by the model for text generation.
* It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object.
- * @param {import('./utils/generation.js').GenerationConfigType} generation_config A `GenerationConfig` object containing generation parameters.
- * @returns {import('./utils/generation.js').GenerationConfigType} The final generation config object to be used by the model for text generation.
+ * @param {GenerationConfig|null} generation_config A `GenerationConfig` object containing generation parameters.
+ * @param {Object} kwargs Additional generation parameters to be used in place of those in the `generation_config` object.
+ * @returns {GenerationConfig} The final generation config object to be used by the model for text generation.
*/
- _get_generation_config(generation_config) {
+ _prepare_generation_config(generation_config, kwargs, cls = GenerationConfig) {
// Create empty generation config (contains defaults)
// We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
- let gen_config = new GenerationConfig(this.config);
+ const config = { ...this.config };
+ for (const key of ["decoder", "generator", "text_config"]) {
+ // Special case: some models have generation attributes set in the decoder.
+ // Use them if still unset in the generation config.
+ if (key in config) {
+ Object.assign(config, config[key]);
+ }
+ }
+
+ const gen_config = new cls(config);
// Apply model's generation config, if it exists
- if ('generation_config' in this) {
- Object.assign(gen_config, this.generation_config);
- }
+ Object.assign(gen_config, this.generation_config ?? {});
- // Finally, use any generation config specified by the user
+ // Next, use any generation config specified by the user
// when calling `generate`
- if (generation_config !== null) {
+ if (generation_config) {
Object.assign(gen_config, generation_config);
}
+
+ // Finally, if any kwargs were passed, use them to overwrite
+ if (kwargs) {
+ Object.assign(gen_config, pick(kwargs, Object.getOwnPropertyNames(gen_config)));
+ }
+
return gen_config;
}
/**
- * @typedef {import('./utils/maths.js').TypedArray} TypedArray
+ *
+ * @param {GenerationConfig} generation_config
+ * @param {StoppingCriteriaList} [stopping_criteria=null]
*/
+ _get_stopping_criteria(generation_config, stopping_criteria = null) {
+ const criteria = new StoppingCriteriaList();
+
+ if (generation_config.max_length !== null) {
+ criteria.push(new MaxLengthCriteria(
+ generation_config.max_length,
+ this.config.max_position_embeddings ?? null,
+ ));
+ }
+ // if (generation_config.max_time !== null) {
+ // criteria.push(new MaxTimeCriteria(generation_config.max_time));
+ // }
+ if (generation_config.eos_token_id !== null) {
+ criteria.push(new EosTokenCriteria(generation_config.eos_token_id));
+ }
+
+ if (stopping_criteria) {
+ criteria.extend(stopping_criteria);
+ }
+ return criteria;
+
+ }
/**
- * @typedef {{ sequences: Tensor, decoder_attentions: Tensor, cross_attentions: Tensor }} EncoderDecoderOutput
- * @typedef {Object} DecoderOutput
- *
- * Generates text based on the given inputs and generation configuration using the model.
- * @param {Tensor|Array|TypedArray} inputs An array of input token IDs.
- * @param {Object|GenerationConfig|null} generation_config The generation configuration to use. If null, default configuration will be used.
- * @param {Object|null} logits_processor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
- * @param {Object} options options
- * @param {Object} [options.inputs_attention_mask=null] An optional attention mask for the inputs.
- * @returns {Promise} An array of generated output sequences, where each sequence is an array of token IDs.
- * @throws {Error} Throws an error if the inputs array is empty.
- */
- async generate(
- inputs,
- generation_config = null,
- logits_processor = null,
- {
- inputs_attention_mask = null
- } = {},
- ) {
+ * Confirms that the model class is compatible with generation.
+ * If not, raises an exception that points to the right class to use.
+ */
+ _validate_model_class() {
if (!this.can_generate) {
+ const generate_compatible_mappings = [
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ // MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, // TODO
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
+ ];
+
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
- let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.`
+ const generate_compatible_classes = new Set();
const modelType = this.config.model_type;
- const possibleInfo =
- MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType)
- ?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType)
- ?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType)
- // ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType) // TODO
- ?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType);
-
- if (possibleInfo) {
- // TODO: support multiple possible classes
- errorMessage += ` Please use the following class instead: '${possibleInfo[0]}'`;
+ for (const model_mapping of generate_compatible_mappings) {
+ const supported_models = model_mapping.get(modelType);
+ if (supported_models) {
+ generate_compatible_classes.add(supported_models[0]);
+ }
+ }
+
+ let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.`
+ if (generate_compatible_classes.size > 0) {
+ errorMessage += ` Please use the following class instead: ${[...generate_compatible_classes].join(', ')}`;
}
throw Error(errorMessage);
}
+ }
- if (!(inputs instanceof Tensor) && !isTypedArray(inputs) && !Array.isArray(inputs)) {
- throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs.constructor.name}".`);
- }
+ prepare_inputs_for_generation(...args) {
+ return this._prepare_inputs_for_generation(this, ...args);
+ }
- let input_ids_seq_length;
+ /**
+ *
+ * @param {Object} inputs
+ * @param {bigint[][]} inputs.generated_input_ids
+ * @param {Object} inputs.outputs
+ * @param {Object} inputs.model_inputs
+ * @param {boolean} inputs.is_encoder_decoder
+ * @returns {Object} The updated model inputs for the next generation iteration.
+ */
+ _update_model_kwargs_for_generation({ generated_input_ids, outputs, model_inputs, is_encoder_decoder }) {
+ // update past_key_values
+ model_inputs['past_key_values'] = this.getPastKeyValues(outputs, model_inputs.past_key_values);
+
+ // update inputs for next run
+ model_inputs['input_ids'] = new Tensor('int64', generated_input_ids.flat(), [generated_input_ids.length, 1]);
+
+ if (!is_encoder_decoder) {
+ // update attention mask
+ model_inputs.attention_mask = cat(
+ [
+ model_inputs.attention_mask,
+ ones([model_inputs.attention_mask.dims[0], 1]),
+ ], 1
+ );
+ } else if ('decoder_attention_mask' in model_inputs) {
+ // TODO: update decoder attention mask if the model requires it
+ }
- // Prepare `input_ids` which will be used for auto-regressive generation
- // TODO: Update to align with HF transformers' implementation
- if (this.config.is_encoder_decoder) {
- // Generating from the encoder outputs
- input_ids_seq_length = 0;
+ // force recreate position_ids in next iteration
+ model_inputs['position_ids'] = null;
- } else {
- input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length;
+ return model_inputs;
+ }
- // decoder-only
- if (input_ids_seq_length === 0) {
- throw Error("Must supply a non-empty array of input token ids.")
+ /**
+ * This function extracts the model-specific `inputs` for generation.
+ * @param {Object} params
+ * @param {Tensor} [params.inputs=null]
+ * @param {number} [params.bos_token_id=null]
+ * @param {Record} [params.model_kwargs]
+ * @returns {{inputs_tensor: Tensor, model_inputs: Record, model_input_name: string}} The model-specific inputs for generation.
+ */
+ _prepare_model_inputs({ inputs, bos_token_id, model_kwargs }) {
+ const model_inputs = pick(model_kwargs, this.forward_params);
+ const input_name = this.main_input_name;
+ if (input_name in model_inputs) {
+ if (inputs) {
+ throw new Error(
+ "`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " +
+ "Make sure to either pass {inputs} or {input_name}=..."
+ );
}
+ } else {
+ model_inputs[input_name] = inputs;
}
- // Update generation config with defaults
- generation_config = this._get_generation_config(generation_config);
-
- logits_processor = logits_processor ?? new LogitsProcessorList()
+ const inputs_tensor = model_inputs[input_name];
- // Update logits processor
- logits_processor = this._get_logits_processor(
- generation_config,
- input_ids_seq_length,
- logits_processor
- )
+ return { inputs_tensor, model_inputs, model_input_name: input_name };
+ }
- /** @type {number[]} */
- let eos_token_ids = generation_config.eos_token_id;
- if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
- eos_token_ids = [eos_token_ids];
+ async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name, generation_config }) {
+ if (
+ this.sessions['model'].inputNames.includes('inputs_embeds')
+ && !model_inputs.inputs_embeds
+ && '_prepare_inputs_embeds' in this
+ ) {
+ // Encoder expects `inputs_embeds` instead of `input_ids`
+ const { input_ids, pixel_values, attention_mask, ...kwargs } = model_inputs;
+ // @ts-ignore
+ const prepared_inputs = await this._prepare_inputs_embeds(model_inputs);
+ model_inputs = {
+ ...kwargs,
+ ...pick(prepared_inputs, ['inputs_embeds', 'attention_mask']),
+ };
}
+ let { last_hidden_state } = await encoderForward(this, model_inputs);
- // TODO implement early_stopping
- // https://huggingface.co/blog/how-to-generate
+ // for classifier free guidance we need to add a 'null' input to our encoder hidden states
+ if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
- let numOutputTokens = 1;
- const maxOutputTokens = numOutputTokens + (generation_config.max_new_tokens ?? Infinity);
+ last_hidden_state = cat([
+ last_hidden_state,
+ full_like(last_hidden_state, 0.0),
+ ], 0);
- // Only use max length if max_new_tokens is not provided
- const useMaxLength = Number.isInteger(generation_config.max_length) && (generation_config.max_new_tokens ?? null) === null;
- let sampler = Sampler.getSampler(generation_config);
+ if ('attention_mask' in model_inputs) {
+ model_inputs['attention_mask'] = cat([
+ model_inputs['attention_mask'],
+ zeros_like(model_inputs['attention_mask']),
+ ], 0);
+ }
- // @ts-ignore
- let beams = this.getStartBeams(inputs, generation_config, numOutputTokens, inputs_attention_mask);
-
- while (beams.some(x => !x.done) && numOutputTokens < maxOutputTokens) {
- let newest_beams = [];
- for (let beam of beams) {
- if (beam.done) {
- // Add this beam back into the pool
- newest_beams.push(beam);
- continue
- }
- if (useMaxLength && beam.output_token_ids.length >= generation_config.max_length) {
- // Set this beam to done and add it back into the pool
- beam.done = true;
- newest_beams.push(beam);
- continue
+ } else if (model_inputs.decoder_input_ids) {
+ // Ensure that the encoder outputs have the same batch size as the decoder inputs,
+ // allowing for more efficient batched generation for single inputs
+ const decoder_input_ids_batch_size = toI64Tensor(model_inputs.decoder_input_ids).dims[0];
+ if (decoder_input_ids_batch_size !== last_hidden_state.dims[0]) {
+ if (last_hidden_state.dims[0] !== 1) {
+ throw new Error(
+ `The encoder outputs have a different batch size (${last_hidden_state.dims[0]}) than the decoder inputs (${decoder_input_ids_batch_size}).`
+ )
}
+ last_hidden_state = cat(Array.from({ length: decoder_input_ids_batch_size }, () => last_hidden_state), 0);
+ }
+ }
+ model_inputs['encoder_outputs'] = last_hidden_state;
- // @ts-ignore
- let output = await this.runBeam(beam);
+ return model_inputs;
+ }
- // add attentions/scores to beam only if user requested
- if (generation_config.output_attentions) {
- this.addAttentionsToBeam(beam, output);
- }
- if (generation_config.output_scores) {
- // TODO add
+ /**
+ * Prepares `decoder_input_ids` for generation with encoder-decoder models
+ * @param {*} param0
+ */
+ _prepare_decoder_input_ids_for_generation({ batch_size, model_input_name, model_kwargs, decoder_start_token_id, bos_token_id, generation_config }) {
+ let { decoder_input_ids, ...model_inputs } = model_kwargs;
+
+ // Prepare input ids if the user has not defined `decoder_input_ids` manually.
+ if (!decoder_input_ids) {
+ decoder_start_token_id ??= bos_token_id;
+
+ if (this.config.model_type === 'musicgen') {
+ // Custom logic (TODO: move to Musicgen class)
+ decoder_input_ids = Array.from({
+ length: batch_size * this.config.decoder.num_codebooks
+ }, () => [decoder_start_token_id]);
+
+ } else if (Array.isArray(decoder_start_token_id)) {
+ if (decoder_start_token_id.length !== batch_size) {
+ throw new Error(
+ `\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
+ )
}
+ decoder_input_ids = decoder_start_token_id;
+ } else {
+ decoder_input_ids = Array.from({
+ length: batch_size,
+ }, () => [decoder_start_token_id]);
+ }
+ } else if (!Array.isArray(decoder_input_ids[0])) {
+ // Correct batch size
+ decoder_input_ids = Array.from({
+ length: batch_size,
+ }, () => decoder_input_ids);
+ }
- // Logits are of the form [batch_size, out_seq_length, vocab_size]
- // In most cases, this will be [batch_size, 1, vocab_size]
- // So, we select the last token's logits:
- // (equivalent to `logits = outputs.logits[:, -1, :]`)
- let logits = output.logits.slice(null, -1, null);
+ decoder_input_ids = toI64Tensor(decoder_input_ids);
+ model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids);
- // Apply logits processor
- logits_processor(beam.output_token_ids, logits);
+ return { input_ids: decoder_input_ids, model_inputs };
+ }
- let sampledTokens = sampler(logits);
- for (let [newTokenId, logProb] of sampledTokens) {
- // use previous beam as a starting point
- let newBeam = { ...beam };
+ /**
+ * Generates sequences of token ids for models with a language modeling head.
+ * @param {import('./generation/parameters.js').GenerationFunctionParameters} options
+ * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores.
+ */
+ async generate({
+ inputs = null,
+ generation_config = null,
+ logits_processor = null,
+ stopping_criteria = null,
+ streamer = null,
- // update new beam
- // @ts-ignore
- this.updateBeam(newBeam, newTokenId);
+ // inputs_attention_mask = null,
+ ...kwargs
+ }) {
+ this._validate_model_class();
- newBeam.score += logProb;
+ // Update generation config with defaults and kwargs
+ generation_config = this._prepare_generation_config(generation_config, kwargs);
- if (eos_token_ids && eos_token_ids.includes(newTokenId)) {
- newBeam.done = true;
- }
+ // 3. Define model inputs
+ let { inputs_tensor, model_inputs, model_input_name } = this._prepare_model_inputs({
+ inputs,
+ model_kwargs: kwargs,
+ });
- newest_beams.push(newBeam);
- }
- }
- ++numOutputTokens;
+ const is_encoder_decoder = this.config.is_encoder_decoder;
- // Next, we get the best beams, per ID
- newest_beams = this.groupBeams(newest_beams).map(
- group => group
- .sort((a, b) => b.score - a.score) // sort by score
- .slice(0, generation_config.num_beams) // remove outside beam width
- );
+ // 4. Define other model kwargs
+ if (!is_encoder_decoder) {
+ // decoder-only models should use left-padding for generation
+ } else if (!('encoder_outputs' in model_inputs)) {
+ // if model is encoder decoder encoder_outputs are created
+ // and added to `model_kwargs`
+ model_inputs = await this._prepare_encoder_decoder_kwargs_for_generation(
+ { inputs_tensor, model_inputs, model_input_name, generation_config }
+ )
+ }
- // Flatten beams
- beams = newest_beams.flat();
+ // 5. Prepare `input_ids` which will be used for auto-regressive generation
+ // TODO: Update to align with HF transformers' implementation
+ let input_ids;
+ if (is_encoder_decoder) {
+ // Generating from the encoder outputs
+ ({ input_ids, model_inputs } = this._prepare_decoder_input_ids_for_generation({
+ batch_size: model_inputs[model_input_name].dims.at(0),
+ model_input_name,
+ model_kwargs: model_inputs,
+ decoder_start_token_id: generation_config.decoder_start_token_id,
+ bos_token_id: generation_config.bos_token_id,
+ generation_config,
+ }));
+ } else {
+ input_ids = model_inputs[model_input_name]
+ }
- // Run callback
- if (generation_config.callback_function) {
- generation_config.callback_function(beams);
- }
+ // 6. Prepare `max_length` depending on other stopping criteria.
+ let input_ids_length = input_ids.dims.at(-1);
+
+ if (generation_config.max_new_tokens !== null) {
+ generation_config.max_length = input_ids_length + generation_config.max_new_tokens;
}
- // TODO: Ensure that we can return non-batched outputs
+ // input_ids_length = model_inputs[model_input_name].dims.at(1);
+ // // inputs instanceof Tensor ? : inputs.length;
- const groupedBeams = this.groupBeams(beams);
+ // // decoder-only
+ // if (input_ids_length === 0) {
+ // throw Error("Must supply a non-empty array of input token ids.")
+ // }
- const getFlattened = (key) => groupedBeams.map(
- batch => {
- if (generation_config.num_return_sequences > 1) {
- return batch.slice(0, generation_config.num_return_sequences).map(x => x[key]);
- } else {
- return [batch[0][key]];
- }
- }
- ).flat(); // Flatten across batches (depth=1)
+ // let decoder_input_ids =
+ // generation_config.decoder_input_ids
+ // ?? generation_config.decoder_start_token_id
+ // ?? generation_config.bos_token_id
+ // ?? generation_config.eos_token_id;
- const sequences = getFlattened('output_token_ids'); // [1, seqLength]
+ // Update logits processor
+ // 8. prepare distribution pre_processing samplers
+ const prepared_logits_processor = this._get_logits_processor(
+ generation_config,
+ input_ids_length,
+ logits_processor,
+ )
- if (generation_config.return_dict_in_generate) {
- // NOTE: `decoder_attentions` and `cross_attentions` should be:
- // list (one element for each generated token)
- // of list (one element for each layer of the decoder)
- // of torch.FloatTensor of shape (batch_size, num_heads, generated_length, sequence_length)
- // However, since we are only generating one batch at a time, they are of the form:
- // list (batches)
- // of list (one element for each generated token)
- // of list (one element for each layer of the decoder)
- // of torch.FloatTensor of shape (1, num_heads, generated_length, sequence_length)
- //
- // TODO: In future (when true parallelism, we should be able to return the correct shape)
-
- const decoder_attentions = getFlattened('decoder_attentions');
- const cross_attentions = getFlattened('cross_attentions');
+ // 9. prepare stopping criteria
+ const prepared_stopping_criteria = this._get_stopping_criteria(
+ generation_config, stopping_criteria
+ )
- return {
- sequences,
+ // /** @type {number[]} */
+ // let eos_token_ids = generation_config.eos_token_id;
+ // if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
+ // eos_token_ids = [eos_token_ids];
+ // }
- decoder_attentions,
- cross_attentions,
- }
- } else {
- return sequences;
+ const numInputs = model_inputs[model_input_name].dims.at(0);
+
+ // TODO:
+ // done is a list of booleans to keep track of which inputs are done
+ // const done = new Array(numInputs).fill(false);
+ // For efficiency purposes, we remove completed rows from model_inputs
+ // when the beam is complete, and we keep track of the row index
+ // const rowIndexToBatchIndex = new Map();
+
+ const sampler = LogitsSampler.getSampler(generation_config);
+
+ // TODO make > numInputs
+ const scores = new Array(numInputs).fill(0);
+ /** @type {bigint[][]} */
+ const all_input_ids = input_ids.tolist();
+ if (streamer) {
+ streamer.put(all_input_ids);
}
- }
+ // const all_generated_input_ids = Array.from({ length: numInputs }, () => []);
+
+ // NOTE: For now, we don't support spawning new beams
+ // TODO: when we do, we simply copy past key values and accumulate into single large tensor
+
+ ////////////////////////////////////////////////////
+ // Generic search which handles 4 generation modes:
+ // - GenerationMode.GREEDY_SEARCH
+ // - GenerationMode.SAMPLE
+ // - GenerationMode.BEAM_SEARCH
+ // - GenerationMode.BEAM_SAMPLE
+ ////////////////////////////////////////////////////
+ let outputs;
+ let attentions = {};
+ while (true) {
+ // prepare model inputs
+ model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs, generation_config);
+ outputs = await this.forward(model_inputs);
+
+ if (generation_config.output_attentions && generation_config.return_dict_in_generate) {
+ // Get attentions if they are present
+ const token_attentions = this.getAttentions(outputs);
+ for (const key in token_attentions) {
+ if (!(key in attentions)) {
+ attentions[key] = [];
+ }
+ attentions[key].push(token_attentions[key]);
+ }
+ }
- /**
- * Helper function to add attentions to beam
- * @param {Object} beam
- * @param {Object} output
- * @private
- */
- addAttentionsToBeam(beam, output) {
- if (this.config.is_encoder_decoder) {
- if (!output.cross_attentions || output.cross_attentions.length === 0) {
- throw Error(
- "`output_attentions` is true, but the model did not produce cross-attentions. " +
- "This is most likely because the model was not exported with `output_attentions=True`."
- )
+ // Logits are of the form [batch_size, out_seq_length, vocab_size]
+ // In most cases, this will be [batch_size, 1, vocab_size]
+ // So, we select the last token's logits:
+ // (equivalent to `logits = outputs.logits[:, -1, :]`)
+ const logits = outputs.logits.slice(null, -1, null);
+
+ const next_tokens_scores = prepared_logits_processor(all_input_ids, logits);
+
+ /** @type {[bigint][]} */
+ const generated_input_ids = [];
+ // const new_kv_cache = [];// NOTE: Only used for beam search when concatenating new kv
+ // Loop over each batch
+ for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) {
+ const logs = next_tokens_scores[batch_idx];
+
+ const sampledTokens = await sampler(logs);
+ for (const [newTokenId, logProb] of sampledTokens) {
+ const bigint = BigInt(newTokenId);
+ // TODO: If branching, use previous beam as a starting point
+ // update generated ids, model inputs, and length for next step
+ scores[batch_idx] += logProb;
+ all_input_ids[batch_idx].push(bigint);
+ generated_input_ids.push([bigint]);
+
+ // TODO: Support beam search
+ break;
+ }
}
- if (!beam.cross_attentions) {
- beam.cross_attentions = [];
+ if (streamer) {
+ streamer.put(generated_input_ids);
}
- beam.cross_attentions.push(output.cross_attentions);
- }
- if (!output.decoder_attentions || output.decoder_attentions.length === 0) {
- throw Error(
- "`output_attentions` is true, but the model did not produce decoder-attentions. " +
- "This is most likely because the model was not exported with `output_attentions=True`."
- )
+ const stop = prepared_stopping_criteria(all_input_ids);
+ if (stop.every(x => x)) {
+ break;
+ }
+
+ model_inputs = this._update_model_kwargs_for_generation({
+ generated_input_ids, outputs, model_inputs, is_encoder_decoder,
+ });
}
- if (!beam.decoder_attentions) {
- beam.decoder_attentions = [];
+
+ if (streamer) {
+ streamer.end();
}
- beam.decoder_attentions.push(output.decoder_attentions);
- }
- /**
- * Groups an array of beam objects by their ids.
- *
- * @param {Array} beams The array of beam objects to group.
- * @returns {Array} An array of arrays, where each inner array contains beam objects with the same id.
- */
- groupBeams(beams) {
- // Group beams by their ids
- const groups = Object.create(null);
- for (const obj of beams) {
- if (groups[obj.id] === undefined) {
- groups[obj.id] = [obj];
- } else {
- groups[obj.id].push(obj);
+ // Retrieve and dispose all final past key values (including encoder attentions)
+ const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true);
+
+ // TODO: ensure all_input_ids is padded correctly...
+ const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]);
+
+ if (generation_config.return_dict_in_generate) {
+ return {
+ sequences,
+ past_key_values,
+ ...attentions,
+ // TODO:
+ // scores,
+ // logits,
}
+ } else {
+ // Dispose all remaining tensors
+ for (const tensor of Object.values(outputs)) {
+ if (tensor.location === 'gpu-buffer') {
+ tensor.dispose();
+ }
+ }
+ return sequences;
}
-
- return Object.values(groups);
}
/**
@@ -1241,47 +1591,55 @@ export class PreTrainedModel extends Callable {
* @param {Object} pastKeyValues The previous past key values.
* @returns {Object} An object containing past key values.
*/
- getPastKeyValues(decoderResults, pastKeyValues) {
-
+ getPastKeyValues(decoderResults, pastKeyValues, disposeEncoderPKVs = false) {
const pkvs = Object.create(null);
for (const name in decoderResults) {
if (name.startsWith('present')) {
- let newName = name.replace('present', 'past_key_values');
-
- if (pastKeyValues && name.includes('encoder')) {
- // Optimization introduced by optimum to reuse past key values. So, we just replace the constant
- // outputs with the previous past key values.
+ const newName = name.replace('present', 'past_key_values');
+ const is_encoder_pkv = name.includes('encoder');
+ if (is_encoder_pkv && pastKeyValues) {
+ // Optimization introduced by optimum to reuse past key values.
+ // So, we just replace the constant outputs (`decoderResults[name]`) with the previous past key values.
// https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
pkvs[newName] = pastKeyValues[newName];
- } else {
+ } else { // decoder or using first encoder PKVs
pkvs[newName] = decoderResults[name];
}
+
+ if (pastKeyValues && (!is_encoder_pkv || disposeEncoderPKVs)) {
+ // - Always dispose decoder PKVs
+ // - Only dispose encoder past key values when requested (after generation)
+ const t = pastKeyValues[newName];
+ if (t.location === 'gpu-buffer') {
+ t.dispose();
+ }
+ }
}
}
return pkvs;
}
/**
- * Returns an object containing attentions from the given decoder results object.
+ * Returns an object containing attentions from the given model output object.
*
- * @param {Object} decoderResults The decoder results object.
- * @returns {Object} An object containing attentions.
+ * @param {Object} model_output The output of the model.
+ * @returns {{cross_attentions?: Tensor[]}} An object containing attentions.
*/
- getAttentions(decoderResults) {
- const attns = Object.create(null);
+ getAttentions(model_output) {
+ const attentions = {};
- for (const attnName of ['cross_attentions', 'decoder_attentions']) {
- const result = [];
- for (const name in decoderResults) {
+ for (const attnName of ['cross_attentions', 'encoder_attentions', 'decoder_attentions']) {
+ for (const name in model_output) {
if (name.startsWith(attnName)) {
- const index = name.split('.').pop()
- result[index] = decoderResults[name];
+ if (!(attnName in attentions)) {
+ attentions[attnName] = [];
+ }
+ attentions[attnName].push(model_output[name]);
}
}
- attns[attnName] = result;
}
- return attns;
+ return attentions;
}
/**
@@ -1294,93 +1652,34 @@ export class PreTrainedModel extends Callable {
if (pastKeyValues) {
Object.assign(decoderFeeds, pastKeyValues)
} else {
- // TODO support batches (i.e., batch_size > 1)
- const batch_size = 1;
+ const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
+ const dtype = session?.config?.kv_cache_dtype ?? 'float32';
+ const empty = (dtype === 'float16') ? new Uint16Array() : [];
- // @ts-ignore
- if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
- // @ts-ignore
- let encoder_dims = [batch_size, this.num_encoder_heads, 0, this.encoder_dim_kv];
- // @ts-ignore
- let decoder_dims = [batch_size, this.num_decoder_heads, 0, this.decoder_dim_kv];
- // @ts-ignore
- for (let i = 0; i < this.num_decoder_layers; ++i) {
- decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
- decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
- decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims)
- decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims)
- }
- } else if (this.config.model_type === 'falcon') {
- // NOTE: Custom implementation for Falcon
- // @ts-ignore
- let dims = [batch_size * this.num_heads, 0, this.dim_kv]
- // @ts-ignore
- for (let i = 0; i < this.num_layers; ++i) {
- decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
- decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
- }
- } else if (this.config.multi_query) { // e.g., for `gpt_bigcode`
- // @ts-ignore
- let dims = [batch_size * this.num_heads, 0, 2 * this.dim_kv]
- // @ts-ignore
- for (let i = 0; i < this.num_layers; ++i) {
- decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
- }
- } else if (this.config.model_type === 'bloom') {
- // NOTE: Custom implementation for Bloom
-
- // @ts-ignore
- let keyDims = [batch_size * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
- // @ts-ignore
- let valueDims = [batch_size * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
- // @ts-ignore
- for (let i = 0; i < this.num_layers; ++i) {
- decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
- decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
- }
- } else { // Decoder-only
- // @ts-ignore
- let dims = [batch_size, this.num_heads, 0, this.dim_kv]
- // @ts-ignore
- for (let i = 0; i < this.num_layers; ++i) {
- decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
- decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
- }
+ const shapes = getKeyValueShapes(this.config);
+
+ for (const name in shapes) {
+ decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]);
}
}
}
- /**
- * Initializes and returns the beam for text generation task
- * @param {Tensor} inputTokenIds The input token ids.
- * @param {Object} generation_config The generation config.
- * @param {number} numOutputTokens The number of tokens to be generated.
- * @param {Tensor} inputs_attention_mask Optional input attention mask.
- * @returns {any} A Beam object representing the initialized beam.
- * @private
- */
- getStartBeams(inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) {
- return this._getStartBeams(this, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask)
- }
-
- /**
- * Runs a single step of the beam search generation algorithm.
- * @param {any} beam The current beam being generated.
- * @returns {Promise} The updated beam after a single generation step.
- * @private
- */
- async runBeam(beam) {
- return await this._runBeam(this, beam);
+ async encode_image({ pixel_values }) {
+ // image_inputs === { pixel_values }
+ const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values })).image_features;
+ if (!this.config.num_image_tokens) {
+ console.warn(
+ 'The number of image tokens was not set in the model configuration. ' +
+ `Setting it to the number of features detected by the vision encoder (${features.dims[1]}).`
+ )
+ this.config.num_image_tokens = features.dims[1];
+ }
+ return features;
}
- /**
- * Update a beam with a new token ID.
- * @param {Object} beam The beam to update.
- * @param {number} newTokenId The new token ID to add to the beam's output.
- * @private
- */
- updateBeam(beam, newTokenId) {
- return this._updateBeam(beam, newTokenId);
+ async encode_text({ input_ids }) {
+ // text_inputs === { input_ids, attention_mask }
+ return (await sessionRun(this.sessions['embed_tokens'], { input_ids })).inputs_embeds;
}
}
@@ -2238,36 +2537,23 @@ export class AlbertForMaskedLM extends AlbertPreTrainedModel {
//////////////////////////////////////////////////
// T5 models
-export class T5PreTrainedModel extends PreTrainedModel { };
+export class T5PreTrainedModel extends PreTrainedModel {
+ forward_params = [
+ 'input_ids',
+ 'attention_mask',
+ 'encoder_outputs',
+ 'decoder_input_ids',
+ 'decoder_attention_mask',
+ 'past_key_values',
+ ];
+};
export class T5Model extends T5PreTrainedModel { }
/**
* T5Model is a class representing a T5 model for conditional generation.
*/
-export class T5ForConditionalGeneration extends T5PreTrainedModel {
-
- /**
- * Creates a new instance of the `T5ForConditionalGeneration` class.
- * @param {Object} config The model configuration.
- * @param {any} session session for the model.
- * @param {any} decoder_merged_session session for the decoder.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.num_decoder_layers;
- this.num_decoder_heads = this.config.num_heads;
- this.decoder_dim_kv = this.config.d_kv;
-
- this.num_encoder_layers = this.config.num_layers;
- this.num_encoder_heads = this.config.num_heads;
- this.encoder_dim_kv = this.config.d_kv;
- }
-}
+export class T5ForConditionalGeneration extends T5PreTrainedModel { }
//////////////////////////////////////////////////
@@ -2286,28 +2572,7 @@ export class LongT5Model extends LongT5PreTrainedModel { }
/**
* LONGT5 Model with a `language modeling` head on top.
*/
-export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel {
- /**
- * Creates a new instance of the `LongT5ForConditionalGeneration` class.
- * @param {Object} config The model configuration.
- * @param {any} session session for the model.
- * @param {any} decoder_merged_session session for the decoder.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.num_decoder_layers;
- this.num_decoder_heads = this.config.num_heads;
- this.decoder_dim_kv = this.config.d_kv;
-
- this.num_encoder_layers = this.config.num_layers;
- this.num_encoder_heads = this.config.num_heads;
- this.encoder_dim_kv = this.config.d_kv;
- }
-}
+export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { }
//////////////////////////////////////////////////
@@ -2320,29 +2585,7 @@ export class MT5Model extends MT5PreTrainedModel { }
/**
* A class representing a conditional sequence-to-sequence model based on the MT5 architecture.
*/
-export class MT5ForConditionalGeneration extends MT5PreTrainedModel {
-
- /**
- * Creates a new instance of the `MT5ForConditionalGeneration` class.
- * @param {any} config The model configuration.
- * @param {any} session The ONNX session containing the encoder weights.
- * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.num_decoder_layers;
- this.num_decoder_heads = this.config.num_heads;
- this.decoder_dim_kv = this.config.d_kv;
-
- this.num_encoder_layers = this.config.num_layers;
- this.num_encoder_heads = this.config.num_heads;
- this.encoder_dim_kv = this.config.d_kv;
- }
-}
+export class MT5ForConditionalGeneration extends MT5PreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -2357,30 +2600,7 @@ export class BartModel extends BartPretrainedModel { }
/**
* The BART Model with a language modeling head. Can be used for summarization.
*/
-export class BartForConditionalGeneration extends BartPretrainedModel {
-
- /**
- * Creates a new instance of the `BartForConditionalGeneration` class.
- * @param {Object} config The configuration object for the Bart model.
- * @param {Object} session The ONNX session used to execute the model.
- * @param {Object} decoder_merged_session The ONNX session used to execute the decoder.
- * @param {Object} generation_config The generation configuration object.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
-
-}
+export class BartForConditionalGeneration extends BartPretrainedModel { }
/**
* Bart model with a sequence classification/head on top (a linear layer on top of the pooled output)
@@ -2411,30 +2631,7 @@ export class MBartModel extends MBartPreTrainedModel { }
/**
* The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.
*/
-export class MBartForConditionalGeneration extends MBartPreTrainedModel {
-
- /**
- * Creates a new instance of the `MBartForConditionalGeneration` class.
- * @param {Object} config The configuration object for the Bart model.
- * @param {Object} session The ONNX session used to execute the model.
- * @param {Object} decoder_merged_session The ONNX session used to execute the decoder.
- * @param {Object} generation_config The generation configuration object.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
-
-}
+export class MBartForConditionalGeneration extends MBartPreTrainedModel { }
/**
* MBart model with a sequence classification/head on top (a linear layer on top of the pooled output).
@@ -2452,26 +2649,7 @@ export class MBartForSequenceClassification extends MBartPreTrainedModel {
}
-export class MBartForCausalLM extends MBartPreTrainedModel {
- /**
- * Creates a new instance of the `MBartForCausalLM` class.
- * @param {Object} config Configuration object for the model.
- * @param {Object} decoder_merged_session ONNX Session object for the decoder.
- * @param {Object} generation_config Configuration object for the generation process.
- */
- constructor(config, decoder_merged_session, generation_config) {
- super(config, decoder_merged_session);
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
-}
+export class MBartForCausalLM extends MBartPreTrainedModel { }
//////////////////////////////////////////////////
@@ -2487,29 +2665,7 @@ export class BlenderbotModel extends BlenderbotPreTrainedModel { }
/**
* The Blenderbot Model with a language modeling head. Can be used for summarization.
*/
-export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel {
-
- /**
- * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
- * @param {any} config The model configuration.
- * @param {any} session The ONNX session containing the encoder weights.
- * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
-}
+export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel { }
//////////////////////////////////////////////////
@@ -2525,29 +2681,7 @@ export class BlenderbotSmallModel extends BlenderbotSmallPreTrainedModel { }
/**
* The BlenderbotSmall Model with a language modeling head. Can be used for summarization.
*/
-export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel {
-
- /**
- * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
- * @param {any} config The model configuration.
- * @param {any} session The ONNX session containing the encoder weights.
- * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
-}
+export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel { }
//////////////////////////////////////////////////
@@ -2775,119 +2909,169 @@ export class ASTForAudioClassification extends ASTPreTrainedModel { }
//////////////////////////////////////////////////
// Whisper models
-export class WhisperPreTrainedModel extends PreTrainedModel { };
+export class WhisperPreTrainedModel extends PreTrainedModel {
+
+ requires_attention_mask = false;
+ main_input_name = 'input_features';
+ forward_params = [
+ 'input_features',
+ 'attention_mask',
+ 'decoder_input_ids',
+ 'decoder_attention_mask',
+ 'past_key_values',
+ ];
+};
/**
* WhisperModel class for training Whisper models without a language model head.
*/
export class WhisperModel extends WhisperPreTrainedModel { }
+
/**
* WhisperForConditionalGeneration class for generating conditional outputs from Whisper models.
*/
export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
- requires_attention_mask = false;
- main_input_name = 'input_features';
+ _prepare_generation_config(generation_config, kwargs) {
+ return /** @type {WhisperGenerationConfig} */ (super._prepare_generation_config(generation_config, kwargs, WhisperGenerationConfig));
+ }
/**
- * Creates a new instance of the `WhisperForConditionalGeneration` class.
- * @param {Object} config Configuration object for the model.
- * @param {Object} session ONNX Session object for the model.
- * @param {Object} decoder_merged_session ONNX Session object for the decoder.
- * @param {Object} generation_config Configuration object for the generation process.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
+ *
+ * @param {WhisperGenerationConfig} generation_config
+ */
+ _retrieve_init_tokens(generation_config) {
+ // prefix tokens are of the form:
+ // - Multilingual: <|startoftranscript|> <|lang_id|> <|task|> [<|notimestamps|>]
+ // - English-only: <|startoftranscript|> [<|notimestamps|>]
+
+ // 1. Handle <|startoftranscript|> token
+ const init_tokens = [generation_config.decoder_start_token_id];
+
+ // 2. Handle <|lang_id|> and <|task> tokens
+ let language = generation_config.language;
+ const task = generation_config.task;
+ if (generation_config.is_multilingual) {
+ if (!language) {
+ // TODO: Implement language detection
+ console.warn('No language specified - defaulting to English (en).');
+ language = 'en';
+ }
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
+ // Add language token
+ const language_code = whisper_language_to_code(language);
+ const language_token = `<|${language_code}|>`;
+ init_tokens.push(generation_config.lang_to_id[language_token])
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
+ // Add task token
+ // NOTE: Defaults to 'transcribe' if no task is specified
+ init_tokens.push(generation_config.task_to_id[task ?? 'transcribe']);
- /**
- * @typedef {Object} WhisperGenerationConfig
- * @extends GenerationConfig
- * @property {boolean} [return_timestamps=null] Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
- * @property {boolean} [return_token_timestamps=null] Whether to return token-level timestamps
- * with the text. This can be used with or without the `return_timestamps` option. To get word-level
- * timestamps, use the tokenizer to group the tokens into words.
- * @property {number} [num_frames=null] The number of audio frames available in this chunk. This is only used generating word-level timestamps.
- */
+ } else if (language || task) {
+ throw new Error(
+ "Cannot specify `task` or `language` for an English-only model. If the model is intended to be multilingual, pass `is_multilingual=true` to generate, or update the generation config."
+ )
+ }
+
+ // 3. Handle <|notimestamps|> token
+ if (
+ !generation_config.return_timestamps
+ && generation_config.no_timestamps_token_id
+ && init_tokens.at(-1) !== generation_config.no_timestamps_token_id
+ ) {
+ init_tokens.push(generation_config.no_timestamps_token_id);
+ } else if (
+ generation_config.return_timestamps
+ &&
+ init_tokens.at(-1) === generation_config.no_timestamps_token_id
+ ) {
+ console.warn("<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `true`.");
+ init_tokens.pop();
+ }
+
+ // let's make sure we don't pass `null` tokens as prompt tokens
+ return init_tokens.filter(token => token != null);
+ }
/**
- * Generates outputs based on input and generation configuration.
- * @param {Object} inputs Input data for the model.
- * @param {WhisperGenerationConfig} generation_config Configuration object for the generation process.
- * @param {Object} logits_processor Optional logits processor object.
- * @returns {Promise} Promise object represents the generated outputs.
+ * Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
+ * @param {import('./models/whisper/generation_whisper.js').WhisperGenerationFunctionParameters} options
+ * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores.
*/
- async generate(
- inputs,
+ async generate({
+ inputs = null,
generation_config = null,
logits_processor = null,
- // {
- // return_timestamps = null,
- // return_token_timestamps = null,
- // language = null,
- // task = null,
- // } = {},
- ) {
- // Create generation config object
- generation_config = this._get_generation_config(generation_config);
+ stopping_criteria = null,
+ // Whisper-specific options (passed to kwargs)
+ // prompt_ids = null,
+ // language = null,
+ // task = null,
- // Whisper has additional options for returning timestamps
- generation_config.return_timestamps ??= false;
+ ...kwargs
+ }) {
+ generation_config = this._prepare_generation_config(generation_config, kwargs);
- // TODO add language and task
+ const init_tokens = kwargs.decoder_input_ids ?? this._retrieve_init_tokens(generation_config);
if (generation_config.return_timestamps) {
- logits_processor = [new WhisperTimeStampLogitsProcessor(generation_config)]
+ logits_processor ??= new LogitsProcessorList();
+ logits_processor.push(
+ new WhisperTimeStampLogitsProcessor(generation_config, init_tokens)
+ );
}
- if (generation_config.return_token_timestamps) {
- generation_config.output_attentions = true;
- generation_config.return_dict_in_generate = true;
-
- if (generation_config.task === 'translate') {
- console.warn("Token-level timestamps may not be reliable for task 'translate'.")
- }
+ if (generation_config.begin_suppress_tokens) {
+ logits_processor ??= new LogitsProcessorList();
+ logits_processor.push(
+ new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, init_tokens.length)
+ );
+ }
+ if (generation_config.return_token_timestamps) {
if (!generation_config.alignment_heads) {
throw new Error(
"Model generation config has no `alignment_heads`, token-level timestamps not available. " +
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)
}
+
+ if (generation_config.task === 'translate') {
+ console.warn("Token-level timestamps may not be reliable for task 'translate'.")
+ }
+
+ generation_config.output_attentions = true;
+ generation_config.return_dict_in_generate = true;
}
- const outputs = await super.generate(inputs, generation_config, logits_processor);
+ const outputs = await super.generate({
+ inputs,
+ generation_config,
+ logits_processor,
+ decoder_input_ids: init_tokens,
+ ...kwargs
+ });
- if (generation_config.return_token_timestamps && generation_config.alignment_heads) {
+ if (generation_config.return_token_timestamps) {
outputs["token_timestamps"] = this._extract_token_timestamps(
outputs,
generation_config.alignment_heads,
generation_config.num_frames,
- )
+ );
}
- return outputs
+ return outputs;
}
/**
* Calculates token-level timestamps using the encoder-decoder cross-attentions and
* dynamic time-warping (DTW) to map each output token to a position in the input audio.
+ * If `num_frames` is specified, the encoder-decoder cross-attentions will be cropped before applying DTW.
* @param {Object} generate_outputs Outputs generated by the model
- * @param {Tensor[][][]} generate_outputs.cross_attentions The cross attentions output by the model
- * @param {Tensor[][][]} generate_outputs.decoder_attentions The decoder attentions output by the model
- * @param {number[][]} generate_outputs.sequences The sequences output by the model
+ * @param {Tensor[][]} generate_outputs.cross_attentions The cross attentions output by the model
+ * @param {Tensor} generate_outputs.sequences The sequences output by the model
* @param {number[][]} alignment_heads Alignment heads of the model
* @param {number} [num_frames=null] Number of frames in the input audio.
* @param {number} [time_precision=0.02] Precision of the timestamps in seconds
@@ -2900,6 +3084,12 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
"This is most likely because the model was not exported with `output_attentions=True`."
)
}
+ if (num_frames == null) {
+ console.warn(
+ "`num_frames` has not been set, meaning the entire audio will be analyzed. " +
+ "This may lead to inaccurate token-level timestamps for short audios (< 30 seconds)."
+ );
+ }
let median_filter_width = this.config.median_filter_width;
if (median_filter_width === undefined) {
@@ -2907,53 +3097,55 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
median_filter_width = 7;
}
- const batchedMatrices = generate_outputs.cross_attentions.map(batch => {
- // Create a list with `decoder_layers` elements, each a tensor of shape
- // (batch size, attention_heads, output length, input length).
- let cross_attentions = Array.from({ length: this.config.decoder_layers },
- (_, i) => cat(batch.map(x => x[i]), 2)
- );
-
- let weights = stack(alignment_heads.map(([l, h]) => {
- return num_frames
- ? cross_attentions[l].slice(null, h, null, [0, num_frames])
- : cross_attentions[l].slice(null, h);
- }));
- weights = weights.transpose(1, 0, 2, 3)
+ // TODO: Improve batch processing
+ const batch = generate_outputs.cross_attentions;
+ // Create a list with `decoder_layers` elements, each a tensor of shape
+ // (batch size, attention_heads, output length, input length).
+ const cross_attentions = Array.from({ length: this.config.decoder_layers },
+ // Concatenate the cross attentions for each layer across sequence length dimension.
+ (_, i) => cat(batch.map(x => x[i]), 2)
+ );
- let [std, calculatedMean] = std_mean(weights, -2, 0, true);
+ const weights = stack(alignment_heads.map(([l, h]) => {
+ if (l >= cross_attentions.length) {
+ throw new Error(`Layer index ${l} is out of bounds for cross attentions (length ${cross_attentions.length}).`)
+ }
+ return num_frames
+ ? cross_attentions[l].slice(null, h, null, [0, num_frames])
+ : cross_attentions[l].slice(null, h);
+ })).transpose(1, 0, 2, 3);
- // Normalize and smoothen the weights.
- let smoothedWeights = weights.clone(); // [1, 8, seqLength, 1500]
+ const [std, calculatedMean] = std_mean(weights, -2, 0, true);
- for (let a = 0; a < smoothedWeights.dims[0]; ++a) {
- let aTensor = smoothedWeights[a]; // [8, seqLength, 1500]
+ // Normalize and smoothen the weights.
+ const smoothedWeights = weights.clone(); // [1, 8, seqLength, 1500]
- for (let b = 0; b < aTensor.dims[0]; ++b) {
- let bTensor = aTensor[b]; // [seqLength, 1500]
+ for (let a = 0; a < smoothedWeights.dims[0]; ++a) {
+ const aTensor = smoothedWeights[a]; // [8, seqLength, 1500]
- const stdTensor = std[a][b][0]; // [1500]
- const meanTensor = calculatedMean[a][b][0]; // [1500]
+ for (let b = 0; b < aTensor.dims[0]; ++b) {
+ const bTensor = aTensor[b]; // [seqLength, 1500]
- for (let c = 0; c < bTensor.dims[0]; ++c) {
+ const stdTensorData = std[a][b][0].data; // [1500]
+ const meanTensorData = calculatedMean[a][b][0].data; // [1500]
- let cTensor = bTensor[c]; // [1500]
- for (let d = 0; d < cTensor.data.length; ++d) {
- cTensor.data[d] = (cTensor.data[d] - meanTensor.data[d]) / stdTensor.data[d]
- }
+ for (let c = 0; c < bTensor.dims[0]; ++c) {
- // Apply median filter.
- cTensor.data.set(medianFilter(cTensor.data, median_filter_width))
+ let cTensorData = bTensor[c].data; // [1500]
+ for (let d = 0; d < cTensorData.length; ++d) {
+ cTensorData[d] = (cTensorData[d] - meanTensorData[d]) / stdTensorData[d]
}
+
+ // Apply median filter.
+ cTensorData.set(medianFilter(cTensorData, median_filter_width))
}
}
+ }
- // Average the different cross-attention heads.
- const matrix = mean(smoothedWeights, 1);
- return matrix;
- });
+ // Average the different cross-attention heads.
+ const batchedMatrices = [mean(smoothedWeights, 1)];
- const timestampsShape = [generate_outputs.sequences.length, generate_outputs.sequences[0].length];
+ const timestampsShape = generate_outputs.sequences.dims;
const timestamps = new Tensor(
'float32',
@@ -2966,16 +3158,16 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
// NOTE: Since we run only one batch at a time, we can squeeze to get the same dimensions
// as the python implementation
const matrix = batchedMatrices[batch_idx].neg().squeeze_(0);
- let [text_indices, time_indices] = dynamicTimeWarping(matrix);
+ const [text_indices, time_indices] = dynamic_time_warping(matrix.tolist());
- let diffs = Array.from({ length: text_indices.length - 1 }, (v, i) => text_indices[i + 1] - text_indices[i]);
- let jumps = mergeArrays([1], diffs).map(x => !!x); // convert to boolean
+ const diffs = Array.from({ length: text_indices.length - 1 }, (v, i) => text_indices[i + 1] - text_indices[i]);
+ const jumps = mergeArrays([1], diffs).map(x => !!x); // convert to boolean
- let jump_times = [];
+ const jump_times = [];
for (let i = 0; i < jumps.length; ++i) {
if (jumps[i]) {
- jump_times.push(time_indices[i] * time_precision);
// NOTE: No point in rounding here, since we set to Float32Array later
+ jump_times.push(time_indices[i] * time_precision);
}
}
timestamps[batch_idx].data.set(jump_times, 1)
@@ -2992,66 +3184,203 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
*/
export class VisionEncoderDecoderModel extends PreTrainedModel {
main_input_name = 'pixel_values';
+ forward_params = [
+ 'pixel_values',
+ 'input_ids',
+ 'encoder_hidden_states',
+ 'past_key_values',
+ ];
+}
+//////////////////////////////////////////////////
- /**
- * Creates a new instance of the `VisionEncoderDecoderModel` class.
- * @param {Object} config The configuration object specifying the hyperparameters and other model settings.
- * @param {Object} session The ONNX session containing the encoder model.
- * @param {any} decoder_merged_session The ONNX session containing the merged decoder model.
- * @param {Object} generation_config Configuration object for the generation process.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- // Extract configs
- const encoderConfig = this.config.encoder;
- const decoderConfig = this.config.decoder;
-
- // Validate encoder
- const encoderModelType = encoderConfig.model_type;
- const encoderModel =
- MODEL_MAPPING_NAMES_ENCODER_ONLY.get(encoderModelType)
- ?? MODEL_MAPPING_NAMES_ENCODER_DECODER.get(encoderModelType);
- if (!encoderModel) {
- console.warn(`Model type for encoder '${encoderModelType}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`);
- }
- // Validate decoder
- const decoderModel = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(decoderConfig.model_type);
- if (!decoderModel) {
- throw new Error(`Unable to construct \`VisionEncoderDecoder\` due to unsupported decoder: "${this.config.decoder.model_type}"`);
+//////////////////////////////////////////////////
+// LLaVa Models
+export class LlavaPreTrainedModel extends PreTrainedModel {
+ forward_params = [
+ 'input_ids',
+ 'pixel_values',
+ 'attention_mask',
+ 'position_ids',
+ 'past_key_values',
+ ];
+}
+
+/**
+ * The LLAVA model which consists of a vision backbone and a language model.
+ */
+export class LlavaForConditionalGeneration extends LlavaPreTrainedModel {
+
+ _merge_input_ids_with_image_features({
+ inputs_embeds,
+ image_features,
+ input_ids,
+ attention_mask,
+ }) {
+
+ const image_token_index = this.config.image_token_index;
+
+ const idsList = input_ids.tolist();
+
+ // NOTE: we use .findIndex instead of .indexOf to perform weak comparison (==) between BigInt and Number
+ const indexOfImage = idsList.map(x => x.findIndex(x => x == image_token_index));
+
+ const noImages = indexOfImage.every(x => x === -1);
+ const allImages = indexOfImage.every(x => x !== -1);
+ if (!noImages && !allImages) {
+ // Check for padding reasons
+ throw new Error('Every input should contain either 0 or 1 image token.');
}
- // @ts-ignore
- const decoderModelClass = decoderModel[1];
- // @ts-ignore
- const decoder = new decoderModelClass(decoderConfig, decoder_merged_session, generation_config);
+ if (noImages) {
+ return {
+ inputs_embeds,
+ attention_mask,
+ }
+ }
- this.add_encoder_pkv = 'num_decoder_layers' in decoder;
- if (this.add_encoder_pkv) {
- // Decoder is part of an encoder-decoder model
- this.num_decoder_layers = decoder.num_decoder_layers;
- this.num_decoder_heads = decoder.num_decoder_heads;
- this.decoder_dim_kv = decoder.decoder_dim_kv;
+ const stacked = [];
+ const stacked_attention_mask = [];
+ for (let i = 0; i < indexOfImage.length; ++i) {
+ const index = indexOfImage[i];
+
+ const e = inputs_embeds[i];
+ const im = image_features[i];
+ const am = attention_mask[i];
+ stacked.push(
+ cat([
+ e.slice([0, index]),
+ im,
+ e.slice([index + 1, e.dims[0]]),
+ ], 0)
+ );
- this.num_encoder_layers = decoder.num_encoder_layers;
- this.num_encoder_heads = decoder.num_encoder_heads;
- this.encoder_dim_kv = decoder.encoder_dim_kv;
+ stacked_attention_mask.push(
+ cat([
+ am.slice([0, index]),
+ ones([im.dims[0]]),
+ am.slice([index + 1, am.dims[0]])
+ ], 0)
+ )
+ }
- } else {
- // Decoder is a decoder-only model
- this.num_layers = decoder.num_layers;
- this.num_heads = decoder.num_heads;
- this.dim_kv = decoder.dim_kv;
+ return {
+ inputs_embeds: stack(stacked, 0),
+ attention_mask: stack(stacked_attention_mask, 0),
}
}
}
//////////////////////////////////////////////////
-//////////////////////////////////////////////////
-// CLIP models
+export class Moondream1ForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration
+
+export class Florence2PreTrainedModel extends PreTrainedModel {
+ forward_params = [
+ // Encoder inputs
+ 'input_ids',
+ 'inputs_embeds',
+ 'attention_mask',
+ 'pixel_values',
+
+ // Decoder inputs
+ 'encoder_outputs',
+ 'decoder_input_ids',
+ 'decoder_inputs_embeds',
+ 'decoder_attention_mask',
+ 'past_key_values',
+ ];
+ main_input_name = 'inputs_embeds';
+}
+
+export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel {
+
+ _merge_input_ids_with_image_features({
+ inputs_embeds,
+ image_features,
+ input_ids,
+ attention_mask,
+ }) {
+ return {
+ inputs_embeds: cat([
+ image_features, // image embeds
+ inputs_embeds, // task prefix embeds
+ ], 1),
+ attention_mask: cat([
+ ones(image_features.dims.slice(0, 2)), // image attention mask
+ attention_mask, // task prefix attention mask
+ ], 1),
+ }
+ }
+
+ async _prepare_inputs_embeds({ input_ids, pixel_values, inputs_embeds, attention_mask }) {
+ if (!input_ids && !pixel_values) {
+ throw new Error('Either `input_ids` or `pixel_values` should be provided.');
+ }
+
+ // 1. Possibly, extract the input embeddings
+ let text_features, image_features;
+ if (input_ids) {
+ text_features = await this.encode_text({ input_ids });
+ }
+ if (pixel_values) {
+ image_features = await this.encode_image({ pixel_values });
+ }
+
+ // 2. Possibly, merge text and images
+ if (text_features && image_features) {
+ ({ inputs_embeds, attention_mask } = this._merge_input_ids_with_image_features({
+ inputs_embeds: text_features,
+ image_features,
+ input_ids,
+ attention_mask,
+ }));
+ } else {
+ inputs_embeds = text_features || image_features;
+ }
+
+ return { inputs_embeds, attention_mask };
+ }
+
+ async forward({
+ input_ids,
+ pixel_values,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ encoder_outputs,
+ past_key_values,
+
+ inputs_embeds,
+ decoder_inputs_embeds,
+ }) {
+ if (!inputs_embeds) {
+ ({ inputs_embeds, attention_mask } = await this._prepare_inputs_embeds({ input_ids, pixel_values, inputs_embeds, attention_mask }));
+ }
+
+ if (!encoder_outputs) {
+ // Must compute encoder outputs
+ let { last_hidden_state } = await encoderForward(this, { inputs_embeds, attention_mask });
+ encoder_outputs = last_hidden_state;
+ }
+
+ if (!decoder_inputs_embeds) {
+ if (!decoder_input_ids) {
+ throw new Error('Either `decoder_input_ids` or `decoder_inputs_embeds` should be provided.');
+ }
+ decoder_inputs_embeds = await this.encode_text({ input_ids: decoder_input_ids });
+ }
+
+ const decoderFeeds = {
+ inputs_embeds: decoder_inputs_embeds,
+ attention_mask: decoder_attention_mask,
+ encoder_attention_mask: attention_mask,
+ encoder_hidden_states: encoder_outputs,
+ past_key_values,
+ };
+ const decoder_outputs = await decoderForward(this, decoderFeeds, true);
+ return decoder_outputs;
+ }
+}
export class CLIPPreTrainedModel extends PreTrainedModel { }
/**
@@ -3060,7 +3389,7 @@ export class CLIPPreTrainedModel extends PreTrainedModel { }
* **Example:** Perform zero-shot image classification with a `CLIPModel`.
*
* ```javascript
- * import { AutoTokenizer, AutoProcessor, CLIPModel, RawImage } from '@xenova/transformers';
+ * import { AutoTokenizer, AutoProcessor, CLIPModel, RawImage } from '@huggingface/transformers';
*
* // Load tokenizer, processor, and model
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/clip-vit-base-patch16');
@@ -3117,7 +3446,7 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
* **Example:** Compute text embeddings with `CLIPTextModelWithProjection`.
*
* ```javascript
- * import { AutoTokenizer, CLIPTextModelWithProjection } from '@xenova/transformers';
+ * import { AutoTokenizer, CLIPTextModelWithProjection } from '@huggingface/transformers';
*
* // Load tokenizer and text model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clip-vit-base-patch16');
@@ -3164,7 +3493,7 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
* **Example:** Compute vision embeddings with `CLIPVisionModelWithProjection`.
*
* ```javascript
- * import { AutoProcessor, CLIPVisionModelWithProjection, RawImage} from '@xenova/transformers';
+ * import { AutoProcessor, CLIPVisionModelWithProjection, RawImage} from '@huggingface/transformers';
*
* // Load processor and vision model
* const processor = await AutoProcessor.from_pretrained('Xenova/clip-vit-base-patch16');
@@ -3205,7 +3534,7 @@ export class SiglipPreTrainedModel extends PreTrainedModel { }
* **Example:** Perform zero-shot image classification with a `SiglipModel`.
*
* ```javascript
- * import { AutoTokenizer, AutoProcessor, SiglipModel, RawImage } from '@xenova/transformers';
+ * import { AutoTokenizer, AutoProcessor, SiglipModel, RawImage } from '@huggingface/transformers';
*
* // Load tokenizer, processor, and model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/siglip-base-patch16-224');
@@ -3250,7 +3579,7 @@ export class SiglipModel extends SiglipPreTrainedModel { }
* **Example:** Compute text embeddings with `SiglipTextModel`.
*
* ```javascript
- * import { AutoTokenizer, SiglipTextModel } from '@xenova/transformers';
+ * import { AutoTokenizer, SiglipTextModel } from '@huggingface/transformers';
*
* // Load tokenizer and text model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/siglip-base-patch16-224');
@@ -3286,7 +3615,7 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
* **Example:** Compute vision embeddings with `SiglipVisionModel`.
*
* ```javascript
- * import { AutoProcessor, SiglipVisionModel, RawImage} from '@xenova/transformers';
+ * import { AutoProcessor, SiglipVisionModel, RawImage} from '@huggingface/transformers';
*
* // Load processor and vision model
* const processor = await AutoProcessor.from_pretrained('Xenova/siglip-base-patch16-224');
@@ -3334,7 +3663,7 @@ export class CLIPSegModel extends CLIPSegPreTrainedModel { }
* **Example:** Perform zero-shot image segmentation with a `CLIPSegForImageSegmentation` model.
*
* ```javascript
- * import { AutoTokenizer, AutoProcessor, CLIPSegForImageSegmentation, RawImage } from '@xenova/transformers';
+ * import { AutoTokenizer, AutoProcessor, CLIPSegForImageSegmentation, RawImage } from '@huggingface/transformers';
*
* // Load tokenizer, processor, and model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clipseg-rd64-refined');
@@ -3380,25 +3709,7 @@ export class CLIPSegForImageSegmentation extends CLIPSegPreTrainedModel { }
//////////////////////////////////////////////////
// GPT2 models
-export class GPT2PreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `GPT2PreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.n_head
- this.num_layers = this.config.n_layer
- this.dim_kv = this.config.n_embd / this.num_heads;
- }
-}
+export class GPT2PreTrainedModel extends PreTrainedModel { }
export class GPT2Model extends GPT2PreTrainedModel { }
@@ -3412,26 +3723,24 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
-// GPTNeo models
-export class GPTNeoPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `GPTNeoPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
+// JAIS models
+export class JAISPreTrainedModel extends PreTrainedModel { }
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
+/**
+ * The bare JAIS Model transformer outputting raw hidden-states without any specific head on top.
+ */
+export class JAISModel extends JAISPreTrainedModel { }
- this.num_heads = this.config.num_heads;
- this.num_layers = this.config.num_layers;
- this.dim_kv = this.config.hidden_size / this.num_heads;
- }
-}
+/**
+ * The JAIS Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
+ */
+export class JAISLMHeadModel extends JAISPreTrainedModel { }
+//////////////////////////////////////////////////
+
+
+//////////////////////////////////////////////////
+// GPTNeo models
+export class GPTNeoPreTrainedModel extends PreTrainedModel { }
export class GPTNeoModel extends GPTNeoPreTrainedModel { }
export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
@@ -3439,25 +3748,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
//////////////////////////////////////////////////
// GPTNeoX models
-export class GPTNeoXPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `GPTNeoXPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.num_attention_heads;
- this.num_layers = this.config.num_hidden_layers;
- this.dim_kv = this.config.hidden_size / this.num_heads;
- }
-}
+export class GPTNeoXPreTrainedModel extends PreTrainedModel { }
export class GPTNeoXModel extends GPTNeoXPreTrainedModel { }
export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
@@ -3466,25 +3757,7 @@ export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
//////////////////////////////////////////////////
// GPT-J models
-export class GPTJPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `GPTJPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.n_head
- this.num_layers = this.config.n_layer
- this.dim_kv = this.config.n_embd / this.num_heads;
- }
-}
+export class GPTJPreTrainedModel extends PreTrainedModel { }
export class GPTJModel extends GPTJPreTrainedModel { }
@@ -3494,25 +3767,7 @@ export class GPTJForCausalLM extends GPTJPreTrainedModel { }
//////////////////////////////////////////////////
// GPTBigCode models
-export class GPTBigCodePreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `GPTBigCodePreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.n_head
- this.num_layers = this.config.n_layer
- this.dim_kv = this.config.n_embd / this.num_heads;
- }
-}
+export class GPTBigCodePreTrainedModel extends PreTrainedModel { }
export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { }
@@ -3521,25 +3776,7 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
//////////////////////////////////////////////////
// CodeGen models
-export class CodeGenPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `CodeGenPreTrainedModel` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.n_head
- this.num_layers = this.config.n_layer
- this.dim_kv = this.config.n_embd / this.num_heads;
- }
-}
+export class CodeGenPreTrainedModel extends PreTrainedModel { }
/**
* CodeGenModel is a class representing a code generation model without a language model head.
*/
@@ -3558,25 +3795,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
/**
* The bare LLama Model outputting raw hidden-states without any specific head on top.
*/
-export class LlamaPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `LlamaPreTrainedModel` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads
- this.num_layers = this.config.num_hidden_layers
- this.dim_kv = this.config.hidden_size / this.config.num_attention_heads
- }
-}
+export class LlamaPreTrainedModel extends PreTrainedModel { }
/**
* The bare LLaMA Model outputting raw hidden-states without any specific head on top.
*/
@@ -3585,31 +3804,71 @@ export class LlamaModel extends LlamaPreTrainedModel { }
export class LlamaForCausalLM extends LlamaPreTrainedModel { }
//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
-// Qwen2 models
+// Granite models
+export class GranitePreTrainedModel extends PreTrainedModel { }
+export class GraniteModel extends GranitePreTrainedModel { }
+export class GraniteForCausalLM extends GranitePreTrainedModel { }
+//////////////////////////////////////////////////
+
+
+//////////////////////////////////////////////////
+// Cohere models
/**
- * The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
+ * The bare Cohere Model outputting raw hidden-states without any specific head on top.
*/
-export class Qwen2PreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `Qwen2PreTrainedModel` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
+export class CoherePreTrainedModel extends PreTrainedModel { }
+export class CohereModel extends CoherePreTrainedModel { }
+
+export class CohereForCausalLM extends CoherePreTrainedModel { }
+//////////////////////////////////////////////////
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
+//////////////////////////////////////////////////
+// Gemma models
- this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads
- this.num_layers = this.config.num_hidden_layers
- this.dim_kv = this.config.hidden_size / this.config.num_attention_heads
- }
-}
+/**
+ * The bare Gemma Model outputting raw hidden-states without any specific head on top.
+ */
+export class GemmaPreTrainedModel extends PreTrainedModel { }
+/**
+ * The bare Gemma Model outputting raw hidden-states without any specific head on top.
+ */
+export class GemmaModel extends GemmaPreTrainedModel { }
+
+export class GemmaForCausalLM extends GemmaPreTrainedModel { }
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+// Gemma2 models
+
+/**
+ * The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
+ */
+export class Gemma2PreTrainedModel extends PreTrainedModel { }
+/**
+ * The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
+ */
+export class Gemma2Model extends Gemma2PreTrainedModel { }
+
+export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+export class OpenELMPreTrainedModel extends PreTrainedModel { }
+export class OpenELMModel extends OpenELMPreTrainedModel { }
+
+export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
+
+
+//////////////////////////////////////////////////
+// Qwen2 models
+
+/**
+ * The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
+ */
+export class Qwen2PreTrainedModel extends PreTrainedModel { }
/**
* The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
*/
@@ -3621,26 +3880,7 @@ export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
//////////////////////////////////////////////////
// Phi models
-
-export class PhiPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `PhiPreTrainedModel` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id;
-
- this.num_heads = this.config.num_attention_heads;
- this.num_layers = this.config.num_hidden_layers;
- this.dim_kv = this.config.hidden_size / this.num_heads;
- }
-}
+export class PhiPreTrainedModel extends PreTrainedModel { }
/**
* The bare Phi Model outputting raw hidden-states without any specific head on top.
*/
@@ -3649,31 +3889,25 @@ export class PhiModel extends PhiPreTrainedModel { }
export class PhiForCausalLM extends PhiPreTrainedModel { }
//////////////////////////////////////////////////
+//////////////////////////////////////////////////
+// Phi3 models
+export class Phi3PreTrainedModel extends PreTrainedModel { }
+
+/**
+ * The bare Phi3 Model outputting raw hidden-states without any specific head on top.
+ */
+export class Phi3Model extends Phi3PreTrainedModel { }
+
+export class Phi3ForCausalLM extends Phi3PreTrainedModel { }
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
// Bloom models
/**
* The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
*/
-export class BloomPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `BloomPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.n_head
- this.num_layers = this.config.n_layer
- this.dim_kv = this.config.hidden_size / this.num_heads;
- }
-}
+export class BloomPreTrainedModel extends PreTrainedModel { }
/**
* The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.
@@ -3688,25 +3922,7 @@ export class BloomForCausalLM extends BloomPreTrainedModel { }
//////////////////////////////////////////////////
// MPT models
-export class MptPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `MptPreTrainedModel` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.n_heads
- this.num_layers = this.config.n_layers
- this.dim_kv = this.config.d_model / this.num_heads;
- }
-}
+export class MptPreTrainedModel extends PreTrainedModel { }
/**
* The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.
@@ -3722,25 +3938,7 @@ export class MptForCausalLM extends MptPreTrainedModel { }
//////////////////////////////////////////////////
// OPT models
-export class OPTPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `OPTPreTrainedModel` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.num_attention_heads;
- this.num_layers = this.config.num_hidden_layers;
- this.dim_kv = this.config.hidden_size / this.num_heads;
- }
-}
+export class OPTPreTrainedModel extends PreTrainedModel { }
/**
* The bare OPT Model outputting raw hidden-states without any specific head on top.
@@ -3766,6 +3964,43 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
}
//////////////////////////////////////////////////
+//////////////////////////////////////////////////
+export class PvtPreTrainedModel extends PreTrainedModel { }
+export class PvtModel extends PvtPreTrainedModel { }
+export class PvtForImageClassification extends PvtPreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new SequenceClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+export class ViTMAEPreTrainedModel extends PreTrainedModel { }
+export class ViTMAEModel extends ViTMAEPreTrainedModel { }
+//////////////////////////////////////////////////
+
+
+//////////////////////////////////////////////////
+export class ViTMSNPreTrainedModel extends PreTrainedModel { }
+export class ViTMSNModel extends ViTMSNPreTrainedModel { }
+export class ViTMSNForImageClassification extends ViTMSNPreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new SequenceClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+export class GroupViTPreTrainedModel extends PreTrainedModel { }
+export class GroupViTModel extends GroupViTPreTrainedModel { }
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
export class FastViTPreTrainedModel extends PreTrainedModel { }
@@ -3788,7 +4023,7 @@ export class VitMattePreTrainedModel extends PreTrainedModel { }
*
* **Example:** Perform image matting with a `VitMatteForImageMatting` model.
* ```javascript
- * import { AutoProcessor, VitMatteForImageMatting, RawImage } from '@xenova/transformers';
+ * import { AutoProcessor, VitMatteForImageMatting, RawImage } from '@huggingface/transformers';
*
* // Load processor and model
* const processor = await AutoProcessor.from_pretrained('Xenova/vitmatte-small-distinctions-646');
@@ -3813,7 +4048,7 @@ export class VitMattePreTrainedModel extends PreTrainedModel { }
*
* You can visualize the alpha matte as follows:
* ```javascript
- * import { Tensor, cat } from '@xenova/transformers';
+ * import { Tensor, cat } from '@huggingface/transformers';
*
* // Visualize predicted alpha matte
* const imageTensor = image.toTensor();
@@ -3954,6 +4189,33 @@ export class DetrSegmentationOutput extends ModelOutput {
}
//////////////////////////////////////////////////
+//////////////////////////////////////////////////
+export class RTDetrPreTrainedModel extends PreTrainedModel { }
+export class RTDetrModel extends RTDetrPreTrainedModel { }
+export class RTDetrForObjectDetection extends RTDetrPreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new RTDetrObjectDetectionOutput(await super._call(model_inputs));
+ }
+}
+
+export class RTDetrObjectDetectionOutput extends ModelOutput {
+ /**
+ * @param {Object} output The output of the model.
+ * @param {Tensor} output.logits Classification logits (including no-object) for all queries.
+ * @param {Tensor} output.pred_boxes Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height).
+ * These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding possible padding).
+ */
+ constructor({ logits, pred_boxes }) {
+ super();
+ this.logits = logits;
+ this.pred_boxes = pred_boxes;
+ }
+}
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
export class TableTransformerPreTrainedModel extends PreTrainedModel { }
@@ -3992,6 +4254,19 @@ export class DeiTForImageClassification extends DeiTPreTrainedModel {
}
//////////////////////////////////////////////////
+//////////////////////////////////////////////////
+export class HieraPreTrainedModel extends PreTrainedModel { }
+export class HieraModel extends HieraPreTrainedModel { }
+export class HieraForImageClassification extends HieraPreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new SequenceClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
/**
@@ -4045,7 +4320,7 @@ export class Swin2SRModel extends Swin2SRPreTrainedModel { }
* **Example:** Super-resolution w/ `Xenova/swin2SR-classical-sr-x2-64`.
*
* ```javascript
- * import { AutoProcessor, Swin2SRForImageSuperResolution, RawImage } from '@xenova/transformers';
+ * import { AutoProcessor, Swin2SRForImageSuperResolution, RawImage } from '@huggingface/transformers';
*
* // Load processor and model
* const model_id = 'Xenova/swin2SR-classical-sr-x2-64';
@@ -4087,7 +4362,7 @@ export class DPTModel extends DPTPreTrainedModel { }
*
* **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas`.
* ```javascript
- * import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers';
+ * import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@huggingface/transformers';
*
* // Load model and processor
* const model_id = 'Xenova/dpt-hybrid-midas';
@@ -4131,6 +4406,24 @@ export class DepthAnythingForDepthEstimation extends DepthAnythingPreTrainedMode
//////////////////////////////////////////////////
+//////////////////////////////////////////////////
+export class SapiensPreTrainedModel extends PreTrainedModel { }
+export class SapiensForSemanticSegmentation extends SapiensPreTrainedModel { }
+export class SapiensForDepthEstimation extends SapiensPreTrainedModel { }
+export class SapiensForNormalEstimation extends SapiensPreTrainedModel { }
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+export class DepthProPreTrainedModel extends PreTrainedModel { }
+export class DepthProForDepthEstimation extends DepthProPreTrainedModel { }
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+export class MaskFormerPreTrainedModel extends PreTrainedModel { }
+export class MaskFormerModel extends MaskFormerPreTrainedModel { }
+export class MaskFormerForInstanceSegmentation extends MaskFormerPreTrainedModel { }
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
export class GLPNPreTrainedModel extends PreTrainedModel { }
@@ -4144,7 +4437,7 @@ export class GLPNModel extends GLPNPreTrainedModel { }
*
* **Example:** Depth estimation w/ `Xenova/glpn-kitti`.
* ```javascript
- * import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@xenova/transformers';
+ * import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@huggingface/transformers';
*
* // Load model and processor
* const model_id = 'Xenova/glpn-kitti';
@@ -4187,7 +4480,7 @@ export class DonutSwinPreTrainedModel extends PreTrainedModel { }
* **Example:** Step-by-step Document Parsing.
*
* ```javascript
- * import { AutoProcessor, AutoTokenizer, AutoModelForVision2Seq, RawImage } from '@xenova/transformers';
+ * import { AutoProcessor, AutoTokenizer, AutoModelForVision2Seq, RawImage } from '@huggingface/transformers';
*
* // Choose model to use
* const model_id = 'Xenova/donut-base-finetuned-cord-v2';
@@ -4222,7 +4515,7 @@ export class DonutSwinPreTrainedModel extends PreTrainedModel { }
* **Example:** Step-by-step Document Visual Question Answering (DocVQA)
*
* ```javascript
- * import { AutoProcessor, AutoTokenizer, AutoModelForVision2Seq, RawImage } from '@xenova/transformers';
+ * import { AutoProcessor, AutoTokenizer, AutoModelForVision2Seq, RawImage } from '@huggingface/transformers';
*
* // Choose model to use
* const model_id = 'Xenova/donut-base-finetuned-docvqa';
@@ -4352,6 +4645,8 @@ export class YolosObjectDetectionOutput extends ModelOutput {
//////////////////////////////////////////////////
+
+
//////////////////////////////////////////////////
export class SamPreTrainedModel extends PreTrainedModel { }
@@ -4361,7 +4656,7 @@ export class SamPreTrainedModel extends PreTrainedModel { }
*
* **Example:** Perform mask generation w/ `Xenova/sam-vit-base`.
* ```javascript
- * import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers';
+ * import { SamModel, AutoProcessor, RawImage } from '@huggingface/transformers';
*
* const model = await SamModel.from_pretrained('Xenova/sam-vit-base');
* const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base');
@@ -4370,7 +4665,7 @@ export class SamPreTrainedModel extends PreTrainedModel { }
* const raw_image = await RawImage.read(img_url);
* const input_points = [[[450, 600]]] // 2D localization of a window
*
- * const inputs = await processor(raw_image, input_points);
+ * const inputs = await processor(raw_image, { input_points });
* const outputs = await model(inputs);
*
* const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes);
@@ -4396,16 +4691,6 @@ export class SamPreTrainedModel extends PreTrainedModel { }
* ```
*/
export class SamModel extends SamPreTrainedModel {
- /**
- * Creates a new instance of the `SamModel` class.
- * @param {Object} config The configuration object specifying the hyperparameters and other model settings.
- * @param {Object} vision_encoder The ONNX session containing the vision encoder model.
- * @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model.
- */
- constructor(config, vision_encoder, prompt_encoder_mask_decoder) {
- super(config, vision_encoder);
- this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder;
- }
/**
* Compute image embeddings and positional image embeddings, given the pixel values of an image.
@@ -4427,7 +4712,7 @@ export class SamModel extends SamPreTrainedModel {
* @typedef {Object} SamModelInputs Object containing the model inputs.
* @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
* These can be obtained using a `SamProcessor`.
- * @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`.
+ * @property {Tensor} [input_points] Input 2D spatial points with shape `(batch_size, num_points, 2)`.
* This is used by the prompt encoder to encode the prompt.
* @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`.
* This is used by the prompt encoder to encode the prompt. There are 4 types of labels:
@@ -4435,6 +4720,7 @@ export class SamModel extends SamPreTrainedModel {
* - `0`: the point is a point that does not contain the object of interest
* - `-1`: the point corresponds to the background
* - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+ * @property {Tensor} [input_boxes] Input bounding boxes with shape `(batch_size, num_boxes, 4)`.
* @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder.
* @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder.
*/
@@ -4452,7 +4738,7 @@ export class SamModel extends SamPreTrainedModel {
}
}
- if (!model_inputs.input_labels) {
+ if (!model_inputs.input_labels && model_inputs.input_points) {
// Set default input labels if they are missing
const shape = model_inputs.input_points.dims.slice(0, -1);
const numElements = shape.reduce((a, b) => a * b, 1);
@@ -4463,15 +4749,24 @@ export class SamModel extends SamPreTrainedModel {
);
}
+ const decoder_inputs = {
+ image_embeddings: model_inputs.image_embeddings,
+ image_positional_embeddings: model_inputs.image_positional_embeddings,
+ };
+ if (model_inputs.input_points) {
+ decoder_inputs.input_points = model_inputs.input_points;
+ }
+ if (model_inputs.input_labels) {
+ decoder_inputs.input_labels = model_inputs.input_labels;
+ }
+ if (model_inputs.input_boxes) {
+ decoder_inputs.input_boxes = model_inputs.input_boxes;
+ }
+
// Returns:
// - iou_scores: tensor.float32[batch_size,point_batch_size,3]
// - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256]
- return await sessionRun(this.prompt_encoder_mask_decoder, {
- input_points: model_inputs.input_points,
- input_labels: model_inputs.input_labels,
- image_embeddings: model_inputs.image_embeddings,
- image_positional_embeddings: model_inputs.image_positional_embeddings,
- });
+ return await sessionRun(this.sessions['prompt_encoder_mask_decoder'], decoder_inputs);
}
/**
@@ -4509,29 +4804,7 @@ export class MarianPreTrainedModel extends PreTrainedModel { };
export class MarianModel extends MarianPreTrainedModel { }
-export class MarianMTModel extends MarianPreTrainedModel {
-
- /**
- * Creates a new instance of the `MarianMTModel` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {any} decoder_merged_session
- * @param {any} generation_config
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
-}
+export class MarianMTModel extends MarianPreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -4540,30 +4813,7 @@ export class M2M100PreTrainedModel extends PreTrainedModel { };
export class M2M100Model extends M2M100PreTrainedModel { }
-export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel {
-
- /**
- * Creates a new instance of the `M2M100ForConditionalGeneration` class.
- * @param {Object} config The model configuration object.
- * @param {Object} session The ONNX session object.
- * @param {any} decoder_merged_session
- * @param {any} generation_config
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
- }
-
-}
+export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { }
//////////////////////////////////////////////////
//////////////////////////////////////////////////
@@ -4576,7 +4826,7 @@ export class Wav2Vec2PreTrainedModel extends PreTrainedModel { };
* **Example:** Load and run a `Wav2Vec2Model` for feature extraction.
*
* ```javascript
- * import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
+ * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/mms-300m');
@@ -4635,6 +4885,92 @@ export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel
}
//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+// PyAnnote models
+export class PyAnnotePreTrainedModel extends PreTrainedModel { };
+
+/**
+ * The bare PyAnnote Model transformer outputting raw hidden-states without any specific head on top.
+ */
+export class PyAnnoteModel extends PyAnnotePreTrainedModel { }
+
+/**
+ * PyAnnote Model with a frame classification head on top for tasks like Speaker Diarization.
+ *
+ * **Example:** Load and run a `PyAnnoteForAudioFrameClassification` for speaker diarization.
+ *
+ * ```javascript
+ * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@huggingface/transformers';
+ *
+ * // Load model and processor
+ * const model_id = 'onnx-community/pyannote-segmentation-3.0';
+ * const model = await AutoModelForAudioFrameClassification.from_pretrained(model_id);
+ * const processor = await AutoProcessor.from_pretrained(model_id);
+ *
+ * // Read and preprocess audio
+ * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.wav';
+ * const audio = await read_audio(url, processor.feature_extractor.config.sampling_rate);
+ * const inputs = await processor(audio);
+ *
+ * // Run model with inputs
+ * const { logits } = await model(inputs);
+ * // {
+ * // logits: Tensor {
+ * // dims: [ 1, 767, 7 ], // [batch_size, num_frames, num_classes]
+ * // type: 'float32',
+ * // data: Float32Array(5369) [ ... ],
+ * // size: 5369
+ * // }
+ * // }
+ *
+ * const result = processor.post_process_speaker_diarization(logits, audio.length);
+ * // [
+ * // [
+ * // { id: 0, start: 0, end: 1.0512535626298245, confidence: 0.8220156481664611 },
+ * // { id: 2, start: 1.0512535626298245, end: 2.3398869619825127, confidence: 0.9008811707860472 },
+ * // ...
+ * // ]
+ * // ]
+ *
+ * // Display result
+ * console.table(result[0], ['start', 'end', 'id', 'confidence']);
+ * // ┌─────────┬────────────────────┬────────────────────┬────┬─────────────────────┐
+ * // │ (index) │ start │ end │ id │ confidence │
+ * // ├─────────┼────────────────────┼────────────────────┼────┼─────────────────────┤
+ * // │ 0 │ 0 │ 1.0512535626298245 │ 0 │ 0.8220156481664611 │
+ * // │ 1 │ 1.0512535626298245 │ 2.3398869619825127 │ 2 │ 0.9008811707860472 │
+ * // │ 2 │ 2.3398869619825127 │ 3.5946089560890773 │ 0 │ 0.7521651315796233 │
+ * // │ 3 │ 3.5946089560890773 │ 4.578039708226655 │ 2 │ 0.8491978128022479 │
+ * // │ 4 │ 4.578039708226655 │ 4.594995410849717 │ 0 │ 0.2935352600416393 │
+ * // │ 5 │ 4.594995410849717 │ 6.121008646925269 │ 3 │ 0.6788051309866024 │
+ * // │ 6 │ 6.121008646925269 │ 6.256654267909762 │ 0 │ 0.37125512393851134 │
+ * // │ 7 │ 6.256654267909762 │ 8.630452635138397 │ 2 │ 0.7467035186353542 │
+ * // │ 8 │ 8.630452635138397 │ 10.088643060721703 │ 0 │ 0.7689364814666032 │
+ * // │ 9 │ 10.088643060721703 │ 12.58113134631177 │ 2 │ 0.9123324509131324 │
+ * // │ 10 │ 12.58113134631177 │ 13.005023911888312 │ 0 │ 0.4828358177572041 │
+ * // └─────────┴────────────────────┴────────────────────┴────┴─────────────────────┘
+ * ```
+ */
+export class PyAnnoteForAudioFrameClassification extends PyAnnotePreTrainedModel {
+ /**
+ * Calls the model on new inputs.
+ * @param {Object} model_inputs The inputs to the model.
+ * @returns {Promise} An object containing the model's output logits for sequence classification.
+ */
+ async _call(model_inputs) {
+ return new TokenClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+// WeSpeakerResNet models
+export class WeSpeakerResNetPreTrainedModel extends PreTrainedModel { };
+export class WeSpeakerResNetModel extends WeSpeakerResNetPreTrainedModel { }
+//////////////////////////////////////////////////
+
+
//////////////////////////////////////////////////
// UniSpeech models
export class UniSpeechPreTrainedModel extends PreTrainedModel { };
@@ -4773,7 +5109,7 @@ export class HubertPreTrainedModel extends PreTrainedModel { }
* **Example:** Load and run a `HubertModel` for feature extraction.
*
* ```javascript
- * import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
+ * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/hubert-base-ls960');
@@ -4837,7 +5173,7 @@ export class WavLMPreTrainedModel extends PreTrainedModel { };
* **Example:** Load and run a `WavLMModel` for feature extraction.
*
* ```javascript
- * import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
+ * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base');
@@ -4892,7 +5228,7 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel {
*
* **Example:** Extract speaker embeddings with `WavLMForXVector`.
* ```javascript
- * import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers';
+ * import { AutoProcessor, AutoModel, read_audio } from '@huggingface/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv');
@@ -4935,7 +5271,7 @@ export class WavLMForXVector extends WavLMPreTrainedModel {
*
* **Example:** Perform speaker diarization with `WavLMForAudioFrameClassification`.
* ```javascript
- * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers';
+ * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@huggingface/transformers';
*
* // Read and preprocess audio
* const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd');
@@ -4995,16 +5331,16 @@ export class SpeechT5Model extends SpeechT5PreTrainedModel { };
*
* **Example:** Generate speech from text with `SpeechT5ForSpeechToText`.
* ```javascript
- * import { AutoTokenizer, AutoProcessor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, Tensor } from '@xenova/transformers';
+ * import { AutoTokenizer, AutoProcessor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, Tensor } from '@huggingface/transformers';
*
* // Load the tokenizer and processor
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/speecht5_tts');
* const processor = await AutoProcessor.from_pretrained('Xenova/speecht5_tts');
*
* // Load the models
- * // NOTE: We use the unquantized versions as they are more accurate
- * const model = await SpeechT5ForTextToSpeech.from_pretrained('Xenova/speecht5_tts', { quantized: false });
- * const vocoder = await SpeechT5HifiGan.from_pretrained('Xenova/speecht5_hifigan', { quantized: false });
+ * // NOTE: We use the full-precision versions as they are more accurate
+ * const model = await SpeechT5ForTextToSpeech.from_pretrained('Xenova/speecht5_tts', { dtype: 'fp32' });
+ * const vocoder = await SpeechT5HifiGan.from_pretrained('Xenova/speecht5_hifigan', { dtype: 'fp32' });
*
* // Load speaker embeddings from URL
* const speaker_embeddings_data = new Float32Array(
@@ -5037,27 +5373,6 @@ export class SpeechT5ForSpeechToText extends SpeechT5PreTrainedModel { }
*/
export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel {
- /**
- * Creates a new instance of the `SpeechT5ForTextToSpeech` class.
- * @param {Object} config The model configuration.
- * @param {any} session session for the model.
- * @param {any} decoder_merged_session session for the decoder.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, decoder_merged_session, generation_config) {
- super(config, session);
- this.decoder_merged_session = decoder_merged_session;
- this.generation_config = generation_config;
-
- this.num_decoder_layers = this.config.decoder_layers;
- this.num_decoder_heads = this.config.decoder_attention_heads;
- this.decoder_dim_kv = this.config.hidden_size / this.num_decoder_heads;
-
- this.num_encoder_layers = this.config.encoder_layers;
- this.num_encoder_heads = this.config.encoder_attention_heads;
- this.encoder_dim_kv = this.config.hidden_size / this.num_encoder_heads;
- }
-
/**
* @typedef {Object} SpeechOutput
* @property {Tensor} [spectrogram] The predicted log-mel spectrogram of shape
@@ -5127,7 +5442,7 @@ export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel {
};
this.addPastKeyValues(decoderFeeds, past_key_values);
- decoder_outputs = await sessionRun(this.decoder_merged_session, decoderFeeds);
+ decoder_outputs = await sessionRun(this.sessions['decoder_model_merged'], decoderFeeds);
past_key_values = this.getPastKeyValues(decoder_outputs, past_key_values);
const { prob, spectrum } = decoder_outputs;
@@ -5142,7 +5457,7 @@ export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel {
}
const spectrogram = cat(spectrogramParts);
- const { waveform } = await sessionRun(vocoder.session, { spectrogram });
+ const { waveform } = await sessionRun(vocoder.sessions['model'], { spectrogram });
return {
spectrogram,
@@ -5165,25 +5480,7 @@ export class SpeechT5HifiGan extends PreTrainedModel {
//////////////////////////////////////////////////
// TrOCR models
-export class TrOCRPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `TrOCRPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id;
-
- this.num_encoder_layers = this.num_decoder_layers = this.config.decoder_layers;
- this.num_encoder_heads = this.num_decoder_heads = this.config.decoder_attention_heads;
- this.encoder_dim_kv = this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;
- }
-}
+export class TrOCRPreTrainedModel extends PreTrainedModel { }
/**
* The TrOCR Decoder with a language modeling head.
@@ -5198,25 +5495,7 @@ export class TrOCRForCausalLM extends TrOCRPreTrainedModel { }
/**
* The bare Mistral Model outputting raw hidden-states without any specific head on top.
*/
-export class MistralPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `MistralPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.num_key_value_heads;
- this.num_layers = this.config.num_hidden_layers;
- this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
- }
-}
+export class MistralPreTrainedModel extends PreTrainedModel { }
export class MistralModel extends MistralPreTrainedModel { }
@@ -5229,25 +5508,7 @@ export class MistralForCausalLM extends MistralPreTrainedModel { }
/**
* The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.
*/
-export class Starcoder2PreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `Starcoder2PreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.num_key_value_heads;
- this.num_layers = this.config.num_hidden_layers;
- this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
- }
-}
+export class Starcoder2PreTrainedModel extends PreTrainedModel { }
export class Starcoder2Model extends Starcoder2PreTrainedModel { }
@@ -5260,25 +5521,7 @@ export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel { }
/**
* The bare Falcon Model outputting raw hidden-states without any specific head on top.
*/
-export class FalconPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `FalconPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.num_attention_heads;
- this.num_layers = this.config.num_hidden_layers;
- this.dim_kv = this.config.hidden_size / this.config.num_attention_heads;
- }
-}
+export class FalconPreTrainedModel extends PreTrainedModel { }
export class FalconModel extends FalconPreTrainedModel { }
@@ -5298,7 +5541,7 @@ export class ClapModel extends ClapPreTrainedModel { }
* **Example:** Compute text embeddings with `ClapTextModelWithProjection`.
*
* ```javascript
- * import { AutoTokenizer, ClapTextModelWithProjection } from '@xenova/transformers';
+ * import { AutoTokenizer, ClapTextModelWithProjection } from '@huggingface/transformers';
*
* // Load tokenizer and text model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/clap-htsat-unfused');
@@ -5334,7 +5577,7 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
* **Example:** Compute audio embeddings with `ClapAudioModelWithProjection`.
*
* ```javascript
- * import { AutoProcessor, ClapAudioModelWithProjection, read_audio } from '@xenova/transformers';
+ * import { AutoProcessor, ClapAudioModelWithProjection, read_audio } from '@huggingface/transformers';
*
* // Load processor and audio model
* const processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused');
@@ -5374,7 +5617,7 @@ export class VitsPreTrainedModel extends PreTrainedModel { }
*
* **Example:** Generate speech from text with `VitsModel`.
* ```javascript
- * import { AutoTokenizer, VitsModel } from '@xenova/transformers';
+ * import { AutoTokenizer, VitsModel } from '@huggingface/transformers';
*
* // Load the tokenizer and model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/mms-tts-eng');
@@ -5428,25 +5671,7 @@ export class SegformerForSemanticSegmentation extends SegformerPreTrainedModel {
//////////////////////////////////////////////////
// StableLm models
-export class StableLmPreTrainedModel extends PreTrainedModel {
- /**
- * Creates a new instance of the `StableLmPreTrainedModel` class.
- * @param {Object} config The configuration of the model.
- * @param {any} session The ONNX session containing the model weights.
- * @param {GenerationConfig} generation_config The generation configuration.
- */
- constructor(config, session, generation_config) {
- super(config, session);
- this.generation_config = generation_config;
-
- // config doesn't contain pad_token_id, so we assume it is the eos_token_id
- this.config.pad_token_id = this.config.eos_token_id
-
- this.num_heads = this.config.num_attention_heads;
- this.num_layers = this.config.num_hidden_layers;
- this.dim_kv = this.config.hidden_size / this.num_heads;
- }
-}
+export class StableLmPreTrainedModel extends PreTrainedModel { }
/**
* The bare StableLm Model transformer outputting raw hidden-states without any specific head on top.
@@ -5481,6 +5706,237 @@ export class EfficientNetForImageClassification extends EfficientNetPreTrainedMo
}
//////////////////////////////////////////////////
+//////////////////////////////////////////////////
+// Musicgen models
+export class MusicgenPreTrainedModel extends PreTrainedModel { }
+
+/**
+ * The bare Musicgen decoder model outputting raw hidden-states without any specific head on top.
+ */
+export class MusicgenModel extends MusicgenPreTrainedModel { }
+
+/**
+ * The MusicGen decoder model with a language modelling head on top.
+ */
+export class MusicgenForCausalLM extends MusicgenPreTrainedModel { }
+
+/**
+ * The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder,
+ * for music generation tasks with one or both of text and audio prompts.
+ *
+ * **Example:** Generate music from text with `Xenova/musicgen-small`.
+ * ```javascript
+ * import { AutoTokenizer, MusicgenForConditionalGeneration } from '@huggingface/transformers';
+ *
+ * // Load tokenizer and model
+ * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/musicgen-small');
+ * const model = await MusicgenForConditionalGeneration.from_pretrained(
+ * 'Xenova/musicgen-small', { dtype: 'fp32' }
+ * );
+ *
+ * // Prepare text input
+ * const prompt = '80s pop track with bassy drums and synth';
+ * const inputs = tokenizer(prompt);
+ *
+ * // Generate audio
+ * const audio_values = await model.generate({
+ * ...inputs,
+ * max_new_tokens: 512,
+ * do_sample: true,
+ * guidance_scale: 3,
+ * });
+ *
+ * // (Optional) Write the output to a WAV file
+ * import wavefile from 'wavefile';
+ * import fs from 'fs';
+ *
+ * const wav = new wavefile.WaveFile();
+ * wav.fromScratch(1, model.config.audio_encoder.sampling_rate, '32f', audio_values.data);
+ * fs.writeFileSync('musicgen_out.wav', wav.toBuffer());
+ * ```
+ */
+export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: not MusicgenPreTrainedModel
+ forward_params = [
+ 'input_ids',
+ 'attention_mask',
+ 'encoder_outputs',
+ 'decoder_input_ids',
+ 'decoder_attention_mask',
+ 'past_key_values',
+ ];
+
+ /**
+ * Apply the pattern mask to the final ids,
+ * then revert the pattern delay mask by filtering the pad token id in a single step.
+ * @param {Tensor} outputs The output tensor from the model.
+ * @returns {Tensor} The filtered output tensor.
+ */
+ _apply_and_filter_by_delay_pattern_mask(outputs) {
+ const [bs_x_codebooks, seqLength] = outputs.dims;
+ const num_codebooks = this.config.decoder.num_codebooks;
+ const upperBound = (seqLength - num_codebooks);
+
+ let newDataSize = 0;
+ for (let i = 0; i < outputs.size; ++i) {
+ if (outputs.data[i] === this.config.decoder.pad_token_id) {
+ continue;
+ }
+
+ const row = (i % seqLength);
+ const col = Math.floor(i / seqLength) % num_codebooks;
+
+ const diff = row - col;
+ if (diff > 0 && diff <= upperBound) {
+ outputs.data[newDataSize++] = outputs.data[i];
+ }
+ }
+
+ const batch_size = Math.floor(bs_x_codebooks / num_codebooks);
+ const inferred = newDataSize / (batch_size * num_codebooks);
+ // TODO: assert `inferred` is an integer
+ return new Tensor(
+ outputs.type,
+ outputs.data.slice(0, newDataSize),
+ [batch_size, num_codebooks, inferred]
+ );
+ }
+
+
+ prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
+ // apply the delay pattern mask
+ let clonedInputIds = structuredClone(input_ids);
+ for (let i = 0; i < clonedInputIds.length; ++i) {
+ for (let j = 0; j < clonedInputIds[i].length; ++j) {
+ if ((i % this.config.decoder.num_codebooks) >= j) {
+ clonedInputIds[i][j] = BigInt(this.config.decoder.pad_token_id);
+ }
+ }
+ }
+ // for classifier free guidance we need to replicate the decoder args across the batch dim
+ // (we'll split these before sampling)
+ if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
+ // [batch, seqLength] -> [2 * batch, seqLength]
+ clonedInputIds = clonedInputIds.concat(clonedInputIds);
+ }
+
+ const prepped = super.prepare_inputs_for_generation(clonedInputIds, model_inputs, generation_config);
+ return prepped;
+ }
+
+ /**
+ * Generates sequences of token ids for models with a language modeling head.
+ * @param {import('./generation/parameters.js').GenerationFunctionParameters} options
+ * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores.
+ */
+ async generate(options) {
+
+ const output_ids = await super.generate(options);
+
+ // apply the pattern mask to the final ids
+ // tensor: int64[1,batch_size,4,chunk_length]
+ const audio_codes = this._apply_and_filter_by_delay_pattern_mask(
+ /** @type {Tensor} */(output_ids)
+ ).unsqueeze_(0); // append the frame dimension back to the audio codes
+
+ const { audio_values } = await sessionRun(this.sessions['encodec_decode'], { audio_codes })
+
+ return audio_values;
+ }
+}
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+// MobileNetV1 models
+export class MobileNetV1PreTrainedModel extends PreTrainedModel { }
+
+/**
+ * The bare MobileNetV1 model outputting raw hidden-states without any specific head on top.
+ */
+export class MobileNetV1Model extends MobileNetV1PreTrainedModel { }
+
+/**
+ * MobileNetV1 model with an image classification head on top (a linear layer on top of the pooled features),
+ * e.g. for ImageNet.
+ */
+export class MobileNetV1ForImageClassification extends MobileNetV1PreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new SequenceClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+// MobileNetV2 models
+export class MobileNetV2PreTrainedModel extends PreTrainedModel { }
+
+/**
+ * The bare MobileNetV2 model outputting raw hidden-states without any specific head on top.
+ */
+export class MobileNetV2Model extends MobileNetV2PreTrainedModel { }
+
+/**
+ * MobileNetV2 model with an image classification head on top (a linear layer on top of the pooled features),
+ * e.g. for ImageNet.
+ */
+export class MobileNetV2ForImageClassification extends MobileNetV2PreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new SequenceClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+// MobileNetV3 models
+export class MobileNetV3PreTrainedModel extends PreTrainedModel { }
+
+/**
+ * The bare MobileNetV3 model outputting raw hidden-states without any specific head on top.
+ */
+export class MobileNetV3Model extends MobileNetV3PreTrainedModel { }
+
+/**
+ * MobileNetV3 model with an image classification head on top (a linear layer on top of the pooled features),
+ * e.g. for ImageNet.
+ */
+export class MobileNetV3ForImageClassification extends MobileNetV3PreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new SequenceClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
+//////////////////////////////////////////////////
+// MobileNetV4 models
+export class MobileNetV4PreTrainedModel extends PreTrainedModel { }
+
+/**
+ * The bare MobileNetV4 model outputting raw hidden-states without any specific head on top.
+ */
+export class MobileNetV4Model extends MobileNetV4PreTrainedModel { }
+
+/**
+ * MobileNetV4 model with an image classification head on top (a linear layer on top of the pooled features),
+ * e.g. for ImageNet.
+ */
+export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedModel {
+ /**
+ * @param {any} model_inputs
+ */
+ async _call(model_inputs) {
+ return new SequenceClassifierOutput(await super._call(model_inputs));
+ }
+}
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
// Decision Transformer models
export class DecisionTransformerPreTrainedModel extends PreTrainedModel { }
@@ -5515,38 +5971,42 @@ export class PretrainedMixin {
static BASE_IF_FAIL = false;
- /** @type {PreTrainedModel.from_pretrained} */
+ /** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, {
- quantized = true,
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
+ subfolder = 'onnx',
+ device = null,
+ dtype = null,
+ use_external_data_format = null,
+ session_options = {},
} = {}) {
- let options = {
- quantized,
+ const options = {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
+ subfolder,
+ device,
+ dtype,
+ use_external_data_format,
+ session_options,
}
- config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
- if (!options.config) {
- // If no config was passed, reuse this config for future processing
- options.config = config;
- }
+ options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
if (!this.MODEL_CLASS_MAPPINGS) {
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
}
- for (let MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
- const modelInfo = MODEL_CLASS_MAPPING.get(config.model_type);
+ for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
+ const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
if (!modelInfo) {
continue; // Item not found in this mapping
}
@@ -5554,10 +6014,10 @@ export class PretrainedMixin {
}
if (this.BASE_IF_FAIL) {
- console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`);
+ console.warn(`Unknown model class "${options.config.model_type}", attempting to construct from base class.`);
return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options);
} else {
- throw Error(`Unsupported model type: ${config.model_type}`)
+ throw Error(`Unsupported model type: ${options.config.model_type}`)
}
}
}
@@ -5593,10 +6053,17 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['wavlm', ['WavLMModel', WavLMModel]],
['audio-spectrogram-transformer', ['ASTModel', ASTModel]],
['vits', ['VitsModel', VitsModel]],
+ ['pyannote', ['PyAnnoteModel', PyAnnoteModel]],
+ ['wespeaker-resnet', ['WeSpeakerResNetModel', WeSpeakerResNetModel]],
['detr', ['DetrModel', DetrModel]],
+ ['rt_detr', ['RTDetrModel', RTDetrModel]],
['table-transformer', ['TableTransformerModel', TableTransformerModel]],
['vit', ['ViTModel', ViTModel]],
+ ['pvt', ['PvtModel', PvtModel]],
+ ['vit_msn', ['ViTMSNModel', ViTMSNModel]],
+ ['vit_mae', ['ViTMAEModel', ViTMAEModel]],
+ ['groupvit', ['GroupViTModel', GroupViTModel]],
['fastvit', ['FastViTModel', FastViTModel]],
['mobilevit', ['MobileViTModel', MobileViTModel]],
['mobilevitv2', ['MobileViTV2Model', MobileViTV2Model]],
@@ -5604,6 +6071,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['owlv2', ['Owlv2Model', Owlv2Model]],
['beit', ['BeitModel', BeitModel]],
['deit', ['DeiTModel', DeiTModel]],
+ ['hiera', ['HieraModel', HieraModel]],
['convnext', ['ConvNextModel', ConvNextModel]],
['convnextv2', ['ConvNextV2Model', ConvNextV2Model]],
['dinov2', ['Dinov2Model', Dinov2Model]],
@@ -5619,6 +6087,13 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['efficientnet', ['EfficientNetModel', EfficientNetModel]],
['decision_transformer', ['DecisionTransformerModel', DecisionTransformerModel]],
+
+ ['mobilenet_v1', ['MobileNetV1Model', MobileNetV1Model]],
+ ['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]],
+ ['mobilenet_v3', ['MobileNetV3Model', MobileNetV3Model]],
+ ['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]],
+
+ ['maskformer', ['MaskFormerModel', MaskFormerModel]],
]);
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
@@ -5637,6 +6112,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['bloom', ['BloomModel', BloomModel]],
+ ['jais', ['JAISModel', JAISModel]],
['gpt2', ['GPT2Model', GPT2Model]],
['gptj', ['GPTJModel', GPTJModel]],
['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]],
@@ -5644,13 +6120,20 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]],
['codegen', ['CodeGenModel', CodeGenModel]],
['llama', ['LlamaModel', LlamaModel]],
+ ['granite', ['GraniteModel', GraniteModel]],
+ ['cohere', ['CohereModel', CohereModel]],
+ ['gemma', ['GemmaModel', GemmaModel]],
+ ['gemma2', ['Gemma2Model', Gemma2Model]],
+ ['openelm', ['OpenELMModel', OpenELMModel]],
['qwen2', ['Qwen2Model', Qwen2Model]],
['phi', ['PhiModel', PhiModel]],
+ ['phi3', ['Phi3Model', Phi3Model]],
['mpt', ['MptModel', MptModel]],
['opt', ['OPTModel', OPTModel]],
['mistral', ['MistralModel', MistralModel]],
['starcoder2', ['Starcoder2Model', Starcoder2Model]],
['falcon', ['FalconModel', FalconModel]],
+ ['stablelm', ['StableLmModel', StableLmModel]],
]);
const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([
@@ -5664,6 +6147,7 @@ const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([
const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([
['vits', ['VitsModel', VitsModel]],
+ ['musicgen', ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration]],
]);
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
@@ -5715,17 +6199,24 @@ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
['blenderbot-small', ['BlenderbotSmallForConditionalGeneration', BlenderbotSmallForConditionalGeneration]],
]);
-const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
+const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
['bloom', ['BloomForCausalLM', BloomForCausalLM]],
['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]],
+ ['jais', ['JAISLMHeadModel', JAISLMHeadModel]],
['gptj', ['GPTJForCausalLM', GPTJForCausalLM]],
['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]],
['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]],
['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]],
['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]],
['llama', ['LlamaForCausalLM', LlamaForCausalLM]],
+ ['granite', ['GraniteForCausalLM', GraniteForCausalLM]],
+ ['cohere', ['CohereForCausalLM', CohereForCausalLM]],
+ ['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
+ ['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
+ ['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]],
['qwen2', ['Qwen2ForCausalLM', Qwen2ForCausalLM]],
['phi', ['PhiForCausalLM', PhiForCausalLM]],
+ ['phi3', ['Phi3ForCausalLM', Phi3ForCausalLM]],
['mpt', ['MptForCausalLM', MptForCausalLM]],
['opt', ['OPTForCausalLM', OPTForCausalLM]],
['mbart', ['MBartForCausalLM', MBartForCausalLM]],
@@ -5777,17 +6268,26 @@ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
]);
+const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
+ ['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]],
+ ['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]],
+ ['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
+]);
+
const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
]);
const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([
['vit', ['ViTForImageClassification', ViTForImageClassification]],
+ ['pvt', ['PvtForImageClassification', PvtForImageClassification]],
+ ['vit_msn', ['ViTMSNForImageClassification', ViTMSNForImageClassification]],
['fastvit', ['FastViTForImageClassification', FastViTForImageClassification]],
['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]],
['mobilevitv2', ['MobileViTV2ForImageClassification', MobileViTV2ForImageClassification]],
['beit', ['BeitForImageClassification', BeitForImageClassification]],
['deit', ['DeiTForImageClassification', DeiTForImageClassification]],
+ ['hiera', ['HieraForImageClassification', HieraForImageClassification]],
['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]],
['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]],
['dinov2', ['Dinov2ForImageClassification', Dinov2ForImageClassification]],
@@ -5795,10 +6295,15 @@ const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([
['swin', ['SwinForImageClassification', SwinForImageClassification]],
['segformer', ['SegformerForImageClassification', SegformerForImageClassification]],
['efficientnet', ['EfficientNetForImageClassification', EfficientNetForImageClassification]],
+ ['mobilenet_v1', ['MobileNetV1ForImageClassification', MobileNetV1ForImageClassification]],
+ ['mobilenet_v2', ['MobileNetV2ForImageClassification', MobileNetV2ForImageClassification]],
+ ['mobilenet_v3', ['MobileNetV3ForImageClassification', MobileNetV3ForImageClassification]],
+ ['mobilenet_v4', ['MobileNetV4ForImageClassification', MobileNetV4ForImageClassification]],
]);
const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([
['detr', ['DetrForObjectDetection', DetrForObjectDetection]],
+ ['rt_detr', ['RTDetrForObjectDetection', RTDetrForObjectDetection]],
['table-transformer', ['TableTransformerForObjectDetection', TableTransformerForObjectDetection]],
['yolos', ['YolosForObjectDetection', YolosForObjectDetection]],
]);
@@ -5809,12 +6314,19 @@ const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
]);
const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
+ // TODO: Do not add new models here
['detr', ['DetrForSegmentation', DetrForSegmentation]],
['clipseg', ['CLIPSegForImageSegmentation', CLIPSegForImageSegmentation]],
]);
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
+ ['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
+]);
+
+const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
+ ['detr', ['DetrForSegmentation', DetrForSegmentation]],
+ ['maskformer', ['MaskFormerForInstanceSegmentation', MaskFormerForInstanceSegmentation]],
]);
const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
@@ -5848,6 +6360,7 @@ const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([
['unispeech-sat', ['UniSpeechSatForAudioFrameClassification', UniSpeechSatForAudioFrameClassification]],
['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]],
['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]],
+ ['pyannote', ['PyAnnoteForAudioFrameClassification', PyAnnoteForAudioFrameClassification]],
]);
const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
@@ -5862,6 +6375,12 @@ const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]],
['depth_anything', ['DepthAnythingForDepthEstimation', DepthAnythingForDepthEstimation]],
['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
+ ['sapiens', ['SapiensForDepthEstimation', SapiensForDepthEstimation]],
+ ['depth_pro', ['DepthProForDepthEstimation', DepthProForDepthEstimation]],
+])
+
+const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
+ ['sapiens', ['SapiensForNormalEstimation', SapiensForNormalEstimation]],
])
// NOTE: This is custom to Transformers.js, and is necessary because certain models
@@ -5879,16 +6398,19 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
- [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
+ [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
+ [MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText],
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
+ [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
+ [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
@@ -5913,6 +6435,10 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
}
const CUSTOM_MAPPING = [
+ // OVERRIDE:
+ // TODO: Refactor to allow class to specify model
+ ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration, MODEL_TYPES.Musicgen],
+
['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly],
['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
@@ -5930,7 +6456,7 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModel.from_pretrained('bert-base-uncased');
+ * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
*/
export class AutoModel extends PretrainedMixin {
/** @type {Map[]} */
@@ -5944,7 +6470,7 @@ export class AutoModel extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english');
+ * let model = await AutoModelForSequenceClassification.from_pretrained('Xenova/distilbert-base-uncased-finetuned-sst-2-english');
*/
export class AutoModelForSequenceClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES];
@@ -5955,7 +6481,7 @@ export class AutoModelForSequenceClassification extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForTokenClassification.from_pretrained('Davlan/distilbert-base-multilingual-cased-ner-hrl');
+ * let model = await AutoModelForTokenClassification.from_pretrained('Xenova/distilbert-base-multilingual-cased-ner-hrl');
*/
export class AutoModelForTokenClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES];
@@ -5966,7 +6492,7 @@ export class AutoModelForTokenClassification extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForSeq2SeqLM.from_pretrained('t5-small');
+ * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
*/
export class AutoModelForSeq2SeqLM extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES];
@@ -6010,10 +6536,10 @@ export class AutoModelForTextToWaveform extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForCausalLM.from_pretrained('gpt2');
+ * let model = await AutoModelForCausalLM.from_pretrained('Xenova/gpt2');
*/
export class AutoModelForCausalLM extends PretrainedMixin {
- static MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES];
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES];
}
/**
@@ -6021,7 +6547,7 @@ export class AutoModelForCausalLM extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForMaskedLM.from_pretrained('bert-base-uncased');
+ * let model = await AutoModelForMaskedLM.from_pretrained('Xenova/bert-base-uncased');
*/
export class AutoModelForMaskedLM extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES];
@@ -6032,7 +6558,7 @@ export class AutoModelForMaskedLM extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad');
+ * let model = await AutoModelForQuestionAnswering.from_pretrained('Xenova/distilbert-base-cased-distilled-squad');
*/
export class AutoModelForQuestionAnswering extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES];
@@ -6043,7 +6569,7 @@ export class AutoModelForQuestionAnswering extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForVision2Seq.from_pretrained('nlpconnect/vit-gpt2-image-captioning');
+ * let model = await AutoModelForVision2Seq.from_pretrained('Xenova/vit-gpt2-image-captioning');
*/
export class AutoModelForVision2Seq extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES];
@@ -6054,7 +6580,7 @@ export class AutoModelForVision2Seq extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224');
+ * let model = await AutoModelForImageClassification.from_pretrained('Xenova/vit-base-patch16-224');
*/
export class AutoModelForImageClassification extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES];
@@ -6065,7 +6591,7 @@ export class AutoModelForImageClassification extends PretrainedMixin {
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForImageSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic');
+ * let model = await AutoModelForImageSegmentation.from_pretrained('Xenova/detr-resnet-50-panoptic');
*/
export class AutoModelForImageSegmentation extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES];
@@ -6082,12 +6608,23 @@ export class AutoModelForSemanticSegmentation extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES];
}
+/**
+ * Helper class which is used to instantiate pretrained universal image segmentation models with the `from_pretrained` function.
+ * The chosen model class is determined by the type specified in the model config.
+ *
+ * @example
+ * let model = await AutoModelForUniversalSegmentation.from_pretrained('hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation');
+ */
+export class AutoModelForUniversalSegmentation extends PretrainedMixin {
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES];
+}
+
/**
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
*
* @example
- * let model = await AutoModelForObjectDetection.from_pretrained('facebook/detr-resnet-50');
+ * let model = await AutoModelForObjectDetection.from_pretrained('Xenova/detr-resnet-50');
*/
export class AutoModelForObjectDetection extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES];
@@ -6141,6 +6678,10 @@ export class AutoModelForDepthEstimation extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES];
}
+export class AutoModelForNormalEstimation extends PretrainedMixin {
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES];
+}
+
export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
}
diff --git a/src/models/whisper/common_whisper.js b/src/models/whisper/common_whisper.js
new file mode 100644
index 000000000..df4cce4d5
--- /dev/null
+++ b/src/models/whisper/common_whisper.js
@@ -0,0 +1,151 @@
+
+
+const WHISPER_LANGUAGES = [
+ ["en", "english"],
+ ["zh", "chinese"],
+ ["de", "german"],
+ ["es", "spanish"],
+ ["ru", "russian"],
+ ["ko", "korean"],
+ ["fr", "french"],
+ ["ja", "japanese"],
+ ["pt", "portuguese"],
+ ["tr", "turkish"],
+ ["pl", "polish"],
+ ["ca", "catalan"],
+ ["nl", "dutch"],
+ ["ar", "arabic"],
+ ["sv", "swedish"],
+ ["it", "italian"],
+ ["id", "indonesian"],
+ ["hi", "hindi"],
+ ["fi", "finnish"],
+ ["vi", "vietnamese"],
+ ["he", "hebrew"],
+ ["uk", "ukrainian"],
+ ["el", "greek"],
+ ["ms", "malay"],
+ ["cs", "czech"],
+ ["ro", "romanian"],
+ ["da", "danish"],
+ ["hu", "hungarian"],
+ ["ta", "tamil"],
+ ["no", "norwegian"],
+ ["th", "thai"],
+ ["ur", "urdu"],
+ ["hr", "croatian"],
+ ["bg", "bulgarian"],
+ ["lt", "lithuanian"],
+ ["la", "latin"],
+ ["mi", "maori"],
+ ["ml", "malayalam"],
+ ["cy", "welsh"],
+ ["sk", "slovak"],
+ ["te", "telugu"],
+ ["fa", "persian"],
+ ["lv", "latvian"],
+ ["bn", "bengali"],
+ ["sr", "serbian"],
+ ["az", "azerbaijani"],
+ ["sl", "slovenian"],
+ ["kn", "kannada"],
+ ["et", "estonian"],
+ ["mk", "macedonian"],
+ ["br", "breton"],
+ ["eu", "basque"],
+ ["is", "icelandic"],
+ ["hy", "armenian"],
+ ["ne", "nepali"],
+ ["mn", "mongolian"],
+ ["bs", "bosnian"],
+ ["kk", "kazakh"],
+ ["sq", "albanian"],
+ ["sw", "swahili"],
+ ["gl", "galician"],
+ ["mr", "marathi"],
+ ["pa", "punjabi"],
+ ["si", "sinhala"],
+ ["km", "khmer"],
+ ["sn", "shona"],
+ ["yo", "yoruba"],
+ ["so", "somali"],
+ ["af", "afrikaans"],
+ ["oc", "occitan"],
+ ["ka", "georgian"],
+ ["be", "belarusian"],
+ ["tg", "tajik"],
+ ["sd", "sindhi"],
+ ["gu", "gujarati"],
+ ["am", "amharic"],
+ ["yi", "yiddish"],
+ ["lo", "lao"],
+ ["uz", "uzbek"],
+ ["fo", "faroese"],
+ ["ht", "haitian creole"],
+ ["ps", "pashto"],
+ ["tk", "turkmen"],
+ ["nn", "nynorsk"],
+ ["mt", "maltese"],
+ ["sa", "sanskrit"],
+ ["lb", "luxembourgish"],
+ ["my", "myanmar"],
+ ["bo", "tibetan"],
+ ["tl", "tagalog"],
+ ["mg", "malagasy"],
+ ["as", "assamese"],
+ ["tt", "tatar"],
+ ["haw", "hawaiian"],
+ ["ln", "lingala"],
+ ["ha", "hausa"],
+ ["ba", "bashkir"],
+ ["jw", "javanese"],
+ ["su", "sundanese"],
+]
+
+// @ts-ignore
+export const WHISPER_LANGUAGE_MAPPING = new Map(WHISPER_LANGUAGES);
+// @ts-ignore
+export const WHISPER_TO_LANGUAGE_CODE_MAPPING = new Map([
+ ...WHISPER_LANGUAGES.map(([k, v]) => [v, k]),
+ ...[
+ ["burmese", "my"],
+ ["valencian", "ca"],
+ ["flemish", "nl"],
+ ["haitian", "ht"],
+ ["letzeburgesch", "lb"],
+ ["pushto", "ps"],
+ ["panjabi", "pa"],
+ ["moldavian", "ro"],
+ ["moldovan", "ro"],
+ ["sinhalese", "si"],
+ ["castilian", "es"],
+ ]
+]);
+
+/**
+ * @param {string} language The language name or code
+ * @returns {string} The language code
+ */
+export function whisper_language_to_code(language) {
+ language = language.toLowerCase();
+
+ // Map to code from user-friendly name (e.g., "english" -> "en")
+ let language_code = WHISPER_TO_LANGUAGE_CODE_MAPPING.get(language);
+
+ if (language_code === undefined) {
+ // User provided something that is not a language name
+
+ if (WHISPER_LANGUAGE_MAPPING.has(language)) {
+ // User provided the language code directly (e.g., "en")
+ language_code = language;
+
+ } else {
+ // User provided something that is not a language code or name
+ const is_language_code = language.length === 2;
+ const langs = is_language_code ? WHISPER_LANGUAGE_MAPPING.keys() : WHISPER_LANGUAGE_MAPPING.values();
+
+ throw new Error(`Language "${language}" is not supported. Must be one of: ${JSON.stringify(langs)}`);
+ }
+ }
+ return language_code;
+}
diff --git a/src/models/whisper/generation_whisper.js b/src/models/whisper/generation_whisper.js
new file mode 100644
index 000000000..690455ff7
--- /dev/null
+++ b/src/models/whisper/generation_whisper.js
@@ -0,0 +1,89 @@
+import { GenerationConfig } from "../../generation/configuration_utils.js";
+
+export class WhisperGenerationConfig extends GenerationConfig {
+
+ /**
+ * Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
+ * @type {boolean}
+ */
+ return_timestamps = null;
+
+ /**
+ * Whether to return token-level timestamps
+ * with the text. This can be used with or without the `return_timestamps` option. To get word-level
+ * timestamps, use the tokenizer to group the tokens into words.
+ * @type {boolean}
+ */
+ return_token_timestamps = null;
+
+ /**
+ * The number of audio frames available in this chunk. This is only used generating word-level timestamps.
+ * @type {number}
+ */
+ num_frames = null;
+
+ /**
+ * Alignment heads to predict word-level timestamps. This is a list of [layer, head] pairs that
+ * select the cross-attention heads that are highly correlated to word-level timing.
+ * @type {[number, number][]}
+ */
+ alignment_heads = null;
+
+ /**
+ * Task to use for generation, either "translate" or "transcribe".
+ * @type {string}
+ */
+ task = null;
+
+ /**
+ * Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`.
+ * You can find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
+ * @type {string}
+ */
+ language = null;
+
+ /**
+ * The id of the `"<|notimestamps|>"` token.
+ * @type {number}
+ */
+ no_timestamps_token_id = null;
+
+ /**
+ * Rank-1 list of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
+ * provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
+ * transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
+ * correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
+ * @type {number[]}
+ */
+ prompt_ids = null;
+
+ /**
+ * Whether the model is multilingual or not.
+ * @type {boolean}
+ */
+ is_multilingual = null;
+
+ /**
+ * (Optional) A mapping from language tokens to their corresponding IDs.
+ * Only required if the model is multilingual.
+ * @type {Record|null}
+ */
+ lang_to_id = null;
+
+ /**
+ * (Optional) A mapping from task tokens to their corresponding IDs.
+ * @type {Record|null}
+ */
+ task_to_id = null;
+
+ /**
+ * Used to set the maximum value of the initial timestamp. This is used to prevent the model from
+ * predicting timestamps that are too far in the future.
+ * @type {number}
+ */
+ max_initial_timestamp_index = 1;
+}
+
+/**
+ * @typedef {import('../../generation/parameters.js').GenerationFunctionParameters & {generation_config: WhisperGenerationConfig} & WhisperGenerationConfig} WhisperGenerationFunctionParameters
+ */
diff --git a/src/ops/registry.js b/src/ops/registry.js
new file mode 100644
index 000000000..9b65fa4a8
--- /dev/null
+++ b/src/ops/registry.js
@@ -0,0 +1,103 @@
+import { createInferenceSession } from "../backends/onnx.js";
+import { Tensor } from "../utils/tensor.js";
+
+/**
+ * Asynchronously creates a wrapper function for running an ONNX inference session.
+ *
+ * @param {number[]} session_bytes The session data in bytes.
+ * @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options The options for the ONNX session.
+ * @template {string | [string] | string[]} T
+ * @param {T} names The name(s) of the output tensor(s).
+ *
+ * @returns {Promise): Promise>}
+ * The wrapper function for running the ONNX inference session.
+ */
+const wrap = async (session_bytes, session_options, names) => {
+ const session = await createInferenceSession(
+ new Uint8Array(session_bytes), session_options,
+ );
+ return /** @type {any} */(async (/** @type {Record} */ inputs) => {
+ const ortFeed = Object.fromEntries(Object.entries(inputs).map(([k, v]) => [k, v.ort_tensor]));
+ const outputs = await session.run(ortFeed);
+
+ if (Array.isArray(names)) {
+ return names.map((n) => new Tensor(outputs[n]));
+ } else {
+ return new Tensor(outputs[/** @type {string} */(names)]);
+ }
+ })
+}
+
+// In-memory registry of initialized ONNX operators
+export class TensorOpRegistry {
+ static session_options = {
+ // TODO: Allow for multiple execution providers
+ // executionProviders: ['webgpu'],
+ };
+
+ static get bilinear_interpolate_4d() {
+ if (!this._bilinear_interpolate_4d) {
+ this._bilinear_interpolate_4d = wrap(
+ [8, 9, 18, 0, 58, 128, 1, 10, 40, 10, 1, 120, 10, 0, 10, 0, 10, 1, 115, 18, 1, 121, 34, 6, 82, 101, 115, 105, 122, 101, 42, 17, 10, 4, 109, 111, 100, 101, 34, 6, 108, 105, 110, 101, 97, 114, 160, 1, 3, 18, 1, 114, 90, 31, 10, 1, 120, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 90, 15, 10, 1, 115, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 4, 98, 31, 10, 1, 121, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 66, 2, 16, 20],
+ this.session_options,
+ 'y',
+ );
+ }
+ return this._bilinear_interpolate_4d;
+ }
+
+ static get bicubic_interpolate_4d() {
+ if (!this._bicubic_interpolate_4d) {
+ this._bicubic_interpolate_4d = wrap(
+ [8, 9, 18, 0, 58, 127, 10, 39, 10, 1, 120, 10, 0, 10, 0, 10, 1, 115, 18, 1, 121, 34, 6, 82, 101, 115, 105, 122, 101, 42, 16, 10, 4, 109, 111, 100, 101, 34, 5, 99, 117, 98, 105, 99, 160, 1, 3, 18, 1, 114, 90, 31, 10, 1, 120, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 90, 15, 10, 1, 115, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 4, 98, 31, 10, 1, 121, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 66, 2, 16, 20],
+ this.session_options,
+ 'y',
+ );
+ }
+ return this._bicubic_interpolate_4d;
+ }
+
+ static get matmul() {
+ if (!this._matmul) {
+ this._matmul = wrap(
+ [8, 9, 18, 0, 58, 55, 10, 17, 10, 1, 97, 10, 1, 98, 18, 1, 99, 34, 6, 77, 97, 116, 77, 117, 108, 18, 1, 114, 90, 9, 10, 1, 97, 18, 4, 10, 2, 8, 1, 90, 9, 10, 1, 98, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 99, 18, 4, 10, 2, 8, 1, 66, 2, 16, 20],
+ this.session_options,
+ 'c',
+ );
+ }
+ return this._matmul;
+ }
+
+ static get stft() {
+ if (!this._stft) {
+ this._stft = wrap(
+ [8, 7, 18, 0, 58, 148, 1, 10, 38, 10, 1, 115, 10, 1, 106, 10, 1, 119, 10, 1, 108, 18, 1, 111, 34, 4, 83, 84, 70, 84, 42, 15, 10, 8, 111, 110, 101, 115, 105, 100, 101, 100, 24, 1, 160, 1, 2, 18, 1, 115, 90, 26, 10, 1, 115, 18, 21, 10, 19, 8, 1, 18, 15, 10, 3, 18, 1, 98, 10, 3, 18, 1, 115, 10, 3, 18, 1, 99, 90, 11, 10, 1, 106, 18, 6, 10, 4, 8, 7, 18, 0, 90, 16, 10, 1, 119, 18, 11, 10, 9, 8, 1, 18, 5, 10, 3, 18, 1, 119, 90, 11, 10, 1, 108, 18, 6, 10, 4, 8, 7, 18, 0, 98, 31, 10, 1, 111, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 102, 10, 3, 18, 1, 100, 10, 3, 18, 1, 99, 66, 2, 16, 17],
+ this.session_options,
+ 'o',
+ )
+ }
+ return this._stft;
+ }
+
+ static get rfft() {
+ if (!this._rfft) {
+ this._rfft = wrap(
+ [8, 9, 18, 0, 58, 97, 10, 33, 10, 1, 120, 10, 0, 10, 1, 97, 18, 1, 121, 34, 3, 68, 70, 84, 42, 15, 10, 8, 111, 110, 101, 115, 105, 100, 101, 100, 24, 1, 160, 1, 2, 18, 1, 100, 90, 21, 10, 1, 120, 18, 16, 10, 14, 8, 1, 18, 10, 10, 3, 18, 1, 115, 10, 3, 18, 1, 99, 90, 11, 10, 1, 97, 18, 6, 10, 4, 8, 7, 18, 0, 98, 21, 10, 1, 121, 18, 16, 10, 14, 8, 1, 18, 10, 10, 3, 18, 1, 115, 10, 3, 18, 1, 99, 66, 2, 16, 20],
+ this.session_options,
+ 'y',
+ )
+ }
+ return this._rfft;
+ }
+
+ static get top_k() {
+ if (!this._top_k) {
+ this._top_k = wrap(
+ [8, 10, 18, 0, 58, 73, 10, 18, 10, 1, 120, 10, 1, 107, 18, 1, 118, 18, 1, 105, 34, 4, 84, 111, 112, 75, 18, 1, 116, 90, 9, 10, 1, 120, 18, 4, 10, 2, 8, 1, 90, 15, 10, 1, 107, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 118, 18, 4, 10, 2, 8, 1, 98, 9, 10, 1, 105, 18, 4, 10, 2, 8, 7, 66, 2, 16, 21],
+ this.session_options,
+ [ /* Values */ 'v', /* Indices */ 'i']
+ )
+ }
+ return this._top_k;
+ }
+}
diff --git a/src/pipelines.js b/src/pipelines.js
index c7772aa55..d955803e6 100644
--- a/src/pipelines.js
+++ b/src/pipelines.js
@@ -3,7 +3,7 @@
*
* **Example:** Instantiate pipeline using the `pipeline` function.
* ```javascript
- * import { pipeline } from '@xenova/transformers';
+ * import { pipeline } from '@huggingface/transformers';
*
* const classifier = await pipeline('sentiment-analysis');
* const output = await classifier('I love transformers!');
@@ -34,6 +34,7 @@ import {
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForSemanticSegmentation,
+ AutoModelForUniversalSegmentation,
AutoModelForObjectDetection,
AutoModelForZeroShotObjectDetection,
AutoModelForDocumentQuestionAnswering,
@@ -47,9 +48,11 @@ import {
Processor
} from './processors.js';
-
import {
Callable,
+} from './utils/generic.js';
+
+import {
dispatchCallback,
pop,
product,
@@ -57,7 +60,6 @@ import {
import {
softmax,
max,
- getTopItems,
round,
} from './utils/maths.js';
import {
@@ -68,6 +70,7 @@ import {
mean_pooling,
interpolate,
quantize_embeddings,
+ topk,
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';
@@ -218,7 +221,7 @@ export class Pipeline extends Callable {
* @typedef {TextClassificationSingle[]} TextClassificationOutput
*
* @typedef {Object} TextClassificationPipelineOptions Parameters specific to text classification pipelines.
- * @property {number} [topk=1] The number of top predictions to be returned.
+ * @property {number} [top_k=1] The number of top predictions to be returned.
*
* @callback TextClassificationPipelineCallback Classify the text(s) given as inputs.
* @param {string|string[]} texts The input text(s) to be classified.
@@ -241,7 +244,7 @@ export class Pipeline extends Callable {
* **Example:** Multilingual sentiment-analysis w/ `Xenova/bert-base-multilingual-uncased-sentiment` (and return top 5 classes).
* ```javascript
* const classifier = await pipeline('sentiment-analysis', 'Xenova/bert-base-multilingual-uncased-sentiment');
- * const output = await classifier('Le meilleur film de tous les temps.', { topk: 5 });
+ * const output = await classifier('Le meilleur film de tous les temps.', { top_k: 5 });
* // [
* // { label: '5 stars', score: 0.9610759615898132 },
* // { label: '4 stars', score: 0.03323351591825485 },
@@ -254,7 +257,7 @@ export class Pipeline extends Callable {
* **Example:** Toxic comment classification w/ `Xenova/toxic-bert` (and return all classes).
* ```javascript
* const classifier = await pipeline('text-classification', 'Xenova/toxic-bert');
- * const output = await classifier('I hate you!', { topk: null });
+ * const output = await classifier('I hate you!', { top_k: null });
* // [
* // { label: 'toxic', score: 0.9593140482902527 },
* // { label: 'insult', score: 0.16187334060668945 },
@@ -277,7 +280,7 @@ export class TextClassificationPipeline extends (/** @type {new (options: TextPi
/** @type {TextClassificationPipelineCallback} */
async _call(texts, {
- topk = 1
+ top_k = 1
} = {}) {
// Run tokenization
@@ -292,28 +295,35 @@ export class TextClassificationPipeline extends (/** @type {new (options: TextPi
// TODO: Use softmax tensor function
const function_to_apply =
this.model.config.problem_type === 'multi_label_classification'
- ? batch => batch.sigmoid().data
- : batch => softmax(batch.data); // single_label_classification (default)
+ ? batch => batch.sigmoid()
+ : batch => new Tensor(
+ 'float32',
+ softmax(batch.data),
+ batch.dims,
+ ); // single_label_classification (default)
const id2label = this.model.config.id2label;
const toReturn = [];
for (const batch of outputs.logits) {
const output = function_to_apply(batch);
- const scores = getTopItems(output, topk);
- const vals = scores.map(x => ({
- label: id2label[x[0]],
- score: x[1],
+ const scores = await topk(output, top_k);
+
+ const values = scores[0].tolist();
+ const indices = scores[1].tolist();
+ const vals = indices.map((x, i) => ({
+ label: id2label ? id2label[x] : `LABEL_${x}`,
+ score: values[i],
}));
- if (topk === 1) {
+ if (top_k === 1) {
toReturn.push(...vals);
} else {
toReturn.push(vals);
}
}
- return Array.isArray(texts) || topk === 1 ? /** @type {TextClassificationOutput} */ (toReturn) : /** @type {TextClassificationOutput[]} */ (toReturn)[0];
+ return Array.isArray(texts) || top_k === 1 ? /** @type {TextClassificationOutput} */ (toReturn) : /** @type {TextClassificationOutput[]} */ (toReturn)[0];
}
}
@@ -428,9 +438,9 @@ export class TokenClassificationPipeline extends (/** @type {new (options: TextP
index: j,
word: word,
- // TODO: null for now, but will add
- start: null,
- end: null,
+ // TODO: Add support for start and end
+ // start: null,
+ // end: null,
});
}
toReturn.push(tokens);
@@ -447,7 +457,7 @@ export class TokenClassificationPipeline extends (/** @type {new (options: TextP
* @property {string} answer The answer to the question.
*
* @typedef {Object} QuestionAnsweringPipelineOptions Parameters specific to question answering pipelines.
- * @property {number} [topk=1] The number of top answer predictions to be returned.
+ * @property {number} [top_k=1] The number of top answer predictions to be returned.
*
* @callback QuestionAnsweringPipelineCallback Answer the question(s) given as inputs by using the context(s).
* @param {string|string[]} question One or several question(s) (must be used in conjunction with the `context` argument).
@@ -485,7 +495,7 @@ export class QuestionAnsweringPipeline extends (/** @type {new (options: TextPip
/** @type {QuestionAnsweringPipelineCallback} */
async _call(question, context, {
- topk = 1
+ top_k = 1
} = {}) {
// Run tokenization
@@ -495,30 +505,70 @@ export class QuestionAnsweringPipeline extends (/** @type {new (options: TextPip
truncation: true,
});
- const output = await this.model(inputs);
+ const { start_logits, end_logits } = await this.model(inputs);
+ const input_ids = inputs.input_ids.tolist();
+ const attention_mask = inputs.attention_mask.tolist();
+
+ // TODO: add support for `return_special_tokens_mask`
+ const special_tokens = this.tokenizer.all_special_ids;
/** @type {QuestionAnsweringOutput[]} */
const toReturn = [];
- for (let j = 0; j < output.start_logits.dims[0]; ++j) {
- const ids = inputs.input_ids[j];
- const sepIndex = ids.indexOf(this.tokenizer.sep_token_id);
+ for (let j = 0; j < start_logits.dims[0]; ++j) {
+ const ids = input_ids[j];
+ const sepIndex = ids.findIndex(x =>
+ // We use == to match bigint with number
+ // @ts-ignore
+ x == this.tokenizer.sep_token_id
+ );
- const s1 = Array.from(softmax(output.start_logits[j].data))
- .map((x, i) => [x, i])
- .filter(x => x[1] > sepIndex);
- const e1 = Array.from(softmax(output.end_logits[j].data))
- .map((x, i) => [x, i])
- .filter(x => x[1] > sepIndex);
- const options = product(s1, e1)
+ const valid_mask = attention_mask[j].map((y, ix) => (
+ y == 1
+ && (
+ ix === 0 // is cls_token
+ || (
+ ix > sepIndex
+ && special_tokens.findIndex(x => x == ids[ix]) === -1 // token is not a special token (special_tokens_mask == 0)
+ )
+ )
+ ));
+
+ const start = start_logits[j].tolist();
+ const end = end_logits[j].tolist();
+
+ // Now, we mask out values that can't be in the answer
+ // NOTE: We keep the cls_token unmasked (some models use it to indicate unanswerable questions)
+ for (let i = 1; i < start.length; ++i) {
+ if (
+ attention_mask[j] == 0 // is part of padding
+ || i <= sepIndex // is before the sep_token
+ || special_tokens.findIndex(x => x == ids[i]) !== -1 // Is a special token
+ ) {
+ // Make sure non-context indexes in the tensor cannot contribute to the softmax
+ start[i] = -Infinity;
+ end[i] = -Infinity;
+ }
+ }
+
+ // Normalize logits and spans to retrieve the answer
+ const start_scores = softmax(start).map((x, i) => [x, i]);
+ const end_scores = softmax(end).map((x, i) => [x, i]);
+
+ // Mask CLS
+ start_scores[0][0] = 0;
+ end_scores[0][0] = 0;
+
+ // Generate all valid spans and select best ones
+ const options = product(start_scores, end_scores)
.filter(x => x[0][1] <= x[1][1])
.map(x => [x[0][1], x[1][1], x[0][0] * x[1][0]])
.sort((a, b) => b[2] - a[2]);
- for (let k = 0; k < Math.min(options.length, topk); ++k) {
+ for (let k = 0; k < Math.min(options.length, top_k); ++k) {
const [start, end, score] = options[k];
- const answer_tokens = [...ids].slice(start, end + 1)
+ const answer_tokens = ids.slice(start, end + 1)
const answer = this.tokenizer.decode(answer_tokens, {
skip_special_tokens: true,
@@ -532,8 +582,8 @@ export class QuestionAnsweringPipeline extends (/** @type {new (options: TextPip
}
}
- // Mimic HF's return type based on topk
- return (topk === 1) ? toReturn[0] : toReturn;
+ // Mimic HF's return type based on top_k
+ return (top_k === 1) ? toReturn[0] : toReturn;
}
}
@@ -547,7 +597,7 @@ export class QuestionAnsweringPipeline extends (/** @type {new (options: TextPip
* @typedef {FillMaskSingle[]} FillMaskOutput
*
* @typedef {Object} FillMaskPipelineOptions Parameters specific to fill mask pipelines.
- * @property {number} [topk=5] When passed, overrides the number of predictions to return.
+ * @property {number} [top_k=5] When passed, overrides the number of predictions to return.
*
* @callback FillMaskPipelineCallback Fill the masked token in the text(s) given as inputs.
* @param {string|string[]} texts One or several texts (or one list of prompts) with masked tokens.
@@ -579,7 +629,7 @@ export class QuestionAnsweringPipeline extends (/** @type {new (options: TextPip
* **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `Xenova/bert-base-cased` (and return top result).
* ```javascript
* const unmasker = await pipeline('fill-mask', 'Xenova/bert-base-cased');
- * const output = await unmasker('The Milky Way is a [MASK] galaxy.', { topk: 1 });
+ * const output = await unmasker('The Milky Way is a [MASK] galaxy.', { top_k: 1 });
* // [{ token_str: 'spiral', score: 0.6299987435340881, token: 14061, sequence: 'The Milky Way is a spiral galaxy.' }]
* ```
*/
@@ -595,7 +645,7 @@ export class FillMaskPipeline extends (/** @type {new (options: TextPipelineCons
/** @type {FillMaskPipelineCallback} */
async _call(texts, {
- topk = 5
+ top_k = 5
} = {}) {
// Run tokenization
@@ -605,30 +655,40 @@ export class FillMaskPipeline extends (/** @type {new (options: TextPipelineCons
});
// Run model
- const outputs = await this.model(model_inputs)
+ const { logits } = await this.model(model_inputs)
const toReturn = [];
- for (let i = 0; i < model_inputs.input_ids.dims[0]; ++i) {
- const ids = model_inputs.input_ids[i];
- const mask_token_index = ids.indexOf(this.tokenizer.mask_token_id)
-
+ /** @type {bigint[][]} */
+ const input_ids = model_inputs.input_ids.tolist();
+ for (let i = 0; i < input_ids.length; ++i) {
+ const ids = input_ids[i];
+ const mask_token_index = ids.findIndex(x =>
+ // We use == to match bigint with number
+ // @ts-ignore
+ x == this.tokenizer.mask_token_id
+ );
if (mask_token_index === -1) {
throw Error(`Mask token (${this.tokenizer.mask_token}) not found in text.`)
}
- const logits = outputs.logits[i];
- const itemLogits = logits[mask_token_index];
+ const itemLogits = logits[i][mask_token_index];
- const scores = getTopItems(softmax(itemLogits.data), topk);
+ const scores = await topk(new Tensor(
+ 'float32',
+ softmax(itemLogits.data),
+ itemLogits.dims,
+ ), top_k);
+ const values = scores[0].tolist();
+ const indices = scores[1].tolist();
- toReturn.push(scores.map(x => {
- const sequence = [...ids];
- sequence[mask_token_index] = x[0];
+ toReturn.push(indices.map((x, i) => {
+ const sequence = ids.slice();
+ sequence[mask_token_index] = x;
return {
- score: x[1],
- token: x[0],
- token_str: this.tokenizer.model.vocab[x[0]],
+ score: values[i],
+ token: Number(x),
+ token_str: this.tokenizer.model.vocab[x],
sequence: this.tokenizer.decode(sequence, { skip_special_tokens: true }),
}
}));
@@ -645,7 +705,7 @@ export class FillMaskPipeline extends (/** @type {new (options: TextPipelineCons
*
* @callback Text2TextGenerationPipelineCallback Generate the output text(s) using text(s) given as inputs.
* @param {string|string[]} texts Input text for the encoder.
- * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model.
+ * @param {Partial} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise}
*
* @typedef {TextPipelineConstructorArgs & Text2TextGenerationPipelineCallback & Disposable} Text2TextGenerationPipelineType
@@ -703,20 +763,19 @@ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextP
padding: true,
truncation: true,
}
- let input_ids;
+ let inputs;
if (this instanceof TranslationPipeline && '_build_translation_inputs' in tokenizer) {
// TODO: move to Translation pipeline?
// Currently put here to avoid code duplication
// @ts-ignore
- input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs).input_ids;
+ inputs = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs);
} else {
- input_ids = tokenizer(texts, tokenizer_options).input_ids;
+ inputs = tokenizer(texts, tokenizer_options);
}
- const outputTokenIds = await this.model.generate(input_ids, generate_kwargs);
-
- return tokenizer.batch_decode(outputTokenIds, {
+ const outputTokenIds = await this.model.generate({ ...inputs, ...generate_kwargs });
+ return tokenizer.batch_decode(/** @type {Tensor} */(outputTokenIds), {
skip_special_tokens: true,
}).map(text => ({ [this._key]: text }));
}
@@ -730,7 +789,7 @@ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextP
*
* @callback SummarizationPipelineCallback Summarize the text(s) given as inputs.
* @param {string|string[]} texts One or several articles (or one list of articles) to summarize.
- * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model.
+ * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise}
*
* @typedef {TextPipelineConstructorArgs & SummarizationPipelineCallback & Disposable} SummarizationPipelineType
@@ -777,7 +836,7 @@ export class SummarizationPipeline extends (/** @type {new (options: TextPipelin
*
* @callback TranslationPipelineCallback Translate the text(s) given as inputs.
* @param {string|string[]} texts Texts to be translated.
- * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model.
+ * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise}
*
* @typedef {TextPipelineConstructorArgs & TranslationPipelineCallback & Disposable} TranslationPipelineType
@@ -855,11 +914,11 @@ function isChat(x) {
* @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines.
* @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences.
* @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned.
- * @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig
+ * @typedef {import('./generation/configuration_utils.js').GenerationConfig & TextGenerationSpecificParams} TextGenerationConfig
*
* @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs.
* @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete.
- * @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
+ * @param {Partial} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise} An array or object containing the generated texts.
*
* @typedef {TextPipelineConstructorArgs & TextGenerationPipelineCallback & Disposable} TextGenerationPipelineType
@@ -966,24 +1025,24 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
: generate_kwargs.return_full_text ?? true;
this.tokenizer.padding_side = 'left';
- const { input_ids, attention_mask } = this.tokenizer(inputs, {
+ const text_inputs = this.tokenizer(inputs, {
add_special_tokens,
padding: true,
truncation: true,
});
- const outputTokenIds = await this.model.generate(input_ids, generate_kwargs, null, {
- inputs_attention_mask: attention_mask
- });
+ const outputTokenIds = /** @type {Tensor} */(await this.model.generate({
+ ...text_inputs,
+ ...generate_kwargs
+ }));
- let decoded = this.tokenizer.batch_decode(outputTokenIds, {
+ const decoded = this.tokenizer.batch_decode(outputTokenIds, {
skip_special_tokens: true,
});
-
let promptLengths;
- if (!return_full_text && input_ids.dims.at(-1) > 0) {
- promptLengths = this.tokenizer.batch_decode(input_ids, {
+ if (!return_full_text && text_inputs.input_ids.dims.at(-1) > 0) {
+ promptLengths = this.tokenizer.batch_decode(text_inputs.input_ids, {
skip_special_tokens: true,
}).map(x => x.length);
}
@@ -991,7 +1050,7 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
/** @type {TextGenerationOutput[]} */
const toReturn = Array.from({ length: texts.length }, _ => []);
for (let i = 0; i < decoded.length; ++i) {
- const textIndex = Math.floor(i / outputTokenIds.length * texts.length);
+ const textIndex = Math.floor(i / outputTokenIds.dims[0] * texts.length);
if (promptLengths) {
// Trim the decoded text to only include the generated part
@@ -1365,7 +1424,7 @@ export class ImageFeatureExtractionPipeline extends (/** @type {new (options: Im
* @typedef {AudioClassificationSingle[]} AudioClassificationOutput
*
* @typedef {Object} AudioClassificationPipelineOptions Parameters specific to audio classification pipelines.
- * @property {number} [topk=null] The number of top labels that will be returned by the pipeline.
+ * @property {number} [top_k=5] The number of top labels that will be returned by the pipeline.
* If the provided number is `null` or higher than the number of labels available in the model configuration,
* it will default to the number of labels.
*
@@ -1400,7 +1459,7 @@ export class ImageFeatureExtractionPipeline extends (/** @type {new (options: Im
* ```javascript
* const classifier = await pipeline('audio-classification', 'Xenova/ast-finetuned-audioset-10-10-0.4593');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cat_meow.wav';
- * const output = await classifier(url, { topk: 4 });
+ * const output = await classifier(url, { top_k: 4 });
* // [
* // { label: 'Meow', score: 0.5617874264717102 },
* // { label: 'Cat', score: 0.22365376353263855 },
@@ -1421,11 +1480,9 @@ export class AudioClassificationPipeline extends (/** @type {new (options: Audio
/** @type {AudioClassificationPipelineCallback} */
async _call(audio, {
- topk = null
+ top_k = 5
} = {}) {
- const single = !Array.isArray(audio);
-
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const preparedAudios = await prepareAudios(audio, sampling_rate);
@@ -1437,20 +1494,23 @@ export class AudioClassificationPipeline extends (/** @type {new (options: Audio
const output = await this.model(inputs);
const logits = output.logits[0];
- const scores = getTopItems(softmax(logits.data), topk);
+ const scores = await topk(new Tensor(
+ 'float32',
+ softmax(logits.data),
+ logits.dims,
+ ), top_k);
- const vals = scores.map(x => ({
- label: /** @type {string} */ (id2label[x[0]]),
- score: /** @type {number} */ (x[1]),
+ const values = scores[0].tolist();
+ const indices = scores[1].tolist();
+
+ const vals = indices.map((x, i) => ({
+ label: /** @type {string} */ (id2label ? id2label[x] : `LABEL_${x}`),
+ score: /** @type {number} */ (values[i]),
}));
- if (topk === 1) {
- toReturn.push(...vals);
- } else {
- toReturn.push(vals);
- }
- }
- return !single || topk === 1 ? /** @type {AudioClassificationOutput} */ (toReturn) : /** @type {AudioClassificationOutput[]} */ (toReturn)[0];
+ toReturn.push(vals);
+ };
+ return Array.isArray(audio) ? toReturn : toReturn[0];
}
}
@@ -1546,12 +1606,6 @@ export class ZeroShotAudioClassificationPipeline extends (/** @type {new (option
}
}
-/**
- * @typedef {{stride: number[], input_features: Tensor, is_last: boolean, tokens?: number[], token_timestamps?: number[]}} ChunkCallbackItem
- * @callback ChunkCallback
- * @param {ChunkCallbackItem} chunk The chunk to process.
- */
-
/**
* @typedef {Object} Chunk
* @property {[number, number]} timestamp The start and end timestamp of the chunk in seconds.
@@ -1565,17 +1619,14 @@ export class ZeroShotAudioClassificationPipeline extends (/** @type {new (option
* containing all the various text chunks identified by the model.
*
* @typedef {Object} AutomaticSpeechRecognitionSpecificParams Parameters specific to automatic-speech-recognition pipelines.
- * @property {boolean|'word'} [kwargs.return_timestamps] Whether to return timestamps or not. Default is `false`.
- * @property {number} [kwargs.chunk_length_s] The length of audio chunks to process in seconds. Default is 0 (no chunking).
- * @property {number} [kwargs.stride_length_s] The length of overlap between consecutive audio chunks in seconds. If not provided, defaults to `chunk_length_s / 6`.
- * @property {ChunkCallback} [kwargs.chunk_callback] Callback function to be called with each chunk processed.
- * @property {boolean} [kwargs.force_full_sequences] Whether to force outputting full sequences or not. Default is `false`.
- * @property {string} [kwargs.language] The source language. Default is `null`, meaning it should be auto-detected. Use this to potentially improve performance if the source language is known.
- * @property {string} [kwargs.task] The task to perform. Default is `null`, meaning it should be auto-detected.
- * @property {number[][]} [kwargs.forced_decoder_ids] A list of pairs of integers which indicates a mapping from generation indices to token indices
- * that will be forced before sampling. For example, [[1, 123]] means the second generated token will always be a token of index 123.
+ * @property {boolean|'word'} [return_timestamps] Whether to return timestamps or not. Default is `false`.
+ * @property {number} [chunk_length_s] The length of audio chunks to process in seconds. Default is 0 (no chunking).
+ * @property {number} [stride_length_s] The length of overlap between consecutive audio chunks in seconds. If not provided, defaults to `chunk_length_s / 6`.
+ * @property {boolean} [force_full_sequences] Whether to force outputting full sequences or not. Default is `false`.
+ * @property {string} [language] The source language. Default is `null`, meaning it should be auto-detected. Use this to potentially improve performance if the source language is known.
+ * @property {string} [task] The task to perform. Default is `null`, meaning it should be auto-detected.
* @property {number} [num_frames] The number of frames in the input audio.
- * @typedef {import('./utils/generation.js').GenerationConfigType & AutomaticSpeechRecognitionSpecificParams} AutomaticSpeechRecognitionConfig
+ * @typedef {import('./generation/configuration_utils.js').GenerationConfig & AutomaticSpeechRecognitionSpecificParams} AutomaticSpeechRecognitionConfig
*
* @callback AutomaticSpeechRecognitionPipelineCallback Transcribe the audio sequence(s) given as inputs to text.
* @param {AudioPipelineInputs} audio The input audio file(s) to be transcribed. The input is either:
@@ -1583,7 +1634,7 @@ export class ZeroShotAudioClassificationPipeline extends (/** @type {new (option
* to get the waveform using the [`AudioContext`](https://developer.mozilla.org/en-US/docs/Web/API/AudioContext) API.
* If `AudioContext` is not available, you should pass the raw waveform in as a Float32Array of shape `(n, )`.
* - `Float32Array` or `Float64Array` of shape `(n, )`, representing the raw audio at the correct sampling rate (no further check will be done).
- * @param {AutomaticSpeechRecognitionConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
+ * @param {Partial} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise} An object containing the transcription text and optionally timestamps if `return_timestamps` is `true`.
*
* @typedef {TextAudioPipelineConstructorArgs & AutomaticSpeechRecognitionPipelineCallback & Disposable} AutomaticSpeechRecognitionPipelineType
@@ -1687,7 +1738,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
* @type {AutomaticSpeechRecognitionPipelineCallback}
* @private
*/
- async _call_wav2vec2(audio, kwargs = {}) {
+ async _call_wav2vec2(audio, kwargs) {
// TODO use kwargs
if (kwargs.language) {
@@ -1725,30 +1776,17 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
* @type {AutomaticSpeechRecognitionPipelineCallback}
* @private
*/
- async _call_whisper(audio, kwargs = {}) {
-
+ async _call_whisper(audio, kwargs) {
const return_timestamps = kwargs.return_timestamps ?? false;
const chunk_length_s = kwargs.chunk_length_s ?? 0;
- const chunk_callback = kwargs.chunk_callback ?? null;
const force_full_sequences = kwargs.force_full_sequences ?? false;
let stride_length_s = kwargs.stride_length_s ?? null;
- if (return_timestamps === 'word') {
- kwargs['return_token_timestamps'] = true;
- }
-
- const language = pop(kwargs, 'language', null);
- const task = pop(kwargs, 'task', null);
+ const generation_config = { ...kwargs }
- if (language || task || return_timestamps) {
- if (kwargs.forced_decoder_ids) {
- throw new Error("Cannot specify `language`/`task`/`return_timestamps` and `forced_decoder_ids` at the same time.")
- }
- // @ts-ignore
- const decoder_prompt_ids = this.tokenizer.get_decoder_prompt_ids({ language, task, no_timestamps: !return_timestamps })
- if (decoder_prompt_ids.length > 0) {
- kwargs.forced_decoder_ids = decoder_prompt_ids;
- }
+ if (return_timestamps === 'word') {
+ generation_config['return_token_timestamps'] = true;
+ generation_config['return_timestamps'] = false; // Do not predict timestamp tokens
}
const single = !Array.isArray(audio);
@@ -1764,7 +1802,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
const toReturn = [];
for (const aud of preparedAudios) {
- /** @type {ChunkCallbackItem[]} */
+ /** @type {{stride: number[], input_features: Tensor, is_last: boolean, tokens?: bigint[], token_timestamps?: number[]}[]} */
let chunks = [];
if (chunk_length_s > 0) {
if (stride_length_s === null) {
@@ -1781,22 +1819,23 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
let offset = 0;
// Create subarrays of audio with overlaps
-
- while (offset < aud.length) {
- const subarr = aud.subarray(offset, offset + window);
+ while (true) {
+ const offset_end = offset + window;
+ const subarr = aud.subarray(offset, offset_end);
const feature = await this.processor(subarr);
- const isFirst = offset === 0;
- const isLast = offset + jump >= aud.length;
+ const is_first = offset === 0;
+ const is_last = offset_end >= aud.length;
chunks.push({
stride: [
subarr.length,
- isFirst ? 0 : stride,
- isLast ? 0 : stride
+ is_first ? 0 : stride,
+ is_last ? 0 : stride
],
input_features: feature.input_features,
- is_last: isLast
+ is_last,
})
+ if (is_last) break;
offset += jump;
}
@@ -1810,28 +1849,27 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
// Generate for each set of input features
for (const chunk of chunks) {
- kwargs.num_frames = Math.floor(chunk.stride[0] / hop_length);
+ generation_config.num_frames = Math.floor(chunk.stride[0] / hop_length);
// NOTE: doing sequentially for now
- const data = await this.model.generate(chunk.input_features, kwargs);
+ const data = await this.model.generate({
+ inputs: chunk.input_features,
+ ...generation_config
+ });
// TODO: Right now we only get top beam
if (return_timestamps === 'word') {
- chunk.tokens = data.sequences[0];
+ chunk.tokens = data.sequences.tolist()[0];
chunk.token_timestamps = data.token_timestamps.tolist()[0].map(
(/** @type {number} */ x) => round(x, 2)
);
} else {
- chunk.tokens = data[0];
+ chunk.tokens = (/** @type {Tensor} */(data))[0].tolist();
}
// convert stride to seconds
chunk.stride = chunk.stride.map(x => x / sampling_rate);
-
- if (chunk_callback !== null) {
- chunk_callback(chunk)
- }
}
// Merge text chunks
@@ -1853,7 +1891,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
*
* @callback ImageToTextPipelineCallback Assign labels to the image(s) passed as inputs.
* @param {ImagePipelineInputs} texts The images to be captioned.
- * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model.
+ * @param {Partial} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise} An object (or array of objects) containing the generated text(s).
*
* @typedef {TextImagePipelineConstructorArgs & ImageToTextPipelineCallback & Disposable} ImageToTextPipelineType
@@ -1899,8 +1937,8 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe
const toReturn = [];
for (const batch of pixel_values) {
batch.dims = [1, ...batch.dims]
- const output = await this.model.generate(batch, generate_kwargs);
- const decoded = this.tokenizer.batch_decode(output, {
+ const output = await this.model.generate({ inputs: batch, ...generate_kwargs });
+ const decoded = this.tokenizer.batch_decode(/** @type {Tensor} */(output), {
skip_special_tokens: true,
}).map(x => ({ generated_text: x.trim() }))
toReturn.push(decoded);
@@ -1917,7 +1955,7 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe
* @typedef {ImageClassificationSingle[]} ImageClassificationOutput
*
* @typedef {Object} ImageClassificationPipelineOptions Parameters specific to image classification pipelines.
- * @property {number} [topk=1] The number of top labels that will be returned by the pipeline.
+ * @property {number} [top_k=1] The number of top labels that will be returned by the pipeline.
*
* @callback ImageClassificationPipelineCallback Assign labels to the image(s) passed as inputs.
* @param {ImagePipelineInputs} images The input images(s) to be classified.
@@ -1945,7 +1983,7 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe
* ```javascript
* const classifier = await pipeline('image-classification', 'Xenova/vit-base-patch16-224');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg';
- * const output = await classifier(url, { topk: 3 });
+ * const output = await classifier(url, { top_k: 3 });
* // [
* // { label: 'tiger, Panthera tigris', score: 0.632695734500885 },
* // { label: 'tiger cat', score: 0.3634825646877289 },
@@ -1957,7 +1995,7 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe
* ```javascript
* const classifier = await pipeline('image-classification', 'Xenova/vit-base-patch16-224');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg';
- * const output = await classifier(url, { topk: 0 });
+ * const output = await classifier(url, { top_k: 0 });
* // [
* // { label: 'tiger, Panthera tigris', score: 0.632695734500885 },
* // { label: 'tiger cat', score: 0.3634825646877289 },
@@ -1979,32 +2017,36 @@ export class ImageClassificationPipeline extends (/** @type {new (options: Image
/** @type {ImageClassificationPipelineCallback} */
async _call(images, {
- topk = 1
+ top_k = 5
} = {}) {
- const isBatched = Array.isArray(images);
const preparedImages = await prepareImages(images);
const { pixel_values } = await this.processor(preparedImages);
const output = await this.model({ pixel_values });
const id2label = this.model.config.id2label;
+
+ /** @type {ImageClassificationOutput[]} */
const toReturn = [];
for (const batch of output.logits) {
- const scores = getTopItems(softmax(batch.data), topk);
+ const scores = await topk(new Tensor(
+ 'float32',
+ softmax(batch.data),
+ batch.dims,
+ ), top_k);
- const vals = scores.map(x => ({
- label: id2label[x[0]],
- score: x[1],
+ const values = scores[0].tolist();
+ const indices = scores[1].tolist();
+
+ const vals = indices.map((x, i) => ({
+ label: /** @type {string} */ (id2label ? id2label[x] : `LABEL_${x}`),
+ score: /** @type {number} */ (values[i]),
}));
- if (topk === 1) {
- toReturn.push(...vals);
- } else {
- toReturn.push(vals);
- }
+ toReturn.push(vals);
}
- return isBatched || topk === 1 ? /** @type {ImageClassificationOutput} */ (toReturn) : /** @type {ImageClassificationOutput[]} */ (toReturn)[0];
+ return Array.isArray(images) ? toReturn : toReturn[0];
}
}
@@ -2348,7 +2390,7 @@ export class ObjectDetectionPipeline extends (/** @type {new (options: ImagePipe
*
* @typedef {Object} ZeroShotObjectDetectionPipelineOptions Parameters specific to zero-shot object detection pipelines.
* @property {number} [threshold=0.1] The probability necessary to make a prediction.
- * @property {number} [topk=null] The number of top predictions that will be returned by the pipeline.
+ * @property {number} [top_k=null] The number of top predictions that will be returned by the pipeline.
* If the provided number is `null` or higher than the number of predictions available, it will default
* to the number of predictions.
* @property {boolean} [percentage=false] Whether to return the boxes coordinates in percentage (true) or in pixels (false).
@@ -2401,7 +2443,7 @@ export class ObjectDetectionPipeline extends (/** @type {new (options: ImagePipe
* const detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
* const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png';
* const candidate_labels = ['hat', 'book', 'sunglasses', 'camera'];
- * const output = await detector(url, candidate_labels, { topk: 4, threshold: 0.05 });
+ * const output = await detector(url, candidate_labels, { top_k: 4, threshold: 0.05 });
* // [
* // {
* // score: 0.1606510728597641,
@@ -2439,7 +2481,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
/** @type {ZeroShotObjectDetectionPipelineCallback} */
async _call(images, candidate_labels, {
threshold = 0.1,
- topk = null,
+ top_k = null,
percentage = false,
} = {}) {
@@ -2474,8 +2516,8 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
label: candidate_labels[processed.classes[i]],
box: get_bounding_box(box, !percentage),
})).sort((a, b) => b.score - a.score);
- if (topk !== null) {
- result = result.slice(0, topk);
+ if (top_k !== null) {
+ result = result.slice(0, top_k);
}
toReturn.push(result)
}
@@ -2492,7 +2534,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
* @callback DocumentQuestionAnsweringPipelineCallback Answer the question given as input by using the document.
* @param {ImageInput} image The image of the document to use.
* @param {string} question A question to ask of the document.
- * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model.
+ * @param {Partial} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise} An object (or array of objects) containing the answer(s).
*
* @typedef {TextImagePipelineConstructorArgs & DocumentQuestionAnsweringPipelineCallback & Disposable} DocumentQuestionAnsweringPipelineType
@@ -2524,6 +2566,7 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
/** @type {DocumentQuestionAnsweringPipelineCallback} */
async _call(image, question, generate_kwargs = {}) {
+ throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented
// NOTE: For now, we only support a batch size of 1
@@ -2540,17 +2583,15 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
}).input_ids;
// Run model
- const output = await this.model.generate(
- pixel_values,
- {
- ...generate_kwargs,
- decoder_input_ids,
- max_length: this.model.config.decoder.max_position_embeddings,
- }
- );
+ const output = await this.model.generate({
+ inputs: pixel_values,
+ max_length: this.model.config.decoder.max_position_embeddings,
+ decoder_input_ids,
+ ...generate_kwargs,
+ });
// Decode output
- const decoded = this.tokenizer.batch_decode(output)[0];
+ const decoded = this.tokenizer.batch_decode(/** @type {Tensor} */(output))[0];
// Parse answer
const match = decoded.match(/(.*?)<\/s_answer>/);
@@ -2671,7 +2712,7 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi
// Load vocoder, if not provided
if (!this.vocoder) {
console.log('No vocoder specified, using default HifiGan vocoder.');
- this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { quantized: false });
+ this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' });
}
// Load speaker embeddings as Float32Array from path/URL
@@ -3005,7 +3046,7 @@ const SUPPORTED_TASKS = Object.freeze({
"image-segmentation": {
// no tokenizer
"pipeline": ImageSegmentationPipeline,
- "model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation],
+ "model": [AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation],
"processor": AutoProcessor,
"default": {
// TODO: replace with original
@@ -3164,7 +3205,7 @@ const TASK_ALIASES = Object.freeze({
* - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`.
* - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`.
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.
- * @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline.
+ * @param {import('./utils/hub.js').PretrainedModelOptions} [options] Optional parameters for the pipeline.
* @returns {Promise} A Pipeline object for the specified task.
* @throws {Error} If an unsupported pipeline is requested.
*/
@@ -3172,13 +3213,15 @@ export async function pipeline(
task,
model = null,
{
- quantized = true,
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
+ device = null,
+ dtype = null,
model_file_name = null,
+ session_options = {},
} = {}
) {
// Helper method to construct pipeline
@@ -3200,13 +3243,15 @@ export async function pipeline(
}
const pretrainedOptions = {
- quantized,
progress_callback,
config,
cache_dir,
local_files_only,
revision,
+ device,
+ dtype,
model_file_name,
+ session_options,
}
const classes = new Map([
@@ -3243,7 +3288,7 @@ async function loadItems(mapping, model, pretrainedOptions) {
/**@type {Promise[]} */
const promises = [];
- for (let [name, cls] of mapping.entries()) {
+ for (const [name, cls] of mapping.entries()) {
if (!cls) continue;
/**@type {Promise} */
@@ -3251,7 +3296,7 @@ async function loadItems(mapping, model, pretrainedOptions) {
if (Array.isArray(cls)) {
promise = new Promise(async (resolve, reject) => {
let e;
- for (let c of cls) {
+ for (const c of cls) {
if (c === null) {
// If null, we resolve it immediately, meaning the relevant
// class was not found, but it is optional.
@@ -3262,7 +3307,17 @@ async function loadItems(mapping, model, pretrainedOptions) {
resolve(await c.from_pretrained(model, pretrainedOptions));
return;
} catch (err) {
- e = err;
+ if (err.message?.includes('Unsupported model type')) {
+ // If the error is due to an unsupported model type, we
+ // save the error and try the next class.
+ e = err;
+ } else if (err.message?.includes('Could not locate file')) {
+ e = err;
+ } else {
+ reject(err);
+ return;
+ }
+
}
}
reject(e);
@@ -3279,7 +3334,7 @@ async function loadItems(mapping, model, pretrainedOptions) {
await Promise.all(promises);
// Then assign to result
- for (let [name, promise] of Object.entries(result)) {
+ for (const [name, promise] of Object.entries(result)) {
result[name] = await promise;
}
diff --git a/src/processors.js b/src/processors.js
index 4b9a60b51..e95dc31e9 100644
--- a/src/processors.js
+++ b/src/processors.js
@@ -4,7 +4,7 @@
*
* **Example:** Using a `WhisperProcessor` to prepare an audio input for a model.
* ```javascript
- * import { AutoProcessor, read_audio } from '@xenova/transformers';
+ * import { AutoProcessor, read_audio } from '@huggingface/transformers';
*
* let processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en');
* let audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac', 16000);
@@ -21,6 +21,9 @@
*/
import {
Callable,
+} from './utils/generic.js';
+
+import {
calculateDimensions,
calculateReflectOffset,
} from './utils/core.js';
@@ -37,7 +40,7 @@ import {
} from './utils/maths.js';
-import { Tensor, permute, cat, interpolate, stack } from './utils/tensor.js';
+import { Tensor, cat, interpolate, stack, interpolate_4d, full } from './utils/tensor.js';
import { RawImage } from './utils/image.js';
import {
@@ -70,7 +73,7 @@ function center_to_corners_format([centerX, centerY, width, height]) {
* @param {Tensor} outputs.logits The logits
* @param {Tensor} outputs.pred_boxes The predicted boxes.
* @param {number} [threshold=0.5] The threshold to use for the scores.
- * @param {number[][]} [target_sizes=null] The sizes of the original images.
+ * @param {[number, number][]} [target_sizes=null] The sizes of the original images.
* @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed.
* @return {Object[]} An array of objects containing the post-processed outputs.
* @private
@@ -116,10 +119,13 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
// This is the background class, skip it
continue;
}
- indices.push(maxIndex);
-
// Compute softmax over classes
probs = softmax(logit.data);
+
+ if (probs[maxIndex] < threshold) {
+ continue;
+ }
+ indices.push(maxIndex);
}
for (const index of indices) {
@@ -144,6 +150,364 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
return toReturn;
}
+
+/**
+ * Post-processes the outputs of the model (for semantic segmentation).
+ * @param {*} outputs Raw outputs of the model.
+ * @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size
+ * (height, width) of each prediction. If unset, predictions will not be resized.
+ * @returns {{segmentation: Tensor; labels: number[]}[]} The semantic segmentation maps.
+ */
+function post_process_semantic_segmentation(outputs, target_sizes = null) {
+
+ const logits = outputs.logits;
+ const batch_size = logits.dims[0];
+
+ if (target_sizes !== null && target_sizes.length !== batch_size) {
+ throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ }
+
+ const toReturn = [];
+ for (let i = 0; i < batch_size; ++i) {
+ const target_size = target_sizes !== null ? target_sizes[i] : null;
+
+ let data = logits[i];
+
+ // 1. If target_size is not null, we need to resize the masks to the target size
+ if (target_size !== null) {
+ // resize the masks to the target size
+ data = interpolate(data, target_size, 'bilinear', false);
+ }
+ const [height, width] = target_size ?? data.dims.slice(-2);
+
+ const segmentation = new Tensor(
+ 'int32',
+ new Int32Array(height * width),
+ [height, width]
+ );
+
+ // Buffer to store current largest value
+ const buffer = data[0].data;
+ const segmentation_data = segmentation.data;
+ for (let j = 1; j < data.dims[0]; ++j) {
+ const row = data[j].data;
+ for (let k = 0; k < row.length; ++k) {
+ if (row[k] > buffer[k]) {
+ buffer[k] = row[k];
+ segmentation_data[k] = j;
+ }
+ }
+ }
+
+ // Store which objects have labels
+ // This is much more efficient that creating a set of the final values
+ const hasLabel = new Array(data.dims[0]);
+ for (let j = 0; j < segmentation_data.length; ++j) {
+ const index = segmentation_data[j];
+ hasLabel[index] = index;
+ }
+ /** @type {number[]} The unique list of labels that were detected */
+ const labels = hasLabel.filter(x => x !== undefined);
+
+ toReturn.push({ segmentation, labels });
+ }
+ return toReturn;
+}
+
+
+/**
+ * Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and `labels`.
+ * @param {Tensor} class_logits The class logits.
+ * @param {Tensor} mask_logits The mask logits.
+ * @param {number} object_mask_threshold A number between 0 and 1 used to binarize the masks.
+ * @param {number} num_labels The number of labels.
+ * @returns {[Tensor[], number[], number[]]} The binarized masks, the scores, and the labels.
+ * @private
+ */
+function remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels) {
+
+ const mask_probs_item = [];
+ const pred_scores_item = [];
+ const pred_labels_item = [];
+
+ for (let j = 0; j < class_logits.dims[0]; ++j) {
+ const cls = class_logits[j];
+ const mask = mask_logits[j];
+
+ const pred_label = max(cls.data)[1];
+ if (pred_label === num_labels) {
+ // Is the background, so we ignore it
+ continue;
+ }
+
+ const scores = softmax(cls.data);
+ const pred_score = scores[pred_label];
+ if (pred_score > object_mask_threshold) {
+ mask_probs_item.push(mask);
+ pred_scores_item.push(pred_score);
+ pred_labels_item.push(pred_label);
+ }
+ }
+
+ return [mask_probs_item, pred_scores_item, pred_labels_item];
+}
+
+/**
+ * Checks whether the segment is valid or not.
+ * @param {Int32Array} mask_labels Labels for each pixel in the mask.
+ * @param {Tensor[]} mask_probs Probabilities for each pixel in the masks.
+ * @param {number} k The class id of the segment.
+ * @param {number} mask_threshold The mask threshold.
+ * @param {number} overlap_mask_area_threshold The overlap mask area threshold.
+ * @returns {[boolean, number[]]} Whether the segment is valid or not, and the indices of the valid labels.
+ * @private
+ */
+function check_segment_validity(
+ mask_labels,
+ mask_probs,
+ k,
+ mask_threshold = 0.5,
+ overlap_mask_area_threshold = 0.8
+) {
+ // mask_k is a 1D array of indices, indicating where the mask is equal to k
+ const mask_k = [];
+ let mask_k_area = 0;
+ let original_area = 0;
+
+ const mask_probs_k_data = mask_probs[k].data;
+
+ // Compute the area of all the stuff in query k
+ for (let i = 0; i < mask_labels.length; ++i) {
+ if (mask_labels[i] === k) {
+ mask_k.push(i);
+ ++mask_k_area;
+ }
+
+ if (mask_probs_k_data[i] >= mask_threshold) {
+ ++original_area;
+ }
+ }
+ let mask_exists = mask_k_area > 0 && original_area > 0;
+
+ // Eliminate disconnected tiny segments
+ if (mask_exists) {
+ // Perform additional check
+ let area_ratio = mask_k_area / original_area;
+ mask_exists = area_ratio > overlap_mask_area_threshold;
+ }
+
+ return [mask_exists, mask_k]
+}
+
+/**
+ * Computes the segments.
+ * @param {Tensor[]} mask_probs The mask probabilities.
+ * @param {number[]} pred_scores The predicted scores.
+ * @param {number[]} pred_labels The predicted labels.
+ * @param {number} mask_threshold The mask threshold.
+ * @param {number} overlap_mask_area_threshold The overlap mask area threshold.
+ * @param {Set} label_ids_to_fuse The label ids to fuse.
+ * @param {number[]} target_size The target size of the image.
+ * @returns {[Tensor, Array<{id: number, label_id: number, score: number}>]} The computed segments.
+ * @private
+ */
+function compute_segments(
+ mask_probs,
+ pred_scores,
+ pred_labels,
+ mask_threshold,
+ overlap_mask_area_threshold,
+ label_ids_to_fuse = null,
+ target_size = null,
+) {
+ const [height, width] = target_size ?? mask_probs[0].dims;
+
+ const segmentation = new Tensor(
+ 'int32',
+ new Int32Array(height * width),
+ [height, width]
+ );
+ const segments = [];
+
+ // 1. If target_size is not null, we need to resize the masks to the target size
+ if (target_size !== null) {
+ // resize the masks to the target size
+ for (let i = 0; i < mask_probs.length; ++i) {
+ mask_probs[i] = interpolate(mask_probs[i], target_size, 'bilinear', false);
+ }
+ }
+
+ // 2. Weigh each mask by its prediction score
+ // NOTE: `mask_probs` is updated in-place
+ //
+ // Temporary storage for the best label/scores for each pixel ([height, width]):
+ const mask_labels = new Int32Array(mask_probs[0].data.length);
+ const bestScores = new Float32Array(mask_probs[0].data.length);
+
+ for (let i = 0; i < mask_probs.length; ++i) {
+ let score = pred_scores[i];
+
+ const mask_probs_i_data = mask_probs[i].data;
+
+ for (let j = 0; j < mask_probs_i_data.length; ++j) {
+ mask_probs_i_data[j] *= score
+ if (mask_probs_i_data[j] > bestScores[j]) {
+ mask_labels[j] = i;
+ bestScores[j] = mask_probs_i_data[j];
+ }
+ }
+ }
+
+ let current_segment_id = 0;
+
+ // let stuff_memory_list = {}
+ const segmentation_data = segmentation.data;
+ for (let k = 0; k < pred_labels.length; ++k) {
+ const pred_class = pred_labels[k];
+
+ // TODO add `should_fuse`
+ // let should_fuse = pred_class in label_ids_to_fuse
+
+ // Check if mask exists and large enough to be a segment
+ const [mask_exists, mask_k] = check_segment_validity(
+ mask_labels,
+ mask_probs,
+ k,
+ mask_threshold,
+ overlap_mask_area_threshold
+ )
+
+ if (!mask_exists) {
+ // Nothing to see here
+ continue;
+ }
+
+ // TODO
+ // if (pred_class in stuff_memory_list) {
+ // current_segment_id = stuff_memory_list[pred_class]
+ // } else {
+ // current_segment_id += 1;
+ // }
+ ++current_segment_id;
+
+
+ // Add current object segment to final segmentation map
+ for (const index of mask_k) {
+ segmentation_data[index] = current_segment_id;
+ }
+
+ segments.push({
+ id: current_segment_id,
+ label_id: pred_class,
+ // was_fused: should_fuse, TODO
+ score: pred_scores[k],
+ })
+
+ // TODO
+ // if(should_fuse){
+ // stuff_memory_list[pred_class] = current_segment_id
+ // }
+ }
+
+ return [segmentation, segments];
+}
+
+
+/**
+ * Post-process the model output to generate the final panoptic segmentation.
+ * @param {*} outputs The model output to post process
+ * @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks.
+ * @param {number} [mask_threshold=0.5] Threshold to use when turning the predicted masks into binary values.
+ * @param {number} [overlap_mask_area_threshold=0.8] The overlap mask area threshold to merge or discard small disconnected parts within each binary instance mask.
+ * @param {Set} [label_ids_to_fuse=null] The labels in this state will have all their instances be fused together.
+ * @param {[number, number][]} [target_sizes=null] The target sizes to resize the masks to.
+ * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>}
+ */
+function post_process_panoptic_segmentation(
+ outputs,
+ threshold = 0.5,
+ mask_threshold = 0.5,
+ overlap_mask_area_threshold = 0.8,
+ label_ids_to_fuse = null,
+ target_sizes = null,
+) {
+ if (label_ids_to_fuse === null) {
+ console.warn("`label_ids_to_fuse` unset. No instance will be fused.")
+ label_ids_to_fuse = new Set();
+ }
+
+ const class_queries_logits = outputs.class_queries_logits ?? outputs.logits; // [batch_size, num_queries, num_classes+1]
+ const masks_queries_logits = outputs.masks_queries_logits ?? outputs.pred_masks; // [batch_size, num_queries, height, width]
+
+ const mask_probs = masks_queries_logits.sigmoid() // [batch_size, num_queries, height, width]
+
+ let [batch_size, num_queries, num_labels] = class_queries_logits.dims;
+ num_labels -= 1; // Remove last class (background)
+
+ if (target_sizes !== null && target_sizes.length !== batch_size) {
+ throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ }
+
+ let toReturn = [];
+ for (let i = 0; i < batch_size; ++i) {
+ let target_size = target_sizes !== null ? target_sizes[i] : null;
+
+ let class_logits = class_queries_logits[i];
+ let mask_logits = mask_probs[i];
+
+ let [mask_probs_item, pred_scores_item, pred_labels_item] = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels);
+
+ if (pred_labels_item.length === 0) {
+ // No mask found
+ let [height, width] = target_size ?? mask_logits.dims.slice(-2);
+
+ let segmentation = new Tensor(
+ 'int32',
+ new Int32Array(height * width).fill(-1),
+ [height, width]
+ )
+ toReturn.push({
+ segmentation: segmentation,
+ segments_info: []
+ });
+ continue;
+ }
+
+
+ // Get segmentation map and segment information of batch item
+ let [segmentation, segments] = compute_segments(
+ mask_probs_item,
+ pred_scores_item,
+ pred_labels_item,
+ mask_threshold,
+ overlap_mask_area_threshold,
+ label_ids_to_fuse,
+ target_size,
+ )
+
+ toReturn.push({
+ segmentation: segmentation,
+ segments_info: segments
+ })
+ }
+
+ return toReturn;
+}
+
+
+/**
+ * Post-processes the outputs of the model (for instance segmentation).
+ * @param {*} outputs Raw outputs of the model.
+ * @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks.
+ * @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size
+ * (height, width) of each prediction. If unset, predictions will not be resized.
+ * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>}
+ */
+function post_process_instance_segmentation(outputs, threshold = 0.5, target_sizes = null) {
+ throw new Error('Not implemented yet');
+ return [];
+}
+
/**
* Named tuple to indicate the order we are using is (height x width), even though
* the Graphics’ industry standard is (width x height).
@@ -334,10 +698,11 @@ export class ImageFeatureExtractor extends FeatureExtractor {
const threshold = gray_threshold / 255;
let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0;
+ const gray_image_data = gray_image.data;
for (let j = 0; j < gray_image.height; ++j) {
const row = j * gray_image.width;
for (let i = 0; i < gray_image.width; ++i) {
- if ((gray_image.data[row + i] - minValue) / diff < threshold) {
+ if ((gray_image_data[row + i] - minValue) / diff < threshold) {
// We have a non-zero pixel, so we update the min/max values accordingly
x_min = Math.min(x_min, i);
y_min = Math.min(y_min, j);
@@ -684,7 +1049,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
return {
original_size: [srcHeight, srcWidth],
reshaped_input_size: reshaped_input_size,
- pixel_values: pixel_values,
+ pixel_values,
}
}
@@ -707,7 +1072,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
const pixel_values = stack(imageData.map(x => x.pixel_values), 0);
return {
- pixel_values: pixel_values,
+ pixel_values,
// Original sizes of images
original_sizes: imageData.map(x => x.original_size),
@@ -719,76 +1084,25 @@ export class ImageFeatureExtractor extends FeatureExtractor {
}
+export class SapiensFeatureExtractor extends ImageFeatureExtractor {
+ /** @type {typeof post_process_semantic_segmentation} */
+ post_process_semantic_segmentation(...args) {
+ return post_process_semantic_segmentation(...args);
+ }
+}
export class SegformerFeatureExtractor extends ImageFeatureExtractor {
-
- /**
- * Converts the output of `SegformerForSemanticSegmentation` into semantic segmentation maps.
- * @param {*} outputs Raw outputs of the model.
- * @param {number[][]} [target_sizes=null] List of tuples corresponding to the requested final size
- * (height, width) of each prediction. If unset, predictions will not be resized.
- * @returns {{segmentation: Tensor; labels: number[]}[]} The semantic segmentation maps.
- */
- post_process_semantic_segmentation(outputs, target_sizes = null) {
-
- const logits = outputs.logits;
- const batch_size = logits.dims[0];
-
- if (target_sizes !== null && target_sizes.length !== batch_size) {
- throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
- }
-
- const toReturn = [];
- for (let i = 0; i < batch_size; ++i) {
- const target_size = target_sizes !== null ? target_sizes[i] : null;
-
- let data = logits[i];
-
- // 1. If target_size is not null, we need to resize the masks to the target size
- if (target_size !== null) {
- // resize the masks to the target size
- data = interpolate(data, target_size, 'bilinear', false);
- }
- const [height, width] = target_size ?? data.dims.slice(-2);
-
- const segmentation = new Tensor(
- 'int32',
- new Int32Array(height * width),
- [height, width]
- );
-
- // Buffer to store current largest value
- const buffer = data[0].data;
- for (let j = 1; j < data.dims[0]; ++j) {
- const row = data[j].data;
- for (let k = 0; k < row.length; ++k) {
- if (row[k] > buffer[k]) {
- buffer[k] = row[k];
- segmentation.data[k] = j;
- }
- }
- }
-
- // Store which objects have labels
- // This is much more efficient that creating a set of the final values
- const hasLabel = new Array(data.dims[0]);
- const out = segmentation.data;
- for (let j = 0; j < out.length; ++j) {
- const index = out[j];
- hasLabel[index] = index;
- }
- /** @type {number[]} The unique list of labels that were detected */
- const labels = hasLabel.filter(x => x !== undefined);
-
- toReturn.push({ segmentation, labels });
- }
- return toReturn;
+ /** @type {typeof post_process_semantic_segmentation} */
+ post_process_semantic_segmentation(...args) {
+ return post_process_semantic_segmentation(...args);
}
}
+export class PvtImageProcessor extends ImageFeatureExtractor { }
export class DPTFeatureExtractor extends ImageFeatureExtractor { }
export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends DPTFeatureExtractor
export class BitImageProcessor extends ImageFeatureExtractor { }
export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
+export class CLIPImageProcessor extends CLIPFeatureExtractor { } // NOTE: extends CLIPFeatureExtractor
export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { }
export class SiglipImageProcessor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor {
@@ -845,17 +1159,28 @@ export class EfficientNetImageProcessor extends ImageFeatureExtractor {
}
}
+export class MobileNetV1FeatureExtractor extends ImageFeatureExtractor { }
+export class MobileNetV2FeatureExtractor extends ImageFeatureExtractor { }
+export class MobileNetV3FeatureExtractor extends ImageFeatureExtractor { }
+export class MobileNetV4FeatureExtractor extends ImageFeatureExtractor { }
export class MobileViTFeatureExtractor extends ImageFeatureExtractor { }
export class MobileViTImageProcessor extends MobileViTFeatureExtractor { } // NOTE extends MobileViTFeatureExtractor
export class OwlViTFeatureExtractor extends ImageFeatureExtractor {
- /** @type {post_process_object_detection} */
+ /** @type {typeof post_process_object_detection} */
post_process_object_detection(...args) {
return post_process_object_detection(...args);
}
}
export class Owlv2ImageProcessor extends OwlViTFeatureExtractor { } // NOTE extends OwlViTFeatureExtractor
+export class RTDetrImageProcessor extends ImageFeatureExtractor {
+ /** @type {typeof post_process_object_detection} */
+ post_process_object_detection(...args) {
+ return post_process_object_detection(...args);
+ }
+}
+
export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
export class DonutFeatureExtractor extends ImageFeatureExtractor {
@@ -911,297 +1236,32 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor {
// TODO support different mask sizes (not just 64x64)
// Currently, just fill pixel mask with 1s
const maskSize = [result.pixel_values.dims[0], 64, 64];
- const pixel_mask = new Tensor(
- 'int64',
- new BigInt64Array(maskSize.reduce((a, b) => a * b)).fill(1n),
- maskSize
- );
+ const pixel_mask = full(maskSize, 1n);
return { ...result, pixel_mask };
}
- /**
- * Post-processes the outputs of the model (for object detection).
- * @param {Object} outputs The outputs of the model that must be post-processed
- * @param {Tensor} outputs.logits The logits
- * @param {Tensor} outputs.pred_boxes The predicted boxes.
- * @return {Object[]} An array of objects containing the post-processed outputs.
- */
-
- /** @type {post_process_object_detection} */
+ /** @type {typeof post_process_object_detection} */
post_process_object_detection(...args) {
return post_process_object_detection(...args);
}
- /**
- * Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and `labels`.
- * @param {Tensor} class_logits The class logits.
- * @param {Tensor} mask_logits The mask logits.
- * @param {number} object_mask_threshold A number between 0 and 1 used to binarize the masks.
- * @param {number} num_labels The number of labels.
- * @returns {[Tensor[], number[], number[]]} The binarized masks, the scores, and the labels.
- */
- remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels) {
-
- let mask_probs_item = [];
- let pred_scores_item = [];
- let pred_labels_item = [];
-
- for (let j = 0; j < class_logits.dims[0]; ++j) {
- let cls = class_logits[j];
- let mask = mask_logits[j];
-
- let pred_label = max(cls.data)[1];
- if (pred_label === num_labels) {
- // Is the background, so we ignore it
- continue;
- }
-
- let scores = softmax(cls.data);
- let pred_score = scores[pred_label];
- if (pred_score > object_mask_threshold) {
- mask_probs_item.push(mask);
- pred_scores_item.push(pred_score);
- pred_labels_item.push(pred_label);
- }
- }
-
- return [mask_probs_item, pred_scores_item, pred_labels_item];
-
- }
-
- /**
- * Checks whether the segment is valid or not.
- * @param {Int32Array} mask_labels Labels for each pixel in the mask.
- * @param {Tensor[]} mask_probs Probabilities for each pixel in the masks.
- * @param {number} k The class id of the segment.
- * @param {number} mask_threshold The mask threshold.
- * @param {number} overlap_mask_area_threshold The overlap mask area threshold.
- * @returns {[boolean, number[]]} Whether the segment is valid or not, and the indices of the valid labels.
- */
- check_segment_validity(
- mask_labels,
- mask_probs,
- k,
- mask_threshold = 0.5,
- overlap_mask_area_threshold = 0.8
- ) {
- // mask_k is a 1D array of indices, indicating where the mask is equal to k
- let mask_k = [];
- let mask_k_area = 0;
- let original_area = 0;
-
- // Compute the area of all the stuff in query k
- for (let i = 0; i < mask_labels.length; ++i) {
- if (mask_labels[i] === k) {
- mask_k.push(i);
- ++mask_k_area;
- }
-
- if (mask_probs[k].data[i] >= mask_threshold) {
- ++original_area;
- }
- }
- let mask_exists = mask_k_area > 0 && original_area > 0;
-
- // Eliminate disconnected tiny segments
- if (mask_exists) {
- // Perform additional check
- let area_ratio = mask_k_area / original_area;
- mask_exists = area_ratio > overlap_mask_area_threshold;
- }
-
- return [mask_exists, mask_k]
+ /** @type {typeof post_process_panoptic_segmentation} */
+ post_process_panoptic_segmentation(...args) {
+ return post_process_panoptic_segmentation(...args);
}
- /**
- * Computes the segments.
- * @param {Tensor[]} mask_probs The mask probabilities.
- * @param {number[]} pred_scores The predicted scores.
- * @param {number[]} pred_labels The predicted labels.
- * @param {number} mask_threshold The mask threshold.
- * @param {number} overlap_mask_area_threshold The overlap mask area threshold.
- * @param {Set} label_ids_to_fuse The label ids to fuse.
- * @param {number[]} target_size The target size of the image.
- * @returns {[Tensor, Array<{id: number, label_id: number, score: number}>]} The computed segments.
- */
- compute_segments(
- mask_probs,
- pred_scores,
- pred_labels,
- mask_threshold,
- overlap_mask_area_threshold,
- label_ids_to_fuse = null,
- target_size = null,
- ) {
- let [height, width] = target_size ?? mask_probs[0].dims;
-
- let segmentation = new Tensor(
- 'int32',
- new Int32Array(height * width),
- [height, width]
- );
- let segments = [];
-
- // 1. If target_size is not null, we need to resize the masks to the target size
- if (target_size !== null) {
- // resize the masks to the target size
- for (let i = 0; i < mask_probs.length; ++i) {
- mask_probs[i] = interpolate(mask_probs[i], target_size, 'bilinear', false);
- }
- }
-
- // 2. Weigh each mask by its prediction score
- // NOTE: `mask_probs` is updated in-place
- //
- // Temporary storage for the best label/scores for each pixel ([height, width]):
- let mask_labels = new Int32Array(mask_probs[0].data.length);
- let bestScores = new Float32Array(mask_probs[0].data.length);
-
- for (let i = 0; i < mask_probs.length; ++i) {
- let score = pred_scores[i];
-
- for (let j = 0; j < mask_probs[i].data.length; ++j) {
- mask_probs[i].data[j] *= score
- if (mask_probs[i].data[j] > bestScores[j]) {
- mask_labels[j] = i;
- bestScores[j] = mask_probs[i].data[j];
- }
- }
- }
-
- let current_segment_id = 0;
-
- // let stuff_memory_list = {}
- for (let k = 0; k < pred_labels.length; ++k) {
- let pred_class = pred_labels[k];
-
- // TODO add `should_fuse`
- // let should_fuse = pred_class in label_ids_to_fuse
-
- // Check if mask exists and large enough to be a segment
- let [mask_exists, mask_k] = this.check_segment_validity(
- mask_labels,
- mask_probs,
- k,
- mask_threshold,
- overlap_mask_area_threshold
- )
-
- if (!mask_exists) {
- // Nothing to see here
- continue;
- }
-
- // TODO
- // if (pred_class in stuff_memory_list) {
- // current_segment_id = stuff_memory_list[pred_class]
- // } else {
- // current_segment_id += 1;
- // }
- ++current_segment_id;
-
-
- // Add current object segment to final segmentation map
- for (let index of mask_k) {
- segmentation.data[index] = current_segment_id;
- }
-
- segments.push({
- id: current_segment_id,
- label_id: pred_class,
- // was_fused: should_fuse, TODO
- score: pred_scores[k],
- })
-
- // TODO
- // if(should_fuse){
- // stuff_memory_list[pred_class] = current_segment_id
- // }
- }
-
- return [segmentation, segments];
+ post_process_instance_segmentation() {
+ // TODO
+ throw Error("Not implemented yet");
}
+}
- /**
- * Post-process the model output to generate the final panoptic segmentation.
- * @param {*} outputs The model output to post process
- * @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks.
- * @param {number} [mask_threshold=0.5] Threshold to use when turning the predicted masks into binary values.
- * @param {number} [overlap_mask_area_threshold=0.8] The overlap mask area threshold to merge or discard small disconnected parts within each binary instance mask.
- * @param {Set} [label_ids_to_fuse=null] The labels in this state will have all their instances be fused together.
- * @param {number[][]} [target_sizes=null] The target sizes to resize the masks to.
- * @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>}
- */
- post_process_panoptic_segmentation(
- outputs,
- threshold = 0.5,
- mask_threshold = 0.5,
- overlap_mask_area_threshold = 0.8,
- label_ids_to_fuse = null,
- target_sizes = null,
- ) {
- if (label_ids_to_fuse === null) {
- console.warn("`label_ids_to_fuse` unset. No instance will be fused.")
- label_ids_to_fuse = new Set();
- }
-
- const class_queries_logits = outputs.logits; // [batch_size, num_queries, num_classes+1]
- const masks_queries_logits = outputs.pred_masks; // [batch_size, num_queries, height, width]
-
- const mask_probs = masks_queries_logits.sigmoid() // [batch_size, num_queries, height, width]
-
- let [batch_size, num_queries, num_labels] = class_queries_logits.dims;
- num_labels -= 1; // Remove last class (background)
-
- if (target_sizes !== null && target_sizes.length !== batch_size) {
- throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
- }
-
- let toReturn = [];
- for (let i = 0; i < batch_size; ++i) {
- let target_size = target_sizes !== null ? target_sizes[i] : null;
-
- let class_logits = class_queries_logits[i];
- let mask_logits = mask_probs[i];
-
- let [mask_probs_item, pred_scores_item, pred_labels_item] = this.remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels);
-
- if (pred_labels_item.length === 0) {
- // No mask found
- let [height, width] = target_size ?? mask_logits.dims.slice(-2);
-
- let segmentation = new Tensor(
- 'int32',
- new Int32Array(height * width).fill(-1),
- [height, width]
- )
- toReturn.push({
- segmentation: segmentation,
- segments_info: []
- });
- continue;
- }
-
-
- // Get segmentation map and segment information of batch item
- let [segmentation, segments] = this.compute_segments(
- mask_probs_item,
- pred_scores_item,
- pred_labels_item,
- mask_threshold,
- overlap_mask_area_threshold,
- label_ids_to_fuse,
- target_size,
- )
-
- toReturn.push({
- segmentation: segmentation,
- segments_info: segments
- })
- }
+export class MaskFormerFeatureExtractor extends ImageFeatureExtractor {
- return toReturn;
+ /** @type {typeof post_process_panoptic_segmentation} */
+ post_process_panoptic_segmentation(...args) {
+ return post_process_panoptic_segmentation(...args);
}
post_process_instance_segmentation() {
@@ -1210,8 +1270,9 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor {
}
}
+
export class YolosFeatureExtractor extends ImageFeatureExtractor {
- /** @type {post_process_object_detection} */
+ /** @type {typeof post_process_object_detection} */
post_process_object_detection(...args) {
return post_process_object_detection(...args);
}
@@ -1224,6 +1285,7 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor {
* @property {HeightWidth[]} reshaped_input_sizes
* @property {Tensor} [input_points]
* @property {Tensor} [input_labels]
+ * @property {Tensor} [input_boxes]
*/
export class SamImageProcessor extends ImageFeatureExtractor {
@@ -1235,7 +1297,7 @@ export class SamImageProcessor extends ImageFeatureExtractor {
* @param {HeightWidth[]} reshaped_input_sizes
* @returns {Tensor}
*/
- reshape_input_points(input_points, original_sizes, reshaped_input_sizes) {
+ reshape_input_points(input_points, original_sizes, reshaped_input_sizes, is_bounding_box = false) {
// Make deep copy to avoid altering user's input
input_points = structuredClone(input_points);
@@ -1244,7 +1306,9 @@ export class SamImageProcessor extends ImageFeatureExtractor {
// TODO: add support for 2D input_points
if (shape.length === 3) {
// Correct user's input
- shape = [1, ...shape];
+ if (!is_bounding_box) {
+ shape = [1, ...shape];
+ }
input_points = [input_points];
} else if (shape.length !== 4) {
throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.")
@@ -1262,8 +1326,8 @@ export class SamImageProcessor extends ImageFeatureExtractor {
for (let j = 0; j < input_points[i].length; ++j) { // point_batch_size
for (let k = 0; k < input_points[i][j].length; ++k) { // nb_points_per_image
- for (let w = 0; w < input_points[i][j][k].length; ++w) { // 2
- input_points[i][j][k][w] *= resizeFactors[w];
+ for (let w = 0; w < input_points[i][j][k].length; ++w) { // 2 or 4
+ input_points[i][j][k][w] *= resizeFactors[w % 2];
}
}
}
@@ -1304,15 +1368,29 @@ export class SamImageProcessor extends ImageFeatureExtractor {
}
/**
* @param {any[]} images The URL(s) of the image(s) to extract features from.
- * @param {any} [input_points] A 3D or 4D array, representing the input points provided by the user.
+ * @param {Object} [options] Additional options for the processor.
+ * @param {any} [options.input_points=null] A 3D or 4D array, representing the input points provided by the user.
* - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1.
* - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`.
- * @param {any} [input_labels] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt.
+ * @param {any} [options.input_labels=null] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt.
* - 2D: `[point_batch_size, nb_points_per_image]`. In this case, `batch_size` is assumed to be 1.
* - 3D: `[batch_size, point_batch_size, nb_points_per_image]`.
+ * @param {number[][][]} [options.input_boxes=null] A 3D array of shape `(batch_size, num_boxes, 4)`, representing the input boxes provided by the user.
+ * This is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks.
+ * The processor will generate a tensor, with each dimension corresponding respectively to the image batch size,
+ * the number of boxes per image and the coordinates of the top left and botton right point of the box.
+ * In the order (`x1`, `y1`, `x2`, `y2`):
+ * - `x1`: the x coordinate of the top left point of the input box
+ * - `y1`: the y coordinate of the top left point of the input box
+ * - `x2`: the x coordinate of the bottom right point of the input box
+ * - `y2`: the y coordinate of the bottom right point of the input box
* @returns {Promise}
*/
- async _call(images, input_points = null, input_labels = null) {
+ async _call(images, {
+ input_points = null,
+ input_labels = null,
+ input_boxes = null
+ } = {}) {
// TODO allow user to use preprocessed images
/** @type {SamImageProcessorResult} */
const processed = await super._call(images);
@@ -1330,23 +1408,29 @@ export class SamImageProcessor extends ImageFeatureExtractor {
processed.input_labels = this.add_input_labels(input_labels, processed.input_points);
}
+ if (input_boxes) {
+ processed.input_boxes = this.reshape_input_points(
+ input_boxes, processed.original_sizes, processed.reshaped_input_sizes, true,
+ );
+ }
+
return processed;
}
/**
* Remove padding and upscale masks to the original image size.
* @param {Tensor} masks Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
- * @param {number[][]} original_sizes The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format.
- * @param {number[][]} reshaped_input_sizes The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
+ * @param {[number, number][]} original_sizes The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format.
+ * @param {[number, number][]} reshaped_input_sizes The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
* @param {Object} options Optional parameters for post-processing.
* @param {number} [options.mask_threshold] The threshold to use for binarizing the masks.
* @param {boolean} [options.binarize] Whether to binarize the masks.
* @param {Object} [options.pad_size] The target size the images were padded to before being passed to the model. If `null`, the target size is assumed to be the processor's `pad_size`.
* @param {number} [options.pad_size.height] The height the images were padded to.
* @param {number} [options.pad_size.width] The width the images were padded to.
- * @returns {Tensor[]} Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size.
+ * @returns {Promise} Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size.
*/
- post_process_masks(masks, original_sizes, reshaped_input_sizes, {
+ async post_process_masks(masks, original_sizes, reshaped_input_sizes, {
mask_threshold = 0.0,
binarize = true,
pad_size = null,
@@ -1357,50 +1441,72 @@ export class SamImageProcessor extends ImageFeatureExtractor {
pad_size = pad_size ?? this.pad_size;
+ /** @type {[number, number]} */
const target_image_size = [pad_size.height, pad_size.width];
for (let i = 0; i < original_sizes.length; ++i) {
const original_size = original_sizes[i];
const reshaped_input_size = reshaped_input_sizes[i];
- const mask = masks[i]; // [b, c, h, w]
-
- // TODO: improve
- const interpolated_masks = [];
- for (let j = 0; j < mask.dims[0]; ++j) {
- const m = mask[j]; // 3d tensor
-
- // Upscale mask to padded size
- let interpolated_mask = interpolate(m, target_image_size, 'bilinear', false);
-
- // Crop mask
- interpolated_mask = interpolated_mask.slice(null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]);
-
- // Downscale mask
- interpolated_mask = interpolate(interpolated_mask, original_size, 'bilinear', false);
-
- if (binarize) {
- const binarizedMaskData = new Uint8Array(interpolated_mask.data.length);
- for (let i = 0; i < interpolated_mask.data.length; ++i) {
- if (interpolated_mask.data[i] > mask_threshold) {
- binarizedMaskData[i] = 1;
- }
+ // Upscale mask to padded size
+ let interpolated_mask = (await interpolate_4d(
+ masks[i],
+ { mode: 'bilinear', size: target_image_size }
+ ));
+
+ // Crop mask
+ interpolated_mask = interpolated_mask.slice(null, null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]);
+
+ // Downscale mask
+ interpolated_mask = (await interpolate_4d(
+ interpolated_mask,
+ { mode: 'bilinear', size: original_size }
+ ));
+
+ if (binarize) {
+ const data = interpolated_mask.data;
+ const binarizedMaskData = new Uint8Array(data.length);
+ for (let i = 0; i < data.length; ++i) {
+ if (data[i] > mask_threshold) {
+ binarizedMaskData[i] = 1;
}
- interpolated_mask = new Tensor(
- 'bool',
- binarizedMaskData,
- interpolated_mask.dims
- )
}
-
- interpolated_masks.push(interpolated_mask);
+ interpolated_mask = new Tensor(
+ 'bool',
+ binarizedMaskData,
+ interpolated_mask.dims
+ )
}
- output_masks.push(stack(interpolated_masks));
+ output_masks.push(interpolated_mask);
}
return output_masks;
}
+
+ /**
+ * Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
+ * @param {RawImage} image Input original image
+ * @param {number} target_size Target size of the resized image
+ * @param {Object} options Options for generating crop boxes
+ * @param {number} [options.crop_n_layers] If >0, mask prediction will be run again on crops of the image.
+ * Sets the number of layers to run, where each layer has 2**i_layer number of image crops.
+ * @param {number} [options.overlap_ratio] Sets the degree to which crops overlap. In the first crop layer,
+ * crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap.
+ * @param {number} [options.points_per_crop] Number of points to sample from each crop.
+ * @param {number} [options.crop_n_points_downscale_factor] The number of points-per-side sampled in layer n is
+ * scaled down by crop_n_points_downscale_factor**n.
+ * @returns {Object} An object containing the crop boxes, number of points per crop, cropped images, and input labels.
+ */
+ generate_crop_boxes(image, target_size, {
+ crop_n_layers = 0,
+ overlap_ratio = 512 / 1500,
+ points_per_crop = 32,
+ crop_n_points_downscale_factor = 1,
+ } = {}) {
+ // TODO: Implement
+ // return { crop_boxes, points_per_crop, cropped_images, input_labels }
+ }
}
export class Swin2SRImageProcessor extends ImageFeatureExtractor {
@@ -1455,7 +1561,7 @@ export class VitMatteImageProcessor extends ImageFeatureExtractor {
), 0);
return {
- pixel_values: pixel_values,
+ pixel_values,
// Original sizes of images
original_sizes: imageData.map(x => x.original_size),
@@ -1488,10 +1594,10 @@ export class WhisperFeatureExtractor extends FeatureExtractor {
/**
* Computes the log-Mel spectrogram of the provided audio waveform.
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
- * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
+ * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
- _extract_fbank_features(waveform) {
- const { data, dims } = spectrogram(
+ async _extract_fbank_features(waveform) {
+ const features = await spectrogram(
waveform,
this.window, // window
this.config.n_fft, // frame_length
@@ -1506,13 +1612,14 @@ export class WhisperFeatureExtractor extends FeatureExtractor {
}
)
+ const data = features.data;
const maxValue = max(data)[0];
for (let i = 0; i < data.length; ++i) {
data[i] = (Math.max(data[i], maxValue - 8.0) + 4.0) / 4.0;
}
- return { data, dims };
+ return features;
}
/**
@@ -1537,13 +1644,10 @@ export class WhisperFeatureExtractor extends FeatureExtractor {
waveform.set(audio);
}
- const { data, dims } = this._extract_fbank_features(waveform);
+ const features = await this._extract_fbank_features(waveform);
return {
- input_features: new Tensor('float32',
- data,
- [1, ...dims]
- )
+ input_features: features.unsqueeze_(0)
};
}
}
@@ -1622,9 +1726,9 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor {
* Computes the log-Mel spectrogram of the provided audio waveform.
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
* @param {number} max_length The maximum number of frames to return.
- * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
+ * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
- _extract_fbank_features(waveform, max_length) {
+ async _extract_fbank_features(waveform, max_length) {
// NOTE: We don't pad/truncate since that is passed in as `max_num_frames`
// Kaldi compliance: 16-bit signed integers
@@ -1671,28 +1775,29 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor {
} = {}) {
validate_audio_inputs(audio, 'SeamlessM4TFeatureExtractor');
- let features = this._extract_fbank_features(audio, this.config.max_length);
+ let features = await this._extract_fbank_features(audio, this.config.max_length);
if (do_normalize_per_mel_bins) {
const [num_features, feature_size] = features.dims;
+ const data = features.data;
for (let i = 0; i < feature_size; ++i) {
let sum = 0;
for (let j = 0; j < num_features; ++j) {
- sum += features.data[j * feature_size + i];
+ sum += data[j * feature_size + i];
}
const mean = sum / num_features;
let variance = 0;
for (let j = 0; j < num_features; ++j) {
- variance += (features.data[j * feature_size + i] - mean) ** 2;
+ variance += (data[j * feature_size + i] - mean) ** 2;
}
variance /= num_features - 1; // NOTE: We use ddof=1
const std = Math.sqrt(variance + 1e-7);
for (let j = 0; j < num_features; ++j) {
const index = j * feature_size + i;
- features.data[index] = (features.data[index] - mean) / std;
+ data[index] = (data[index] - mean) / std;
}
}
}
@@ -1700,18 +1805,20 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor {
let padded_attention_mask;
if (padding) {
const [num_frames, num_channels] = features.dims;
+ const data = /** @type {Float32Array} */(features.data);
const pad_size = num_frames % pad_to_multiple_of;
if (pad_size > 0) {
const padded_data = new Float32Array(num_channels * (num_frames + pad_size));
- padded_data.set(features.data)
- padded_data.fill(this.config.padding_value, features.data.length)
+ padded_data.set(data)
+ padded_data.fill(this.config.padding_value, data.length)
const numPaddedFrames = num_frames + pad_size;
- features = {
- data: padded_data,
- dims: [numPaddedFrames, num_channels],
- }
+ features = new Tensor(
+ features.type,
+ padded_data,
+ [numPaddedFrames, num_channels],
+ )
if (return_attention_mask) {
padded_attention_mask = new Tensor(
@@ -1732,10 +1839,7 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor {
throw new Error(`The number of frames (${num_frames}) must be a multiple of the stride (${stride}).`)
}
- const input_features = new Tensor('float32',
- features.data,
- features.dims,
- ).view(
+ const input_features = features.view(
1,
Math.floor(num_frames / stride),
num_channels * stride,
@@ -1746,20 +1850,21 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor {
if (return_attention_mask) {
const reshapedNumFrames = input_features.dims[1];
- const attention_mask = new Tensor(
- 'int64',
- new BigInt64Array(reshapedNumFrames),
- [1, reshapedNumFrames],
- );
+ const attention_mask_data = new BigInt64Array(reshapedNumFrames);
+
if (padded_attention_mask) {
+ const padded_attention_mask_data = padded_attention_mask.data;
for (let i = 1, j = 0; i < num_frames; i += stride, ++j) {
- attention_mask.data[j] = padded_attention_mask.data[i];
+ attention_mask_data[j] = padded_attention_mask_data[i];
}
} else {
- attention_mask.data.fill(1n);
+ attention_mask_data.fill(1n);
}
-
- result.attention_mask = attention_mask;
+ result.attention_mask = new Tensor(
+ 'int64',
+ attention_mask_data,
+ [1, reshapedNumFrames],
+ );
}
return result;
@@ -1802,9 +1907,9 @@ export class ASTFeatureExtractor extends FeatureExtractor {
* Computes the log-Mel spectrogram of the provided audio waveform.
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
* @param {number} max_length The maximum number of frames to return.
- * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
+ * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
- _extract_fbank_features(waveform, max_length) {
+ async _extract_fbank_features(waveform, max_length) {
// NOTE: We don't pad/truncate since that is passed in as `max_num_frames`
return spectrogram(
waveform,
@@ -1837,20 +1942,18 @@ export class ASTFeatureExtractor extends FeatureExtractor {
async _call(audio) {
validate_audio_inputs(audio, 'ASTFeatureExtractor');
- const features = this._extract_fbank_features(audio, this.config.max_length);
+ const features = await this._extract_fbank_features(audio, this.config.max_length);
if (this.config.do_normalize) {
// Normalize the input audio spectrogram to have mean=0, std=0.5
const denom = this.std * 2;
- for (let i = 0; i < features.data.length; ++i) {
- features.data[i] = (features.data[i] - this.mean) / denom;
+ const features_data = features.data;
+ for (let i = 0; i < features_data.length; ++i) {
+ features_data[i] = (features_data[i] - this.mean) / denom;
}
}
return {
- input_values: new Tensor('float32',
- features.data,
- [1, ...features.dims]
- )
+ input_values: features.unsqueeze_(0)
};
}
}
@@ -1903,11 +2006,12 @@ export class ClapFeatureExtractor extends FeatureExtractor {
* @param {number} max_length The maximum length of the waveform.
* @param {string} truncation The truncation strategy to use.
* @param {string} padding The padding strategy to use.
- * @returns {{ data: Float32Array; dims: number[]; longer: boolean; }} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length.
+ * @returns {Promise} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length.
+ * @private
*/
- _get_input_mel(waveform, max_length, truncation, padding) {
+ async _get_input_mel(waveform, max_length, truncation, padding) {
- /** @type {{ data: Float32Array; dims: number[]}} */
+ /** @type {Tensor} */
let input_mel;
let longer = false;
const diff = waveform.length - max_length;
@@ -1917,8 +2021,7 @@ export class ClapFeatureExtractor extends FeatureExtractor {
const idx = Math.floor(Math.random() * (diff + 1));
waveform = waveform.subarray(idx, idx + max_length);
- input_mel = this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
- input_mel.dims = [1, ...input_mel.dims]; // "unsqueeze"
+ input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
} else {
// TODO implement fusion strategy
throw new Error(`Truncation strategy "${truncation}" not implemented`)
@@ -1944,14 +2047,10 @@ export class ClapFeatureExtractor extends FeatureExtractor {
throw new Error(`Truncation strategy "${truncation}" not implemented`)
}
- input_mel = this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
- input_mel.dims = [1, ...input_mel.dims]; // "unsqueeze"
+ input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
}
- return {
- ...input_mel,
- longer,
- }
+ return input_mel.unsqueeze_(0);
}
/**
@@ -1967,9 +2066,9 @@ export class ClapFeatureExtractor extends FeatureExtractor {
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
* @param {number[][]} mel_filters The mel filters to use.
* @param {number} [max_length=null] The maximum number of frames to return.
- * @returns {{data: Float32Array, dims: number[]}} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
+ * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
- _extract_fbank_features(waveform, mel_filters, max_length = null) {
+ async _extract_fbank_features(waveform, mel_filters, max_length = null) {
// NOTE: We don't pad/truncate since that is passed in as `max_num_frames`
return spectrogram(
waveform,
@@ -2001,24 +2100,195 @@ export class ClapFeatureExtractor extends FeatureExtractor {
validate_audio_inputs(audio, 'ClapFeatureExtractor');
// convert to mel spectrogram, truncate and pad if needed.
- const padded_inputs = this._get_input_mel(
+ const padded_inputs = await this._get_input_mel(
audio,
max_length ?? this.config.nb_max_samples,
this.config.truncation,
this.config.padding,
);
+ return {
+ input_features: padded_inputs.unsqueeze_(0),
+ }
+ }
+}
+
+
+export class PyAnnoteFeatureExtractor extends FeatureExtractor {
+ /**
+ * Asynchronously extracts features from a given audio using the provided configuration.
+ * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
+ * @returns {Promise<{ input_values: Tensor; }>} The extracted input features.
+ */
+ async _call(audio) {
+ validate_audio_inputs(audio, 'PyAnnoteFeatureExtractor');
+
+ if (audio instanceof Float64Array) {
+ audio = new Float32Array(audio);
+ }
+ const shape = [
+ 1, /* batch_size */
+ 1, /* num_channels */
+ audio.length, /* num_samples */
+ ];
return {
- input_features: new Tensor('float32',
- padded_inputs.data,
- [1, ...padded_inputs.dims]
- )
+ input_values: new Tensor('float32', audio, shape),
};
}
+
+ /**
+ * NOTE: Can return fractional values. `Math.ceil` will ensure correct value.
+ * @param {number} samples The number of frames in the audio.
+ * @returns {number} The number of frames in the audio.
+ */
+ samples_to_frames(samples) {
+ return ((samples - this.config.offset) / this.config.step);
+ }
+
+ /**
+ * Post-processes the speaker diarization logits output by the model.
+ * @param {Tensor} logits The speaker diarization logits output by the model.
+ * @param {number} num_samples Number of samples in the input audio.
+ * @returns {Array>} The post-processed speaker diarization results.
+ */
+ post_process_speaker_diarization(logits, num_samples) {
+ const ratio = (
+ num_samples / this.samples_to_frames(num_samples)
+ ) / this.config.sampling_rate;
+
+ const results = [];
+ for (const scores of logits.tolist()) {
+ const accumulated_segments = [];
+
+ let current_speaker = -1;
+ for (let i = 0; i < scores.length; ++i) {
+ const probabilities = softmax(scores[i]);
+ const [score, id] = max(probabilities);
+ const [start, end] = [i, i + 1];
+
+ if (id !== current_speaker) {
+ // Speaker has changed
+ current_speaker = id;
+ accumulated_segments.push({ id, start, end, score });
+ } else {
+ // Continue the current segment
+ accumulated_segments.at(-1).end = end;
+ accumulated_segments.at(-1).score += score;
+ }
+ }
+
+ results.push(accumulated_segments.map(
+ // Convert frame-space to time-space
+ // and compute the confidence
+ ({ id, start, end, score }) => ({
+ id,
+ start: start * ratio,
+ end: end * ratio,
+ confidence: score / (end - start),
+ })
+ ));
+ }
+ return results;
+ }
+
}
+export class WeSpeakerFeatureExtractor extends FeatureExtractor {
+
+ constructor(config) {
+ super(config);
+
+ const sampling_rate = this.config.sampling_rate;
+ const mel_filters = mel_filter_bank(
+ 256, // num_frequency_bins
+ this.config.num_mel_bins, // num_mel_filters
+ 20, // min_frequency
+ Math.floor(sampling_rate / 2), // max_frequency
+ sampling_rate, // sampling_rate
+ null, // norm
+ "kaldi", // mel_scale
+ true, // triangularize_in_mel_space
+ );
+
+ // Do padding:
+ for (let i = 0; i < mel_filters.length; ++i) {
+ mel_filters[i].push(0);
+ }
+ this.mel_filters = mel_filters;
+
+ this.window = window_function(400, 'hamming', {
+ periodic: false,
+ })
+ this.min_num_frames = this.config.min_num_frames;
+ }
+
+ /**
+ * Computes the log-Mel spectrogram of the provided audio waveform.
+ * @param {Float32Array|Float64Array} waveform The audio waveform to process.
+ * @returns {Promise} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
+ */
+ async _extract_fbank_features(waveform) {
+ // Kaldi compliance: 16-bit signed integers
+ // 32768 == 2 ** 15
+ waveform = waveform.map((/** @type {number} */ x) => x * 32768)
+
+ return spectrogram(
+ waveform,
+ this.window, // window
+ 400, // frame_length
+ 160, // hop_length
+ {
+ fft_length: 512,
+ power: 2.0,
+ center: false,
+ preemphasis: 0.97,
+ mel_filters: this.mel_filters,
+ log_mel: 'log',
+ mel_floor: 1.192092955078125e-07,
+ remove_dc_offset: true,
+
+ // Custom
+ transpose: true,
+ min_num_frames: this.min_num_frames,
+ }
+ )
+ }
+
+
+ /**
+ * Asynchronously extracts features from a given audio using the provided configuration.
+ * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
+ * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor.
+ */
+ async _call(audio) {
+ validate_audio_inputs(audio, 'WeSpeakerFeatureExtractor');
+
+ const features = (await this._extract_fbank_features(audio)).unsqueeze_(0);
+
+ if (this.config.fbank_centering_span === null) {
+ // center features with global average
+ const meanData = /** @type {Float32Array} */ (features.mean(1).data);
+ const featuresData = /** @type {Float32Array} */(features.data);
+ const [batch_size, num_frames, feature_size] = features.dims;
+
+ for (let i = 0; i < batch_size; ++i) {
+ const offset1 = i * num_frames * feature_size;
+ const offset2 = i * feature_size;
+ for (let j = 0; j < num_frames; ++j) {
+ const offset3 = offset1 + j * feature_size;
+ for (let k = 0; k < feature_size; ++k) {
+ featuresData[offset3 + k] -= meanData[offset2 + k];
+ }
+ }
+ }
+ }
+ return {
+ input_features: features
+ };
+ }
+}
export class SpeechT5FeatureExtractor extends FeatureExtractor { }
@@ -2099,6 +2369,23 @@ export class Wav2Vec2ProcessorWithLM extends Processor {
}
}
+export class PyAnnoteProcessor extends Processor {
+ /**
+ * Calls the feature_extractor function with the given audio input.
+ * @param {any} audio The audio input to extract features from.
+ * @returns {Promise} A Promise that resolves with the extracted features.
+ */
+ async _call(audio) {
+ return await this.feature_extractor(audio)
+ }
+
+ post_process_speaker_diarization(...args) {
+ // @ts-ignore
+ return this.feature_extractor.post_process_speaker_diarization(...args);
+ }
+
+}
+
export class SpeechT5Processor extends Processor {
/**
* Calls the feature_extractor function with the given input.
@@ -2112,6 +2399,110 @@ export class SpeechT5Processor extends Processor {
export class OwlViTProcessor extends Processor { }
+export class Florence2Processor extends Processor {
+ constructor(feature_extractor) {
+ super(feature_extractor);
+
+ const {
+ tasks_answer_post_processing_type,
+ task_prompts_without_inputs,
+ task_prompts_with_input,
+ } = feature_extractor.config;
+
+ /** @type {Map} */
+ this.tasks_answer_post_processing_type = new Map(Object.entries(tasks_answer_post_processing_type ?? {}));
+
+ /** @type {Map} */
+ this.task_prompts_without_inputs = new Map(Object.entries(task_prompts_without_inputs ?? {}));
+
+ /** @type {Map} */
+ this.task_prompts_with_input = new Map(Object.entries(task_prompts_with_input ?? {}));
+
+ this.regexes = {
+ quad_boxes: /(.+?)/gm,
+ bboxes: /([^<]+)?/gm,
+ }
+ this.size_per_bin = 1000;
+ }
+
+ /**
+ * Helper function to construct prompts from input texts
+ * @param {string|string[]} text
+ * @returns {string[]}
+ */
+ construct_prompts(text) {
+ if (typeof text === 'string') {
+ text = [text];
+ }
+
+ const prompts = [];
+ for (const t of text) {
+ // 1. fixed task prompts without additional inputs
+ if (this.task_prompts_without_inputs.has(t)) {
+ prompts.push(this.task_prompts_without_inputs.get(t));
+ }
+ // 2. task prompts with additional inputs
+ else {
+ for (const [task, prompt] of this.task_prompts_with_input) {
+ if (t.includes(task)) {
+ prompts.push(prompt.replaceAll('{input}', t).replaceAll(task, ''));
+ break;
+ }
+ }
+
+ // 3. default prompt
+ if (prompts.length !== text.length) {
+ prompts.push(t);
+ }
+ }
+ }
+ return prompts;
+ }
+
+ /**
+ * Post-process the output of the model to each of the task outputs.
+ * @param {string} text The text to post-process.
+ * @param {string} task The task to post-process the text for.
+ * @param {[number, number]} image_size The size of the image. height x width.
+ */
+ post_process_generation(text, task, image_size) {
+ const task_answer_post_processing_type = this.tasks_answer_post_processing_type.get(task) ?? 'pure_text';
+
+ // remove the special tokens
+ text = text.replaceAll('', '').replaceAll(' ', '');
+
+ let final_answer;
+ switch (task_answer_post_processing_type) {
+ case 'pure_text':
+ final_answer = text;
+ break;
+
+ case 'description_with_bboxes':
+ case 'bboxes':
+ case 'phrase_grounding':
+ case 'ocr':
+ const key = task_answer_post_processing_type === 'ocr' ? 'quad_boxes' : 'bboxes';
+ const matches = text.matchAll(this.regexes[key]);
+ const labels = [];
+ const items = [];
+ for (const [_, label, ...locations] of matches) {
+ // Push new label, or duplicate the last label
+ labels.push(label ? label.trim() : labels.at(-1) ?? '');
+ items.push(locations.map((x, i) =>
+ // NOTE: Add 0.5 to use the center position of the bin as the coordinate.
+ (Number(x) + 0.5) / this.size_per_bin * image_size[i % 2])
+ );
+ }
+ final_answer = { labels, [key]: items };
+ break;
+
+ default:
+ throw new Error(`Task "${task}" (of type "${task_answer_post_processing_type}") not yet implemented.`);
+ }
+
+ return { [task]: final_answer }
+ }
+}
//////////////////////////////////////////////////
/**
@@ -2151,21 +2542,31 @@ export class AutoProcessor {
ViTFeatureExtractor,
MobileViTFeatureExtractor,
MobileViTImageProcessor,
+ MobileNetV1FeatureExtractor,
+ MobileNetV2FeatureExtractor,
+ MobileNetV3FeatureExtractor,
+ MobileNetV4FeatureExtractor,
OwlViTFeatureExtractor,
Owlv2ImageProcessor,
CLIPFeatureExtractor,
+ CLIPImageProcessor,
+ Florence2Processor,
ChineseCLIPFeatureExtractor,
SiglipImageProcessor,
ConvNextFeatureExtractor,
ConvNextImageProcessor,
SegformerFeatureExtractor,
+ SapiensFeatureExtractor,
BitImageProcessor,
DPTImageProcessor,
DPTFeatureExtractor,
+ PvtImageProcessor,
GLPNFeatureExtractor,
BeitFeatureExtractor,
DeiTFeatureExtractor,
DetrFeatureExtractor,
+ RTDetrImageProcessor,
+ MaskFormerFeatureExtractor,
YolosFeatureExtractor,
DonutFeatureExtractor,
NougatImageProcessor,
@@ -2180,14 +2581,18 @@ export class AutoProcessor {
SpeechT5FeatureExtractor,
ASTFeatureExtractor,
ClapFeatureExtractor,
+ PyAnnoteFeatureExtractor,
+ WeSpeakerFeatureExtractor,
}
static PROCESSOR_CLASS_MAPPING = {
WhisperProcessor,
Wav2Vec2ProcessorWithLM,
+ PyAnnoteProcessor,
SamProcessor,
SpeechT5Processor,
OwlViTProcessor,
+ Florence2Processor,
}
/**
diff --git a/src/tokenizers.js b/src/tokenizers.js
index 234eef15e..5b4e0170c 100644
--- a/src/tokenizers.js
+++ b/src/tokenizers.js
@@ -5,7 +5,7 @@
* **Example:** Create an `AutoTokenizer` and use it to tokenize a sentence.
* This will automatically detect the tokenizer type based on the tokenizer class defined in `tokenizer.json`.
* ```javascript
- * import { AutoTokenizer } from '@xenova/transformers';
+ * import { AutoTokenizer } from '@huggingface/transformers';
*
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* const { input_ids } = await tokenizer('I love transformers!');
@@ -19,13 +19,16 @@
*
* @module tokenizers
*/
-
import {
Callable,
+} from './utils/generic.js';
+
+import {
reverseDictionary,
escapeRegExp,
isIntegralNumber,
mergeArrays,
+ len,
} from './utils/core.js';
import {
@@ -43,6 +46,11 @@ import {
import { Template } from '@huggingface/jinja';
+import {
+ WHISPER_LANGUAGE_MAPPING,
+ whisper_language_to_code,
+} from './models/whisper/common_whisper.js';
+import { GITHUB_ISSUE_URL } from './utils/constants.js';
/**
* @typedef {Object} TokenizerProperties Additional tokenizer-specific properties.
@@ -188,7 +196,7 @@ function clean_up_tokenization(text) {
* @returns {string} The text with accents removed.
*/
function remove_accents(text) {
- return text.replace(/[\u0300-\u036f]/g, '');
+ return text.replace(/\p{M}/gu, '');
}
/**
@@ -200,24 +208,55 @@ function lowercase_and_remove_accent(text) {
return remove_accents(text.toLowerCase());
}
+
+/**
+ * Checks whether the given Unicode codepoint represents a CJK (Chinese, Japanese, or Korean) character.
+ *
+ * A "chinese character" is defined as anything in the CJK Unicode block:
+ * https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ *
+ * Note that the CJK Unicode block is NOT all Japanese and Korean characters, despite its name.
+ * The modern Korean Hangul alphabet is a different block, as is Japanese Hiragana and Katakana.
+ * Those alphabets are used to write space-separated words, so they are not treated specially
+ * and are handled like all other languages.
+ *
+ * @param {number|bigint} cp The Unicode codepoint to check.
+ * @returns {boolean} True if the codepoint represents a CJK character, false otherwise.
+ */
+export function is_chinese_char(cp) {
+ return (
+ (cp >= 0x4E00 && cp <= 0x9FFF)
+ || (cp >= 0x3400 && cp <= 0x4DBF)
+ || (cp >= 0x20000 && cp <= 0x2A6DF)
+ || (cp >= 0x2A700 && cp <= 0x2B73F)
+ || (cp >= 0x2B740 && cp <= 0x2B81F)
+ || (cp >= 0x2B820 && cp <= 0x2CEAF)
+ || (cp >= 0xF900 && cp <= 0xFAFF)
+ || (cp >= 0x2F800 && cp <= 0x2FA1F)
+ )
+}
+
/**
- * Helper function to fuse consecutive values in an array equal to the specified value.
- * @param {string[]} arr The input array
- * @param {any} value The value to fuse on.
- * @param {Map} mapping The mapping from input domain to value.
+ * Helper function to fuse consecutive unknown tokens.
+ * @param {string[]} arr The list of input tokens
+ * @param {Map} tokens_to_ids The mapping from tokens to token ids.
+ * @param {number} unk_token_id The value to fuse on.
+ * @private
*/
-function fuse(arr, value, mapping) {
+function fuse_unk(arr, tokens_to_ids, unk_token_id) {
const fused = [];
let i = 0;
while (i < arr.length) {
fused.push(arr[i])
- if ((mapping.get(arr[i]) ?? value) !== value) {
+ if ((tokens_to_ids.get(arr[i]) ?? unk_token_id) !== unk_token_id) {
++i;
continue;
}
- while (i < arr.length && (mapping.get(arr[i]) ?? value) === value) {
- ++i;
+ while (++i < arr.length && (tokens_to_ids.get(arr[i]) ?? unk_token_id) === unk_token_id) {
+ if (tokens_to_ids.get(fused.at(-1)) !== unk_token_id) {
+ fused[fused.length - 1] += arr[i];
+ }
}
}
@@ -234,12 +273,18 @@ function whitespace_split(text) {
}
const PUNCTUATION_REGEX = '\\p{P}\\u0021-\\u002F\\u003A-\\u0040\\u005B-\\u0060\\u007B-\\u007E';
+const PUNCTUATION_ONLY_REGEX = new RegExp(`^[${PUNCTUATION_REGEX}]+$`, 'gu');
+const BLOOM_SPLIT_CHARS = '.,!?\u2026\u3002\uff0c\u3001\u0964\u06d4\u060c';
-// A mapping of regex patterns to their equivalent (but longer) JS-compatible versions.
+// A mapping of regex patterns to their equivalent (but possibly longer) JS-compatible versions.
const PROBLEMATIC_REGEX_MAP = new Map([
// This uses the case insensitive group modifier, which is not supported in JavaScript.
// When parsing the regex, an "Invalid group" error is thrown.
["(?i:'s|'t|'re|'ve|'m|'ll|'d)", "(?:'([sS]|[tT]|[rR][eE]|[vV][eE]|[mM]|[lL][lL]|[dD]))"],
+
+ // Used to override the default (invalid) regex of the bloom pretokenizer.
+ // For more information, see https://github.com/huggingface/transformers.js/issues/94
+ [` ?[^(\\s|[${BLOOM_SPLIT_CHARS}])]+`, ` ?[^\\s${BLOOM_SPLIT_CHARS}]+`],
])
@@ -317,14 +362,21 @@ export class TokenizerModel extends Callable {
case 'Unigram':
// @ts-ignore
return new Unigram(config, ...args);
-
case 'BPE':
return new BPE(config);
default:
+ // Some tokenizers, like for google-t5/t5-small, do not have a `type` field.
+ // In this case, we can infer the tokenizer type based on the structure of the `vocab` field.
if (config.vocab) {
- // @ts-ignore
- return new LegacyTokenizerModel(config, ...args);
+ if (Array.isArray(config.vocab)) {
+ // config.vocab is of type `[string, number][]`
+ // @ts-ignore
+ return new Unigram(config, ...args);
+ } else {
+ // @ts-ignore
+ return new LegacyTokenizerModel(config, ...args);
+ }
}
throw new Error(`Unknown TokenizerModel type: ${config.type}`);
}
@@ -333,15 +385,15 @@ export class TokenizerModel extends Callable {
/**
* Internal function to call the TokenizerModel instance.
* @param {string[]} tokens The tokens to encode.
- * @returns {string[]} The encoded token IDs.
+ * @returns {string[]} The encoded tokens.
*/
_call(tokens) {
- let ids = this.encode(tokens);
+ tokens = this.encode(tokens);
if (this.fuse_unk) {
// Fuse unknown tokens
- ids = fuse(ids, this.unk_token_id, this.tokens_to_ids);
+ tokens = fuse_unk(tokens, this.tokens_to_ids, this.unk_token_id);
}
- return ids;
+ return tokens;
}
/**
@@ -365,7 +417,7 @@ export class TokenizerModel extends Callable {
/**
* Converts a list of token IDs into a list of tokens.
- * @param {number[]} ids The token IDs to convert.
+ * @param {number[]|bigint[]} ids The token IDs to convert.
* @returns {string[]} The converted tokens.
*/
convert_ids_to_tokens(ids) {
@@ -502,18 +554,18 @@ class Unigram extends TokenizerModel {
this.unk_token = this.vocab[config.unk_id];
this.tokens_to_ids = new Map(this.vocab.map((x, i) => [x, i]));
- this.bosToken = ' '; // beginning of a sentence token
+ this.bos_token = ' '; // beginning of a sentence token
- this.bosTokenId = this.tokens_to_ids.get(this.bosToken); // NOTE: may be undefined
- this.eosToken = moreConfig.eos_token;
+ this.bos_token_id = this.tokens_to_ids.get(this.bos_token); // NOTE: may be undefined
+ this.eos_token = moreConfig.eos_token;
- this.eosTokenId = this.tokens_to_ids.get(this.eosToken);
- this.unkToken = this.vocab[this.unk_token_id];
+ this.eos_token_id = this.tokens_to_ids.get(this.eos_token);
+ this.unk_token = this.vocab[this.unk_token_id];
this.minScore = min(this.scores)[0];
- this.unkScore = this.minScore - 10.0;
- this.scores[this.unk_token_id] = this.unkScore;
+ this.unk_score = this.minScore - 10.0;
+ this.scores[this.unk_token_id] = this.unk_score;
this.trie = new CharTrie();
this.trie.extend(this.vocab);
@@ -528,26 +580,27 @@ class Unigram extends TokenizerModel {
* @param {TokenLattice} lattice The token lattice to populate with nodes.
*/
populateNodes(lattice) {
- const sentence = lattice.sentence;
- const len = sentence.length;
+ const chars = lattice.chars;
+ const mblen = 1;
let beginPos = 0;
- while (beginPos < len) {
- const mblen = 1;
+ while (beginPos < chars.length) {
let hasSingleNode = false;
- const tokens = [];
- for (let token of this.trie.commonPrefixSearch(sentence.slice(beginPos))) {
+ const tokens = [];
+ const sliced = chars.slice(beginPos).join('');
+ const prefixedTokens = this.trie.commonPrefixSearch(sliced);
+ for (const token of prefixedTokens) {
tokens.push(token);
const tokenId = this.tokens_to_ids.get(token);
const tokenScore = this.scores[tokenId];
- const n = token.length;
+ const n = len(token);
lattice.insert(beginPos, n, tokenScore, tokenId);
if (!hasSingleNode && n === mblen) {
hasSingleNode = true;
}
}
if (!hasSingleNode) {
- lattice.insert(beginPos, mblen, this.unkScore, this.unk_token_id);
+ lattice.insert(beginPos, mblen, this.unk_score, this.unk_token_id);
}
beginPos += mblen;
}
@@ -560,7 +613,7 @@ class Unigram extends TokenizerModel {
* @returns {string[]} An array of subtokens obtained by encoding the input tokens using the unigram model.
*/
tokenize(normalized) {
- const lattice = new TokenLattice(normalized, this.bosTokenId, this.eosTokenId);
+ const lattice = new TokenLattice(normalized, this.bos_token_id, this.eos_token_id);
this.populateNodes(lattice);
return lattice.tokens();
}
@@ -630,7 +683,7 @@ class BPE extends TokenizerModel {
* Create a BPE instance.
* @param {Object} config The configuration object for BPE.
* @param {Object} config.vocab A mapping of tokens to ids.
- * @param {string[]} config.merges An array of BPE merges as strings.
+ * @param {string[]|[string, string][]} config.merges An array of BPE merges as strings.
* @param {string} config.unk_token The unknown token used for out of vocabulary words.
* @param {string} config.end_of_word_suffix The suffix to place at the end of each word.
* @param {string} [config.continuing_subword_suffix] The suffix to insert between words.
@@ -640,8 +693,6 @@ class BPE extends TokenizerModel {
constructor(config) {
super(config);
- this.BPE_SPLIT_TOKEN = ' ';
-
/** @type {Map} */
this.tokens_to_ids = objectToMap(config.vocab);
@@ -653,8 +704,15 @@ class BPE extends TokenizerModel {
this.vocab[value] = key;
}
- this.bpe_ranks = new Map(config.merges.map((x, i) => [x, i]));
- this.merges = config.merges.map(x => x.split(this.BPE_SPLIT_TOKEN));
+ // Tokenizers >= 0.20.0 serializes BPE merges as a [string, string][] instead of a string[],
+ // which resolves the ambiguity for merges containing spaces.
+ const use_new_merge_format = Array.isArray(config.merges[0]);
+
+ /** @type {[string, string][]} */
+ this.merges = use_new_merge_format
+ ? /** @type {[string, string][]} */(config.merges)
+ : (/** @type {string[]} */(config.merges)).map(x => /** @type {[string, string]} */(x.split(' ', 2)));
+ this.bpe_ranks = new Map(this.merges.map((x, i) => [JSON.stringify(x), i]));
this.end_of_word_suffix = config.end_of_word_suffix;
@@ -814,7 +872,7 @@ class BPE extends TokenizerModel {
// `score` is a measure of the merge priority: lower means higher priority
// We use the BPE rank as a measure of priority (i.e., the local of the merge in the merges list)
// We also add a fractional component to the score to break ties (with the earlier character having higher priority)
- const rank = this.bpe_ranks.get(node.token + this.BPE_SPLIT_TOKEN + node.next.token);
+ const rank = this.bpe_ranks.get(JSON.stringify([node.token, node.next.token]));
if (rank !== undefined) {
node.score = rank + node.bias;
queue.push(node);
@@ -839,15 +897,19 @@ class BPE extends TokenizerModel {
for (const t of bpe_token_list) {
if (this.tokens_to_ids.has(t)) {
outputTokens.push(t);
- } else {
- if (this.byte_fallback) {
- outputTokens.push(
- ...Array.from(this.text_encoder.encode(t))
- .map(x => `<0x${x.toString(16).toUpperCase().padStart(2, '0')}>`)
- );
+ } else if (this.byte_fallback) {
+ const byteTokens = Array.from(this.text_encoder.encode(t))
+ .map(x => `<0x${x.toString(16).toUpperCase().padStart(2, '0')}>`);
+ if (byteTokens.every(x => this.tokens_to_ids.has(x))) {
+ // Ensure the byte tokens are actually in the vocabulary, otherwise
+ // we fall back to the unknown token. For more information, see
+ // https://github.com/huggingface/transformers/issues/28096.
+ outputTokens.push(...byteTokens);
} else {
outputTokens.push(this.unk_token);
}
+ } else {
+ outputTokens.push(this.unk_token);
}
}
}
@@ -1154,7 +1216,7 @@ class BertNormalizer extends Normalizer {
for (let i = 0; i < text.length; ++i) {
const char = text[i];
const cp = char.charCodeAt(0);
- if (this._is_chinese_char(cp)) {
+ if (is_chinese_char(cp)) {
output.push(" ");
output.push(char);
output.push(" ");
@@ -1165,39 +1227,14 @@ class BertNormalizer extends Normalizer {
return output.join("");
}
- /**
- * Checks whether the given Unicode codepoint represents a CJK (Chinese, Japanese, or Korean) character.
- *
- * A "chinese character" is defined as anything in the CJK Unicode block:
- * https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
- *
- * Note that the CJK Unicode block is NOT all Japanese and Korean characters, despite its name.
- * The modern Korean Hangul alphabet is a different block, as is Japanese Hiragana and Katakana.
- * Those alphabets are used to write space-separated words, so they are not treated specially
- * and are handled like all other languages.
- *
- * @param {number} cp The Unicode codepoint to check.
- * @returns {boolean} True if the codepoint represents a CJK character, false otherwise.
- */
- _is_chinese_char(cp) {
- return (
- (cp >= 0x4E00 && cp <= 0x9FFF)
- || (cp >= 0x3400 && cp <= 0x4DBF)
- || (cp >= 0x20000 && cp <= 0x2A6DF)
- || (cp >= 0x2A700 && cp <= 0x2B73F)
- || (cp >= 0x2B740 && cp <= 0x2B81F)
- || (cp >= 0x2B820 && cp <= 0x2CEAF)
- || (cp >= 0xF900 && cp <= 0xFAFF)
- || (cp >= 0x2F800 && cp <= 0x2FA1F)
- )
- }
/**
* Strips accents from the given text.
* @param {string} text The text to strip accents from.
* @returns {string} The text with accents removed.
*/
stripAccents(text) {
- return text.normalize('NFD').replace(/[\u0300-\u036f]/g, '');
+ // "Mark, Nonspacing" (Mn)
+ return text.normalize('NFD').replace(/\p{Mn}/gu, '');
}
@@ -2315,7 +2352,7 @@ class Precompiled extends Normalizer {
// TODO: detect when a different `this.charsmap` is used.
text = text.replace(/[\u0001-\u0008\u000B\u000E-\u001F\u007F\u008F\u009F]/gm, ''); // Remove control characters
- text = text.replace(/[\u0009\u000A\u000C\u000D\u1680\u200B\u200C\u200E\u200F\u2028\u2029\u2581\uFEFF\uFFFD]/gm, '\u0020'); // Replace certain characters with a space
+ text = text.replace(/[\u0009\u000A\u000C\u000D\u00A0\u1680\u2000-\u200F\u2028\u2029\u202F\u205F\u2581\u3000\uFEFF\uFFFD]/gm, '\u0020'); // Replace certain characters with a space
if (text.includes('\uFF5E')) {
// To match the sentencepiece implementation 100%, we must handle a very strange edge-case.
@@ -2452,7 +2489,7 @@ const SPECIAL_TOKEN_ATTRIBUTES = [
* @param {Record} item The input object.
* @param {number} length The length to pad to.
* @param {(key: string) => any} value_fn Determine the value to fill the array, based on its key.
- * @param {'right'|'left'} side Which side to pad the array.
+ * @param {string} side Which side to pad the array.
* @private
*/
function padHelper(item, length, value_fn, side) {
@@ -2492,8 +2529,7 @@ function truncateHelper(item, length) {
export class PreTrainedTokenizer extends Callable {
return_token_type_ids = false;
- _default_chat_template = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}`;
-
+ padding_side = 'right';
/**
* Create a new PreTrainedTokenizer instance.
* @param {Object} tokenizerJSON The JSON of the tokenizer.
@@ -2541,14 +2577,17 @@ export class PreTrainedTokenizer extends Callable {
// Another slight hack to add `end_of_word_suffix` (if present) to the decoder
// This is needed for cases where BPE model and ByteLevel decoder are used
- // For more information, see https://github.com/xenova/transformers.js/issues/74
+ // For more information, see https://github.com/huggingface/transformers.js/issues/74
// TODO: save this to the decoder when exporting?
this.decoder.end_of_word_suffix = this.model.end_of_word_suffix;
}
-
this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
- this.added_tokens.map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`).join('|')
+ this.added_tokens.slice()
+ // Sort by length (desc) to avoid early partial matches
+ .sort((a, b) => b.content.length - a.content.length)
+ .map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`)
+ .join('|')
) : null;
// Set mask token if present (otherwise will be undefined, which is fine)
@@ -2572,9 +2611,9 @@ export class PreTrainedTokenizer extends Callable {
this.clean_up_tokenization_spaces = tokenizerConfig.clean_up_tokenization_spaces ?? true;
this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false;
- // TODO allow user to change this
- /** @type {'right'|'left'} */
- this.padding_side = 'right';
+ if (tokenizerConfig.padding_side) {
+ this.padding_side = tokenizerConfig.padding_side;
+ }
this.legacy = false;
@@ -2599,6 +2638,7 @@ export class PreTrainedTokenizer extends Callable {
* @param {...string} keys One or more keys to search for in the tokenizer config object.
* @returns {string|null} The value associated with the first matching key, or null if no match is found.
* @throws {Error} If an object is found for a matching key and its __type property is not "AddedToken".
+ * @private
*/
getToken(...keys) {
for (const key of keys) {
@@ -2707,11 +2747,11 @@ export class PreTrainedTokenizer extends Callable {
}
encodedTokens = text.map(
- (t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens, return_token_type_ids })
+ (t, i) => this._encode_plus(t, { text_pair: text_pair[i], add_special_tokens, return_token_type_ids })
)
} else {
- encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens, return_token_type_ids }));
+ encodedTokens = text.map(x => this._encode_plus(x, { add_special_tokens, return_token_type_ids }));
}
} else {
@@ -2724,7 +2764,7 @@ export class PreTrainedTokenizer extends Callable {
}
// For single input, we just wrap in an array, and then unwrap later.
- encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens, return_token_type_ids })];
+ encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids })];
}
// At this point, tokens is batched: [batch_size, tokens]
// However, array may be jagged. So, we pad to max_length
@@ -2743,7 +2783,7 @@ export class PreTrainedTokenizer extends Callable {
}
// Ensure it is less than model max length
- max_length = Math.min(max_length, this.model_max_length)
+ max_length = Math.min(max_length, this.model_max_length ?? Infinity);
if (padding || truncation) {
@@ -2879,56 +2919,88 @@ export class PreTrainedTokenizer extends Callable {
* Encodes a single text or a pair of texts using the model's tokenizer.
*
* @param {string} text The text to encode.
- * @param {string|null} text_pair The optional second text to encode.
* @param {Object} options An optional object containing the following properties:
+ * @param {string} [options.text_pair=null] The optional second text to encode.
* @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model.
* @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids.
* @returns {EncodingSingle} An object containing the encoded text.
* @private
*/
- _encode_plus(text, text_pair = null, {
+ _encode_plus(text, {
+ text_pair = null,
add_special_tokens = true,
return_token_type_ids = null,
} = {}) {
- // Function called by users to encode possibly multiple texts
- const tokens = this._encode_text(text);
- const tokens2 = this._encode_text(text_pair);
- const combinedTokens = this.post_processor
- ? this.post_processor(tokens, tokens2, { add_special_tokens })
- : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) };
+ const { tokens, token_type_ids } = this._tokenize_helper(text, { pair: text_pair, add_special_tokens });
- const input_ids = this.model.convert_tokens_to_ids(combinedTokens.tokens);
+ const input_ids = this.model.convert_tokens_to_ids(tokens);
const result = {
input_ids,
attention_mask: new Array(input_ids.length).fill(1),
}
- if ((return_token_type_ids ?? this.return_token_type_ids) && combinedTokens.token_type_ids) {
- result.token_type_ids = combinedTokens.token_type_ids;
+ if ((return_token_type_ids ?? this.return_token_type_ids) && token_type_ids) {
+ result.token_type_ids = token_type_ids;
}
return result;
}
+ /**
+ * Internal helper function to tokenize a text, and optionally a pair of texts.
+ * @param {string} text The text to tokenize.
+ * @param {Object} options An optional object containing the following properties:
+ * @param {string} [options.pair=null] The optional second text to tokenize.
+ * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model.
+ * @returns {{tokens: string[], token_type_ids?: number[]}} An object containing the tokens and optionally the token type IDs.
+ */
+ _tokenize_helper(text, {
+ pair = null,
+ add_special_tokens = false,
+ } = {}) {
+ const tokens = this._encode_text(text);
+ const tokens2 = this._encode_text(pair);
+
+ return this.post_processor
+ ? this.post_processor(tokens, tokens2, { add_special_tokens })
+ : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) };
+ }
+
+ /**
+ * Converts a string into a sequence of tokens.
+ * @param {string} text The sequence to be encoded.
+ * @param {Object} options An optional object containing the following properties:
+ * @param {string} [options.pair] A second sequence to be encoded with the first.
+ * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model.
+ * @returns {string[]} The list of tokens.
+ */
+ tokenize(text, {
+ pair = null,
+ add_special_tokens = false,
+ } = {}) {
+ return this._tokenize_helper(text, { pair, add_special_tokens }).tokens;
+ }
+
/**
* Encodes a single text or a pair of texts using the model's tokenizer.
*
* @param {string} text The text to encode.
- * @param {string|null} text_pair The optional second text to encode.
* @param {Object} options An optional object containing the following properties:
+ * @param {string} [options.text_pair=null] The optional second text to encode.
* @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model.
* @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids.
* @returns {number[]} An array of token IDs representing the encoded text(s).
*/
- encode(text, text_pair = null, {
+ encode(text, {
+ text_pair = null,
add_special_tokens = true,
return_token_type_ids = null,
} = {}) {
- const { input_ids } = this._encode_plus(text, text_pair, {
+ return this._encode_plus(text, {
+ text_pair,
add_special_tokens,
return_token_type_ids,
- });
- return input_ids;
+ }).input_ids;
}
/**
@@ -2947,7 +3019,7 @@ export class PreTrainedTokenizer extends Callable {
/**
* Decodes a sequence of token IDs back to a string.
*
- * @param {number[]|Tensor} token_ids List/Tensor of token IDs to decode.
+ * @param {number[]|bigint[]|Tensor} token_ids List/Tensor of token IDs to decode.
* @param {Object} [decode_args={}]
* @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string.
* @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed.
@@ -2972,7 +3044,7 @@ export class PreTrainedTokenizer extends Callable {
/**
* Decode a single list of token ids to a string.
- * @param {number[]} token_ids List of token ids to decode
+ * @param {number[]|bigint[]} token_ids List of token ids to decode
* @param {Object} decode_args Optional arguments for decoding
* @param {boolean} [decode_args.skip_special_tokens=false] Whether to skip special tokens during decoding
* @param {boolean} [decode_args.clean_up_tokenization_spaces=null] Whether to clean up tokenization spaces during decoding.
@@ -3012,32 +3084,77 @@ export class PreTrainedTokenizer extends Callable {
return decoded;
}
- get default_chat_template() {
- if (!this._warned_about_chat_template) {
- console.warn(
- "No chat template is defined for this tokenizer - using a default chat template " +
- "that implements the ChatML format. If the default is not appropriate for " +
- "your model, please set `tokenizer.chat_template` to an appropriate template. " +
- "See https://huggingface.co/docs/transformers/main/chat_templating for more information."
- )
- this._warned_about_chat_template = true; // TODO move to logger.warning_once()
- }
+ /**
+ * Retrieve the chat template string used for tokenizing chat messages. This template is used
+ * internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat
+ * template for better generation tracking.
+ *
+ * @param {Object} options An optional object containing the following properties:
+ * @param {string} [options.chat_template=null]
+ * A Jinja template or the name of a template to use for this conversion.
+ * It is usually not necessary to pass anything to this argument,
+ * as the model's template will be used by default.
+ * @param {Object[]} [options.tools=null]
+ * A list of tools (callable functions) that will be accessible to the model. If the template does not
+ * support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+ * giving the name, description and argument types for the tool. See our
+ * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+ * for more information.
+ * @returns {string} The chat template string.
+ */
+ get_chat_template({
+ chat_template = null,
+ tools = null,
+ } = {}) {
- return this._default_chat_template;
+ // First, handle the cases when the model has a dict of multiple templates
+ if (this.chat_template && typeof this.chat_template === 'object') {
+ const template_dict = this.chat_template;
+
+ if (chat_template !== null && Object.hasOwn(template_dict, chat_template)) {
+ // The user can pass the name of a template to the chat template argument instead of an entire template
+ chat_template = template_dict[chat_template];
+ } else if (chat_template === null) {
+ if (tools !== null && 'tool_use' in template_dict) {
+ chat_template = template_dict['tool_use'];
+ } else if ('default' in template_dict) {
+ chat_template = template_dict['default'];
+ } else {
+ throw Error(
+ `This model has multiple chat templates with no default specified! Please either pass a chat ` +
+ `template or the name of the template you wish to use to the 'chat_template' argument. Available ` +
+ `template names are ${Object.keys(template_dict).sort()}.`
+ )
+ }
+ }
+ } else if (chat_template === null) {
+ // These are the cases when the model has a single template
+ // priority: `chat_template` argument > `tokenizer.chat_template`
+ if (this.chat_template) {
+ chat_template = this.chat_template;
+ } else {
+ throw Error(
+ "Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template " +
+ "argument was passed! For information about writing templates and setting the " +
+ "tokenizer.chat_template attribute, please see the documentation at " +
+ "https://huggingface.co/docs/transformers/main/en/chat_templating"
+ )
+ }
+ }
+ return chat_template;
}
/**
* Converts a list of message objects with `"role"` and `"content"` keys to a list of token
* ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
- * determine the format and control tokens to use when converting. When chat_template is None, it will fall back
- * to the default_chat_template specified at the class level.
+ * determine the format and control tokens to use when converting.
*
* See [here](https://huggingface.co/docs/transformers/chat_templating) for more information.
*
* **Example:** Applying a chat template to a conversation.
*
* ```javascript
- * import { AutoTokenizer } from "@xenova/transformers";
+ * import { AutoTokenizer } from "@huggingface/transformers";
*
* const tokenizer = await AutoTokenizer.from_pretrained("Xenova/mistral-tokenizer-v1");
*
@@ -3054,10 +3171,23 @@ export class PreTrainedTokenizer extends Callable {
* // [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793]
* ```
*
- * @param {Message[]} conversation A list of message objects with `"role"` and `"content"` keys.
+ * @param {Message[]} conversation A list of message objects with `"role"` and `"content"` keys,
+ * representing the chat history so far.
* @param {Object} options An optional object containing the following properties:
* @param {string} [options.chat_template=null] A Jinja template to use for this conversion. If
- * this is not passed, the model's default chat template will be used instead.
+ * this is not passed, the model's chat template will be used instead.
+ * @param {Object[]} [options.tools=null]
+ * A list of tools (callable functions) that will be accessible to the model. If the template does not
+ * support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+ * giving the name, description and argument types for the tool. See our
+ * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+ * for more information.
+ * @param {Record[]} [options.documents=null]
+ * A list of dicts representing documents that will be accessible to the model if it is performing RAG
+ * (retrieval-augmented generation). If the template does not support RAG, this argument will have no
+ * effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
+ * see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
+ * for examples of passing documents with chat templates.
* @param {boolean} [options.add_generation_prompt=false] Whether to end the prompt with the token(s) that indicate
* the start of an assistant message. This is useful when you want to generate a response from the model.
* Note that this argument will be passed to the chat template, and so it must be supported in the
@@ -3068,10 +3198,13 @@ export class PreTrainedTokenizer extends Callable {
* @param {number} [options.max_length=null] Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false.
* If not specified, the tokenizer's `max_length` attribute will be used as a default.
* @param {boolean} [options.return_tensor=true] Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false.
+ * @param {boolean} [options.return_dict=true] Whether to return a dictionary with named outputs. Has no effect if tokenize is false.
* @param {Object} [options.tokenizer_kwargs={}] Additional options to pass to the tokenizer.
- * @returns {string | Tensor | number[]| number[][]} The tokenized output.
+ * @returns {string | Tensor | number[]| number[][]|BatchEncoding} The tokenized output.
*/
apply_chat_template(conversation, {
+ tools = null,
+ documents = null,
chat_template = null,
add_generation_prompt = false,
tokenize = true,
@@ -3079,34 +3212,13 @@ export class PreTrainedTokenizer extends Callable {
truncation = false,
max_length = null,
return_tensor = true,
+ return_dict = false,
tokenizer_kwargs = {},
...kwargs
} = {}) {
- // First, handle the cases when the model has a dict of multiple templates
- if (
- (this.chat_template && typeof this.chat_template === 'object') ||
- (this.chat_template === null && this.default_chat_template && typeof this.default_chat_template === 'object')
- ) {
- const template_dict = this.chat_template ?? this.default_chat_template; // Guaranteed to be a non-null object
+ chat_template = this.get_chat_template({ chat_template, tools });
- if (chat_template !== null && Object.hasOwn(template_dict, chat_template)) {
- // The user can pass the name of a template to the chat template argument instead of an entire template
- chat_template = template_dict[chat_template];
- } else if (chat_template === null && 'default' in template_dict) {
- chat_template = template_dict['default'];
- } else if (chat_template === null) {
- throw Error(
- `This model has multiple chat templates with no default specified! Please either pass a chat ` +
- `template or the name of the template you wish to use to the 'chat_template' argument. Available ` +
- `template names are ${Object.keys(template_dict).sort()}.`
- )
- }
- } else {
- // These are the cases when the model has a single template
- // priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
- chat_template ??= this.chat_template ?? this.default_chat_template;
- }
if (typeof chat_template !== 'string') {
throw Error(`chat_template must be a string, but got ${typeof chat_template}`);
}
@@ -3128,21 +3240,23 @@ export class PreTrainedTokenizer extends Callable {
const rendered = compiledTemplate.render({
messages: conversation,
- add_generation_prompt: add_generation_prompt,
-
+ add_generation_prompt,
+ tools,
+ documents,
...special_tokens_map,
...kwargs,
});
if (tokenize) {
- return this._call(rendered, {
+ const out = this._call(rendered, {
add_special_tokens: false,
padding,
truncation,
max_length,
return_tensor,
...tokenizer_kwargs,
- }).input_ids;
+ });
+ return return_dict ? out : out.input_ids;
}
return rendered;
@@ -3199,9 +3313,7 @@ export class ElectraTokenizer extends PreTrainedTokenizer {
}
export class T5Tokenizer extends PreTrainedTokenizer { }
-export class GPT2Tokenizer extends PreTrainedTokenizer {
- _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`
-}
+export class GPT2Tokenizer extends PreTrainedTokenizer { }
export class BartTokenizer extends PreTrainedTokenizer { }
export class MBartTokenizer extends PreTrainedTokenizer {
constructor(tokenizerJSON, tokenizerConfig) {
@@ -3227,35 +3339,16 @@ export class MBart50Tokenizer extends MBartTokenizer { } // NOTE: extends MBartT
export class RobertaTokenizer extends PreTrainedTokenizer { }
-export class BloomTokenizer extends GPT2Tokenizer { // NOTE: `GPT2Tokenizer` to get the correct chat template
-
- constructor(tokenizerJSON, tokenizerConfig) {
- // Override the default (invalid) regex of the pretokenizer.
- // For more information, see https://github.com/xenova/transformers.js/issues/94
- const splitChars = '.,!?\u2026\u3002\uff0c\u3001\u0964\u06d4\u060c';
- const patternObject = tokenizerJSON.pre_tokenizer?.pretokenizers[0]?.pattern;
- if (patternObject && patternObject.Regex === ` ?[^(\\s|[${splitChars}])]+`) {
- patternObject.Regex = ` ?[^\\s${splitChars}]+`;
- }
- super(tokenizerJSON, tokenizerConfig);
- }
-}
+export class BloomTokenizer extends PreTrainedTokenizer { }
const SPIECE_UNDERLINE = "▁";
export class LlamaTokenizer extends PreTrainedTokenizer {
- _default_chat_template = `{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\n' + system_message + '\n< >\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\n' + content.strip() + '\n< >\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}`
- DEFAULT_SYSTEM_PROMPT =
- "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your " +
- "answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure " +
- "that your responses are socially unbiased and positive in nature.\n\n" +
- "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not " +
- "correct. If you don't know the answer to a question, please don't share false information."
+ padding_side = 'left';
constructor(tokenizerJSON, tokenizerConfig) {
super(tokenizerJSON, tokenizerConfig);
- this.use_default_system_prompt = tokenizerConfig.use_default_system_prompt ?? false;
this.legacy = tokenizerConfig.legacy ?? true;
if (!this.legacy) {
@@ -3288,14 +3381,8 @@ export class LlamaTokenizer extends PreTrainedTokenizer {
}
return tokens;
}
-
- get default_chat_template() {
- return super.default_chat_template
- .replaceAll('USE_DEFAULT_PROMPT', this.use_default_system_prompt ? 'true' : 'false')
- .replaceAll('DEFAULT_SYSTEM_MESSAGE', this.DEFAULT_SYSTEM_PROMPT.replaceAll("\n", "\\n").replaceAll("'", "\\'"));
- }
}
-export class CodeLlamaTokenizer extends LlamaTokenizer { } // NOTE: `LlamaTokenizer` to get the correct chat template
+export class CodeLlamaTokenizer extends PreTrainedTokenizer { }
export class XLMRobertaTokenizer extends PreTrainedTokenizer { }
export class MPNetTokenizer extends PreTrainedTokenizer { }
@@ -3308,9 +3395,7 @@ export class EsmTokenizer extends PreTrainedTokenizer { }
export class Qwen2Tokenizer extends PreTrainedTokenizer { }
-export class GemmaTokenizer extends PreTrainedTokenizer {
- _default_chat_template = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}"
-}
+export class GemmaTokenizer extends PreTrainedTokenizer { }
export class Grok1Tokenizer extends PreTrainedTokenizer { }
@@ -3433,139 +3518,19 @@ export class M2M100Tokenizer extends PreTrainedTokenizer {
}
}
-
-const WHISPER_LANGUAGES = [
- ["en", "english"],
- ["zh", "chinese"],
- ["de", "german"],
- ["es", "spanish"],
- ["ru", "russian"],
- ["ko", "korean"],
- ["fr", "french"],
- ["ja", "japanese"],
- ["pt", "portuguese"],
- ["tr", "turkish"],
- ["pl", "polish"],
- ["ca", "catalan"],
- ["nl", "dutch"],
- ["ar", "arabic"],
- ["sv", "swedish"],
- ["it", "italian"],
- ["id", "indonesian"],
- ["hi", "hindi"],
- ["fi", "finnish"],
- ["vi", "vietnamese"],
- ["he", "hebrew"],
- ["uk", "ukrainian"],
- ["el", "greek"],
- ["ms", "malay"],
- ["cs", "czech"],
- ["ro", "romanian"],
- ["da", "danish"],
- ["hu", "hungarian"],
- ["ta", "tamil"],
- ["no", "norwegian"],
- ["th", "thai"],
- ["ur", "urdu"],
- ["hr", "croatian"],
- ["bg", "bulgarian"],
- ["lt", "lithuanian"],
- ["la", "latin"],
- ["mi", "maori"],
- ["ml", "malayalam"],
- ["cy", "welsh"],
- ["sk", "slovak"],
- ["te", "telugu"],
- ["fa", "persian"],
- ["lv", "latvian"],
- ["bn", "bengali"],
- ["sr", "serbian"],
- ["az", "azerbaijani"],
- ["sl", "slovenian"],
- ["kn", "kannada"],
- ["et", "estonian"],
- ["mk", "macedonian"],
- ["br", "breton"],
- ["eu", "basque"],
- ["is", "icelandic"],
- ["hy", "armenian"],
- ["ne", "nepali"],
- ["mn", "mongolian"],
- ["bs", "bosnian"],
- ["kk", "kazakh"],
- ["sq", "albanian"],
- ["sw", "swahili"],
- ["gl", "galician"],
- ["mr", "marathi"],
- ["pa", "punjabi"],
- ["si", "sinhala"],
- ["km", "khmer"],
- ["sn", "shona"],
- ["yo", "yoruba"],
- ["so", "somali"],
- ["af", "afrikaans"],
- ["oc", "occitan"],
- ["ka", "georgian"],
- ["be", "belarusian"],
- ["tg", "tajik"],
- ["sd", "sindhi"],
- ["gu", "gujarati"],
- ["am", "amharic"],
- ["yi", "yiddish"],
- ["lo", "lao"],
- ["uz", "uzbek"],
- ["fo", "faroese"],
- ["ht", "haitian creole"],
- ["ps", "pashto"],
- ["tk", "turkmen"],
- ["nn", "nynorsk"],
- ["mt", "maltese"],
- ["sa", "sanskrit"],
- ["lb", "luxembourgish"],
- ["my", "myanmar"],
- ["bo", "tibetan"],
- ["tl", "tagalog"],
- ["mg", "malagasy"],
- ["as", "assamese"],
- ["tt", "tatar"],
- ["haw", "hawaiian"],
- ["ln", "lingala"],
- ["ha", "hausa"],
- ["ba", "bashkir"],
- ["jw", "javanese"],
- ["su", "sundanese"],
-]
-
-// @ts-ignore
-const WHISPER_LANGUAGE_MAPPING = new Map(WHISPER_LANGUAGES);
-// @ts-ignore
-const WHISPER_TO_LANGUAGE_CODE_MAPPING = new Map([
- ...WHISPER_LANGUAGES.map(([k, v]) => [v, k]),
- ...[
- ["burmese", "my"],
- ["valencian", "ca"],
- ["flemish", "nl"],
- ["haitian", "ht"],
- ["letzeburgesch", "lb"],
- ["pushto", "ps"],
- ["panjabi", "pa"],
- ["moldavian", "ro"],
- ["moldovan", "ro"],
- ["sinhalese", "si"],
- ["castilian", "es"],
- ]
-]);
-
/**
* WhisperTokenizer tokenizer
* @extends PreTrainedTokenizer
*/
export class WhisperTokenizer extends PreTrainedTokenizer {
- _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`;
+
+ get timestamp_begin() {
+ return this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1;
+ }
/**
* Decodes automatic speech recognition (ASR) sequences.
- * @param {Array<{tokens: number[], token_timestamps?: number[], stride: number[]}>} sequences The sequences to decode.
+ * @param {Array<{tokens: bigint[], token_timestamps?: number[], stride: number[]}>} sequences The sequences to decode.
* @param {Object} options The options to use for decoding.
* @returns {Array, text: string}>}>} The decoded sequences.
*/
@@ -3609,7 +3574,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
const chunks = [];
let chunk = new_chunk();
let time_offset = 0.0;
- const timestamp_begin = this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1;
+ const timestamp_begin = this.timestamp_begin;
let previous_tokens = [];
let previous_token_timestamps = [];
@@ -3647,7 +3612,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
if (stride_right) {
for (let i = token_ids.length - 1; i >= 0; --i) {
- const token = token_ids[i];
+ const token = Number(token_ids[i]);
if (token >= timestamp_begin) {
// There can be several token in the right stride
// But the last one is ALWAYS going to be skipped
@@ -3665,7 +3630,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
// - all tokens within output
for (let i = 0; i < token_ids.length; ++i) {
- const token = token_ids[i];
+ const token = Number(token_ids[i]);
// 4 possible states for each token
// - 1/ Language code
// - 2/ all other special tokens (which we ignore)
@@ -3766,6 +3731,14 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
let end_time;
if (i + 1 < token_timestamps.length) {
end_time = round(token_timestamps[i + 1] + time_offset, 2);
+
+ // Do not allow punctuation-only tokens to have a duration.
+ // This prevents long pauses from messing up the timestamps.
+ const decoded_text = this.decode([token]);
+ if (PUNCTUATION_ONLY_REGEX.test(decoded_text)) {
+ // Add `time_precision` to avoid overlapping timestamps
+ end_time = round(Math.min(start_time + time_precision, end_time), 2);
+ }
} else {
// should never happen
end_time = null;
@@ -3909,7 +3882,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
const rightLength = rightSequence.length;
for (let j = 1; j < leftLength + rightLength; ++j) {
- const eps = j / 10000.0;
+ // Slightly convoluted because we don't want out of bound indices
+ // This will be necessary for a small conflict resolution optimization
+ // later
const leftStart = Math.max(0, leftLength - j);
const leftStop = Math.min(leftLength, leftLength + rightLength - j);
const left = leftSequence.slice(leftStart, leftStop);
@@ -3919,7 +3894,21 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
if (left.length !== right.length) {
throw new Error("There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference.");
}
- const matches = left.filter((elem, idx) => elem === right[idx]).length;
+
+ let matches;
+ if (use_token_timestamp_sequences) {
+ // Get length of longest subsequence of tokens that match
+ // and have timestamps that are in order
+ matches = left.filter((elem, idx) => (
+ elem === right[idx]
+ && left_token_timestamp_sequence[leftStart + idx] <= token_timestamp_sequences[i][rightStart + idx]
+ )).length;
+ } else {
+ matches = left.filter((elem, idx) => elem === right[idx]).length;
+ }
+
+ // epsilon to favor long perfect matches
+ const eps = j / 10000.0;
const matching = matches / j + eps;
if (matches > 1 && matching > max) {
max = matching;
@@ -3999,7 +3988,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
) {
let text;
// @ts-ignore
- if (decode_args && decode_args.decode_with_timestamps) {
+ if (decode_args?.decode_with_timestamps) {
if (token_ids instanceof Tensor) {
token_ids = prepareTensorForDecode(token_ids);
}
@@ -4015,7 +4004,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
}
/**
- * @param {number[]} token_ids List of token IDs to decode.
+ * @param {number[]|bigint[]} token_ids List of token IDs to decode.
* @param {Object} decode_args Optional arguments for decoding
* @private
*/
@@ -4025,9 +4014,10 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
const timestamp_begin = Array.from(this.all_special_ids).at(-1) + 1;
/**@type {Array} */
let outputs = [[]];
- for (const token of token_ids) {
+ for (let token of token_ids) {
+ token = Number(token);
if (token >= timestamp_begin) {
- const timestamp = round((token - timestamp_begin) * time_precision, 2);
+ const timestamp = ((token - timestamp_begin) * time_precision).toFixed(2);
outputs.push(`<|${timestamp}|>`);
outputs.push([]);
} else {
@@ -4035,13 +4025,7 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
}
}
outputs = outputs.map(
- s => {
- if (typeof s === 'string') {
- return s;
- } else {
- return super.decode(s, decode_args);
- }
- }
+ s => typeof s === 'string' ? s : super.decode(s, decode_args)
)
return outputs.join('');
@@ -4192,105 +4176,6 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
newIndices.filter(x => x.length > 0),
]
}
-
- /**
- * Helper function to build translation inputs for a `WhisperTokenizer`,
- * depending on the language, task, and whether to predict timestamp tokens.
- *
- * Used to override the prefix tokens appended to the start of the label sequence.
- *
- * **Example: Get ids for a language**
- * ```javascript
- * // instantiate the tokenizer and set the prefix token to Spanish
- * const tokenizer = await WhisperTokenizer.from_pretrained('Xenova/whisper-tiny');
- * const forced_decoder_ids = tokenizer.get_decoder_prompt_ids({ language: 'spanish' });
- * // [(1, 50262), (2, 50363)]
- * ```
- *
- * @param {Object} options Options to generate the decoder prompt.
- * @param {string} [options.language] The language of the transcription text.
- * The corresponding language id token is appended to the start of the sequence for multilingual
- * speech recognition and speech translation tasks, e.g. for "Spanish" the token "<|es|>" is appended
- * to the start of sequence.
- * @param {string} [options.task] Task identifier to append at the start of sequence (if any).
- * This should be used for mulitlingual fine-tuning, with "transcribe" for speech recognition and
- * "translate" for speech translation.
- * @param {boolean} [options.no_timestamps] Whether to add the <|notimestamps|> token at the start of the sequence.
- * @returns {number[][]} The decoder prompt ids.
- */
- get_decoder_prompt_ids({
- language = null,
- task = null,
- no_timestamps = true,
- } = {}) {
-
- // <|lang_id|> <|task|> <|notimestamps|>
-
- const forced_decoder_ids = [];
-
- if (language) {
- // User wishes to specify the language
- language = language.toLowerCase();
-
- // Map to code from user-friendly name (e.g., "english" -> "en")
- let language_code = WHISPER_TO_LANGUAGE_CODE_MAPPING.get(language);
-
- if (language_code === undefined) {
- // User provided something that is not a language name
-
- if (WHISPER_LANGUAGE_MAPPING.has(language)) {
- // User provided the language code directly (e.g., "en")
- language_code = language;
-
- } else {
- // User provided something that is not a language code or name
- const is_language_code = language.length === 2;
- const langs = is_language_code ? WHISPER_LANGUAGE_MAPPING.keys() : WHISPER_LANGUAGE_MAPPING.values();
-
- throw new Error(`Language "${language}" is not supported. Must be one of: ${JSON.stringify(langs)}`);
- }
- }
-
- const language_token_id = this.model.tokens_to_ids.get(`<|${language_code}|>`);
- if (language_token_id === undefined) {
- throw new Error(`Unable to find language "${language_code}" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.`)
- }
-
- forced_decoder_ids.push(language_token_id);
- } else {
- // No token will be forced, which leaves the model to predict the language
- forced_decoder_ids.push(null);
- }
-
- if (task) {
- task = task.toLowerCase();
- if (task !== 'transcribe' && task !== 'translate') {
- throw new Error(`Task "${task}" is not supported. Must be one of: ["transcribe", "translate"]`);
- }
-
- const task_token_id = this.model.tokens_to_ids.get(`<|${task}|>`);
- if (task_token_id === undefined) {
- throw new Error(`Unable to find task "${task}" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.`)
- }
-
- forced_decoder_ids.push(task_token_id);
- } else {
- // No token will be forced, which leaves the model to predict the task
- forced_decoder_ids.push(null);
- }
-
- if (no_timestamps) {
- const no_timestamps_id = this.model.tokens_to_ids.get(`<|notimestamps|>`);
- if (no_timestamps_id === undefined) {
- throw new Error('Unable to find "<|notimestamps|>" in model vocabulary. Please report this issue at https://github.com/xenova/transformers.js/issues/new/choose.')
- }
-
- forced_decoder_ids.push(no_timestamps_id);
- }
-
- return forced_decoder_ids.map((x, i) => [i + 1, x]).filter(x => x[1] !== null);
-
- }
}
export class CodeGenTokenizer extends PreTrainedTokenizer { }
export class CLIPTokenizer extends PreTrainedTokenizer { }
@@ -4351,10 +4236,8 @@ export class MarianTokenizer extends PreTrainedTokenizer {
export class Wav2Vec2CTCTokenizer extends PreTrainedTokenizer { }
-export class BlenderbotTokenizer extends PreTrainedTokenizer {
- _default_chat_template = `{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}`;
-}
-export class BlenderbotSmallTokenizer extends BlenderbotTokenizer { } // NOTE `BlenderbotTokenizer` to get the correct chat template
+export class BlenderbotTokenizer extends PreTrainedTokenizer { }
+export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { }
export class SpeechT5Tokenizer extends PreTrainedTokenizer { }
@@ -4447,7 +4330,6 @@ export class AutoTokenizer {
* @returns {Promise} A new instance of the PreTrainedTokenizer class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
- quantized = true,
progress_callback = null,
config = null,
cache_dir = null,
@@ -4457,7 +4339,6 @@ export class AutoTokenizer {
} = {}) {
const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, {
- quantized,
progress_callback,
config,
cache_dir,
diff --git a/src/transformers.js b/src/transformers.js
index 9dcd0160c..be7ad176e 100644
--- a/src/transformers.js
+++ b/src/transformers.js
@@ -11,8 +11,8 @@
* @module transformers
*/
+export { env } from './env.js';
export * from './pipelines.js';
-export * from './env.js';
export * from './models.js';
export * from './tokenizers.js';
export * from './processors.js';
@@ -22,3 +22,7 @@ export * from './utils/audio.js';
export * from './utils/image.js';
export * from './utils/tensor.js';
export * from './utils/maths.js';
+
+export * from './generation/streamers.js';
+export * from './generation/stopping_criteria.js';
+
diff --git a/src/utils/audio.js b/src/utils/audio.js
index 59c2705db..a1b1326df 100644
--- a/src/utils/audio.js
+++ b/src/utils/audio.js
@@ -14,6 +14,7 @@ import { FFT, max } from './maths.js';
import {
calculateReflectOffset,
} from './core.js';
+import { Tensor, matmul } from './tensor.js';
/**
@@ -78,28 +79,54 @@ export async function read_audio(url, sampling_rate) {
}
/**
- * Generates a Hanning window of length M.
- *
- * @param {number} M The length of the Hanning window to generate.
- * @returns {Float64Array} The generated Hanning window.
+ * Helper function to generate windows that are special cases of the generalized cosine window.
+ * See https://www.mathworks.com/help/signal/ug/generalized-cosine-windows.html for more information.
+ * @param {number} M Number of points in the output window. If zero or less, an empty array is returned.
+ * @param {number} a_0 Offset for the generalized cosine window.
+ * @returns {Float64Array} The generated window.
*/
-export function hanning(M) {
+function generalized_cosine_window(M, a_0) {
if (M < 1) {
return new Float64Array();
}
if (M === 1) {
return new Float64Array([1]);
}
- const denom = M - 1;
- const factor = Math.PI / denom;
+
+ const a_1 = 1 - a_0;
+ const factor = 2 * Math.PI / (M - 1);
+
const cos_vals = new Float64Array(M);
for (let i = 0; i < M; ++i) {
- const n = 2 * i - denom;
- cos_vals[i] = 0.5 + 0.5 * Math.cos(factor * n);
+ cos_vals[i] = a_0 - a_1 * Math.cos(i * factor);
}
return cos_vals;
}
+/**
+ * Generates a Hanning window of length M.
+ * See https://numpy.org/doc/stable/reference/generated/numpy.hanning.html for more information.
+ *
+ * @param {number} M The length of the Hanning window to generate.
+ * @returns {Float64Array} The generated Hanning window.
+ */
+export function hanning(M) {
+ return generalized_cosine_window(M, 0.5);
+}
+
+
+/**
+ * Generates a Hamming window of length M.
+ * See https://numpy.org/doc/stable/reference/generated/numpy.hamming.html for more information.
+ *
+ * @param {number} M The length of the Hamming window to generate.
+ * @returns {Float64Array} The generated Hamming window.
+ */
+export function hamming(M) {
+ return generalized_cosine_window(M, 0.54);
+}
+
+
const HERTZ_TO_MEL_MAPPING = {
"htk": (/** @type {number} */ freq) => 2595.0 * Math.log10(1.0 + (freq / 700.0)),
"kaldi": (/** @type {number} */ freq) => 1127.0 * Math.log(1.0 + (freq / 700.0)),
@@ -427,11 +454,12 @@ function power_to_db(spectrogram, reference = 1.0, min_value = 1e-10, db_range =
* @param {boolean} [options.remove_dc_offset=null] Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
* order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
* @param {number} [options.max_num_frames=null] If provided, limits the number of frames to compute to this value.
+ * @param {number} [options.min_num_frames=null] If provided, ensures the number of frames to compute is at least this value.
* @param {boolean} [options.do_pad=true] If `true`, pads the output spectrogram to have `max_num_frames` frames.
* @param {boolean} [options.transpose=false] If `true`, the returned spectrogram will have shape `(num_frames, num_frequency_bins/num_mel_filters)`. If `false`, the returned spectrogram will have shape `(num_frequency_bins/num_mel_filters, num_frames)`.
- * @returns {{data: Float32Array, dims: number[]}} Spectrogram of shape `(num_frequency_bins, length)` (regular spectrogram) or shape `(num_mel_filters, length)` (mel spectrogram).
+ * @returns {Promise} Spectrogram of shape `(num_frequency_bins, length)` (regular spectrogram) or shape `(num_mel_filters, length)` (mel spectrogram).
*/
-export function spectrogram(
+export async function spectrogram(
waveform,
window,
frame_length,
@@ -452,6 +480,7 @@ export function spectrogram(
remove_dc_offset = null,
// Custom parameters for efficiency reasons
+ min_num_frames = null,
max_num_frames = null,
do_pad = true,
transpose = false,
@@ -489,8 +518,10 @@ export function spectrogram(
}
// split waveform into frames of frame_length size
- const num_frames = Math.floor(1 + Math.floor((waveform.length - frame_length) / hop_length))
-
+ let num_frames = Math.floor(1 + Math.floor((waveform.length - frame_length) / hop_length))
+ if (min_num_frames !== null && num_frames < min_num_frames) {
+ num_frames = min_num_frames
+ }
const num_frequency_bins = onesided ? Math.floor(fft_length / 2) + 1 : fft_length
let d1 = num_frames;
@@ -511,34 +542,43 @@ export function spectrogram(
const fft = new FFT(fft_length);
const inputBuffer = new Float64Array(fft_length);
const outputBuffer = new Float64Array(fft.outputBufferSize);
- const magnitudes = new Array(d1);
+ const transposedMagnitudeData = new Float32Array(num_frequency_bins * d1Max);
for (let i = 0; i < d1; ++i) {
// Populate buffer with waveform data
const offset = i * hop_length;
- for (let j = 0; j < frame_length; ++j) {
+ const buffer_size = Math.min(waveform.length - offset, frame_length);
+ if (buffer_size !== frame_length) {
+ // The full buffer is not needed, so we need to reset it (avoid overflow from previous iterations)
+ // NOTE: We don't need to reset the buffer if it's full since we overwrite the first
+ // `frame_length` values and the rest (`fft_length - frame_length`) remains zero.
+ inputBuffer.fill(0, 0, frame_length);
+ }
+
+ for (let j = 0; j < buffer_size; ++j) {
inputBuffer[j] = waveform[offset + j];
}
if (remove_dc_offset) {
let sum = 0;
- for (let j = 0; j < frame_length; ++j) {
+ for (let j = 0; j < buffer_size; ++j) {
sum += inputBuffer[j];
}
- const mean = sum / frame_length;
- for (let j = 0; j < frame_length; ++j) {
+ const mean = sum / buffer_size;
+ for (let j = 0; j < buffer_size; ++j) {
inputBuffer[j] -= mean;
}
}
if (preemphasis !== null) {
// Done in reverse to avoid copies and distructive modification
- for (let j = frame_length - 1; j >= 1; --j) {
+ for (let j = buffer_size - 1; j >= 1; --j) {
inputBuffer[j] -= preemphasis * inputBuffer[j - 1];
}
inputBuffer[0] *= 1 - preemphasis;
}
+ // Apply window function
for (let j = 0; j < window.length; ++j) {
inputBuffer[j] *= window[j];
}
@@ -546,74 +586,63 @@ export function spectrogram(
fft.realTransform(outputBuffer, inputBuffer);
// compute magnitudes
- const row = new Array(num_frequency_bins);
- for (let j = 0; j < row.length; ++j) {
+ for (let j = 0; j < num_frequency_bins; ++j) {
const j2 = j << 1;
- row[j] = outputBuffer[j2] ** 2 + outputBuffer[j2 + 1] ** 2;
+
+ // NOTE: We transpose the data here to avoid doing it later
+ transposedMagnitudeData[j * d1Max + i] = outputBuffer[j2] ** 2 + outputBuffer[j2 + 1] ** 2;
}
- magnitudes[i] = row;
}
if (power !== null && power !== 2) {
// slight optimization to not sqrt
const pow = 2 / power; // we use 2 since we already squared
- for (let i = 0; i < magnitudes.length; ++i) {
- const magnitude = magnitudes[i];
- for (let j = 0; j < magnitude.length; ++j) {
- magnitude[j] **= pow;
- }
+ for (let i = 0; i < transposedMagnitudeData.length; ++i) {
+ transposedMagnitudeData[i] **= pow;
}
}
// TODO: What if `mel_filters` is null?
const num_mel_filters = mel_filters.length;
- // Only here do we create Float32Array
- const mel_spec = new Float32Array(num_mel_filters * d1Max);
-
// Perform matrix muliplication:
// mel_spec = mel_filters @ magnitudes.T
// - mel_filters.shape=(80, 201)
- // - magnitudes.shape=(3000, 201) => - magnitudes.T.shape=(201, 3000)
+ // - magnitudes.shape=(3000, 201) => magnitudes.T.shape=(201, 3000)
// - mel_spec.shape=(80, 3000)
- const dims = transpose ? [d1Max, num_mel_filters] : [num_mel_filters, d1Max];
- for (let i = 0; i < num_mel_filters; ++i) { // num melfilters (e.g., 80)
- const filter = mel_filters[i];
- for (let j = 0; j < d1; ++j) { // num frames (e.g., 3000)
- const magnitude = magnitudes[j];
-
- let sum = 0;
- for (let k = 0; k < num_frequency_bins; ++k) { // num frequency bins (e.g., 201)
- sum += filter[k] * magnitude[k];
- }
+ let mel_spec = await matmul(
+ // TODO: Make `mel_filters` a Tensor during initialization
+ new Tensor('float32', mel_filters.flat(), [num_mel_filters, num_frequency_bins]),
+ new Tensor('float32', transposedMagnitudeData, [num_frequency_bins, d1Max]),
+ );
+ if (transpose) {
+ mel_spec = mel_spec.transpose(1, 0);
+ }
- mel_spec[
- transpose
- ? j * num_mel_filters + i
- : i * d1 + j
- ] = Math.max(mel_floor, sum);
- }
+ const mel_spec_data = /** @type {Float32Array} */(mel_spec.data);
+ for (let i = 0; i < mel_spec_data.length; ++i) {
+ mel_spec_data[i] = Math.max(mel_floor, mel_spec_data[i]);
}
if (power !== null && log_mel !== null) {
- const o = Math.min(mel_spec.length, d1 * num_mel_filters);
+ const o = Math.min(mel_spec_data.length, d1 * num_mel_filters);
+ // NOTE: operates in-place
switch (log_mel) {
case 'log':
for (let i = 0; i < o; ++i) {
- mel_spec[i] = Math.log(mel_spec[i]);
+ mel_spec_data[i] = Math.log(mel_spec_data[i]);
}
break;
case 'log10':
for (let i = 0; i < o; ++i) {
- mel_spec[i] = Math.log10(mel_spec[i]);
+ mel_spec_data[i] = Math.log10(mel_spec_data[i]);
}
break;
case 'dB':
if (power === 1.0) {
- // NOTE: operates in-place
- amplitude_to_db(mel_spec, reference, min_value, db_range);
+ amplitude_to_db(mel_spec_data, reference, min_value, db_range);
} else if (power === 2.0) {
- power_to_db(mel_spec, reference, min_value, db_range);
+ power_to_db(mel_spec_data, reference, min_value, db_range);
} else {
throw new Error(`Cannot use log_mel option '${log_mel}' with power ${power}`)
}
@@ -623,7 +652,7 @@ export function spectrogram(
}
}
- return { data: mel_spec, dims };
+ return mel_spec;
}
/**
@@ -652,6 +681,9 @@ export function window_function(window_length, name, {
case 'hann_window':
window = hanning(length);
break;
+ case 'hamming':
+ window = hamming(length);
+ break;
case 'povey':
window = hanning(length).map(x => Math.pow(x, 0.85));
break;
diff --git a/src/utils/constants.js b/src/utils/constants.js
new file mode 100644
index 000000000..9d0e9ee42
--- /dev/null
+++ b/src/utils/constants.js
@@ -0,0 +1,2 @@
+
+export const GITHUB_ISSUE_URL = 'https://github.com/huggingface/transformers.js/issues/new/choose';
\ No newline at end of file
diff --git a/src/utils/core.js b/src/utils/core.js
index 4ed0f15ef..6a6137dff 100644
--- a/src/utils/core.js
+++ b/src/utils/core.js
@@ -42,40 +42,6 @@ export function escapeRegExp(string) {
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
}
-/**
- * A base class for creating callable objects.
- *
- * @type {new () => {(...args: any[]): any, _call(...args: any[]): any}}
- */
-export const Callable = /** @type {any} */ (class {
- /**
- * Creates a new instance of the Callable class.
- */
- constructor() {
- /**
- * Creates a closure that delegates to a private method '_call' with the given arguments.
- * @type {any}
- * @param {...any} args Zero or more arguments to pass to the '_call' method.
- * @returns {*} The result of calling the '_call' method.
- */
- let closure = function (...args) {
- return closure._call(...args)
- }
- return Object.setPrototypeOf(closure, new.target.prototype)
- }
-
- /**
- * This method should be implemented in subclasses to provide the
- * functionality of the callable object.
- *
- * @param {any[]} args
- * @throws {Error} If the subclass does not implement the `_call` method.
- */
- _call(...args) {
- throw Error('Must implement _call method in subclass')
- }
-});
-
/**
* Check if a value is a typed array.
* @param {*} val The value to check.
@@ -97,15 +63,6 @@ export function isIntegralNumber(x) {
return Number.isInteger(x) || typeof x === 'bigint'
}
-/**
- * Check if a value is exists.
- * @param {*} x The value to check.
- * @returns {boolean} True if the value exists, false otherwise.
- */
-export function exists(x) {
- return x !== undefined && x !== null;
-}
-
/**
* Calculates the dimensions of a nested array.
*
@@ -173,3 +130,32 @@ export function product(...a) {
export function calculateReflectOffset(i, w) {
return Math.abs((i + w) % (2 * w) - w);
}
+
+/**
+ *
+ * @param {Object} o
+ * @param {string[]} props
+ * @returns {Object}
+ */
+export function pick(o, props) {
+ return Object.assign(
+ {},
+ ...props.map((prop) => {
+ if (o[prop] !== undefined) {
+ return { [prop]: o[prop] };
+ }
+ })
+ );
+}
+
+/**
+ * Calculate the length of a string, taking multi-byte characters into account.
+ * This mimics the behavior of Python's `len` function.
+ * @param {string} s The string to calculate the length of.
+ * @returns {number} The length of the string.
+ */
+export function len(s) {
+ let length = 0;
+ for (const c of s) ++length;
+ return length;
+}
diff --git a/src/utils/data-structures.js b/src/utils/data-structures.js
index dd8a78867..2340d12c0 100644
--- a/src/utils/data-structures.js
+++ b/src/utils/data-structures.js
@@ -22,11 +22,12 @@ export class PriorityQueue {
/**
* Create a new PriorityQueue.
- * @param {Function} comparator Comparator function to determine priority. Defaults to a MaxHeap.
+ * @param {function(any, any): boolean} comparator Comparator function to determine priority. Defaults to a MaxHeap.
*/
- constructor(comparator = (a, b) => a > b) {
+ constructor(comparator = (a, b) => a > b, maxSize = Infinity) {
this._heap = [];
this._comparator = comparator;
+ this._maxSize = maxSize;
}
/**
@@ -68,8 +69,20 @@ export class PriorityQueue {
*/
extend(values) {
for (const value of values) {
- this._heap.push(value);
- this._siftUp();
+ if (this.size < this._maxSize) {
+ this._heap.push(value);
+ this._siftUp();
+ } else {
+ // Get index of value with the lowest priority
+ const smallest = this._smallest();
+
+ // If the new value has higher priority than the smallest value in the heap
+ // then replace the smallest value with the new value and update the heap
+ if (this._comparator(value, this._heap[smallest])) {
+ this._heap[smallest] = value;
+ this._siftUpFrom(smallest);
+ }
+ }
}
return this.size;
}
@@ -160,12 +173,20 @@ export class PriorityQueue {
* @private
*/
_siftUp() {
- let node = this.size - 1;
+ this._siftUpFrom(this.size - 1);
+ }
+
+ /**
+ * Helper function to sift up from a given node.
+ * @param {number} node The index of the node to start sifting up from.
+ */
+ _siftUpFrom(node) {
while (node > 0 && this._greater(node, this._parent(node))) {
this._swap(node, this._parent(node));
node = this._parent(node);
}
}
+
/**
* Maintain the heap property by updating positions in the heap,
* starting at the first element and moving down the heap.
@@ -184,6 +205,15 @@ export class PriorityQueue {
node = maxChild;
}
}
+
+ /**
+ * Get the index of the smallest element in the heap. Since we use an array-based heap,
+ * the index can be computed without needing to traverse the heap.
+ * @private
+ */
+ _smallest() {
+ return (2 ** (Math.floor(Math.log2(this.size))) - 1);
+ }
}
/**
@@ -199,7 +229,7 @@ export class CharTrie {
* @param {string[]} texts The strings to add to the trie.
*/
extend(texts) {
- for (let text of texts) {
+ for (const text of texts) {
this.push(text);
}
}
@@ -210,7 +240,7 @@ export class CharTrie {
*/
push(text) {
let node = this.root;
- for (let ch of text) {
+ for (const ch of text) {
let child = node.children.get(ch);
if (child === undefined) {
child = CharTrieNode.default();
@@ -228,12 +258,14 @@ export class CharTrie {
*/
*commonPrefixSearch(text) {
let node = this.root;
+ if (node === undefined) return;
+
let prefix = "";
- for (let i = 0; i < text.length && node !== undefined; ++i) {
- const ch = text[i];
+ for (const ch of text) {
prefix += ch;
node = node.children.get(ch);
- if (node !== undefined && node.isLeaf) {
+ if (node === undefined) return;
+ if (node.isLeaf) {
yield prefix;
}
}
@@ -275,8 +307,8 @@ export class TokenLattice {
* @param {number} eosTokenId The end-of-sequence token ID.
*/
constructor(sentence, bosTokenId, eosTokenId) {
- this.sentence = sentence;
- this.len = sentence.length;
+ this.chars = Array.from(sentence);
+ this.len = this.chars.length;
this.bosTokenId = bosTokenId;
this.eosTokenId = eosTokenId;
this.nodes = [];
@@ -310,7 +342,7 @@ export class TokenLattice {
/**
* Implements the Viterbi algorithm to compute the most likely sequence of tokens.
*
- * @returns {TokenLatticeNode[]} The array of nodes representing the most likely sequence of tokens.
+ * @returns {TokenLatticeNode[]} The most likely sequence of tokens.
*/
viterbi() {
const len = this.len;
@@ -364,11 +396,11 @@ export class TokenLattice {
* @returns {string} The array of nodes representing the most likely sequence of tokens.
*/
piece(node) {
- return this.sentence.slice(node.pos, node.pos + node.length);
+ return this.chars.slice(node.pos, node.pos + node.length).join('');
}
/**
- * @returns {Array} The array of nodes representing the most likely sequence of tokens.
+ * @returns {string[]} The most likely sequence of tokens.
*/
tokens() {
const nodes = this.viterbi();
@@ -376,7 +408,7 @@ export class TokenLattice {
}
/**
- * @returns {Array} The array of nodes representing the most likely sequence of tokens.
+ * @returns {number[]} The most likely sequence of token ids.
*/
tokenIds() {
const nodes = this.viterbi();
diff --git a/src/utils/devices.js b/src/utils/devices.js
new file mode 100644
index 000000000..1086b33e4
--- /dev/null
+++ b/src/utils/devices.js
@@ -0,0 +1,22 @@
+
+/**
+ * The list of devices supported by Transformers.js
+ */
+export const DEVICE_TYPES = Object.freeze({
+ auto: 'auto', // Auto-detect based on device and environment
+ gpu: 'gpu', // Auto-detect GPU
+ cpu: 'cpu', // CPU
+ wasm: 'wasm', // WebAssembly
+ webgpu: 'webgpu', // WebGPU
+ cuda: 'cuda', // CUDA
+ dml: 'dml', // DirectML
+
+ webnn: 'webnn', // WebNN (default)
+ 'webnn-npu': 'webnn-npu', // WebNN NPU
+ 'webnn-gpu': 'webnn-gpu', // WebNN GPU
+ 'webnn-cpu': 'webnn-cpu', // WebNN CPU
+});
+
+/**
+ * @typedef {keyof typeof DEVICE_TYPES} DeviceType
+ */
diff --git a/src/utils/dtypes.js b/src/utils/dtypes.js
new file mode 100644
index 000000000..fa6d94be5
--- /dev/null
+++ b/src/utils/dtypes.js
@@ -0,0 +1,60 @@
+import { apis } from "../env.js";
+
+import { DEVICE_TYPES } from "./devices.js";
+
+// TODO: Use the adapter from `env.backends.onnx.webgpu.adapter` to check for `shader-f16` support,
+// when available in https://github.com/microsoft/onnxruntime/pull/19940.
+// For more information, see https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753
+
+/**
+ * Checks if WebGPU fp16 support is available in the current environment.
+ */
+export const isWebGpuFp16Supported = (function () {
+ /** @type {boolean} */
+ let cachedResult;
+
+ return async function () {
+ if (cachedResult === undefined) {
+ if (!apis.IS_WEBGPU_AVAILABLE) {
+ cachedResult = false;
+ } else {
+ try {
+ const adapter = await navigator.gpu.requestAdapter();
+ cachedResult = adapter.features.has('shader-f16');
+ } catch (e) {
+ cachedResult = false;
+ }
+ }
+ }
+ return cachedResult;
+ };
+})();
+
+export const DATA_TYPES = Object.freeze({
+ fp32: 'fp32',
+ fp16: 'fp16',
+ q8: 'q8',
+ int8: 'int8',
+ uint8: 'uint8',
+ q4: 'q4',
+ bnb4: 'bnb4',
+ q4f16: 'q4f16', // fp16 model with int4 block weight quantization
+});
+/** @typedef {keyof typeof DATA_TYPES} DataType */
+
+export const DEFAULT_DEVICE_DTYPE_MAPPING = Object.freeze({
+ // NOTE: If not specified, will default to fp32
+ [DEVICE_TYPES.wasm]: DATA_TYPES.q8,
+});
+
+/** @type {Record} */
+export const DEFAULT_DTYPE_SUFFIX_MAPPING = Object.freeze({
+ [DATA_TYPES.fp32]: '',
+ [DATA_TYPES.fp16]: '_fp16',
+ [DATA_TYPES.int8]: '_int8',
+ [DATA_TYPES.uint8]: '_uint8',
+ [DATA_TYPES.q8]: '_quantized',
+ [DATA_TYPES.q4]: '_q4',
+ [DATA_TYPES.q4f16]: '_q4f16',
+ [DATA_TYPES.bnb4]: '_bnb4',
+});
diff --git a/src/utils/generation.js b/src/utils/generation.js
deleted file mode 100644
index 1f9dc898b..000000000
--- a/src/utils/generation.js
+++ /dev/null
@@ -1,873 +0,0 @@
-
-/**
- * @file Classes, functions, and utilities for generation.
- *
- * @todo Describe how to create a custom `GenerationConfig`.
- *
- * @module utils/generation
- */
-import { Tensor } from './tensor.js';
-import {
- Callable,
- exists,
-} from './core.js';
-import {
- max,
- softmax,
- log_softmax,
- getTopItems,
-} from './maths.js';
-
-/**
- * A class representing a list of logits processors. A logits processor is a function that modifies the logits
- * output of a language model. This class provides methods for adding new processors and applying all processors to a
- * batch of logits.
- *
- * @extends Callable
- */
-export class LogitsProcessorList extends Callable {
- /**
- * Constructs a new instance of `LogitsProcessorList`.
- */
- constructor() {
- super();
- this.processors = [];
- }
-
- /**
- * Adds a new logits processor to the list.
- *
- * @param {LogitsProcessor} item The logits processor function to add.
- */
- push(item) {
- this.processors.push(item);
- }
-
- /**
- * Adds multiple logits processors to the list.
- *
- * @param {LogitsProcessor[]} items The logits processor functions to add.
- */
- extend(items) {
- this.processors.push(...items);
- }
-
- /**
- * Applies all logits processors in the list to a batch of logits, modifying them in-place.
- *
- * @param {number[]} input_ids The input IDs for the language model.
- * @param {number[][]} batchedLogits A 2D array of logits, where each row corresponds to a single
- * input sequence in the batch.
- */
- _call(input_ids, batchedLogits) {
- // NOTE: This is different from the Python code, since vanilla JS does not support vectorized operations.
- // As a result, we apply each processor to each item in the batch.
- for (let logits of batchedLogits) {
- // Modifies logits inplace
- this.processors.forEach(
- func => func(input_ids, logits)
- )
- }
- }
-
- [Symbol.iterator]() {
- return this.processors.values();
- }
-}
-
-/**
- * Base class for processing logits.
- * @extends Callable
- */
-export class LogitsProcessor extends Callable {
- /**
- * Apply the processor to the input logits.
- *
- * @abstract
- * @param {Array} input_ids The input ids.
- * @param {Tensor} logits The logits to process.
- * @throws {Error} Throws an error if `_call` is not implemented in the subclass.
- */
- _call(input_ids, logits) {
- throw Error("`_call` should be implemented in a subclass")
- }
-}
-
-/**
- * A logits processor that forces a specific token to be generated by the decoder.
- *
- * @extends LogitsProcessor
- */
-export class ForceTokensLogitsProcessor extends LogitsProcessor {
- /**
- * Constructs a new instance of `ForceTokensLogitsProcessor`.
- *
- * @param {Array} forced_decoder_ids The ids of tokens that should be forced.
- */
- constructor(forced_decoder_ids) {
- super();
- this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []);
- }
-
- /**
- * Apply the processor to the input logits.
- *
- * @param {Array} input_ids The input ids.
- * @param {Tensor} logits The logits to process.
- * @returns {Tensor} The processed logits.
- */
- _call(input_ids, logits) {
- let map = this.force_token_map[input_ids.length];
- if (exists(map)) { // There exists a mapping
- logits.data.fill(-Infinity)
- logits.data[map] = 0;
- }
- return logits;
- }
-}
-
-/**
- * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
- * @extends LogitsProcessor
- */
-export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
- /**
- * Create a ForcedBOSTokenLogitsProcessor.
- * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced.
- */
- constructor(bos_token_id) {
- super();
- this.bos_token_id = bos_token_id;
- }
-
- /**
- * Apply the BOS token forcing to the logits.
- * @param {Array} input_ids The input IDs.
- * @param {Object} logits The logits.
- * @returns {Object} The logits with BOS token forcing.
- */
- _call(input_ids, logits) {
- if (input_ids.length === 1) {
- logits.data.fill(-Infinity)
- logits.data[this.bos_token_id] = 0;
- }
- return logits;
- }
-}
-
-/**
- * A logits processor that forces end-of-sequence token probability to 1.
- *
- * @extends LogitsProcessor
- */
-export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
- /**
- * Create a ForcedEOSTokenLogitsProcessor.
- * @param {number} max_length Max length of the sequence.
- * @param {number|number[]} forced_eos_token_id The ID of the end-of-sequence token to be forced.
- */
- constructor(max_length, forced_eos_token_id) {
- super();
- this.max_length = max_length;
- this.forced_eos_token_id = forced_eos_token_id;
- }
-
- /**
- * Apply the processor to input_ids and logits.
- *
- * @param {number[]} input_ids The input ids.
- * @param {Tensor} logits The logits tensor.
- */
- _call(input_ids, logits) {
- // console.log('call ForcedEOSTokenLogitsProcessor')
- // TODO
- }
-}
-
-/**
- * A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts
- * generating using `begin_index` tokens. This should ensure that the tokens defined by
- * `begin_suppress_tokens` at not sampled at the begining of the generation.
- * @extends LogitsProcessor
- */
-export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
- /**
- * Create a SuppressTokensAtBeginLogitsProcessor.
- * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress.
- * @param {number} begin_index The number of tokens to generate before suppressing tokens.
- */
- constructor(begin_suppress_tokens, begin_index) {
- super();
- this.begin_suppress_tokens = begin_suppress_tokens;
- this.begin_index = begin_index;
- }
-
- /**
- * Apply the BOS token forcing to the logits.
- * @param {Array} input_ids The input IDs.
- * @param {Object} logits The logits.
- * @returns {Object} The logits with BOS token forcing.
- */
- _call(input_ids, logits) {
- if (input_ids.length === this.begin_index) {
- for (let token_id of this.begin_suppress_tokens) {
- logits.data[token_id] = -Infinity;
- }
- }
- return logits;
- }
-}
-
-/**
- * A LogitsProcessor that handles adding timestamps to generated text.
- * @extends LogitsProcessor
- */
-export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
- /**
- * Constructs a new WhisperTimeStampLogitsProcessor.
- * @param {Object} generate_config The config object passed to the `generate()` method of a transformer model.
- * @param {number} generate_config.eos_token_id The ID of the end-of-sequence token.
- * @param {number} generate_config.no_timestamps_token_id The ID of the token used to indicate that a token should not have a timestamp.
- * @param {number[][]} [generate_config.forced_decoder_ids] An array of two-element arrays representing decoder IDs that are forced to appear in the output. The second element of each array indicates whether the token is a timestamp.
- * @param {number} [generate_config.max_initial_timestamp_index] The maximum index at which an initial timestamp can appear.
- */
- constructor(generate_config) {
- super();
- this.eos_token_id = generate_config.eos_token_id;
- this.no_timestamps_token_id = generate_config.no_timestamps_token_id;
- this.timestamp_begin = this.no_timestamps_token_id + 1;
-
- this.begin_index = (generate_config.forced_decoder_ids || []).length + 2;
- if (generate_config.forced_decoder_ids.slice(-1)[0][1] === this.no_timestamps_token_id) {
- this.begin_index -= 1;
- }
- this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index;
-
- }
-
- /**
- * Modify the logits to handle timestamp tokens.
- * @param {Array} input_ids The input sequence of tokens.
- * @param {Tensor} logits The logits output by the model.
- * @returns {Tensor} The modified logits.
- */
- _call(input_ids, logits) {
- const logitsData = /** @type {Float32Array} */(logits.data);
-
- // suppress <|notimestamps|> which is handled by without_timestamps
- logitsData[this.no_timestamps_token_id] = -Infinity;
-
- if (input_ids.length === this.begin_index - 1) {
- logitsData.fill(-Infinity);
- logitsData[this.timestamp_begin] = 0;
- return logits;
- }
-
- // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
- const seq = input_ids.slice(this.begin_index);
- const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
- const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
-
- if (last_was_timestamp) {
- if (penultimate_was_timestamp) { // has to be non-timestamp
- logitsData.subarray(this.timestamp_begin).fill(-Infinity);
- } else { // cannot be normal text tokens
- logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
- }
- }
-
- // apply the `max_initial_timestamp` option
- if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) {
- const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
- logitsData.subarray(last_allowed + 1).fill(-Infinity);
- }
-
- // if sum of probability over timestamps is above any other token, sample timestamp
- const logprobs = log_softmax(logitsData);
- const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
- const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
-
- if (timestamp_logprob > max_text_token_logprob) {
- logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
- }
-
- return logits;
- }
-}
-
-/**
- * A logits processor that disallows ngrams of a certain size to be repeated.
- *
- * @extends LogitsProcessor
- */
-export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
- /**
- * Create a NoRepeatNGramLogitsProcessor.
- * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once.
- */
- constructor(no_repeat_ngram_size) {
- super();
- this.no_repeat_ngram_size = no_repeat_ngram_size;
- }
-
- /**
- * Generate n-grams from a sequence of token ids.
- * @param {number[]} prevInputIds List of previous input ids
- * @returns {Map} Map of generated n-grams
- */
- getNgrams(prevInputIds) {
- const curLen = prevInputIds.length;
-
- /**@type {number[][]} */
- const ngrams = [];
- for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) {
- const ngram = [];
- for (let k = 0; k < this.no_repeat_ngram_size; ++k) {
- ngram.push(prevInputIds[j + k]);
- }
- ngrams.push(ngram);
- }
-
- /** @type {Map} */
- const generatedNgram = new Map();
- for (const ngram of ngrams) {
- const prevNgram = ngram.slice(0, ngram.length - 1);
- const prevNgramKey = JSON.stringify(prevNgram);
- const prevNgramValue = generatedNgram.get(prevNgramKey) ?? [];
- prevNgramValue.push(ngram[ngram.length - 1]);
- generatedNgram.set(prevNgramKey, prevNgramValue);
- }
- return generatedNgram;
- }
-
- /**
- * Generate n-grams from a sequence of token ids.
- * @param {Map} bannedNgrams Map of banned n-grams
- * @param {number[]} prevInputIds List of previous input ids
- * @returns {number[]} Map of generated n-grams
- */
- getGeneratedNgrams(bannedNgrams, prevInputIds) {
- const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length);
- const banned = bannedNgrams.get(JSON.stringify(ngramIdx)) ?? [];
- return banned;
- }
-
- /**
- * Calculate banned n-gram tokens
- * @param {number[]} prevInputIds List of previous input ids
- * @returns {number[]} Map of generated n-grams
- */
- calcBannedNgramTokens(prevInputIds) {
- const bannedTokens = [];
- if (prevInputIds.length + 1 < this.no_repeat_ngram_size) {
- // return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
- return bannedTokens;
-
- } else {
- const generatedNgrams = this.getNgrams(prevInputIds);
- const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds);
- return bannedTokens;
- }
- }
-
- /**
- * Apply the no-repeat-ngram processor to the logits.
- * @param {Array} input_ids The input IDs.
- * @param {Object} logits The logits.
- * @returns {Object} The logits with no-repeat-ngram processing.
- */
- _call(input_ids, logits) {
- const bannedTokens = this.calcBannedNgramTokens(input_ids);
-
- for (const token of bannedTokens) {
- logits.data[token] = -Infinity;
- }
- return logits;
- }
-}
-
-/**
- * A logits processor that penalises repeated output tokens.
- *
- * @extends LogitsProcessor
- */
-export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
- /**
- * Create a RepetitionPenaltyLogitsProcessor.
- * @param {number} penalty The penalty to apply for repeated tokens.
- */
- constructor(penalty) {
- super();
- this.penalty = penalty;
- }
-
- /**
- * Apply the repetition penalty to the logits.
- * @param {Array} input_ids The input IDs.
- * @param {Object} logits The logits.
- * @returns {Object} The logits with repetition penalty processing.
- */
- _call(input_ids, logits) {
- // Modify the logits corresponding to each element in `input_ids`.
- // As a consequence, the logits corresponding to tokens that appear
- // many times in the output will be penalised more.
- for (const input_id of input_ids) {
- if (logits.data[input_id] < 0) {
- logits.data[input_id] *= this.penalty;
- } else {
- logits.data[input_id] /= this.penalty;
- }
- }
- return logits
- }
-}
-
-/**
- * A logits processor that enforces a minimum number of tokens.
- *
- * @extends LogitsProcessor
- */
-export class MinLengthLogitsProcessor extends LogitsProcessor {
- /**
- * Create a MinLengthLogitsProcessor.
- * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
- * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
- */
- constructor(min_length, eos_token_id) {
- super();
- this.min_length = min_length;
- this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
- }
-
- /**
- * Apply logit processor.
- * @param {Array} input_ids The input IDs.
- * @param {Object} logits The logits.
- * @returns {Object} The processed logits.
- */
- _call(input_ids, logits) {
- if (input_ids.length < this.min_length) {
- for (const eos_token of this.eos_token_id) {
- logits.data[eos_token] = -Infinity;
- }
- }
-
- return logits
- }
-}
-
-/**
- * A logits processor that enforces a minimum number of new tokens.
- *
- * @extends LogitsProcessor
- */
-export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
- /**
- * Create a MinNewTokensLengthLogitsProcessor.
- * @param {number} prompt_length_to_skip The input tokens length.
- * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
- * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
- */
- constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
- super();
- this.prompt_length_to_skip = prompt_length_to_skip;
- this.min_new_tokens = min_new_tokens;
- this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
- }
-
- /**
- * Apply logit processor.
- * @param {Array} input_ids The input IDs.
- * @param {Object} logits The logits.
- * @returns {Object} The processed logits.
- */
- _call(input_ids, logits) {
- const new_tokens_length = input_ids.length - this.prompt_length_to_skip;
- if (new_tokens_length < this.min_new_tokens) {
- for (const eos_token of this.eos_token_id) {
- logits.data[eos_token] = -Infinity;
- }
- }
-
- return logits
- }
-}
-
-export class NoBadWordsLogitsProcessor extends LogitsProcessor {
- /**
- * Create a `NoBadWordsLogitsProcessor`.
- * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
- * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
- */
- constructor(bad_words_ids, eos_token_id) {
- super();
- this.bad_words_ids = bad_words_ids;
- this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
- }
-
- /**
- * Apply logit processor.
- * @param {Array} input_ids The input IDs.
- * @param {Object} logits The logits.
- * @returns {Object} The processed logits.
- */
- _call(input_ids, logits) {
-
- for (const bad_word_ids of this.bad_words_ids) {
- // Whether to modify the logits of the last token in the bad word id sequence
- let mark = true;
-
- // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
- // then we set the logits of the last bad word id to -Infinity.
- for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) {
-
- if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) {
- // We have found a mismatch
- mark = false;
- break;
- }
- }
- if (mark) {
- logits.data[bad_word_ids.at(-1)] = -Infinity;
- }
- }
-
- return logits
- }
-}
-
-/**
- * @typedef {Object} GenerationConfigType The default configuration parameters.
- * @property {number} [max_length=20] The maximum length the generated tokens can have. Corresponds to the length of the input prompt + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
- * @property {number} [max_new_tokens=null] The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
- * @property {number} [min_length=0] The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
- * @property {number} [min_new_tokens=null] The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
- * @property {boolean|"never"} [early_stopping=false] Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
- * - `true`, where the generation stops as soon as there are `num_beams` complete candidates;
- * - `false`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
- * - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm).
- * @property {number} [max_time=null] The maximum amount of time you allow the computation to run for in seconds. Generation will still finish the current pass after allocated time has been passed.
- *
- * @property {boolean} [do_sample=false] Whether or not to use sampling; use greedy decoding otherwise.
- * @property {number} [num_beams=1] Number of beams for beam search. 1 means no beam search.
- * @property {number} [num_beam_groups=1] Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
- * @property {number} [penalty_alpha=null] The values balance the model confidence and the degeneration penalty in contrastive search decoding.
- * @property {boolean} [use_cache=true] Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
- *
- * @property {number} [temperature=1.0] The value used to modulate the next token probabilities.
- * @property {number} [top_k=50] The number of highest probability vocabulary tokens to keep for top-k-filtering.
- * @property {number} [top_p=1.0] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
- * @property {number} [typical_p=1.0] Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
- * @property {number} [epsilon_cutoff=0.0] If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
- * @property {number} [eta_cutoff=0.0] Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
- * @property {number} [diversity_penalty=0.0] This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
- * @property {number} [repetition_penalty=1.0] The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
- * @property {number} [encoder_repetition_penalty=1.0] The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the original input. 1.0 means no penalty.
- * @property {number} [length_penalty=1.0] Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
- * @property {number} [no_repeat_ngram_size=0] If set to int > 0, all ngrams of that size can only occur once.
- * @property {number[][]} [bad_words_ids=null] List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `(await tokenizer(bad_words, {add_prefix_space: true, add_special_tokens: false})).input_ids`.
- * @property {number[][]|number[][][]} [force_words_ids=null] List of token ids that must be generated. If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `number[][][]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word.
- * @property {boolean} [renormalize_logits=false] Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `true` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization.
- * @property {Object[]} [constraints=null] Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
- *
- * @property {number} [forced_bos_token_id=null] The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like mBART where the first generated token needs to be the target language token.
- * @property {number|number[]} [forced_eos_token_id=null] The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a list to set multiple *end-of-sequence* tokens.
- * @property {boolean} [remove_invalid_values=false] Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation.
- * @property {number[]} [exponential_decay_length_penalty=null] This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay.
- * @property {number[]} [suppress_tokens=null] A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
- * @property {number[]} [begin_suppress_tokens=null] A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
- * @property {number[][]} [forced_decoder_ids=null] A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123.
- *
- * @property {number} [num_return_sequences=1] The number of independently computed returned sequences for each element in the batch.
- * @property {boolean} [output_attentions=false] Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details.
- * @property {boolean} [output_hidden_states=false] Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details.
- * @property {boolean} [output_scores=false] Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- * @property {boolean} [return_dict_in_generate=false] Whether or not to return a `ModelOutput` instead of a plain tuple.
- *
- * @property {number} [pad_token_id=null] The id of the *padding* token.
- * @property {number} [bos_token_id=null] The id of the *beginning-of-sequence* token.
- * @property {number|number[]} [eos_token_id=null] The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
- *
- * @property {number} [encoder_no_repeat_ngram_size=0] If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.
- * @property {number} [decoder_start_token_id=null] If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
- *
- * @property {Object} [generation_kwargs={}] Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not present in `generate`'s signature will be used in the model forward pass.
- */
-
-/**
- * Class that holds a configuration for a generation task.
- * @type {new (kwargs?: GenerationConfigType) => GenerationConfigType}
- */
-export const GenerationConfig = /** @type {any} */ (class {
-
- /**
- * Create a new GenerationConfig object.
- * @param {GenerationConfigType} kwargs
- */
- constructor(kwargs = {}) {
- // Parameters that control the length of the output
- this.max_length = kwargs.max_length ?? 20;
- this.max_new_tokens = kwargs.max_new_tokens ?? null;
- this.min_length = kwargs.min_length ?? 0;
- this.min_new_tokens = kwargs.min_new_tokens ?? null;
- this.early_stopping = kwargs.early_stopping ?? false;
- this.max_time = kwargs.max_time ?? null;
-
- // Parameters that control the generation strategy used
- this.do_sample = kwargs.do_sample ?? false;
- this.num_beams = kwargs.num_beams ?? 1;
- this.num_beam_groups = kwargs.num_beam_groups ?? 1;
- this.penalty_alpha = kwargs.penalty_alpha ?? null;
- this.use_cache = kwargs.use_cache ?? true;
-
- // Parameters for manipulation of the model output logits
- this.temperature = kwargs.temperature ?? 1.0;
- this.top_k = kwargs.top_k ?? 50;
- this.top_p = kwargs.top_p ?? 1.0;
- this.typical_p = kwargs.typical_p ?? 1.0;
- this.epsilon_cutoff = kwargs.epsilon_cutoff ?? 0.0;
- this.eta_cutoff = kwargs.eta_cutoff ?? 0.0;
- this.diversity_penalty = kwargs.diversity_penalty ?? 0.0;
- this.repetition_penalty = kwargs.repetition_penalty ?? 1.0;
- this.encoder_repetition_penalty = kwargs.encoder_repetition_penalty ?? 1.0;
- this.length_penalty = kwargs.length_penalty ?? 1.0;
- this.no_repeat_ngram_size = kwargs.no_repeat_ngram_size ?? 0;
- this.bad_words_ids = kwargs.bad_words_ids ?? null;
- this.force_words_ids = kwargs.force_words_ids ?? null;
- this.renormalize_logits = kwargs.renormalize_logits ?? false;
- this.constraints = kwargs.constraints ?? null;
- this.forced_bos_token_id = kwargs.forced_bos_token_id ?? null;
- this.forced_eos_token_id = kwargs.forced_eos_token_id ?? null;
- this.remove_invalid_values = kwargs.remove_invalid_values ?? false;
- this.exponential_decay_length_penalty = kwargs.exponential_decay_length_penalty ?? null;
- this.suppress_tokens = kwargs.suppress_tokens ?? null;
- this.begin_suppress_tokens = kwargs.begin_suppress_tokens ?? null;
- this.forced_decoder_ids = kwargs.forced_decoder_ids ?? null;
-
- // Parameters that define the output variables of `generate`
- this.num_return_sequences = kwargs.num_return_sequences ?? 1;
- this.output_attentions = kwargs.output_attentions ?? false;
- this.output_hidden_states = kwargs.output_hidden_states ?? false;
- this.output_scores = kwargs.output_scores ?? false;
- this.return_dict_in_generate = kwargs.return_dict_in_generate ?? false;
-
- // Special tokens that can be used at generation time
- this.pad_token_id = kwargs.pad_token_id ?? null;
- this.bos_token_id = kwargs.bos_token_id ?? null;
- this.eos_token_id = kwargs.eos_token_id ?? null;
-
- // Generation parameters exclusive to encoder-decoder models
- this.encoder_no_repeat_ngram_size = kwargs.encoder_no_repeat_ngram_size ?? 0;
- this.decoder_start_token_id = kwargs.decoder_start_token_id ?? null;
-
- // Wild card
- this.generation_kwargs = kwargs.generation_kwargs ?? {};
- }
-});
-
-/**
- * Sampler is a base class for all sampling methods used for text generation.
- */
-export class Sampler extends Callable {
- /**
- * Creates a new Sampler object with the specified generation config.
- * @param {GenerationConfigType} generation_config The generation config.
- */
- constructor(generation_config) {
- super();
- this.generation_config = generation_config;
- }
-
- /**
- * Executes the sampler, using the specified logits.
- * @param {Tensor} logits
- * @param {number} index
- * @returns {void}
- */
- _call(logits, index = -1) {
- // Sample from logits, of dims [batch, sequence_length, vocab_size].
- // If index is specified, sample from [batch, index, vocab_size].
- return this.sample(logits, index);
- }
-
- /**
- * Abstract method for sampling the logits.
- * @param {Tensor} logits
- * @param {number} index
- * @throws {Error}
- */
- sample(logits, index) {
- throw Error("sample should be implemented in subclasses.")
- }
-
- /**
- * Returns the specified logits as an array, with temperature applied.
- * @param {Tensor} logits
- * @param {number} index
- * @returns {Float32Array}
- */
- getLogits(logits, index) {
- let vocabSize = logits.dims.at(-1);
-
- let logs = /** @type {Float32Array} */(logits.data);
-
- if (index === -1) {
- logs = logs.slice(-vocabSize);
- } else {
- let startIndex = index * vocabSize;
- logs = logs.slice(startIndex, startIndex + vocabSize);
- }
-
- // add temperature
- if (this.generation_config.temperature > 0) {
- logs = logs.map(x => x / this.generation_config.temperature)
- }
- return logs;
- }
-
- /**
- * Selects an item randomly based on the specified probabilities.
- * @param {Array} probabilities An array of probabilities to use for selection.
- * @returns {number} The index of the selected item.
- */
- randomSelect(probabilities) {
- // Return index of chosen item
- let sumProbabilities = probabilities.reduce((acc, curr) => acc + curr, 0);
-
- let r = Math.random() * sumProbabilities;
- for (let i = 0; i < probabilities.length; ++i) {
- r -= probabilities[i];
- if (r <= 0) {
- return i;
- }
- }
- return 0; // return first (most probable) as a fallback
- }
-
- /**
- * Returns a Sampler object based on the specified options.
- * @param {GenerationConfigType} generation_config An object containing options for the sampler.
- * @returns {Sampler} A Sampler object.
- */
- static getSampler(generation_config) {
- // - *greedy decoding*: `num_beams=1` and `do_sample=False`
- // - *contrastive search*: `penalty_alpha>0` and `top_k>1`
- // - *multinomial sampling*: `num_beams=1` and `do_sample=True`
- // - *beam-search decoding*: `num_beams>1` and `do_sample=False`
- // - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True`
- // - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1`
- // - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None`
-
- // NOTE: beam search is implemented directly into the generation function
- if (generation_config.do_sample) {
- return new MultinomialSampler(generation_config);
-
- } else if (generation_config.num_beams > 1) {
- return new BeamSearchSampler(generation_config);
-
- } else {
- if (generation_config.num_return_sequences > 1) {
- throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`)
- }
- return new GreedySampler(generation_config);
- }
- }
-}
-
-/**
- * Class representing a Greedy Sampler.
- * @extends Sampler
- */
-class GreedySampler extends Sampler {
- /**
- * Sample the maximum probability of a given logits tensor.
- * @param {Tensor} logits
- * @param {number} [index=-1]
- * @returns {Array} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search).
- */
- sample(logits, index = -1) {
- // NOTE: no need to do log_softmax here since we only take the maximum
- let logs = this.getLogits(logits, index);
- let argmax = max(logs)[1];
-
- // Note: score is meaningless in this context, since we are performing
- // greedy search (p = 1 => log(p) = 0)
- return [
- [argmax, 0]
- ];
- }
-}
-
-/**
- * Class representing a MultinomialSampler.
- * @extends Sampler
- */
-class MultinomialSampler extends Sampler {
-
- /**
- * Sample from the logits.
- * @param {Tensor} logits
- * @param {number} index
- * @returns {Array}
- */
- sample(logits, index = -1) {
- let k = logits.dims.at(-1); // defaults to vocab size
- if (this.generation_config.top_k > 0) {
- k = Math.min(this.generation_config.top_k, k);
- }
-
- // Get logits of nth token
- const logs = this.getLogits(logits, index);
-
- // Get top k tokens
- const topLogits = getTopItems(logs, k);
-
- // Compute softmax over logits
- const probabilities = softmax(topLogits.map(x => x[1]));
-
- return Array.from({ length: this.generation_config.num_beams }, () => {
- const sampledIndex = this.randomSelect(probabilities);
- return [
- topLogits[sampledIndex][0], // token id
- Math.log(probabilities[sampledIndex]), // score
- ];
- });
- }
-}
-
-
-/**
- * Class representing a BeamSearchSampler.
- * @extends Sampler
- */
-class BeamSearchSampler extends Sampler {
-
- /**
- * Sample from the logits.
- * @param {Tensor} logits
- * @param {number} index
- * @returns {Array}
- */
- sample(logits, index = -1) {
- let k = logits.dims.at(-1); // defaults to vocab size
- if (this.generation_config.top_k > 0) {
- k = Math.min(this.generation_config.top_k, k);
- }
-
- // Get logits of nth token
- const logs = this.getLogits(logits, index);
-
- // Get top k tokens
- const topLogits = getTopItems(logs, k);
-
- // Compute softmax over logits
- const probabilities = softmax(topLogits.map(x => x[1]));
-
- return Array.from({ length: this.generation_config.num_beams }, (_, i) => {
- return [
- topLogits[i][0], // token id
- Math.log(probabilities[i]), // score
- ];
- });
- }
-}
diff --git a/src/utils/generic.js b/src/utils/generic.js
new file mode 100644
index 000000000..5ccd467ad
--- /dev/null
+++ b/src/utils/generic.js
@@ -0,0 +1,35 @@
+
+/**
+ * A base class for creating callable objects.
+ * See [here](https://stackoverflow.com/q/76073890) for more information.
+ *
+ * @type {new () => {(...args: any[]): any, _call(...args: any[]): any}}
+ */
+export const Callable = /** @type {any} */ (class {
+ /**
+ * Creates a new instance of the Callable class.
+ */
+ constructor() {
+ /**
+ * Creates a closure that delegates to a private method '_call' with the given arguments.
+ * @type {any}
+ * @param {...any} args Zero or more arguments to pass to the '_call' method.
+ * @returns {*} The result of calling the '_call' method.
+ */
+ let closure = function (...args) {
+ return closure._call(...args)
+ }
+ return Object.setPrototypeOf(closure, new.target.prototype)
+ }
+
+ /**
+ * This method should be implemented in subclasses to provide the
+ * functionality of the callable object.
+ *
+ * @param {any[]} args
+ * @throws {Error} If the subclass does not implement the `_call` method.
+ */
+ _call(...args) {
+ throw Error('Must implement _call method in subclass')
+ }
+});
diff --git a/src/utils/hub.js b/src/utils/hub.js
old mode 100644
new mode 100755
index 32cab6c5b..71c20c861
--- a/src/utils/hub.js
+++ b/src/utils/hub.js
@@ -13,9 +13,8 @@ import { dispatchCallback } from './core.js';
/**
* @typedef {Object} PretrainedOptions Options for loading a pretrained model.
- * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files).
* @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates.
- * @property {Object} [config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:
+ * @property {import('../configs.js').PretrainedConfig} [config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:
* - The model is a model provided by the library (loaded with the *model id* string of a pretrained model).
* - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory.
* @property {string} [cache_dir=null] Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
@@ -23,24 +22,39 @@ import { dispatchCallback } from './core.js';
* @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id,
* since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
* NOTE: This setting is ignored for local requests.
+ */
+
+/**
+ * @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model.
+ * @property {string} [subfolder='onnx'] In case the relevant files are located inside a subfolder of the model repo on huggingface.co,
+ * you can specify the folder name here.
* @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models.
+ * @property {import("./devices.js").DeviceType|Record} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings.
+ * @property {import("./dtypes.js").DataType|Record} [dtype=null] The data type to use for the model. If not specified, the data type will be chosen from the environment settings.
+ * @property {boolean|Record} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
+ * @property {import('onnxruntime-common').InferenceSession.SessionOptions} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen.
*/
+/**
+ * @typedef {PretrainedOptions & ModelSpecificPretrainedOptions} PretrainedModelOptions Options for loading a pretrained model.
+ */
+
+/**
+ * Mapping from file extensions to MIME types.
+ */
+const CONTENT_TYPE_MAP = {
+ 'txt': 'text/plain',
+ 'html': 'text/html',
+ 'css': 'text/css',
+ 'js': 'text/javascript',
+ 'json': 'application/json',
+ 'png': 'image/png',
+ 'jpg': 'image/jpeg',
+ 'jpeg': 'image/jpeg',
+ 'gif': 'image/gif',
+}
class FileResponse {
- /**
- * Mapping from file extensions to MIME types.
- */
- _CONTENT_TYPE_MAP = {
- 'txt': 'text/plain',
- 'html': 'text/html',
- 'css': 'text/css',
- 'js': 'text/javascript',
- 'json': 'application/json',
- 'png': 'image/png',
- 'jpg': 'image/jpeg',
- 'jpeg': 'image/jpeg',
- 'gif': 'image/gif',
- }
+
/**
* Creates a new `FileResponse` object.
* @param {string|URL} filePath
@@ -83,7 +97,7 @@ class FileResponse {
updateContentType() {
// Set content-type header based on file extension
const extension = this.filePath.toString().split('.').pop().toLowerCase();
- this.headers.set('content-type', this._CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream');
+ this.headers.set('content-type', CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream');
}
/**
@@ -323,7 +337,7 @@ async function tryCache(cache, ...names) {
* @param {PretrainedOptions} [options] An object containing optional parameters.
*
* @throws Will throw an error if the file is not found and `fatal` is true.
- * @returns {Promise} A Promise that resolves with the file content as a buffer.
+ * @returns {Promise} A Promise that resolves with the file content as a buffer.
*/
export async function getModelFile(path_or_repo_id, filename, fatal = true, options = {}) {
diff --git a/src/utils/image.js b/src/utils/image.js
index 89bb2481b..33bdf11d8 100644
--- a/src/utils/image.js
+++ b/src/utils/image.js
@@ -39,7 +39,7 @@ if (BROWSER_ENV) {
const metadata = await img.metadata();
const rawChannels = metadata.channels;
- let { data, info } = await img.rotate().raw().toBuffer({ resolveWithObject: true });
+ const { data, info } = await img.rotate().raw().toBuffer({ resolveWithObject: true });
const newImage = new RawImage(new Uint8ClampedArray(data), info.width, info.height, info.channels);
if (rawChannels !== undefined && rawChannels !== info.channels) {
@@ -125,6 +125,20 @@ export class RawImage {
}
}
+ /**
+ * Read an image from a canvas.
+ * @param {HTMLCanvasElement|OffscreenCanvas} canvas The canvas to read the image from.
+ * @returns {RawImage} The image object.
+ */
+ static fromCanvas(canvas) {
+ if (!BROWSER_ENV) {
+ throw new Error('fromCanvas() is only supported in browser environments.')
+ }
+
+ const ctx = canvas.getContext('2d');
+ const data = ctx.getImageData(0, 0, canvas.width, canvas.height).data;
+ return new RawImage(data, canvas.width, canvas.height, 4);
+ }
/**
* Read an image from a URL or file path.
@@ -132,11 +146,11 @@ export class RawImage {
* @returns {Promise} The image object.
*/
static async fromURL(url) {
- let response = await getFile(url);
+ const response = await getFile(url);
if (response.status !== 200) {
throw new Error(`Unable to read image from "${url}" (${response.status} ${response.statusText})`);
}
- let blob = await response.blob();
+ const blob = await response.blob();
return this.fromBlob(blob);
}
@@ -148,7 +162,7 @@ export class RawImage {
static async fromBlob(blob) {
if (BROWSER_ENV) {
// Running in environment with canvas
- let img = await loadImageFunction(blob);
+ const img = await loadImageFunction(blob);
const ctx = createCanvasFunction(img.width, img.height).getContext('2d');
@@ -159,7 +173,7 @@ export class RawImage {
} else {
// Use sharp.js to read (and possible resize) the image.
- let img = sharp(await blob.arrayBuffer());
+ const img = sharp(await blob.arrayBuffer());
return await loadImageFunction(img);
}
@@ -204,7 +218,7 @@ export class RawImage {
return this;
}
- let newData = new Uint8ClampedArray(this.width * this.height * 1);
+ const newData = new Uint8ClampedArray(this.width * this.height * 1);
switch (this.channels) {
case 3: // rgb to grayscale
case 4: // rgba to grayscale
@@ -231,7 +245,7 @@ export class RawImage {
return this;
}
- let newData = new Uint8ClampedArray(this.width * this.height * 3);
+ const newData = new Uint8ClampedArray(this.width * this.height * 3);
switch (this.channels) {
case 1: // grayscale to rgb
@@ -264,7 +278,7 @@ export class RawImage {
return this;
}
- let newData = new Uint8ClampedArray(this.width * this.height * 4);
+ const newData = new Uint8ClampedArray(this.width * this.height * 4);
switch (this.channels) {
case 1: // grayscale to rgba
@@ -309,10 +323,10 @@ export class RawImage {
// TODO use `resample` in browser environment
// Store number of channels before resizing
- let numChannels = this.channels;
+ const numChannels = this.channels;
// Create canvas object for this image
- let canvas = this.toCanvas();
+ const canvas = this.toCanvas();
// Actually perform resizing using the canvas API
const ctx = createCanvasFunction(width, height).getContext('2d');
@@ -321,7 +335,7 @@ export class RawImage {
ctx.drawImage(canvas, 0, 0, width, height);
// Create image from the resized data
- let resizedImage = new RawImage(ctx.getImageData(0, 0, width, height).data, width, height, 4);
+ const resizedImage = new RawImage(ctx.getImageData(0, 0, width, height).data, width, height, 4);
// Convert back so that image has the same number of channels as before
return resizedImage.convert(numChannels);
@@ -380,13 +394,13 @@ export class RawImage {
if (BROWSER_ENV) {
// Store number of channels before padding
- let numChannels = this.channels;
+ const numChannels = this.channels;
// Create canvas object for this image
- let canvas = this.toCanvas();
+ const canvas = this.toCanvas();
- let newWidth = this.width + left + right;
- let newHeight = this.height + top + bottom;
+ const newWidth = this.width + left + right;
+ const newHeight = this.height + top + bottom;
// Create a new canvas of the desired size.
const ctx = createCanvasFunction(newWidth, newHeight).getContext('2d');
@@ -398,7 +412,7 @@ export class RawImage {
);
// Create image from the padded data
- let paddedImage = new RawImage(
+ const paddedImage = new RawImage(
ctx.getImageData(0, 0, newWidth, newHeight).data,
newWidth, newHeight, 4);
@@ -406,7 +420,7 @@ export class RawImage {
return paddedImage.convert(numChannels);
} else {
- let img = this.toSharp().extend({ left, right, top, bottom });
+ const img = this.toSharp().extend({ left, right, top, bottom });
return await loadImageFunction(img);
}
}
@@ -470,16 +484,16 @@ export class RawImage {
}
// Determine bounds of the image in the new canvas
- let width_offset = (this.width - crop_width) / 2;
- let height_offset = (this.height - crop_height) / 2;
+ const width_offset = (this.width - crop_width) / 2;
+ const height_offset = (this.height - crop_height) / 2;
if (BROWSER_ENV) {
// Store number of channels before resizing
- let numChannels = this.channels;
+ const numChannels = this.channels;
// Create canvas object for this image
- let canvas = this.toCanvas();
+ const canvas = this.toCanvas();
// Create a new canvas of the desired size. This is needed since if the
// image is too small, we need to pad it with black pixels.
@@ -509,7 +523,7 @@ export class RawImage {
);
// Create image from the resized data
- let resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4);
+ const resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4);
// Convert back so that image has the same number of channels as before
return resizedImage.convert(numChannels);
@@ -529,8 +543,8 @@ export class RawImage {
} else if (width_offset <= 0 && height_offset <= 0) {
// Cropped image lies entirely outside the original image,
// so we add padding
- let top = Math.floor(-height_offset);
- let left = Math.floor(-width_offset);
+ const top = Math.floor(-height_offset);
+ const left = Math.floor(-width_offset);
img = img.extend({
top: top,
left: left,
@@ -611,13 +625,13 @@ export class RawImage {
// Clone, and convert data to RGBA before drawing to canvas.
// This is because the canvas API only supports RGBA
- let cloned = this.clone().rgba();
+ const cloned = this.clone().rgba();
// Create canvas object for the cloned image
- let clonedCanvas = createCanvasFunction(cloned.width, cloned.height);
+ const clonedCanvas = createCanvasFunction(cloned.width, cloned.height);
// Draw image to context
- let data = new ImageDataClass(cloned.data, cloned.width, cloned.height);
+ const data = new ImageDataClass(cloned.data, cloned.width, cloned.height);
clonedCanvas.getContext('2d').putImageData(data, 0, 0);
return clonedCanvas;
@@ -728,4 +742,4 @@ export class RawImage {
}
});
}
-}
+}
\ No newline at end of file
diff --git a/src/utils/maths.js b/src/utils/maths.js
index 319f4a347..e6cb2d6ca 100644
--- a/src/utils/maths.js
+++ b/src/utils/maths.js
@@ -190,27 +190,6 @@ export function dot(arr1, arr2) {
return result;
}
-
-/**
- * Get the top k items from an iterable, sorted by descending order
- * @param {any[]|TypedArray} items The items to be sorted
- * @param {number|null} [top_k=0] The number of top items to return (default: 0 = return all)
- * @returns {[number, any][]} The top k items, sorted by descending order
- */
-export function getTopItems(items, top_k = 0) {
- // if top == 0, return all
-
- items = Array.from(items)
- .map((x, i) => [i, x]) // Get indices ([index, score])
- .sort((a, b) => b[1] - a[1]) // Sort by log probabilities
-
- if (top_k !== null && top_k > 0) {
- items = items.slice(0, top_k); // Get top k items
- }
-
- return items
-}
-
/**
* Computes the cosine similarity between two arrays.
*
@@ -247,7 +226,7 @@ export function magnitude(arr) {
/**
* Returns the value and index of the minimum element in an array.
* @param {number[]|TypedArray} arr array of numbers.
- * @returns {number[]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin]
+ * @returns {[number, number]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin]
* @throws {Error} If array is empty.
*/
export function min(arr) {
@@ -992,3 +971,89 @@ export function bankers_round(x) {
const br = Math.abs(x) % 1 === 0.5 ? (r % 2 === 0 ? r : r - 1) : r;
return br;
}
+
+
+/**
+ * Measures similarity between two temporal sequences (e.g., input audio and output tokens
+ * to generate token-level timestamps).
+ * @param {number[][]} matrix
+ * @returns {number[][]}
+ */
+export function dynamic_time_warping(matrix) {
+ const output_length = matrix.length;
+ const input_length = matrix[0].length;
+
+ const outputShape = [output_length + 1, input_length + 1];
+
+ const cost = Array.from(
+ { length: outputShape[0] },
+ () => Array(outputShape[1]).fill(Infinity)
+ );
+ cost[0][0] = 0;
+
+ const trace = Array.from(
+ { length: outputShape[0] },
+ () => Array(outputShape[1]).fill(-1)
+ );
+
+ for (let j = 1; j < outputShape[1]; ++j) {
+ for (let i = 1; i < outputShape[0]; ++i) {
+ const c0 = cost[i - 1][j - 1];
+ const c1 = cost[i - 1][j];
+ const c2 = cost[i][j - 1];
+
+ let c, t;
+ if (c0 < c1 && c0 < c2) {
+ c = c0;
+ t = 0;
+ } else if (c1 < c0 && c1 < c2) {
+ c = c1;
+ t = 1;
+ } else {
+ c = c2;
+ t = 2;
+ }
+ cost[i][j] = matrix[i - 1][j - 1] + c;
+ trace[i][j] = t;
+ }
+ }
+
+ for (let i = 0; i < outputShape[1]; ++i) { // trace[0, :] = 2
+ trace[0][i] = 2;
+ }
+ for (let i = 0; i < outputShape[0]; ++i) { // trace[:, 0] = 1
+ trace[i][0] = 1;
+ }
+
+ // backtrace
+ let i = output_length;
+ let j = input_length;
+ let text_indices = [];
+ let time_indices = [];
+ while (i > 0 || j > 0) {
+ text_indices.push(i - 1);
+ time_indices.push(j - 1);
+
+ switch (trace[i][j]) {
+ case 0:
+ --i; --j;
+ break;
+ case 1:
+ --i;
+ break;
+ case 2:
+ --j;
+ break;
+ default:
+ throw new Error(
+ `Internal error in dynamic time warping. Unexpected trace[${i}, ${j}]. Please file a bug report.`
+ )
+ }
+ }
+
+ text_indices.reverse();
+ time_indices.reverse();
+
+ return [text_indices, time_indices];
+
+}
diff --git a/src/utils/tensor.js b/src/utils/tensor.js
index 469054cac..536a8c249 100644
--- a/src/utils/tensor.js
+++ b/src/utils/tensor.js
@@ -1,22 +1,26 @@
/**
* @file Helper module for `Tensor` processing.
- *
- * These functions and classes are only used internally,
+ *
+ * These functions and classes are only used internally,
* meaning an end-user shouldn't need to access anything here.
- *
+ *
* @module utils/tensor
*/
-import { ONNX } from '../backends/onnx.js';
-
import {
interpolate_data,
permute_data
} from './maths.js';
+import {
+ Tensor as ONNXTensor, isONNXTensor,
+} from '../backends/onnx.js';
+
+import { TensorOpRegistry } from '../ops/registry.js';
const DataTypeMap = Object.freeze({
float32: Float32Array,
+ float16: Uint16Array,
float64: Float64Array,
string: Array, // string[]
int8: Int8Array,
@@ -35,37 +39,55 @@ const DataTypeMap = Object.freeze({
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
*/
-const ONNXTensor = ONNX.Tensor;
export class Tensor {
/** @type {number[]} Dimensions of the tensor. */
- dims;
+ get dims() {
+ // @ts-ignore
+ return this.ort_tensor.dims;
+ }
+ set dims(value) {
+ // FIXME: ONNXTensor declares dims as readonly so one needs to use the constructor() if dims change.
+ // @ts-ignore
+ this.ort_tensor.dims = value;
+ }
/** @type {DataType} Type of the tensor. */
- type;
+ get type() {
+ return this.ort_tensor.type;
+ };
/** @type {DataArray} The data stored in the tensor. */
- data;
+ get data() {
+ return this.ort_tensor.data;
+ }
/** @type {number} The number of elements in the tensor. */
- size;
+ get size() {
+ return this.ort_tensor.size;
+ };
+
+ /** @type {string} The location of the tensor data. */
+ get location() {
+ return this.ort_tensor.location;
+ };
+
+ ort_tensor;
/**
* Create a new Tensor or copy an existing Tensor.
- * @param {[DataType, DataArray, number[]]|[import('onnxruntime-common').Tensor]} args
+ * @param {[DataType, DataArray, number[]]|[ONNXTensor]} args
*/
constructor(...args) {
- if (args[0] instanceof ONNXTensor) {
- // Create shallow copy
- Object.assign(this, args[0]);
-
+ if (isONNXTensor(args[0])) {
+ this.ort_tensor = /** @type {ONNXTensor} */ (args[0]);
} else {
// Create new tensor
- Object.assign(this, new ONNXTensor(
+ this.ort_tensor = new ONNXTensor(
/** @type {DataType} */(args[0]),
/** @type {Exclude} */(args[1]),
args[2]
- ));
+ );
}
return new Proxy(this, {
@@ -89,6 +111,11 @@ export class Tensor {
});
}
+ dispose() {
+ this.ort_tensor.dispose();
+ // this.ort_tensor = undefined;
+ }
+
/**
* Returns an iterator object for iterating over the tensor data in row-major order.
* If the tensor has more than one dimension, the iterator will yield subarrays.
@@ -131,9 +158,10 @@ export class Tensor {
* @returns {number} The index of the first occurrence of item in the tensor data.
*/
indexOf(item) {
- for (let index = 0; index < this.data.length; ++index) {
+ const this_data = this.data;
+ for (let index = 0; index < this_data.length; ++index) {
// Note: == instead of === so we can match Ints with BigInts
- if (this.data[index] == item) {
+ if (this_data[index] == item) {
return index;
}
}
@@ -141,9 +169,9 @@ export class Tensor {
}
/**
- * @param {number} index
- * @param {number} iterSize
- * @param {any} iterDims
+ * @param {number} index
+ * @param {number} iterSize
+ * @param {any} iterDims
* @returns {Tensor}
*/
_subarray(index, iterSize, iterDims) {
@@ -165,10 +193,11 @@ export class Tensor {
* @throws {Error} If the tensor has more than one element.
*/
item() {
- if (this.data.length !== 1) {
- throw new Error(`a Tensor with ${this.data.length} elements cannot be converted to Scalar`);
+ const this_data = this.data;
+ if (this_data.length !== 1) {
+ throw new Error(`a Tensor with ${this_data.length} elements cannot be converted to Scalar`);
}
- return this.data[0];
+ return this_data[0];
}
/**
@@ -192,8 +221,33 @@ export class Tensor {
* @returns {Tensor} Returns `this`.
*/
sigmoid_() {
- for (let i = 0; i < this.data.length; ++i) {
- this.data[i] = 1 / (1 + Math.exp(-this.data[i]));
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] = 1 / (1 + Math.exp(-this_data[i]));
+ }
+ return this;
+ }
+
+ /**
+ * Return a new Tensor with a callback function applied to each element.
+ * @param {Function} callback - The function to apply to each element. It should take three arguments:
+ * the current element, its index, and the tensor's data array.
+ * @returns {Tensor} A new Tensor with the callback function applied to each element.
+ */
+ map(callback) {
+ return this.clone().map_(callback);
+ }
+
+ /**
+ * Apply a callback function to each element of the tensor in place.
+ * @param {Function} callback - The function to apply to each element. It should take three arguments:
+ * the current element, its index, and the tensor's data array.
+ * @returns {Tensor} Returns `this`.
+ */
+ map_(callback) {
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] = callback(this_data[i], i, this_data);
}
return this;
}
@@ -213,12 +267,34 @@ export class Tensor {
* @returns {Tensor} Returns `this`.
*/
mul_(val) {
- for (let i = 0; i < this.data.length; ++i) {
- this.data[i] *= val;
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] *= val;
}
return this;
}
+ /**
+ * Return a new Tensor with every element divided by a constant.
+ * @param {number} val The value to divide by.
+ * @returns {Tensor} The new tensor.
+ */
+ div(val) {
+ return this.clone().div_(val);
+ }
+
+ /**
+ * Divide the tensor by a constant in place.
+ * @param {number} val The value to divide by.
+ * @returns {Tensor} Returns `this`.
+ */
+ div_(val) {
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] /= val;
+ }
+ return this;
+ }
/**
* Return a new Tensor with every element added by a constant.
@@ -235,19 +311,43 @@ export class Tensor {
* @returns {Tensor} Returns `this`.
*/
add_(val) {
- for (let i = 0; i < this.data.length; ++i) {
- this.data[i] += val;
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] += val;
}
return this;
}
+
+ /**
+ * Return a new Tensor with every element subtracted by a constant.
+ * @param {number} val The value to subtract by.
+ * @returns {Tensor} The new tensor.
+ */
+ sub(val) {
+ return this.clone().sub_(val);
+ }
+
+ /**
+ * Subtract the tensor by a constant in place.
+ * @param {number} val The value to subtract by.
+ * @returns {Tensor} Returns `this`.
+ */
+ sub_(val) {
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] -= val;
+ }
+ return this;
+ }
+
clone() {
return new Tensor(this.type, this.data.slice(), this.dims.slice());
}
slice(...slices) {
// This allows for slicing with ranges and numbers
- let newTensorDims = [];
- let newOffsets = [];
+ const newTensorDims = [];
+ const newOffsets = [];
// slices is an array of numbers or arrays of numbers
// e.g., slices = [0, [1, 3], null, [0, 3]]
@@ -267,14 +367,21 @@ export class Tensor {
} else if (Array.isArray(slice) && slice.length === 2) {
// An array of length 2 means take a range of elements
-
- if (slice[0] > slice[1]) {
+ let [start, end] = slice;
+ start = start === null
+ ? 0
+ : safeIndex(start, this.dims[sliceIndex], sliceIndex, false);
+ end = end === null
+ ? this.dims[sliceIndex]
+ : safeIndex(end, this.dims[sliceIndex], sliceIndex, false);
+
+ if (start > end) {
throw new Error(`Invalid slice: ${slice}`);
}
- let offsets = [
- Math.max(slice[0], 0),
- Math.min(slice[1], this.dims[sliceIndex])
+ const offsets = [
+ Math.max(start, 0),
+ Math.min(end, this.dims[sliceIndex])
];
newOffsets.push(offsets);
@@ -285,12 +392,13 @@ export class Tensor {
}
}
- let newDims = newOffsets.map(([start, end]) => end - start);
- let newBufferSize = newDims.reduce((a, b) => a * b);
+ const newDims = newOffsets.map(([start, end]) => end - start);
+ const newBufferSize = newDims.reduce((a, b) => a * b);
+ const this_data = this.data;
// Allocate memory
// @ts-ignore
- let data = new this.data.constructor(newBufferSize);
+ const data = new this_data.constructor(newBufferSize);
// Precompute strides
const stride = this.stride();
@@ -302,7 +410,7 @@ export class Tensor {
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
num = Math.floor(num / size);
}
- data[i] = this.data[originalIndex];
+ data[i] = this_data[originalIndex];
}
return new Tensor(this.type, data, newTensorDims);
@@ -326,7 +434,7 @@ export class Tensor {
/**
* Returns the sum of each row of the input tensor in the given dimension dim.
- *
+ *
* @param {number} [dim=null] The dimension or dimensions to reduce. If `null`, all dimensions are reduced.
* @param {boolean} keepdim Whether the output tensor has `dim` retained or not.
* @returns The summed tensor
@@ -351,9 +459,11 @@ export class Tensor {
throw Error(`Unsupported norm: ${p}`);
}
+ const this_data = this.data;
+
if (dim === null) {
// @ts-ignore
- let val = this.data.reduce((a, b) => a + (b ** p), 0) ** (1 / p);
+ let val = this_data.reduce((a, b) => a + (b ** p), 0) ** (1 / p);
return new Tensor(this.type, [val], []);
}
@@ -366,10 +476,10 @@ export class Tensor {
// Create a new array to store the accumulated values
// @ts-ignore
- const result = new this.data.constructor(this.data.length / this.dims[dim]);
+ const result = new this_data.constructor(this_data.length / this.dims[dim]);
// Iterate over the data array
- for (let i = 0; i < this.data.length; ++i) {
+ for (let i = 0; i < this_data.length; ++i) {
// Calculate the index in the resulting array
let resultIndex = 0;
@@ -385,7 +495,7 @@ export class Tensor {
}
// Accumulate the value at the current index
- result[resultIndex] += (this.data[i]) ** p;
+ result[resultIndex] += (this_data[i]) ** p;
}
if (p !== 1) {
@@ -412,7 +522,9 @@ export class Tensor {
const norm = this.norm(p, dim, true);
- for (let i = 0; i < this.data.length; ++i) {
+ const this_data = this.data;
+ const norm_data = norm.data;
+ for (let i = 0; i < this_data.length; ++i) {
// Calculate the index in the resulting array
let resultIndex = 0;
@@ -428,7 +540,7 @@ export class Tensor {
}
// Divide by normalized value
- this.data[i] /= norm.data[resultIndex];
+ this_data[i] /= norm_data[resultIndex];
}
return this;
@@ -455,12 +567,12 @@ export class Tensor {
/**
* Returns a tensor with all specified dimensions of input of size 1 removed.
- *
+ *
* NOTE: The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.
* If you would like a copy, use `tensor.clone()` before squeezing.
- *
+ *
* @param {number} [dim=null] If given, the input will be squeezed only in the specified dimensions.
- * @returns The squeezed tensor
+ * @returns {Tensor} The squeezed tensor
*/
squeeze(dim = null) {
return new Tensor(
@@ -480,11 +592,11 @@ export class Tensor {
/**
* Returns a new tensor with a dimension of size one inserted at the specified position.
- *
+ *
* NOTE: The returned tensor shares the same underlying data with this tensor.
- *
+ *
* @param {number} dim The index at which to insert the singleton dimension
- * @returns The unsqueezed tensor
+ * @returns {Tensor} The unsqueezed tensor
*/
unsqueeze(dim = null) {
return new Tensor(
@@ -523,7 +635,7 @@ export class Tensor {
* and ending with `end_dim` are flattened. The order of elements in input is unchanged.
* @param {number} start_dim the first dim to flatten
* @param {number} end_dim the last dim to flatten
- * @returns The flattened tensor.
+ * @returns {Tensor} The flattened tensor.
*/
flatten(start_dim = 0, end_dim = -1) {
return this.clone().flatten_(start_dim, end_dim);
@@ -546,20 +658,22 @@ export class Tensor {
}
}
+ const this_data = this.data;
if (inferredIndex !== -1) {
// Some dimension must be inferred
const productOther = dims.reduce((product, curr, index) => {
return index !== inferredIndex ? product * curr : product
}, 1);
- dims[inferredIndex] = this.data.length / productOther;
+ dims[inferredIndex] = this_data.length / productOther;
}
- return new Tensor(this.type, this.data, dims); // NOTE: uses same underlying storage
+ return new Tensor(this.type, this_data, dims); // NOTE: uses same underlying storage
}
neg_() {
- for (let i = 0; i < this.data.length; ++i) {
- this.data[i] = -this.data[i];
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] = -this_data[i];
}
return this;
}
@@ -571,8 +685,9 @@ export class Tensor {
* In-place version of @see {@link Tensor.clamp}
*/
clamp_(min, max) {
- for (let i = 0; i < this.data.length; ++i) {
- this.data[i] = Math.min(Math.max(this.data[i], min), max);
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] = Math.min(Math.max(this_data[i], min), max);
}
return this;
}
@@ -581,7 +696,7 @@ export class Tensor {
* Clamps all elements in input into the range [ min, max ]
* @param {number} min lower-bound of the range to be clamped to
* @param {number} max upper-bound of the range to be clamped to
- * @returns the output tensor.
+ * @returns {Tensor} the output tensor.
*/
clamp(min, max) {
return this.clone().clamp_(min, max);
@@ -591,20 +706,25 @@ export class Tensor {
* In-place version of @see {@link Tensor.round}
*/
round_() {
- for (let i = 0; i < this.data.length; ++i) {
- this.data[i] = Math.round(this.data[i]);
+ const this_data = this.data;
+ for (let i = 0; i < this_data.length; ++i) {
+ this_data[i] = Math.round(this_data[i]);
}
return this;
}
/**
* Rounds elements of input to the nearest integer.
- * @returns the output tensor.
+ * @returns {Tensor} the output tensor.
*/
round() {
return this.clone().round_();
}
+ mean(dim = null, keepdim = false) {
+ return mean(this, dim, keepdim);
+ }
+
/**
* Performs Tensor dtype conversion.
* @param {DataType} type The desired data type.
@@ -625,7 +745,7 @@ export class Tensor {
/**
* This creates a nested array of a given type and depth (see examples).
- *
+ *
* @example
* NestArray; // string[]
* @example
@@ -718,6 +838,105 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a
return new Tensor(input.type, output, [in_channels, out_height, out_width]);
}
+
+/**
+ * Down/up samples the input.
+ * Inspired by https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html.
+ * @param {Tensor} input the input tensor
+ * @param {Object} options the options for the interpolation
+ * @param {[number, number]|[number, number, number]|[number, number, number, number]} [options.size=null] output spatial size.
+ * @param {"bilinear"|"bicubic"} [options.mode='bilinear'] algorithm used for upsampling
+ * @returns {Promise} The interpolated tensor.
+ */
+export async function interpolate_4d(input, {
+ size = null,
+ mode = 'bilinear',
+} = {}) {
+
+ // Error checking
+ if (input.dims.length !== 4) {
+ throw new Error('`interpolate_4d` currently only supports 4D input.');
+ }
+ if (!size) {
+ // TODO: support scale_factor
+ throw new Error('`interpolate_4d` requires a `size` argument.');
+ }
+
+ // Fill in missing dimensions
+ let targetDims;
+ if (size.length === 2) {
+ targetDims = [...input.dims.slice(0, 2), ...size];
+ } else if (size.length === 3) {
+ targetDims = [input.dims[0], ...size];
+ } else if (size.length === 4) {
+ targetDims = size;
+ } else {
+ throw new Error('`size` must be of length 2, 3, or 4.');
+ }
+
+ let op;
+ if (mode === 'bilinear') {
+ op = await TensorOpRegistry.bilinear_interpolate_4d;
+ } else if (mode === 'bicubic') {
+ op = await TensorOpRegistry.bicubic_interpolate_4d;
+ } else {
+ throw new Error(`Unsupported mode: ${mode}`);
+ }
+
+ const sizeTensor = new Tensor('int64', new BigInt64Array(targetDims.map(BigInt)), [targetDims.length]);
+ return await op({ x: input, s: sizeTensor });
+}
+
+/**
+ * Matrix product of two tensors.
+ * Inspired by https://pytorch.org/docs/stable/generated/torch.matmul.html
+ * @param {Tensor} a the first tensor to be multiplied
+ * @param {Tensor} b the second tensor to be multiplied
+ * @returns {Promise} The matrix product of the two tensors.
+ */
+export async function matmul(a, b) {
+ const op = await TensorOpRegistry.matmul;
+ return await op({ a, b });
+}
+
+/**
+ * Computes the one dimensional Fourier transform of real-valued input.
+ * Inspired by https://pytorch.org/docs/stable/generated/torch.fft.rfft.html
+ * @param {Tensor} x the real input tensor
+ * @param {Tensor} a The dimension along which to take the one dimensional real FFT.
+ * @returns {Promise} the output tensor.
+ */
+export async function rfft(x, a) {
+ const op = await TensorOpRegistry.rfft;
+ return await op({ x, a });
+}
+
+
+/**
+ * Returns the k largest elements of the given input tensor.
+ * Inspired by https://pytorch.org/docs/stable/generated/torch.topk.html
+ * @param {Tensor} x the input tensor
+ * @param {number} k the k in "top-k"
+ * @returns {Promise<[Tensor, Tensor]>} the output tuple of (Tensor, LongTensor) of top-k elements and their indices.
+ */
+export async function topk(x, k) {
+ const op = await TensorOpRegistry.top_k;
+
+ if (k === null) {
+ k = x.dims.at(-1);
+ } else {
+ k = Math.min(k, x.dims.at(-1));
+ }
+ return await op({
+ x,
+ k: new Tensor(
+ 'int64',
+ [BigInt(k)],
+ [1]
+ )
+ });
+}
+
/**
* Perform mean pooling of the last hidden state followed by a normalization step.
* @param {Tensor} last_hidden_state Tensor of shape [batchSize, seqLength, embedDim]
@@ -727,32 +946,35 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a
export function mean_pooling(last_hidden_state, attention_mask) {
// last_hidden_state: [batchSize, seqLength, embedDim]
// attention_mask: [batchSize, seqLength]
+ const lastHiddenStateData = last_hidden_state.data;
+ const attentionMaskData = attention_mask.data;
+
+ const shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]];
- let shape = [last_hidden_state.dims[0], last_hidden_state.dims[2]];
// @ts-ignore
- let returnedData = new last_hidden_state.data.constructor(shape[0] * shape[1]);
- let [batchSize, seqLength, embedDim] = last_hidden_state.dims;
+ const returnedData = new lastHiddenStateData.constructor(shape[0] * shape[1]);
+ const [batchSize, seqLength, embedDim] = last_hidden_state.dims;
let outIndex = 0;
for (let i = 0; i < batchSize; ++i) {
- let offset = i * embedDim * seqLength;
+ const offset = i * embedDim * seqLength;
for (let k = 0; k < embedDim; ++k) {
let sum = 0;
let count = 0;
- let attnMaskOffset = i * seqLength;
- let offset2 = offset + k;
+ const attnMaskOffset = i * seqLength;
+ const offset2 = offset + k;
// Pool over all words in sequence
for (let j = 0; j < seqLength; ++j) {
// index into attention mask
- let attn = Number(attention_mask.data[attnMaskOffset + j]);
+ const attn = Number(attentionMaskData[attnMaskOffset + j]);
count += attn;
- sum += last_hidden_state.data[offset2 + j * embedDim] * attn;
+ sum += lastHiddenStateData[offset2 + j * embedDim] * attn;
}
- let avg = sum / count;
+ const avg = sum / count;
returnedData[outIndex++] = avg;
}
}
@@ -786,15 +1008,19 @@ export function layer_norm(input, normalized_shape, {
}
const [std, mean] = std_mean(input, 1, 0, true);
+ const stdData = /** @type {Float32Array} */(std.data);
+ const meanData = /** @type {Float32Array} */(mean.data);
+
+ const inputData = /** @type {Float32Array} */(input.data);
// @ts-ignore
- const returnedData = new input.data.constructor(input.data.length);
+ const returnedData = new inputData.constructor(inputData.length);
for (let i = 0; i < batchSize; ++i) {
const offset = i * featureDim;
for (let j = 0; j < featureDim; ++j) {
const offset2 = offset + j;
- returnedData[offset2] = (input.data[offset2] - mean.data[i]) / (std.data[i] + eps);
+ returnedData[offset2] = (inputData[offset2] - meanData[i]) / (stdData[i] + eps);
}
}
return new Tensor(input.type, returnedData, input.dims);
@@ -804,7 +1030,7 @@ export function layer_norm(input, normalized_shape, {
* Helper function to calculate new dimensions when performing a squeeze operation.
* @param {number[]} dims The dimensions of the tensor.
* @param {number|number[]|null} dim The dimension(s) to squeeze.
- * @returns The new dimensions.
+ * @returns {number[]} The new dimensions.
* @private
*/
function calc_squeeze_dims(dims, dim) {
@@ -827,7 +1053,7 @@ function calc_squeeze_dims(dims, dim) {
* Helper function to calculate new dimensions when performing an unsqueeze operation.
* @param {number[]} dims The dimensions of the tensor.
* @param {number} dim The dimension to unsqueeze.
- * @returns The new dimensions.
+ * @returns {number[]} The new dimensions.
* @private
*/
function calc_unsqueeze_dims(dims, dim) {
@@ -846,12 +1072,12 @@ function calc_unsqueeze_dims(dims, dim) {
* @param {number} size The size of the array.
* @param {number} [dimension=null] The dimension that the index is for (optional).
* @returns {number} The index, guaranteed to be non-negative and less than `arrayLength`.
- *
+ *
* @throws {Error} If the index is out of range.
* @private
*/
-function safeIndex(index, size, dimension = null) {
- if (index < -size || index >= size) {
+function safeIndex(index, size, dimension = null, boundsCheck = true) {
+ if (boundsCheck && (index < -size || index >= size)) {
throw new Error(`IndexError: index ${index} is out of bounds for dimension${dimension === null ? '' : ' ' + dimension} with size ${size}`);
}
@@ -888,9 +1114,10 @@ export function cat(tensors, dim = 0) {
// Handle special case for performance reasons
let offset = 0;
- for (let t of tensors) {
- result.set(t.data, offset);
- offset += t.data.length;
+ for (const tensor of tensors) {
+ const tensorData = tensor.data;
+ result.set(tensorData, offset);
+ offset += tensorData.length;
}
} else {
@@ -898,15 +1125,15 @@ export function cat(tensors, dim = 0) {
let currentDim = 0;
for (let t = 0; t < tensors.length; ++t) {
- let tensor = tensors[t];
+ const { data, dims } = tensors[t];
// Iterate over the data array
- for (let i = 0; i < tensor.data.length; ++i) {
+ for (let i = 0; i < data.length; ++i) {
// Calculate the index in the resulting array
let resultIndex = 0;
- for (let j = tensor.dims.length - 1, num = i, resultMultiplier = 1; j >= 0; --j) {
- const size = tensor.dims[j];
+ for (let j = dims.length - 1, num = i, resultMultiplier = 1; j >= 0; --j) {
+ const size = dims[j];
let index = num % size;
if (j === dim) {
index += currentDim;
@@ -916,10 +1143,10 @@ export function cat(tensors, dim = 0) {
num = Math.floor(num / size);
}
// Accumulate the value at the current index
- result[resultIndex] = tensor.data[i];
+ result[resultIndex] = data[i];
}
- currentDim += tensor.dims[dim];
+ currentDim += dims[dim];
}
}
return new Tensor(resultType, result, resultDims);
@@ -947,14 +1174,14 @@ export function stack(tensors, dim = 0) {
* @returns {Tensor[]} A tuple of (std, mean) tensors.
*/
export function std_mean(input, dim = null, correction = 1, keepdim = false) {
+ const inputData = /** @type {Float32Array} */(input.data);
+ const inputDims = input.dims;
if (dim === null) {
// None to reduce over all dimensions.
- // @ts-ignore
- const sum = input.data.reduce((a, b) => a + b, 0);
- const mean = sum / input.data.length;
- // @ts-ignore
- const std = Math.sqrt(input.data.reduce((a, b) => a + (b - mean) ** 2, 0) / (input.data.length - correction));
+ const sum = inputData.reduce((a, b) => a + b, 0);
+ const mean = sum / inputData.length;
+ const std = Math.sqrt(inputData.reduce((a, b) => a + (b - mean) ** 2, 0) / (inputData.length - correction));
const meanTensor = new Tensor(input.type, [mean], [/* scalar */]);
const stdTensor = new Tensor(input.type, [std], [/* scalar */]);
@@ -963,26 +1190,27 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
}
// Negative indexing
- dim = safeIndex(dim, input.dims.length);
+ dim = safeIndex(dim, inputDims.length);
const meanTensor = mean(input, dim, keepdim);
+ const meanTensorData = meanTensor.data;
// Calculate the shape of the resulting array after summation
- const resultDims = input.dims.slice(); // Copy the original dimensions
+ const resultDims = inputDims.slice(); // Copy the original dimensions
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
- const result = new input.data.constructor(input.data.length / input.dims[dim]);
+ const result = new inputData.constructor(inputData.length / inputDims[dim]);
// Iterate over the data array
- for (let i = 0; i < input.data.length; ++i) {
+ for (let i = 0; i < inputData.length; ++i) {
// Calculate the index in the resulting array
let resultIndex = 0;
- for (let j = input.dims.length - 1, num = i, resultMultiplier = 1; j >= 0; --j) {
- const size = input.dims[j];
+ for (let j = inputDims.length - 1, num = i, resultMultiplier = 1; j >= 0; --j) {
+ const size = inputDims[j];
if (j !== dim) {
const index = num % size;
resultIndex += index * resultMultiplier;
@@ -992,11 +1220,11 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
}
// Accumulate the value at the current index
- result[resultIndex] += (input.data[i] - meanTensor.data[resultIndex]) ** 2;
+ result[resultIndex] += (inputData[i] - meanTensorData[resultIndex]) ** 2;
}
for (let i = 0; i < result.length; ++i) {
- result[i] = Math.sqrt(result[i] / (input.dims[dim] - correction));
+ result[i] = Math.sqrt(result[i] / (inputDims[dim] - correction));
}
if (!keepdim) {
@@ -1014,36 +1242,38 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
* @param {Tensor} input the input tensor.
* @param {number|null} dim the dimension to reduce.
* @param {boolean} keepdim whether the output tensor has dim retained or not.
- * @returns A new tensor with means taken along the specified dimension.
+ * @returns {Tensor} A new tensor with means taken along the specified dimension.
*/
export function mean(input, dim = null, keepdim = false) {
+ const inputData = /** @type {Float32Array} */(input.data);
if (dim === null) {
// None to reduce over all dimensions.
// @ts-ignore
- let val = input.data.reduce((a, b) => a + b, 0);
- return new Tensor(input.type, [val / input.data.length], [/* scalar */]);
+ const val = inputData.reduce((a, b) => a + b, 0);
+ return new Tensor(input.type, [val / inputData.length], [/* scalar */]);
}
+ const inputDims = input.dims;
// Negative indexing
- dim = safeIndex(dim, input.dims.length);
+ dim = safeIndex(dim, inputDims.length);
// Calculate the shape of the resulting array after summation
- const resultDims = input.dims.slice(); // Copy the original dimensions
+ const resultDims = inputDims.slice(); // Copy the original dimensions
resultDims[dim] = 1; // Remove the specified axis
// Create a new array to store the accumulated values
// @ts-ignore
- const result = new input.data.constructor(input.data.length / input.dims[dim]);
+ const result = new inputData.constructor(inputData.length / inputDims[dim]);
// Iterate over the data array
- for (let i = 0; i < input.data.length; ++i) {
+ for (let i = 0; i < inputData.length; ++i) {
// Calculate the index in the resulting array
let resultIndex = 0;
- for (let j = input.dims.length - 1, num = i, resultMultiplier = 1; j >= 0; --j) {
- const size = input.dims[j];
+ for (let j = inputDims.length - 1, num = i, resultMultiplier = 1; j >= 0; --j) {
+ const size = inputDims[j];
if (j !== dim) {
const index = num % size;
resultIndex += index * resultMultiplier;
@@ -1053,12 +1283,12 @@ export function mean(input, dim = null, keepdim = false) {
}
// Accumulate the value at the current index
- result[resultIndex] += input.data[i];
+ result[resultIndex] += inputData[i];
}
- if (input.dims[dim] !== 1) {
+ if (inputDims[dim] !== 1) {
for (let i = 0; i < result.length; ++i) {
- result[i] = result[i] / input.dims[dim];
+ result[i] = result[i] / inputDims[dim];
}
}
@@ -1070,99 +1300,6 @@ export function mean(input, dim = null, keepdim = false) {
}
-/**
- *
- * Measures similarity between two temporal sequences (e.g., input audio and output tokens
- * to generate token-level timestamps).
- * @param {Tensor} matrix
- * @returns {number[][]}
- */
-export function dynamicTimeWarping(matrix) {
- const [output_length, input_length] = matrix.dims;
-
- const outputShape = [output_length + 1, input_length + 1];
-
- const cost = new Tensor(
- 'float32',
- new Float32Array(outputShape[0] * outputShape[1]).fill(Infinity),
- outputShape
- );
-
- const trace = new Tensor(
- 'float32',
- new Float32Array(outputShape[0] * outputShape[1]).fill(-1),
- outputShape
- )
-
- // same as `cost[0][0] = 0`;
- cost[0].data[0] = 0;
-
- for (let j = 1; j < input_length + 1; ++j) {
- for (let i = 1; i < output_length + 1; ++i) {
-
- const c0 = cost[i - 1][j - 1].item();
- const c1 = cost[i - 1][j].item();
- const c2 = cost[i][j - 1].item();
-
- let c, t;
- if (c0 < c1 && c0 < c2) {
- c = c0;
- t = 0;
- } else if (c1 < c0 && c1 < c2) {
- c = c1;
- t = 1;
- } else {
- c = c2;
- t = 2;
- }
-
- cost[i].data[j] = matrix[i - 1][j - 1].item() + c;
- trace[i].data[j] = t;
- }
- }
-
- // backtrace
- let i = output_length;
- let j = input_length;
-
- // @ts-ignore
- trace.data.fill(2, 0, outputShape[1]) // trace[0, :] = 2
- for (let i = 0; i < outputShape[0]; ++i) { // trace[:, 0] = 1
- trace[i].data[0] = 1;
- }
-
- let text_indices = [];
- let time_indices = [];
-
- while (i > 0 || j > 0) {
- text_indices.push(i - 1);
- time_indices.push(j - 1);
-
- const t = trace[i][j].item();
- switch (t) {
- case 0:
- --i; --j;
- break;
- case 1:
- --i;
- break;
- case 2:
- --j;
- break;
- default:
- throw new Error(
- `Internal error in dynamic time warping. Unexpected trace[${i}, ${j}]. Please file a bug report.`
- )
- }
- }
-
- text_indices.reverse();
- time_indices.reverse();
-
- return [text_indices, time_indices];
-
-}
-
function dimsToStride(dims) {
const stride = new Array(dims.length);
for (let i = dims.length - 1, s2 = 1; i >= 0; --i) {
@@ -1172,28 +1309,77 @@ function dimsToStride(dims) {
return stride;
}
+function fullHelper(size, fill_value, dtype, cls) {
+ const numElements = size.reduce((a, b) => a * b, 1);
+ return new Tensor(
+ dtype,
+ new cls(numElements).fill(fill_value),
+ size
+ )
+}
+
+/**
+ * Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value.
+ * @param {number[]} size A sequence of integers defining the shape of the output tensor.
+ * @param {number|bigint} fill_value The value to fill the output tensor with.
+ * @returns {Tensor} The filled tensor.
+ */
+export function full(size, fill_value) {
+ let dtype;
+ let typedArrayCls;
+ if (typeof fill_value === 'number') {
+ dtype = 'float32';
+ typedArrayCls = Float32Array;
+ } else if (typeof fill_value === 'bigint') {
+ dtype = 'int64';
+ typedArrayCls = BigInt64Array;
+ } else {
+ // TODO: support other dtypes
+ throw new Error(`Unsupported data type: ${typeof fill_value}`);
+ }
+ return fullHelper(size, fill_value, dtype, typedArrayCls);
+}
+
+export function full_like(tensor, fill_value) {
+ return full(tensor.dims, fill_value);
+}
+
/**
* Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size.
* @param {number[]} size A sequence of integers defining the shape of the output tensor.
+ * @returns {Tensor} The ones tensor.
*/
export function ones(size) {
- const numElements = size.reduce((a, b) => a * b, 1);
- return new Tensor(
- 'int64',
- new BigInt64Array(numElements).fill(1n),
- size
- )
+ return fullHelper(size, 1n, 'int64', BigInt64Array);
}
/**
* Returns a tensor filled with the scalar value 1, with the same size as input.
* @param {Tensor} tensor The size of input will determine size of the output tensor.
- * @returns The ones tensor.
+ * @returns {Tensor} The ones tensor.
*/
export function ones_like(tensor) {
return ones(tensor.dims);
}
+/**
+ * Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size.
+ * @param {number[]} size A sequence of integers defining the shape of the output tensor.
+ * @returns {Tensor} The zeros tensor.
+ */
+export function zeros(size) {
+ return fullHelper(size, 0n, 'int64', BigInt64Array);
+}
+
+/**
+ * Returns a tensor filled with the scalar value 0, with the same size as input.
+ * @param {Tensor} tensor The size of input will determine size of the output tensor.
+ * @returns {Tensor} The zeros tensor.
+ */
+export function zeros_like(tensor) {
+ return zeros(tensor.dims);
+}
+
/**
* Quantizes the embeddings tensor to binary or unsigned binary precision.
* @param {Tensor} tensor The tensor to quantize.
diff --git a/tests/configs.test.js b/tests/configs.test.js
index 8cdfe28c7..f66a8a887 100644
--- a/tests/configs.test.js
+++ b/tests/configs.test.js
@@ -1,25 +1,23 @@
-
-
-import { AutoConfig, env } from '../src/transformers.js';
-import { getFile } from '../src/utils/hub.js';
-import { m } from './init.js';
+import { AutoConfig, env } from "../src/transformers.js";
+import { getFile } from "../src/utils/hub.js";
// Initialise the testing environment
-env.allowLocalModels=false;
-env.useFSCache=false;
-
-// Load test data generated by the python tests
-// TODO do this dynamically?
-let testsData = await (await getFile('./tests/data/config_tests.json')).json()
-
-describe('Configs', () => {
+env.allowLocalModels = false;
+env.useFSCache = false;
- for (let [configName, targetConfig] of Object.entries(testsData)) {
+const TEST_DATA = {
+ "Xenova/bert-base-uncased": {
+ model_type: "bert",
+ },
+};
- it(configName, async () => {
- let config = await AutoConfig.from_pretrained(m(configName));
- expect(config.model_type).toEqual(targetConfig.model_type);
- expect(config.is_encoder_decoder).toEqual(targetConfig.is_encoder_decoder);
- });
- }
+describe("Configs", () => {
+ for (const [model_id, minimal_config] of Object.entries(TEST_DATA)) {
+ it(model_id, async () => {
+ const config = await AutoConfig.from_pretrained(model_id);
+ for (const [key, value] of Object.entries(minimal_config)) {
+ expect(config[key]).toEqual(value);
+ }
+ });
+ }
});
diff --git a/tests/data/.gitignore b/tests/data/.gitignore
deleted file mode 100644
index 5b8e8d398..000000000
--- a/tests/data/.gitignore
+++ /dev/null
@@ -1,3 +0,0 @@
-# Folder to store generated test data
-# Do not commit these files to the repository
-*.json
diff --git a/tests/generate_tests.py b/tests/generate_tests.py
deleted file mode 100644
index 3f103778b..000000000
--- a/tests/generate_tests.py
+++ /dev/null
@@ -1,467 +0,0 @@
-# Helper file to dynamically generate unit tests
-# This is done by running the python Transformers library and comparing its outputs with ours.
-
-import json
-import os
-from itertools import product
-
-from transformers import AutoTokenizer, AutoConfig
-import numpy as np
-
-from scripts.supported_models import SUPPORTED_MODELS
-
-# List of tokenizers where the model isn't yet supported, but the tokenizer is
-ADDITIONAL_TOKENIZERS_TO_TEST = {
- 'falcon': [
- 'tiiuae/falcon-7b',
- ],
- "llama": [
- 'hf-internal-testing/llama-tokenizer', # Special tokens: normalized=true
- 'Xenova/llama2-tokenizer', # Special tokens: normalized=false
- 'Xenova/llama2-chat-tokenizer', # Special tokens: normalized=false
- 'hf-internal-testing/llama-code-tokenizer',
-
- # TODO: add back when llama tests are fixed
- # 'Xenova/llama3-tokenizer-new', # PostProcessor type: Sequence
- ],
- 'mpt': [
- 'mosaicml/mpt-7b',
- ],
- 't5': [
- # TODO: Add back when https://github.com/huggingface/transformers/issues/26318 is fixed
- # 'Xenova/t5-tokenizer-new',
- ],
- 'bert': [
- # Uses `Whitespace` pretokenizer
- 'Xenova/jina-embeddings-v2-base-zh-tokenizer',
- ],
- 'qwen2': [
- # Uses a pretokenizer regex which is not compatible with JavaScript.
- 'Qwen/Qwen1.5-0.5B-Chat',
- ],
- 'gemma': [
- 'Xenova/gemma-tokenizer',
- ],
-}
-
-MODELS_TO_IGNORE = [
- # TODO: remove when https://github.com/huggingface/tokenizers/issues/251 is fixed
- 'xlm',
-
- # TODO: remove when https://github.com/huggingface/transformers/issues/26018 is fixed
- 'marian',
-
- # TODO: remove when https://github.com/huggingface/transformers/issues/26547 is fixed
- 'speecht5',
-
- # TODO: remove when https://github.com/huggingface/transformers/pull/26522 is merged
- 'siglip',
-
- # TODO: remove when https://github.com/huggingface/transformers/issues/28164 is fixed
- 'roformer',
-
- # TODO: remove when https://github.com/huggingface/transformers/issues/28173 is fixed. Issues include:
- # - decoding with `skip_special_tokens=True`.
- # - interspersing the pad token is broken.
- 'vits',
-]
-
-TOKENIZERS_TO_IGNORE = [
- # TODO: remove when https://github.com/huggingface/transformers/pull/25478 is merged
- 'facebook/m2m100_418M',
-
- # TODO: remove when https://github.com/huggingface/transformers/issues/28096 is addressed
- 'RajuKandasamy/tamillama_tiny_30m',
-
- # Requires `trust_remote_code`
- 'monologg/kobert',
-]
-
-MAX_TESTS = {
- 'marian': 10,
-}
-
-TOKENIZER_TEST_DATA = {
- "shared": [
- "hello world",
- "Hello World",
- "How are you doing?",
- "You should've done this",
- "A\n'll !!to?'d''d of, can't.",
- "def main():\n\tpass",
- "This\n\nis\na\ntest.",
- "let a = obj.toString();\ntoString();",
- 'Hi Hello',
- "trailing space ",
- " leading space",
- "生活的真谛是",
- "The company was founded in 2016.",
- "test $1 R2 #3 €4 £5 ¥6 ₣7 ₹8 ₱9 test",
- "I bought an apple for $1.00 at the store.",
- "you… ",
- "\u0079\u006F\u0075\u2026\u00A0\u00A0",
- "\u0079\u006F\u0075\u2026\u00A0\u00A0\u0079\u006F\u0075\u2026\u00A0\u00A0",
- "▁This ▁is ▁a ▁test ▁.",
- "weird \uFF5E edge \uFF5E case",
-
- # SentencePiece-specific test cases
- "\n",
- " test ",
- "test",
-
- # Control characters
- "1\u00002\uFFFD3",
- ],
- "custom_by_model_type": {
- "llama": [
- # Additional test-cases for the Llama tokenizer, adapted from
- # https://github.com/belladoreai/llama-tokenizer-js/blob/master/llama-tokenizer.js#L381-L452
- "grabbed",
- " grabbed",
- " grabbed",
- "\n",
- " \n",
- " tabs out here",
- "\n\t\n",
- "ax\n####\nboo",
- "镇",
- "🦙",
- "🦙Ꙋ",
- "Ꙋ🦙",
- "The llama (/ˈlɑːmə/; 🦙Spanish pronunciation: [ˈʎama]) (Lama glama) is a domesticated South American " \
- "camelid, widely used as a meat and pack animal by Andean cultures since the Pre-Columbian era. Llamas " \
- "are social animals and live with others as a herd. Their wool is soft and contains only a small " \
- "amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they " \
- "can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the " \
- "past also spelled \"lama\" or \"glama\") was adopted by European settlers from native Peruvians.[4] " \
- "The ancestors of llamas are thought to have originated from the Great Plains of North America about " \
- "40 million years ago, and subsequently migrated to South America about three million years ago during " \
- "the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were " \
- "extinct in North America.[3] As of 2007, there were over seven million llamas and alpacas in South " \
- "America and over 158,000 llamas and 100,000Ꙋ🦙 alpacas, descended from progenitors imported late in " \
- "the 20th century, in the United States and Canada.[5] In Aymara mythology, llamas are important beings. " \
- "The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to " \
- "Aymara eschatology, llamas will return to the water springs and lagoons where they come from at the " \
- "end of time.[6]",
- ],
-
- "vits": [
- "abcdefghijklmnopqrstuvwxyz01234567890",
- # Special treatment of characters in certain language
- "ț ţ",
- ],
-
- "qwen2": [
- "i'm i'M i've i've i'Ve i'vE i'VE",
- ],
- },
- "custom": {
- "facebook/blenderbot_small-90M": [
- # Test special tokens
- "__start__hello world__end__",
- # The original (python) tokenizer simply joins by spaces (regardless of special tokens or not)
- "__start__ hey __end__" # --> ... --> "__start__ hey __end__"
- "__start__hey __end__" # --> ... --> "__start__ hey __end__"
- ],
- "tiiuae/falcon-7b": [
- "12 and 123 and 1234", # Special case for splitting on 3 numbers
- ],
- "InstaDeepAI/nucleotide-transformer-500m-human-ref": [
- # Actual protein sequences
- "ATTCCGATTCCGATTCCG",
- "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT",
-
- # Special tokens
- "",
- ],
-
- "distil-whisper/distil-small.en": [
- " <|startoftranscript|> <|en|> ", # Tests lstrip+rstrip
- ],
-
- "Xenova/t5-tokenizer-new": [
- # Tests the new T5 tokenizer, which uses a different prepend_scheme for its pre_tokenizer:
- # tokenizer._tokenizer.pre_tokenizer = Metaspace(add_prefix_space = True, replacement = "▁", prepend_scheme = "first")
- # See https://github.com/huggingface/transformers/pull/26678 for more information.
- # - Old (incorrect): ['▁Hey', '▁', '', '▁', '.', '▁how', '▁are', '▁you']
- # - New (correct): ['▁Hey', '▁', '', '.', '▁how', '▁are', '▁you']
- "Hey . how are you",
- ],
- },
-}
-
-TOKENIZER_TEXT_PAIR_TEST_DATA = [
- {
- 'text': 'a',
- 'text_pair': 'b'
- },
- {
- 'text': 'a b',
- 'text_pair': 'c d e'
- },
- {
- 'text': ['a b c', 'd'],
- 'text_pair': ['e f', 'g h'],
- },
- {
- 'text': ['a', 'b c', 'd e f'],
- 'text_pair': ['g h i', 'j k', 'l'],
- }
-]
-
-CHAT_MESSAGES_EXAMPLES = {
- 'basic': [
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
- {"role": "user", "content": "I'd like to show off how chat templating works!"},
- ],
-
- 'system': [
- {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"},
- {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
- ],
-
- 'system + assistant': [
- {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"},
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
- {"role": "user", "content": "I'd like to show off how chat templating works!"},
- ],
-}
-
-TOKENIZERS_WITH_CHAT_TEMPLATES = {
- # https://huggingface.co/docs/transformers/main/en/chat_templating
- 'Xenova/mistral-tokenizer-v1': [
- 'basic',
- ],
-
- 'HuggingFaceH4/zephyr-7b-beta': [
- 'system',
- ],
-
- 'Xenova/llama2-chat-tokenizer': [
- 'basic',
- 'system',
- 'system + assistant',
- ],
-}
-
-
-FLATTENED_SUPPORTED_MODELS = [
- (model_type, [
- model for task_models in tasks.values() for model in task_models
- ]) for model_type, tasks in SUPPORTED_MODELS.items()
-]
-
-
-def generate_tokenizer_tests():
-
- tokenization_results = {}
-
- tokenizers_to_test = FLATTENED_SUPPORTED_MODELS + \
- list(ADDITIONAL_TOKENIZERS_TO_TEST.items())
-
- for model_type, tokenizer_names in tokenizers_to_test:
- if model_type in MODELS_TO_IGNORE:
- continue
- if model_type in MAX_TESTS:
- tokenizer_names = tokenizer_names[:MAX_TESTS[model_type]]
-
- custom_by_model_type_texts = TOKENIZER_TEST_DATA["custom_by_model_type"].get(
- model_type, [])
-
- print(f'Generating tests for {model_type}')
- for tokenizer_name in tokenizer_names:
- if tokenizer_name in TOKENIZERS_TO_IGNORE:
- continue
-
- print(' -', tokenizer_name)
-
- try:
- # Load tokenizer
- if model_type == 'llama':
- # As of 17/12/2023, there are a few issues with the Llama tokenizers in transformers.
- # (1) Encoding with fast tokenizer adds whitespace after special tokens:
- # - https://github.com/huggingface/transformers/issues/25881
- # - https://github.com/huggingface/transformers/issues/26318
- # - https://github.com/huggingface/transformers/issues/26455
- # - https://github.com/huggingface/transformers/issues/27544
- # (2) Decoding with slow tokenizer adds whitespace after special tokens:
- # - https://github.com/huggingface/transformers/issues/25073
- #
- # So for now, we mix and match the tokenizers:
- # i.e., use the fast tokenizer for encoding, and the slow tokenizer for decoding.
- # TODO: remove when the above issues are fixed:
- tokenizer = AutoTokenizer.from_pretrained(
- tokenizer_name,
- use_fast=False,
- )
- decoder_tokenizer = AutoTokenizer.from_pretrained(
- tokenizer_name,
- use_fast=True,
- )
-
- else:
- decoder_tokenizer = tokenizer = AutoTokenizer.from_pretrained(
- tokenizer_name)
-
- except (KeyError, EnvironmentError):
- # If a KeyError/EnvironmentError is raised from the AutoTokenizer, it
- # means the model does not use a tokenizer (e.g., vision models)
- continue
-
- try:
- # Disable dropout, if the model allows it
- tokenizer.backend_tokenizer.model.dropout = 0
- except AttributeError:
- pass
-
- tokenizer_results = []
-
- for data in TOKENIZER_TEXT_PAIR_TEST_DATA:
- try:
- output = tokenizer(**data).data
- except Exception:
- # Ignore testing tokenizers which fail in the python library
- continue
- tokenizer_results.append(dict(
- input=data,
- output=output,
- ))
-
- shared_texts = TOKENIZER_TEST_DATA["shared"]
- custom_texts = TOKENIZER_TEST_DATA["custom"].get(
- tokenizer_name, [])
-
- # Run tokenizer on test cases
- for text in shared_texts + custom_texts + custom_by_model_type_texts:
- try:
- encoded = tokenizer(text).data
- except Exception:
- # Ignore testing tokenizers which fail in the python library
- continue
-
- decoded_with_special = decoder_tokenizer.decode(
- encoded["input_ids"], skip_special_tokens=False)
- decoded_without_special = decoder_tokenizer.decode(
- encoded["input_ids"], skip_special_tokens=True)
-
- tokenizer_results.append(dict(
- input=text,
- encoded=encoded,
- decoded_with_special=decoded_with_special,
- decoded_without_special=decoded_without_special,
- ))
-
- if tokenizer_results:
- tokenization_results[tokenizer_name] = tokenizer_results
-
- template_results = {}
-
- for tokenizer_id in TOKENIZERS_WITH_CHAT_TEMPLATES:
- print(f'Generating chat templates for {tokenizer_id}')
- tokenizer = AutoTokenizer.from_pretrained(
- tokenizer_id,
-
- # TODO: Remove once https://github.com/huggingface/transformers/pull/26678 is fixed
- use_fast='llama' not in tokenizer_id,
- )
- tokenizer_results = []
- for key in TOKENIZERS_WITH_CHAT_TEMPLATES[tokenizer_id]:
- messages = CHAT_MESSAGES_EXAMPLES[key]
-
- for add_generation_prompt, tokenize in product([True, False], [True, False]):
- tokenizer_results.append(dict(
- messages=messages,
- add_generation_prompt=add_generation_prompt,
- tokenize=tokenize,
- target=tokenizer.apply_chat_template(
- messages,
- add_generation_prompt=add_generation_prompt,
- tokenize=tokenize,
- ),
- ))
-
- template_results[tokenizer_id] = tokenizer_results
-
- return dict(
- tokenization=tokenization_results,
- templates=template_results,
- )
-
-
-def generate_config_tests():
- results = {}
- for model_type, config_names in FLATTENED_SUPPORTED_MODELS:
- print(f'Generating tests for {model_type}')
-
- for config_name in config_names:
- print(' -', config_name)
- try:
- # Load config
- config = AutoConfig.from_pretrained(config_name)
- except Exception:
- # Something went wrong, skip this config
- continue
- results[config_name] = config.to_dict()
-
- # TODO: Remove after https://github.com/huggingface/transformers/issues/23876 fixed
- results[config_name].pop('torch_dtype', None)
-
- return results
-
-
-ARRAY_SIZES = sorted(set([2 ** i for i in range(1, 10)])
- | set([3 ** i for i in range(1, 8)])
- | set([5 ** i for i in range(1, 6)])
- | set([7 ** i for i in range(1, 4)]))
-
-
-def serialize_complex_array(arr):
- return [float(x) for y in arr for x in [y.real, y.imag]]
-
-
-def serialize_real_array(arr):
- return arr.tolist()
-
-
-def generate_fft_tests():
- np.random.seed(0)
- tests = {}
- for complex in [False, True]:
- serialize_fn = serialize_complex_array if complex else serialize_real_array
- for size in ARRAY_SIZES:
- arr = np.random.randn(size).astype(
- np.complex64 if complex else np.float64)
- if complex:
- arr += np.random.randn(size) * 1j
- tests[f"fft_{size}_{'complex' if complex else 'real'}"] = {
- "complex": complex,
- "input": serialize_fn(arr),
- "output": serialize_complex_array(np.fft.fft(arr)),
- }
- return tests
-
-
-def main():
- # TODO add option to cache generated data + force build tests
-
- data_dir = os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "data",
- )
-
- tokenizer_tests = generate_tokenizer_tests()
- with open(os.path.join(data_dir, "tokenizer_tests.json"), "w", encoding="utf-8") as fp:
- json.dump(tokenizer_tests, fp)
-
- config_tests = generate_config_tests()
- with open(os.path.join(data_dir, "config_tests.json"), "w", encoding="utf-8") as fp:
- json.dump(config_tests, fp)
-
- fft_tests = generate_fft_tests()
- with open(os.path.join(data_dir, "fft_tests.json"), "w", encoding="utf-8") as fp:
- json.dump(fft_tests, fp)
-
-
-if __name__ == "__main__":
- main()
diff --git a/tests/generation.test.js b/tests/generation.test.js
deleted file mode 100644
index da50388aa..000000000
--- a/tests/generation.test.js
+++ /dev/null
@@ -1,173 +0,0 @@
-
-import { pipeline } from '../src/transformers.js';
-import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js';
-
-// Initialise the testing environment
-init();
-
-describe('Generation parameters', () => {
-
- // List all models which will be tested
- const models = [
- 'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
- 'MBZUAI/LaMini-GPT-124M', // decoder-only
-
- 'Xenova/llama2.c-stories15M', // decoder-only
- ];
-
- // encoder-decoder model
- it(models[0], async () => {
- const text = 'how can I become more healthy?';
-
- const generator = await pipeline('text2text-generation', m(models[0]));
-
- // default
- // NOTE: Since `max_length` defaults to 20, this case also tests that.
- {
- const outputs = await generator(text);
-
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- expect(tokens.length).toEqual(20);
- }
-
- // max_new_tokens
- {
- // NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 64 tokens are generated.
- // So, the following tests are valid.
- const MAX_NEW_TOKENS = 20;
- const outputs = await generator(text, {
- max_new_tokens: MAX_NEW_TOKENS,
- });
-
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- expect(tokens.length).toEqual(MAX_NEW_TOKENS + 1); // + 1 due to forced BOS token
- }
-
- // min_length
- {
- // NOTE: Without setting `min_length` (but setting `max_new_tokens`), 64 tokens are generated.
- // So, the following tests are valid.
- const MAX_NEW_TOKENS = 128;
- const MIN_LENGTH = 65;
- const outputs = await generator(text, {
- max_new_tokens: MAX_NEW_TOKENS,
- min_length: MIN_LENGTH,
- });
-
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- expect(tokens.length).toBeGreaterThanOrEqual(MIN_LENGTH);
- }
-
- // min_new_tokens
- {
- // NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 64 tokens are generated.
- // So, the following tests are valid.
- const MAX_NEW_TOKENS = 128;
- const MIN_NEW_TOKENS = 65;
- const outputs = await generator(text, {
- max_new_tokens: MAX_NEW_TOKENS,
- min_new_tokens: MIN_NEW_TOKENS,
- });
-
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- expect(tokens.length).toBeGreaterThanOrEqual(MIN_NEW_TOKENS);
- }
-
- await generator.dispose();
-
- }, MAX_TEST_EXECUTION_TIME);
-
- // decoder-only model
- it(models[1], async () => {
- const text = "### Instruction:\nTrue or False: The earth is flat?\n\n### Response: ";
-
- const generator = await pipeline('text-generation', m(models[1]));
-
- // default
- // NOTE: Since `max_length` defaults to 20, this case also tests that.
- {
- const outputs = await generator(text);
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- expect(tokens.length).toEqual(20);
- }
-
- // max_new_tokens
- {
- const MAX_NEW_TOKENS = 20;
- const outputs = await generator(text, {
- max_new_tokens: MAX_NEW_TOKENS,
- });
- const promptTokens = generator.tokenizer.encode(text)
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- expect(tokens.length).toBeGreaterThan(promptTokens.length);
- }
-
- // min_length
- {
- // NOTE: Without setting `min_length` (but setting `max_new_tokens`), 22 tokens are generated.
- // So, the following tests are valid.
- const MAX_NEW_TOKENS = 10;
- const MIN_LENGTH = 25;
- const outputs = await generator(text, {
- max_new_tokens: MAX_NEW_TOKENS,
- min_length: MIN_LENGTH,
- });
-
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- expect(tokens.length).toBeGreaterThanOrEqual(MIN_LENGTH);
- }
-
- // min_new_tokens
- {
- // NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 22 tokens are generated.
- // So, the following tests are valid.
- const MAX_NEW_TOKENS = 32;
- const MIN_NEW_TOKENS = 10;
- const outputs = await generator(text, {
- max_new_tokens: MAX_NEW_TOKENS,
- min_new_tokens: MIN_NEW_TOKENS,
- });
-
- const tokens = generator.tokenizer.encode(outputs[0].generated_text)
- const promptTokens = generator.tokenizer.encode(text)
- expect(tokens.length).toBeGreaterThanOrEqual(promptTokens.length + MIN_NEW_TOKENS);
- }
-
- await generator.dispose();
-
- }, MAX_TEST_EXECUTION_TIME);
-
- // decoder-only model
- it(models[2], async () => {
- const MAX_NEW_TOKENS = 1;
-
- const text = [
- 'Once upon a time,',
- 'Lily',
- 'Suddenly,',
- ];
-
- const generator = await pipeline('text-generation', m(models[2]));
-
- { // return_full_text=false
- const output = await generator(text, {
- return_full_text: false,
- max_new_tokens: MAX_NEW_TOKENS,
- num_beams: 2,
- num_return_sequences: 2,
- });
- const lengths = output.flatMap(
- x => x.flatMap(
- y => generator.tokenizer.encode(y.generated_text.trim(), null, {
- add_special_tokens: false,
- }).length
- )
- ).every(x => x === MAX_NEW_TOKENS);
-
- expect(lengths).toBe(true);
- }
- await generator.dispose();
-
- }, MAX_TEST_EXECUTION_TIME);
-
-});
\ No newline at end of file
diff --git a/tests/hub.test.js b/tests/hub.test.js
deleted file mode 100644
index 76c0ad143..000000000
--- a/tests/hub.test.js
+++ /dev/null
@@ -1,35 +0,0 @@
-
-
-import { AutoModel, PreTrainedModel } from '../src/transformers.js';
-import { MAX_TEST_EXECUTION_TIME } from './init.js';
-
-// TODO: Set cache folder to a temp directory
-
-describe('Hub', () => {
-
- describe('Loading models', () => {
-
- it('should load a model from the local cache', async () => {
- // 1. Local model exists (doesn't matter about status of remote file since local is tried first)
- let model = await AutoModel.from_pretrained('t5-small');
- expect(model).toBeInstanceOf(PreTrainedModel);
- }, MAX_TEST_EXECUTION_TIME);
-
- it('should load a model from the remote cache', async () => {
- // 2. Local model doesn't exist, remote file exists
- // This tests that fallback functionality is working
- let model = await AutoModel.from_pretrained('Xenova/t5-small');
- expect(model).toBeInstanceOf(PreTrainedModel);
- }, MAX_TEST_EXECUTION_TIME);
-
- it('should fail to load a model', async () => {
- // 3. Local model doesn't exist, remote file doesn't exist
- // This tests that error handling is working.
- await expect(
- AutoModel.from_pretrained('Xenova/this-model-does-not-exist')
- ).rejects
- .toBeInstanceOf(Error);
- }, MAX_TEST_EXECUTION_TIME);
- });
-
-});
diff --git a/tests/init.js b/tests/init.js
index 2fcd8609e..65f079086 100644
--- a/tests/init.js
+++ b/tests/init.js
@@ -1,100 +1,64 @@
// Helper functions used when initialising the testing environment.
-
// Import Node typing utilities
import * as types from "node:util/types";
// Import onnxruntime-node's default backend
import { onnxruntimeBackend } from "onnxruntime-node/dist/backend";
-import ONNX_COMMON from "onnxruntime-common";
+import * as ONNX_COMMON from "onnxruntime-common";
+/**
+ * A workaround to define a new backend for onnxruntime, which
+ * will not throw an error when running tests with jest.
+ * For more information, see: https://github.com/jestjs/jest/issues/11864#issuecomment-1261468011
+ */
export function init() {
- // In rare cases (specifically when running unit tests with GitHub actions), possibly due to
- // a large number of concurrent executions, onnxruntime might fallback to use the WASM backend.
- // In this case, we set the number of threads to 1 to avoid errors like:
- // - `TypeError: The worker script or module filename must be an absolute path or a relative path starting with './' or '../'. Received "blob:nodedata:..."`
- ONNX_COMMON.env.wasm.numThreads = 1;
-
- // A workaround to define a new backend for onnxruntime, which
- // will not throw an error when running tests with jest.
- // For more information, see: https://github.com/jestjs/jest/issues/11864#issuecomment-1261468011
-
- let registerBackend = ONNX_COMMON.registerBackend;
-
- // Define the constructors to monkey-patch
- const TYPED_ARRAYS_CONSTRUCTOR_NAMES = [
- "Int8Array",
- "Int16Array",
- "Int32Array",
- "BigInt64Array",
- "Uint8Array",
- "Uint8ClampedArray",
- "Uint16Array",
- "Uint32Array",
- "BigUint64Array",
- "Float32Array",
- "Float64Array",
- ];
-
- // Keep a reference to the original initialization method
- const originalMethod = onnxruntimeBackend.init;
-
- // Monkey-patch the initialization function
- onnxruntimeBackend.init = function (...args) {
- // There is probably a better way to do this
- Array.isArray = x =>
- typeof x === "object" &&
- x !== null &&
- typeof x.length === "number" &&
- x?.constructor.toString() === Array.toString();
-
- // For each typed array constructor
- for (const ctorName of TYPED_ARRAYS_CONSTRUCTOR_NAMES) {
- // Get the constructor from the current context
- const ctor = global[ctorName];
-
- // Get the corresponding test function from the `util` module
- const value = types[`is${ctorName}`].bind(types);
-
- // Monkey-patch the constructor so "x instanceof ctor" returns "types[`is${ctorName}`](x)"
- Object.defineProperty(ctor, Symbol.hasInstance, {
- value,
- writable: false,
- configurable: false,
- enumerable: false,
- });
- }
-
- // Call the original method
- return originalMethod.apply(this, args);
- };
+ // In rare cases (specifically when running unit tests with GitHub actions), possibly due to
+ // a large number of concurrent executions, onnxruntime might fallback to use the WASM backend.
+ // In this case, we set the number of threads to 1 to avoid errors like:
+ // - `TypeError: The worker script or module filename must be an absolute path or a relative path starting with './' or '../'. Received "blob:nodedata:..."`
+ ONNX_COMMON.env.wasm.numThreads = 1;
+
+ let registerBackend = ONNX_COMMON.registerBackend;
+
+ // Define the constructors to monkey-patch
+ const TYPED_ARRAYS_CONSTRUCTOR_NAMES = ["Int8Array", "Int16Array", "Int32Array", "BigInt64Array", "Uint8Array", "Uint8ClampedArray", "Uint16Array", "Uint32Array", "BigUint64Array", "Float32Array", "Float64Array"];
+
+ // Keep a reference to the original initialization method
+ const originalMethod = onnxruntimeBackend.init;
+
+ // Monkey-patch the initialization function
+ onnxruntimeBackend.init = function (...args) {
+ // There is probably a better way to do this
+ Array.isArray = (x) => typeof x === "object" && x !== null && typeof x.length === "number" && x?.constructor.toString() === Array.toString();
+
+ // For each typed array constructor
+ for (const ctorName of TYPED_ARRAYS_CONSTRUCTOR_NAMES) {
+ // Get the constructor from the current context
+ const ctor = globalThis[ctorName];
+
+ // Get the corresponding test function from the `util` module
+ const value = types[`is${ctorName}`].bind(types);
+
+ // Monkey-patch the constructor so "x instanceof ctor" returns "types[`is${ctorName}`](x)"
+ Object.defineProperty(ctor, Symbol.hasInstance, {
+ value,
+ writable: true, // writable=true is necessary to overwrite the default implementation (and allow subsequent overwrites)
+ configurable: false,
+ enumerable: false,
+ });
+ }
- // Register the backend with the highest priority, so it is used instead of the default one
- registerBackend("test", onnxruntimeBackend, Number.POSITIVE_INFINITY);
+ // Call the original method
+ return originalMethod.apply(this, args);
+ };
+ // Register the backend with the highest priority, so it is used instead of the default one
+ registerBackend("test", onnxruntimeBackend, Number.POSITIVE_INFINITY);
}
+export const MAX_MODEL_LOAD_TIME = 10_000; // 10 seconds
+export const MAX_TEST_EXECUTION_TIME = 30_000; // 30 seconds
+export const MAX_MODEL_DISPOSE_TIME = 1_000; // 1 second
-export let m = x => x;
-if (process.env.TESTING_REMOTELY) {
- // Running in a remote environment where models are not present locally (e.g., GitHub actions).
-
- // In this case, we use the "test" models, under the following org/username:
- const TEST_USERNAME = 'Xenova';
-
- m = (name) => {
- // Split into parts: [username, model]
- let parts = name.split(/\/+/, 2);
- if (parts.length === 2) {
- // Replace username
- parts[0] = TEST_USERNAME;
- } else {
- // Add username
- parts.unshift(TEST_USERNAME);
- }
-
- return parts.join('/');
- }
-}
-
-export const MAX_TEST_EXECUTION_TIME = 60_000; // 60 seconds
+export const MAX_TEST_TIME = MAX_MODEL_LOAD_TIME + MAX_TEST_EXECUTION_TIME + MAX_MODEL_DISPOSE_TIME;
diff --git a/tests/maths.test.js b/tests/maths.test.js
deleted file mode 100644
index 3c00cfa26..000000000
--- a/tests/maths.test.js
+++ /dev/null
@@ -1,156 +0,0 @@
-
-import { compare } from './test_utils.js';
-
-import { getFile } from '../src/utils/hub.js';
-import { FFT, medianFilter, bankers_round, log_softmax } from '../src/utils/maths.js';
-
-
-const fft = (arr, complex = false) => {
- let output;
- let fft;
- if (complex) {
- fft = new FFT(arr.length / 2);
- output = new Float64Array(fft.outputBufferSize);
- fft.transform(output, arr);
- } else {
- fft = new FFT(arr.length);
- output = new Float64Array(fft.outputBufferSize);
- fft.realTransform(output, arr);
- }
- if (!fft.isPowerOfTwo) {
- output = output.slice(0, complex ? arr.length : 2 * arr.length);
- }
- return output;
-}
-
-const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json()
-
-describe('Mathematical operations', () => {
-
- describe('bankers rounding', () => {
- it('should round up to nearest even', () => {
- expect(bankers_round(-0.5)).toBeCloseTo(0);
- expect(bankers_round(1.5)).toBeCloseTo(2);
- expect(bankers_round(19.5)).toBeCloseTo(20);
- });
- it('should round down to nearest even', () => {
- expect(bankers_round(-1.5)).toBeCloseTo(-2);
- expect(bankers_round(2.5)).toBeCloseTo(2);
- expect(bankers_round(18.5)).toBeCloseTo(18);
- });
- });
-
- describe('median filtering', () => {
-
-
- it('should compute median filter', async () => {
- const t1 = new Float32Array([5, 12, 2, 6, 3, 10, 9, 1, 4, 8, 11, 7]);
- const window = 3;
-
- const target = new Float32Array([12, 5, 6, 3, 6, 9, 9, 4, 4, 8, 8, 11]);
-
- const output = medianFilter(t1, window);
- compare(output, target, 1e-3);
- });
-
-
- // TODO add tests for errors
- });
-
- describe('FFT', () => {
- // Should match output of numpy fft
- it('should compute real FFT for power of two', () => {
- { // size = 4
- // np.fft.fft([1,2,3,4]) == array([10.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
- const input = new Float32Array([1, 2, 3, 4]);
- const target = new Float32Array([10, 0, -2, 2, -2, 0, -2, -2]);
-
- const output = fft(input);
- compare(output, target, 1e-3);
- }
-
- { // size = 16
- // np.fft.fft([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
- // == array([136. +0.j , -8.+40.21871594j, -8.+19.3137085j ,
- // -8.+11.9728461j , -8. +8.j , -8. +5.3454291j ,
- // -8. +3.3137085j , -8. +1.59129894j, -8. +0.j ,
- // -8. -1.59129894j, -8. -3.3137085j , -8. -5.3454291j ,
- // -8. -8.j , -8.-11.9728461j , -8.-19.3137085j ,
- // -8.-40.21871594j])
- const input = new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
- const target = new Float32Array([136.0, 0.0, -8.0, 40.218715937006785, -8.0, 19.31370849898476, -8.0, 11.972846101323912, -8.0, 8.0, -8.0, 5.345429103354389, -8.0, 3.313708498984761, -8.0, 1.5912989390372658, -8.0, 0.0, -8.0, -1.5912989390372658, -8.0, -3.313708498984761, -8.0, -5.345429103354389, -8.0, -8.0, -8.0, -11.972846101323912, -8.0, -19.31370849898476, -8.0, -40.218715937006785]);
-
- const output = fft(input);
- compare(output, target, 1e-3);
- }
- });
-
- it('should compute real FFT for non-power of two', () => {
- { // size = 3
- // np.fft.fft([1,2,3]) == array([ 6. +0.j, -1.5+0.8660254j, -1.5-0.8660254j])
- const input = new Float32Array([1, 2, 3]);
- const target = new Float32Array([6, 0, -1.5, 0.8660254, -1.5, -0.8660254]);
-
- const output = fft(input);
- compare(output, target, 1e-3);
- }
- });
-
- it('should compute complex FFT for non-power of two', () => {
- { // size = 3
- // np.fft.fft([1+3j,2-2j,3+1j]) == array([ 6. +2.j, -4.09807621+4.3660254j, 1.09807621+2.6339746j])
- const input = new Float32Array([1, 3, 2, -2, 3, 1]);
- const target = new Float32Array([6, 2, -4.09807621, 4.3660254, 1.09807621, 2.6339746]);
-
- const output = fft(input, true);
- compare(output, target, 1e-3);
- }
- });
-
- it('should compute complex FFT for power of two', () => {
- { // size = 4
- // np.fft.fft([1+4j, 2-3j,3+2j, 4-1j]) == array([10. +2.j, -4. +4.j, -2.+10.j, 0. +0.j])
- const input = new Float32Array([1, 4, 2, -3, 3, 2, 4, -1]);
- const target = new Float32Array([10, 2, -4, 4, -2, 10, 0, 0]);
-
- const output = fft(input, true);
- compare(output, target, 1e-3);
- }
- });
- })
-
- describe('FFT (dynamic)', () => {
- // Should match output of numpy fft
- for (const [name, test] of Object.entries(fftTestsData)) {
- // if (test.input.length > 5) continue;
- it(name, () => {
- const output = fft(test.input, test.complex);
-
- if (output.map((v, i) => Math.abs(v - test.output[i])).some(v => v > 1e-4)) {
- console.log('input', test.input)
- console.log('output', output)
- console.log('target', test.output)
- }
- compare(output, test.output, 1e-4);
-
- });
- }
- });
-
- describe('log softmax', () => {
- // Should match output of scipy log_softmax
- it('should compute log softmax correctly for usual values', () => {
- const input = [0, 1, 2, 3];
- const expected = [-3.4401896985611953, -2.4401896985611953, -1.4401896985611953, -0.44018969856119533];
- const output = log_softmax(input);
- compare(output, expected, 1e-13);
- });
-
- it('should compute log softmax correctly for values with large differences', () => {
- const input = [1000, 1];
- const expected = [0, -999];
- const output = log_softmax(input);
- compare(output, expected, 1e-13);
- });
- });
-});
diff --git a/tests/models.test.js b/tests/models.test.js
index 126e10e1d..f1bc7961c 100644
--- a/tests/models.test.js
+++ b/tests/models.test.js
@@ -2,147 +2,129 @@
* Test that models loaded outside of the `pipeline` function work correctly (e.g., `AutoModel.from_pretrained(...)`);
*/
-import {
- AutoTokenizer,
- AutoModel,
- AutoProcessor,
+import { AutoTokenizer, AutoModel, AutoProcessor, BertModel, GPT2Model, T5ForConditionalGeneration, CLIPTextModelWithProjection, CLIPVisionModelWithProjection, BertTokenizer, GPT2Tokenizer, T5Tokenizer, RawImage } from "../src/transformers.js";
- BertModel,
- GPT2Model,
- T5Model,
- CLIPTextModelWithProjection,
- CLIPVisionModelWithProjection,
+import { init, MAX_TEST_EXECUTION_TIME } from "./init.js";
- BertTokenizer,
- GPT2Tokenizer,
- T5Tokenizer,
-
- RawImage,
-} from '../src/transformers.js';
-
-import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js';
-
-import { compare } from './test_utils.js';
+import { compare } from "./test_utils.js";
// Initialise the testing environment
init();
-describe('Models', () => {
-
- describe('Loading different architecture types', () => {
-
- // List all models which will be tested
- const models_to_test = [
- // [name, modelClass, tokenizerClass]
- ['bert-base-uncased', BertModel, BertTokenizer], // Encoder-only
- ['gpt2', GPT2Model, GPT2Tokenizer], // Decoder-only
- ['t5-small', T5Model, T5Tokenizer], // Encoder-decoder
- ];
-
- let texts = [
- 'Once upon a time',
- 'I like to eat apples',
- ];
-
- for (let [name, modelClass, tokenizerClass] of models_to_test) {
-
- // Test that both the auto model and the specific model work
- let tokenizers = [AutoTokenizer, tokenizerClass];
- let models = [AutoModel, modelClass];
-
- for (let i = 0; i < tokenizers.length; ++i) {
- const tokenizerClassToTest = tokenizers[i];
- const modelClassToTest = models[i];
-
- it(`${name} (${modelClassToTest.name})`, async () => {
- const model_id = m(name);
-
- // Load model and tokenizer
- let tokenizer = await tokenizerClassToTest.from_pretrained(model_id);
- let model = await modelClassToTest.from_pretrained(model_id);
-
- let tests = [
- texts[0], // single
- texts, // batched
- ]
- for (let test of tests) {
- let encodings = await tokenizer(test, { truncation: true, padding: true });
- let output = await model(encodings);
-
- if (output.logits) {
- // Ensure correct shapes
- let expected_shape = [...encodings.input_ids.dims, model.config.vocab_size];
- let actual_shape = output.logits.dims;
- compare(expected_shape, actual_shape);
- } else if (output.last_hidden_state) {
- let expected_shape = [...encodings.input_ids.dims, model.config.d_model];
- let actual_shape = output.last_hidden_state.dims;
- compare(expected_shape, actual_shape);
- } else {
- console.warn('Unexpected output', output);
- throw new Error('Unexpected output');
- }
-
- }
-
- await model.dispose();
-
- }, MAX_TEST_EXECUTION_TIME);
-
+describe("Models", () => {
+ describe("Loading different architecture types", () => {
+ // List all models which will be tested
+ const models_to_test = [
+ // [name, modelClass, tokenizerClass]
+ ["hf-internal-testing/tiny-random-BertForMaskedLM", BertModel, BertTokenizer], // Encoder-only
+ ["hf-internal-testing/tiny-random-GPT2LMHeadModel", GPT2Model, GPT2Tokenizer], // Decoder-only
+ ["hf-internal-testing/tiny-random-T5ForConditionalGeneration", T5ForConditionalGeneration, T5Tokenizer], // Encoder-decoder
+ ];
+
+ const texts = ["Once upon a time", "I like to eat apples"];
+
+ for (const [model_id, modelClass, tokenizerClass] of models_to_test) {
+ // Test that both the auto model and the specific model work
+ const tokenizers = [AutoTokenizer, tokenizerClass];
+ const models = [AutoModel, modelClass];
+
+ for (let i = 0; i < tokenizers.length; ++i) {
+ const tokenizerClassToTest = tokenizers[i];
+ const modelClassToTest = models[i];
+
+ it(
+ `${model_id} (${modelClassToTest.name})`,
+ async () => {
+ // Load model and tokenizer
+ const tokenizer = await tokenizerClassToTest.from_pretrained(model_id);
+ const model = await modelClassToTest.from_pretrained(model_id);
+
+ const tests = [
+ texts[0], // single
+ texts, // batched
+ ];
+ for (const test of tests) {
+ const inputs = await tokenizer(test, { truncation: true, padding: true });
+ if (model.config.is_encoder_decoder) {
+ inputs.decoder_input_ids = inputs.input_ids;
+ }
+ const output = await model(inputs);
+
+ if (output.logits) {
+ // Ensure correct shapes
+ const expected_shape = [...inputs.input_ids.dims, model.config.vocab_size];
+ const actual_shape = output.logits.dims;
+ compare(expected_shape, actual_shape);
+ } else if (output.last_hidden_state) {
+ const expected_shape = [...inputs.input_ids.dims, model.config.d_model];
+ const actual_shape = output.last_hidden_state.dims;
+ compare(expected_shape, actual_shape);
+ } else {
+ console.warn("Unexpected output", output);
+ throw new Error("Unexpected output");
+ }
}
- }
-
- });
-
- describe('Running specific models', () => {
- const models_to_test = [
- 'openai/clip-vit-base-patch16',
- ];
- it(`CLIP (text)`, async () => {
- const model_id = m(models_to_test[0]);
-
- // Load tokenizer and text model
- const tokenizer = await AutoTokenizer.from_pretrained(model_id);
- const text_model = await CLIPTextModelWithProjection.from_pretrained(model_id);
-
- // Run tokenization
- const texts = ['a photo of a car', 'a photo of a football match'];
- const text_inputs = tokenizer(texts, { padding: true, truncation: true });
-
- // Compute embeddings
- const { text_embeds } = await text_model(text_inputs);
-
- // Ensure correct shapes
- const expected_shape = [texts.length, text_model.config.projection_dim];
- const actual_shape = text_embeds.dims;
- compare(expected_shape, actual_shape);
-
- await text_model.dispose();
-
- }, MAX_TEST_EXECUTION_TIME);
-
- it(`CLIP (vision)`, async () => {
- const model_id = m(models_to_test[0]);
-
- // Load processor and vision model
- const processor = await AutoProcessor.from_pretrained(model_id);
- const vision_model = await CLIPVisionModelWithProjection.from_pretrained(model_id);
-
- // Read image and run processor
- const image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg');
- const image_inputs = await processor(image);
-
- // Compute embeddings
- const { image_embeds } = await vision_model(image_inputs);
-
- // Ensure correct shapes
- const expected_shape = [1, vision_model.config.projection_dim];
- const actual_shape = image_embeds.dims;
- compare(expected_shape, actual_shape);
-
- await vision_model.dispose();
-
- }, MAX_TEST_EXECUTION_TIME);
- });
+ await model.dispose();
+ },
+ MAX_TEST_EXECUTION_TIME,
+ );
+ }
+ }
+ });
+
+ describe("Running specific models", () => {
+ const models_to_test = ["hf-internal-testing/tiny-random-CLIPModel"];
+ it(
+ `CLIP (text)`,
+ async () => {
+ const model_id = models_to_test[0];
+
+ // Load tokenizer and text model
+ const tokenizer = await AutoTokenizer.from_pretrained(model_id);
+ const text_model = await CLIPTextModelWithProjection.from_pretrained(model_id, { revision: "refs/pr/5" });
+
+ // Run tokenization
+ const texts = ["a photo of a car", "a photo of a football match"];
+ const text_inputs = tokenizer(texts, { padding: true, truncation: true });
+
+ // Compute embeddings
+ const { text_embeds } = await text_model(text_inputs);
+
+ // Ensure correct shapes
+ const expected_shape = [texts.length, text_model.config.projection_dim];
+ const actual_shape = text_embeds.dims;
+ compare(expected_shape, actual_shape);
+
+ await text_model.dispose();
+ },
+ MAX_TEST_EXECUTION_TIME,
+ );
+
+ it(
+ `CLIP (vision)`,
+ async () => {
+ const model_id = models_to_test[0];
+
+ // Load processor and vision model
+ const processor = await AutoProcessor.from_pretrained(model_id);
+ const vision_model = await CLIPVisionModelWithProjection.from_pretrained(model_id, { revision: "refs/pr/5" });
+
+ // Read image and run processor
+ const image = await RawImage.read("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg");
+ const image_inputs = await processor(image);
+
+ // Compute embeddings
+ const { image_embeds } = await vision_model(image_inputs);
+
+ // Ensure correct shapes
+ const expected_shape = [1, vision_model.config.projection_dim];
+ const actual_shape = image_embeds.dims;
+ compare(expected_shape, actual_shape);
+
+ await vision_model.dispose();
+ },
+ MAX_TEST_EXECUTION_TIME,
+ );
+ });
});
diff --git a/tests/models/albert/tokenization.js b/tests/models/albert/tokenization.js
new file mode 100644
index 000000000..875bc418f
--- /dev/null
+++ b/tests/models/albert/tokenization.js
@@ -0,0 +1,183 @@
+import { AlbertTokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS, BERT_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = AlbertTokenizer;
+export const TEST_CONFIG = {
+ // - uses `StripAccents` normalizer
+ "Xenova/albert-base-v2": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["\u2581how", "\u2581are", "\u2581you", "\u2581doing", "?"],
+ ids: [2, 184, 50, 42, 845, 60, 3],
+ decoded: "[CLS] how are you doing?[SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["\u2581you", "\u2581should", "'", "ve", "\u2581done", "\u2581this"],
+ ids: [2, 42, 378, 22, 195, 677, 48, 3],
+ decoded: "[CLS] you should've done this[SEP]",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ tokens: ["\u25810", "12", "345", "67", "89", "\u25810", "\u25811", "\u25812", "\u25813", "\u25814", "\u25815", "\u25816", "\u25817", "\u25818", "\u25819", "\u258110", "\u2581100", "\u25811000"],
+ ids: [2, 713, 918, 21997, 4167, 3877, 713, 137, 172, 203, 268, 331, 400, 453, 469, 561, 332, 808, 6150, 3],
+ decoded: "[CLS] 0123456789 0 1 2 3 4 5 6 7 8 9 10 100 1000[SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["\u2581the", "\u2581company", "\u2581was", "\u2581founded", "\u2581in", "\u25812016", "."],
+ ids: [2, 14, 237, 23, 785, 19, 690, 9, 3],
+ decoded: "[CLS] the company was founded in 2016.[SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["\u2581a", "\u2581", "'", "ll", "\u2581", "!!", "to", "?'", "d", '"', "d", "\u2581of", ",", "\u2581can", "'", "t", "."],
+ ids: [2, 21, 13, 22, 211, 13, 19015, 262, 5663, 43, 7, 43, 16, 15, 92, 22, 38, 9, 3],
+ decoded: "[CLS] a 'll!!to?'d\"d of, can't.[SEP]",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["\u2581def", "\u2581main", "(", ")", ":", "\u2581pass"],
+ ids: [2, 6312, 407, 5, 6, 45, 1477, 3],
+ decoded: "[CLS] def main(): pass[SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["\u2581let", "\u2581a", "\u2581=", "\u2581ob", "j", ".", "to", "string", "(", ")", ";", "\u2581to", "string", "(", ")", ";"],
+ ids: [2, 408, 21, 800, 5122, 728, 9, 262, 11130, 5, 6, 73, 20, 11130, 5, 6, 73, 3],
+ decoded: "[CLS] let a = obj.tostring(); tostring();[SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["\u2581this", "\u2581is", "\u2581a", "\u2581test", "."],
+ ids: [2, 48, 25, 21, 1289, 9, 3],
+ decoded: "[CLS] this is a test.[SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["\u2581unwanted", ",", "running"],
+ ids: [2, 21095, 15, 11325, 3],
+ decoded: "[CLS] unwanted,running[SEP]",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["\u25811", "\u0000", "2", "\u25813"],
+ ids: [2, 137, 1, 135, 203, 3],
+ decoded: "[CLS] 12 3[SEP]",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["\u2581hello", "\u2581world"],
+ ids: [2, 10975, 126, 3],
+ decoded: "[CLS] hello world[SEP]",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["\u2581hello", "\u2581world"],
+ ids: [2, 10975, 126, 3],
+ decoded: "[CLS] hello world[SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u2581", "\u751f\u6d3b\u7684\u771f\u8c1b\u662f"],
+ ids: [2, 13, 1, 3],
+ decoded: "[CLS] [SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["\u2581leading", "\u2581space"],
+ ids: [2, 1005, 726, 3],
+ decoded: "[CLS] leading space[SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["\u2581trailing", "\u2581space"],
+ ids: [2, 14323, 726, 3],
+ decoded: "[CLS] trailing space[SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["\u2581hi", "\u2581hello"],
+ ids: [2, 4148, 10975, 3],
+ decoded: "[CLS] hi hello[SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["\u2581test", "\u2581$1", "\u2581r", "2", "\u2581#3", "\u2581", "\u20ac", "4", "\u2581", "\u00a3", "5", "\u2581", "\u00a5", "6", "\u2581", "\u20a3", "7", "\u2581", "\u20b9", "8", "\u2581", "\u20b1", "9", "\u2581test"],
+ ids: [2, 1289, 3742, 761, 135, 11489, 13, 12, 300, 13, 11, 264, 13, 1, 379, 13, 1, 465, 13, 1, 457, 13, 1, 518, 1289, 3],
+ decoded: "[CLS] test $1 r2 #3 \u20ac4 \u00a35 6 7 8 9 test[SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["\u2581i", "\u2581bought", "\u2581an", "\u2581apple", "\u2581for", "\u2581$1", ".", "00", "\u2581at", "\u2581the", "\u2581store", "."],
+ ids: [2, 31, 2448, 40, 4037, 26, 3742, 9, 2032, 35, 14, 1718, 9, 3],
+ decoded: "[CLS] i bought an apple for $1.00 at the store.[SEP]",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["\u2581you", ".", ".", "."],
+ ids: [2, 42, 9, 9, 9, 3],
+ decoded: "[CLS] you...[SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["\u2581you", ".", ".", "."],
+ ids: [2, 42, 9, 9, 9, 3],
+ decoded: "[CLS] you...[SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["\u2581you", ".", ".", ".", "\u2581you", ".", ".", "."],
+ ids: [2, 42, 9, 9, 9, 42, 9, 9, 9, 3],
+ decoded: "[CLS] you... you...[SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["\u2581weird", "\u2581", "~", "\u2581edge", "\u2581", "~", "\u2581case"],
+ ids: [2, 5455, 13, 1, 1407, 13, 1, 610, 3],
+ decoded: "[CLS] weird edge case[SEP]",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["\u2581this", "\u2581is", "\u2581a", "\u2581test", "\u2581", "."],
+ ids: [2, 48, 25, 21, 1289, 13, 9, 3],
+ decoded: "[CLS] this is a test.[SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\u2581", "\ud83d\ude02", "\u2581", "\ud83d\udc4d", "\u2581", "\ud83e\udd23", "\u2581", "\ud83d\ude0d", "\u2581", "\ud83d\ude2d", "\u2581", "\ud83c\udf89", "\u2581", "\ud83d\ude4f", "\u2581", "\ud83d\ude0a", "\u2581", "\ud83d\udd25", "\u2581", "\ud83d\ude01", "\u2581", "\ud83d\ude05", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\ude06", "\u2581", "\ud83d\udc4f", "\u2581", "\u2764", "\u2581", "\ud83d\udc9c", "\u2581", "\ud83d\udc9a", "\u2581", "\ud83d\udc97", "\u2581", "\ud83d\udc99", "\u2581", "\ud83d\udda4", "\u2581", "\ud83d\ude0e", "\u2581", "\ud83d\udc4c", "\u2581", "\ud83e\udd73", "\u2581", "\ud83d\udcaa", "\u2581", "\u2728", "\u2581", "\ud83d\udc49", "\u2581", "\ud83d\udc40", "\u2581", "\ud83d\udcaf", "\u2581", "\ud83c\udf88", "\u2581", "\ud83d\ude48", "\u2581", "\ud83d\ude4c", "\u2581", "\ud83d\udc80", "\u2581", "\ud83d\udc47", "\u2581", "\ud83d\udc4b", "\u2581", "\u2705", "\u2581", "\ud83c\udf81", "\u2581", "\ud83c\udf1e", "\u2581", "\ud83c\udf38", "\u2581", "\ud83d\udcb0"],
+ ids: [2, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 3],
+ decoded: "[CLS] [SEP]",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u2581", "\u2728", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\udc41", "\u2581", "\ud83d\udc71\ud83c\udffb", "\u2581", "\ud83d\udd75", "\u2581", "\u2642", "\u2581", "\ud83e\uddd9\ud83c\udffb", "\u2581", "\u2642", "\u2581", "\ud83d\udc68\ud83c\udffb", "\u2581", "\ud83c\udf3e", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83d\udc69", "\u2581", "\u2764", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc67", "\u2581", "\ud83d\udc66", "\u2581", "\ud83e\uddd1\ud83c\udffb", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1\ud83c\udffb", "\u2581", "\ud83c\udff4\udb40\udc67\udb40\udc62\udb40\udc65\udb40\udc6e\udb40\udc67\udb40\udc7f", "\u2581", "\ud83d\udc68\ud83c\udffb", "\u2581", "\u2764", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68\ud83c\udffc"],
+ ids: [2, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 13, 1, 3],
+ decoded: "[CLS] [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["\u2581", "ah", "\u535a\u63a8", "zz"],
+ ids: [2, 13, 1307, 1, 5092, 3],
+ decoded: "[CLS] ahzz[SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["\u2581hello"],
+ ids: [2, 10975, 3],
+ decoded: "[CLS] hello[SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["\u2581hello", "!", "how", "\u2581are", "\u2581you", "?"],
+ ids: [2, 10975, 187, 1544, 50, 42, 60, 3],
+ decoded: "[CLS] hello!how are you?[SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["\u2581hall", "o", "!", "how", "\u2581are", "\u2581you", "?"],
+ ids: [2, 554, 111, 187, 1544, 50, 42, 60, 3],
+ decoded: "[CLS] hallo!how are you?[SEP]",
+ },
+ },
+};
diff --git a/tests/models/all_tokenization_tests.js b/tests/models/all_tokenization_tests.js
new file mode 100644
index 000000000..00ec6d639
--- /dev/null
+++ b/tests/models/all_tokenization_tests.js
@@ -0,0 +1,22 @@
+export * as AlbertTokenizer from "./albert/tokenization.js";
+export * as BertTokenizer from "./bert/tokenization.js";
+export * as BlenderbotSmallTokenizer from "./blenderbot_small/tokenization.js";
+export * as BloomTokenizer from "./bloom/tokenization.js";
+export * as CLIPTokenizer from "./clip/tokenization.js";
+export * as DebertaV2Tokenizer from "./deberta-v2/tokenization.js";
+export * as DistilBertTokenizer from "./distilbert/tokenization.js";
+export * as EsmTokenizer from "./esm/tokenization.js";
+export * as FalconTokenizer from "./falcon/tokenization.js";
+export * as GPT2Tokenizer from "./gpt2/tokenization.js";
+export * as GemmaTokenizer from "./gemma/tokenization.js";
+export * as LlamaTokenizer from "./llama/tokenization.js";
+export * as M2M100Tokenizer from "./m2m_100/tokenization.js";
+export * as MPNetTokenizer from "./mpnet/tokenization.js";
+export * as NllbTokenizer from "./nllb/tokenization.js";
+export * as Qwen2Tokenizer from "./qwen2/tokenization.js";
+export * as RobertaTokenizer from "./roberta/tokenization.js";
+export * as T5Tokenizer from "./t5/tokenization.js";
+export * as VitsTokenizer from "./vits/tokenization.js";
+export * as Wav2Vec2CTCTokenizer from "./wav2vec2/tokenization.js";
+export * as WhisperTokenizer from "./whisper/tokenization.js";
+export * as XLMRobertaTokenizer from "./xlm-roberta/tokenization.js";
diff --git a/tests/models/bert/tokenization.js b/tests/models/bert/tokenization.js
new file mode 100644
index 000000000..54b253260
--- /dev/null
+++ b/tests/models/bert/tokenization.js
@@ -0,0 +1,1335 @@
+import { BertTokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS, BERT_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = BertTokenizer;
+export const TEST_CONFIG = {
+ "Xenova/bert-base-uncased": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["how", "are", "you", "doing", "?"],
+ ids: [101, 2129, 2024, 2017, 2725, 1029, 102],
+ decoded: "[CLS] how are you doing? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["you", "should", "'", "ve", "done", "this"],
+ ids: [101, 2017, 2323, 1005, 2310, 2589, 2023, 102],
+ decoded: "[CLS] you should've done this [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["the", "company", "was", "founded", "in", "2016", "."],
+ ids: [101, 1996, 2194, 2001, 2631, 1999, 2355, 1012, 102],
+ decoded: "[CLS] the company was founded in 2016. [SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["a", "'", "ll", "!", "!", "to", "?", "'", "d", "'", "'", "d", "of", ",", "can", "'", "t", "."],
+ ids: [101, 1037, 1005, 2222, 999, 999, 2000, 1029, 1005, 1040, 1005, 1005, 1040, 1997, 1010, 2064, 1005, 1056, 1012, 102],
+ decoded: "[CLS] a'll!! to?'d'' d of, can't. [SEP]",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["def", "main", "(", ")", ":", "pass"],
+ ids: [101, 13366, 2364, 1006, 1007, 1024, 3413, 102],
+ decoded: "[CLS] def main ( ) : pass [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "to", "##st", "##ring", "(", ")", ";", "to", "##st", "##ring", "(", ")", ";"],
+ ids: [101, 2292, 1037, 1027, 27885, 3501, 1012, 2000, 3367, 4892, 1006, 1007, 1025, 2000, 3367, 4892, 1006, 1007, 1025, 102],
+ decoded: "[CLS] let a = obj. tostring ( ) ; tostring ( ) ; [SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["this", "is", "a", "test", "."],
+ ids: [101, 2023, 2003, 1037, 3231, 1012, 102],
+ decoded: "[CLS] this is a test. [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["unwanted", ",", "running"],
+ ids: [101, 18162, 1010, 2770, 102],
+ decoded: "[CLS] unwanted, running [SEP]",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["123"],
+ ids: [101, 13138, 102],
+ decoded: "[CLS] 123 [SEP]",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["hello", "world"],
+ ids: [101, 7592, 2088, 102],
+ decoded: "[CLS] hello world [SEP]",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hello", "world"],
+ ids: [101, 7592, 2088, 102],
+ decoded: "[CLS] hello world [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "[UNK]", "\u7684", "\u771f", "[UNK]", "[UNK]"],
+ ids: [101, 1910, 100, 1916, 1921, 100, 100, 102],
+ decoded: "[CLS] \u751f [UNK] \u7684 \u771f [UNK] [UNK] [SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["leading", "space"],
+ ids: [101, 2877, 2686, 102],
+ decoded: "[CLS] leading space [SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["trailing", "space"],
+ ids: [101, 12542, 2686, 102],
+ decoded: "[CLS] trailing space [SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["hi", "hello"],
+ ids: [101, 7632, 7592, 102],
+ decoded: "[CLS] hi hello [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r", "##2", "#", "3", "\u20ac", "##4", "\u00a35", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [101, 3231, 1002, 1015, 1054, 2475, 1001, 1017, 1574, 2549, 27813, 1071, 2575, 100, 1576, 2620, 1575, 2683, 3231, 102],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 \u20b19 test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["i", "bought", "an", "apple", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 1045, 4149, 2019, 6207, 2005, 1002, 1015, 1012, 4002, 2012, 1996, 3573, 1012, 102],
+ decoded: "[CLS] i bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["you", "\u2026"],
+ ids: [101, 2017, 1529, 102],
+ decoded: "[CLS] you \u2026 [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["you", "\u2026"],
+ ids: [101, 2017, 1529, 102],
+ decoded: "[CLS] you \u2026 [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["you", "\u2026", "you", "\u2026"],
+ ids: [101, 2017, 1529, 2017, 1529, 102],
+ decoded: "[CLS] you \u2026 you \u2026 [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["weird", "\uff5e", "edge", "\uff5e", "case"],
+ ids: [101, 6881, 1995, 3341, 1995, 2553, 102],
+ decoded: "[CLS] weird \uff5e edge \uff5e case [SEP]",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "."],
+ ids: [101, 100, 100, 100, 100, 100, 1012, 102],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK]. [SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["ah", "\u535a", "[UNK]", "z", "##z"],
+ ids: [101, 6289, 1786, 100, 1062, 2480, 102],
+ decoded: "[CLS] ah \u535a [UNK] zz [SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["hello"],
+ ids: [101, 7592, 102],
+ decoded: "[CLS] hello [SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["hello", "!", "how", "are", "you", "?"],
+ ids: [101, 7592, 999, 2129, 2024, 2017, 1029, 102],
+ decoded: "[CLS] hello! how are you? [SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["hall", "##o", "!", "how", "are", "you", "?"],
+ ids: [101, 2534, 2080, 999, 2129, 2024, 2017, 1029, 102],
+ decoded: "[CLS] hallo! how are you? [SEP]",
+ },
+ ONLY_WHITESPACE: {
+ text: BASE_TEST_STRINGS.ONLY_WHITESPACE,
+ tokens: [],
+ ids: [101, 102],
+ decoded: "[CLS] [SEP]",
+ },
+
+ TEXT_PAIR: {
+ text: "hello",
+ text_pair: "world",
+ tokens: ["hello", "world"],
+ ids: [101, 7592, 102, 2088, 102],
+ decoded: "[CLS] hello [SEP] world [SEP]",
+ },
+ },
+ "Xenova/bert-base-cased": {
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "o", "##b", "##j", ".", "to", "##S", "##tring", "(", ")", ";", "to", "##S", "##tring", "(", ")", ";"],
+ ids: [101, 1519, 170, 134, 184, 1830, 3361, 119, 1106, 1708, 28108, 113, 114, 132, 1106, 1708, 28108, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. toString ( ) ; toString ( ) ; [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["UN", "##wan", "##t\u00e9", "##d", ",", "running"],
+ ids: [101, 7414, 5491, 14608, 1181, 117, 1919, 102],
+ decoded: "[CLS] UNwant\u00e9d, running [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "[UNK]", "[UNK]", "\u771f", "[UNK]", "[UNK]"],
+ ids: [101, 1056, 100, 100, 1061, 100, 100, 102],
+ decoded: "[CLS] \u751f [UNK] [UNK] \u771f [UNK] [UNK] [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "R", "##2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [101, 2774, 109, 122, 155, 1477, 108, 124, 836, 1527, 202, 1571, 203, 1545, 100, 838, 1604, 837, 1580, 2774, 102],
+ decoded: "[CLS] test $ 1 R2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 \u20b19 test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "bought", "an", "apple", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 146, 3306, 1126, 12075, 1111, 109, 122, 119, 3135, 1120, 1103, 2984, 119, 102],
+ decoded: "[CLS] I bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["weird", "[UNK]", "edge", "[UNK]", "case"],
+ ids: [101, 6994, 100, 2652, 100, 1692, 102],
+ decoded: "[CLS] weird [UNK] edge [UNK] case [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["ah", "[UNK]", "[UNK]", "z", "##z"],
+ ids: [101, 18257, 100, 100, 195, 1584, 102],
+ decoded: "[CLS] ah [UNK] [UNK] zz [SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["H", "##\u00e9", "##llo"],
+ ids: [101, 145, 2744, 6643, 102],
+ decoded: "[CLS] H\u00e9llo [SEP]",
+ },
+ },
+
+ "Xenova/bert-base-multilingual-cased-ner-hrl": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["How", "are", "you", "doing", "?"],
+ ids: [101, 14962, 10301, 13028, 30918, 136, 102],
+ decoded: "[CLS] How are you doing? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["You", "should", "'", "ve", "done", "this"],
+ ids: [101, 11065, 14819, 112, 10323, 20378, 10531, 102],
+ decoded: "[CLS] You should've done this [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["The", "company", "was", "founded", "in", "2016", "."],
+ ids: [101, 10117, 12100, 10134, 14078, 10106, 10255, 119, 102],
+ decoded: "[CLS] The company was founded in 2016. [SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["A", "'", "ll", "!", "!", "to", "?", "'", "d", "'", "'", "d", "of", ",", "can", "'", "t", "."],
+ ids: [101, 138, 112, 22469, 106, 106, 10114, 136, 112, 172, 112, 112, 172, 10108, 117, 10944, 112, 188, 119, 102],
+ decoded: "[CLS] A'll!! to?'d'' d of, can't. [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "to", "##S", "##trin", "##g", "(", ")", ";", "to", "##S", "##trin", "##g", "(", ")", ";"],
+ ids: [101, 13595, 169, 134, 17339, 10418, 119, 10114, 10731, 109163, 10240, 113, 114, 132, 10114, 10731, 109163, 10240, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. toString ( ) ; toString ( ) ; [SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["This", "is", "a", "test", "."],
+ ids: [101, 10747, 10124, 169, 15839, 119, 102],
+ decoded: "[CLS] This is a test. [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["UN", "##want", "##\u00e9d", ",", "running"],
+ ids: [101, 26578, 104216, 84193, 117, 18020, 102],
+ decoded: "[CLS] UNwant\u00e9d, running [SEP]",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["Hello", "World"],
+ ids: [101, 31178, 10315, 102],
+ decoded: "[CLS] Hello World [SEP]",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hell", "##o", "world"],
+ ids: [101, 61694, 10133, 11356, 102],
+ decoded: "[CLS] hello world [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "\u6d3b", "\u7684", "\u771f", "\u8c1b", "\u662f"],
+ ids: [101, 5600, 4978, 5718, 5769, 7378, 4380, 102],
+ decoded: "[CLS] \u751f \u6d3b \u7684 \u771f \u8c1b \u662f [SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["trail", "##ing", "space"],
+ ids: [101, 56559, 10230, 16199, 102],
+ decoded: "[CLS] trailing space [SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["Hi", "Hello"],
+ ids: [101, 20065, 31178, 102],
+ decoded: "[CLS] Hi Hello [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "R2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "[UNK]", "test"],
+ ids: [101, 15839, 109, 122, 94000, 108, 124, 1775, 11011, 201, 11166, 202, 11211, 100, 1776, 11396, 100, 15839, 102],
+ decoded: "[CLS] test $ 1 R2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 [UNK] test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "bought", "an", "app", "##le", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 146, 28870, 10151, 72894, 10284, 10142, 109, 122, 119, 11025, 10160, 10105, 13708, 119, 102],
+ decoded: "[CLS] I bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["you", "[UNK]"],
+ ids: [101, 13028, 100, 102],
+ decoded: "[CLS] you [UNK] [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["you", "[UNK]"],
+ ids: [101, 13028, 100, 102],
+ decoded: "[CLS] you [UNK] [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["you", "[UNK]", "you", "[UNK]"],
+ ids: [101, 13028, 100, 13028, 100, 102],
+ decoded: "[CLS] you [UNK] you [UNK] [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["wei", "##rd", "\uff5e", "edge", "\uff5e", "case"],
+ ids: [101, 86981, 12023, 10096, 30599, 10096, 13474, 102],
+ decoded: "[CLS] weird \uff5e edge \uff5e case [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["ah", "\u535a", "\u63a8", "z", "##z"],
+ ids: [101, 69863, 2684, 4163, 194, 10305, 102],
+ decoded: "[CLS] ah \u535a \u63a8 zz [SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["H", "##\u00e9l", "##lo"],
+ ids: [101, 145, 24817, 10715, 102],
+ decoded: "[CLS] H\u00e9llo [SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["He", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?"],
+ ids: [101, 10357, 82834, 10133, 106, 14796, 13491, 13672, 12022, 136, 102],
+ decoded: "[CLS] HeLLo! how Are yoU? [SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["H", "##\u00e4", "##LL", "##o", "!", "how", "Are", "yo", "##U", "?"],
+ ids: [101, 145, 11013, 82834, 10133, 106, 14796, 13491, 13672, 12022, 136, 102],
+ decoded: "[CLS] H\u00e4LLo! how Are yoU? [SEP]",
+ },
+ },
+ "Xenova/paraphrase-multilingual-MiniLM-L12-v2": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["\u2581How", "\u2581are", "\u2581you", "\u2581doing", "?"],
+ ids: [0, 11249, 621, 398, 20594, 32, 2],
+ decoded: " How are you doing? ",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["\u2581You", "\u2581should", "'", "ve", "\u2581done", "\u2581this"],
+ ids: [0, 2583, 5608, 25, 272, 16940, 903, 2],
+ decoded: " You should've done this ",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["\u2581The", "\u2581company", "\u2581was", "\u2581found", "ed", "\u2581in", "\u25812016."],
+ ids: [0, 581, 14380, 509, 14037, 297, 23, 6360, 2],
+ decoded: " The company was founded in 2016. ",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["\u2581A", "\u2581'", "ll", "\u2581!!", "to", "?", "'", "d", "''", "d", "\u2581of", ",", "\u2581can", "'", "t", "."],
+ ids: [0, 62, 242, 1181, 6506, 188, 32, 25, 71, 4765, 71, 111, 4, 831, 25, 18, 5, 2],
+ decoded: " A 'll!!to?'d''d of, can't. ",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["\u2581de", "f", "\u2581main", "(", "):", "\u2581pass"],
+ ids: [0, 8, 420, 5201, 132, 2077, 27875, 2],
+ decoded: " def main(): pass ",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["\u2581let", "\u2581a", "\u2581=", "\u2581ob", "j", ".", "to", "Str", "ing", "(", ");", "\u2581to", "Str", "ing", "(", ");"],
+ ids: [0, 2633, 10, 2203, 995, 170, 5, 188, 71713, 214, 132, 3142, 47, 71713, 214, 132, 3142, 2],
+ decoded: " let a = obj.toString(); toString(); ",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["\u2581This", "\u2581is", "\u2581a", "\u2581test", "."],
+ ids: [0, 3293, 83, 10, 3034, 5, 2],
+ decoded: " This is a test. ",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["\u2581UN", "wan", "t\u00e9", "d", ",", "run", "ning"],
+ ids: [0, 8274, 3206, 2312, 71, 4, 16428, 592, 2],
+ decoded: " UNwant\u00e9d,running ",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["\u25811", "\u0000", "2", "\u25813"],
+ ids: [0, 106, 3, 304, 138, 2],
+ decoded: " 12 3 ",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["\u2581Hello", "\u2581World"],
+ ids: [0, 35378, 6661, 2],
+ decoded: " Hello World ",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["\u2581hell", "o", "\u2581world"],
+ ids: [0, 33600, 31, 8999, 2],
+ decoded: " hello world ",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u2581", "\u751f\u6d3b\u7684", "\u771f", "\u8c1b", "\u662f"],
+ ids: [0, 6, 62668, 5364, 245875, 354, 2],
+ decoded: " \u751f\u6d3b\u7684\u771f\u8c1b\u662f ",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["\u2581leading", "\u2581space"],
+ ids: [0, 105207, 32628, 2],
+ decoded: " leading space ",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["\u2581trail", "ing", "\u2581space"],
+ ids: [0, 141037, 214, 32628, 2],
+ decoded: " trailing space ",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["\u2581Hi", "\u2581Hello"],
+ ids: [0, 2673, 35378, 2],
+ decoded: " Hi Hello ",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["\u2581test", "\u2581$1", "\u2581R", "2", "\u2581#3", "\u2581\u20ac", "4", "\u2581\u00a3", "5", "\u2581", "\u00a5", "6", "\u2581", "\u20a3", "7", "\u2581\u20b9", "8", "\u2581", "\u20b1", "9", "\u2581test"],
+ ids: [0, 3034, 38629, 627, 304, 111378, 2505, 617, 11762, 758, 6, 32389, 910, 6, 3, 966, 87316, 1019, 6, 247425, 1126, 3034, 2],
+ decoded: " test $1 R2 #3 \u20ac4 \u00a35 \u00a56 7 \u20b98 \u20b19 test ",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["\u2581I", "\u2581bought", "\u2581an", "\u2581apple", "\u2581for", "\u2581$", "1.00", "\u2581at", "\u2581the", "\u2581store", "."],
+ ids: [0, 87, 123997, 142, 108787, 100, 3650, 146533, 99, 70, 4343, 5, 2],
+ decoded: " I bought an apple for $1.00 at the store. ",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["\u2581you", "..."],
+ ids: [0, 398, 27, 2],
+ decoded: " you... ",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["\u2581you", "..."],
+ ids: [0, 398, 27, 2],
+ decoded: " you... ",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["\u2581you", "...", "\u2581you", "..."],
+ ids: [0, 398, 27, 398, 27, 2],
+ decoded: " you... you... ",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["\u2581weird", "\u2581", "\uff5e", "\u2581edge", "\u2581", "\uff5e", "\u2581case"],
+ ids: [0, 179459, 6, 6087, 121303, 6, 6087, 7225, 2],
+ decoded: " weird \uff5e edge \uff5e case ",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["\u2581This", "\u2581is", "\u2581a", "\u2581test", "\u2581", "."],
+ ids: [0, 3293, 83, 10, 3034, 6, 5, 2],
+ decoded: " This is a test. ",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\u2581", "\ud83d\ude02", "\u2581", "\ud83d\udc4d", "\u2581", "\ud83e\udd23", "\u2581", "\ud83d\ude0d", "\u2581", "\ud83d\ude2d", "\u2581", "\ud83c\udf89", "\u2581", "\ud83d\ude4f", "\u2581", "\ud83d\ude0a", "\u2581", "\ud83d\udd25", "\u2581", "\ud83d\ude01", "\u2581", "\ud83d\ude05", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\ude06", "\u2581", "\ud83d\udc4f", "\u2581\u2764", "\ufe0f", "\u2581", "\ud83d\udc9c", "\u2581", "\ud83d\udc9a", "\u2581", "\ud83d\udc97", "\u2581", "\ud83d\udc99", "\u2581", "\ud83d\udda4", "\u2581", "\ud83d\ude0e", "\u2581", "\ud83d\udc4c", "\u2581", "\ud83e\udd73", "\u2581", "\ud83d\udcaa", "\u2581", "\u2728", "\u2581", "\ud83d\udc49", "\u2581", "\ud83d\udc40", "\u2581", "\ud83d\udcaf", "\u2581", "\ud83c\udf88", "\u2581", "\ud83d\ude48", "\u2581", "\ud83d\ude4c", "\u2581", "\ud83d\udc80", "\u2581", "\ud83d\udc47", "\u2581", "\ud83d\udc4b", "\u2581", "\u2705", "\u2581", "\ud83c\udf81", "\u2581", "\ud83c\udf1e", "\u2581", "\ud83c\udf38", "\u2581", "\ud83d\udcb0"],
+ ids: [0, 6, 115114, 6, 118280, 6, 243385, 6, 84464, 6, 232773, 6, 243816, 6, 113612, 6, 82803, 6, 222326, 6, 201344, 6, 239569, 6, 243544, 6, 191876, 6, 243404, 49933, 15755, 6, 244233, 6, 244162, 6, 244181, 6, 243892, 6, 245820, 6, 161546, 6, 204811, 6, 3, 6, 238992, 6, 167474, 6, 120242, 6, 245561, 6, 244864, 6, 246144, 6, 244459, 6, 244703, 6, 246887, 6, 144400, 6, 246511, 6, 142325, 6, 244230, 6, 245559, 6, 243374, 6, 245200, 2],
+ decoded: " \ud83d\ude02 \ud83d\udc4d \ud83e\udd23 \ud83d\ude0d \ud83d\ude2d \ud83c\udf89 \ud83d\ude4f \ud83d\ude0a \ud83d\udd25 \ud83d\ude01 \ud83d\ude05 \ud83e\udd17 \ud83d\ude06 \ud83d\udc4f \u2764\ufe0f \ud83d\udc9c \ud83d\udc9a \ud83d\udc97 \ud83d\udc99 \ud83d\udda4 \ud83d\ude0e \ud83d\udc4c \ud83d\udcaa \u2728 \ud83d\udc49 \ud83d\udc40 \ud83d\udcaf \ud83c\udf88 \ud83d\ude48 \ud83d\ude4c \ud83d\udc80 \ud83d\udc47 \ud83d\udc4b \u2705 \ud83c\udf81 \ud83c\udf1e \ud83c\udf38 \ud83d\udcb0 ",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u2581", "\u2728", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\udc41", "\ufe0f", "\u2581", "\ud83d\udc71", "\ud83c\udffb", "\u2581", "\ud83d\udd75", "\u2581", "\u2642", "\ufe0f", "\u2581", "\ud83e\uddd9", "\ud83c\udffb", "\u2581", "\u2642", "\u2581", "\ud83d\udc68", "\ud83c\udffb", "\u2581", "\ud83c\udf3e", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83d\udc69", "\u2581\u2764", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc67", "\u2581", "\ud83d\udc66", "\u2581", "\ud83e\uddd1", "\ud83c\udffb", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1", "\ud83c\udffb", "\u2581", "\ud83c\udff4\udb40\udc67\udb40\udc62\udb40\udc65\udb40\udc6e\udb40\udc67\udb40\udc7f", "\u2581", "\ud83d\udc68", "\ud83c\udffb", "\u2581\u2764", "\ufe0f", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68", "\ud83c\udffc"],
+ ids: [0, 6, 167474, 6, 243544, 6, 246984, 15755, 6, 247201, 79500, 6, 248325, 6, 228250, 15755, 6, 3, 79500, 6, 228250, 6, 244314, 79500, 6, 246529, 6, 3, 6, 247443, 6, 3, 6, 244785, 49933, 6, 244960, 6, 244314, 6, 244785, 6, 244785, 6, 245719, 6, 246167, 6, 3, 79500, 6, 247443, 6, 3, 79500, 6, 3, 6, 244314, 79500, 49933, 15755, 6, 244960, 6, 244314, 239719, 2],
+ decoded: " \u2728 \ud83e\udd17 \ud83d\udc41\ufe0f \ud83d\udc71\ud83c\udffb \ud83d\udd75 \u2642\ufe0f \ud83c\udffb \u2642 \ud83d\udc68\ud83c\udffb \ud83c\udf3e \ud83e\udd1d \ud83d\udc69 \u2764 \ud83d\udc8b \ud83d\udc68 \ud83d\udc69 \ud83d\udc69 \ud83d\udc67 \ud83d\udc66 \ud83c\udffb \ud83e\udd1d \ud83c\udffb \ud83d\udc68\ud83c\udffb \u2764\ufe0f \ud83d\udc8b \ud83d\udc68\ud83c\udffc ",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["\u2581ah", "\u535a", "\u63a8", "zz"],
+ ids: [0, 1263, 11173, 10238, 13894, 2],
+ decoded: " ah\u535a\u63a8zz ",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["\u2581H\u00e9", "llo"],
+ ids: [0, 88064, 9284, 2],
+ decoded: " H\u00e9llo ",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["\u2581He", "LL", "o", "!", "how", "\u2581Are", "\u2581yo", "U", "?"],
+ ids: [0, 1529, 23708, 31, 38, 47251, 15901, 3005, 1062, 32, 2],
+ decoded: " HeLLo!how Are yoU? ",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["\u2581H\u00e4", "LL", "o", "!", "how", "\u2581Are", "\u2581yo", "U", "?"],
+ ids: [0, 28863, 23708, 31, 38, 47251, 15901, 3005, 1062, 32, 2],
+ decoded: " H\u00e4LLo!how Are yoU? ",
+ },
+ },
+ "Xenova/bert-base-multilingual-uncased-sentiment": {
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "tos", "##tri", "##ng", "(", ")", ";", "tos", "##tri", "##ng", "(", ")", ";"],
+ ids: [101, 12421, 143, 134, 15547, 10428, 119, 53564, 27711, 10422, 113, 114, 132, 53564, 27711, 10422, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. tostring ( ) ; tostring ( ) ; [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["un", "##wan", "##ted", ",", "running"],
+ ids: [101, 10119, 15134, 11894, 117, 16484, 102],
+ decoded: "[CLS] unwanted, running [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [101, 14084, 109, 122, 85583, 108, 124, 1329, 11124, 175, 11301, 177, 11325, 100, 1332, 11544, 1330, 11518, 14084, 102],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 \u20b19 test [SEP]",
+ },
+ },
+ "Xenova/multilingual-e5-small": {
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["\u2581trail", "ing", "\u2581space", "\u2581"],
+ ids: [0, 141037, 214, 32628, 6, 2],
+ decoded: " trailing space ",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["\u2581you", "...", "\u2581"],
+ ids: [0, 398, 27, 6, 2],
+ decoded: " you... ",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["\u2581you", "...", "\u2581"],
+ ids: [0, 398, 27, 6, 2],
+ decoded: " you... ",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["\u2581you", "...", "\u2581you", "...", "\u2581"],
+ ids: [0, 398, 27, 398, 27, 6, 2],
+ decoded: " you... you... ",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["\u2581He", "LL", "o", "!", "how", "\u2581Are", "\u2581yo", "U", "?", "\u2581"],
+ ids: [0, 1529, 23708, 31, 38, 47251, 15901, 3005, 1062, 32, 6, 2],
+ decoded: " HeLLo!how Are yoU? ",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["\u2581H\u00e4", "LL", "o", "!", "how", "\u2581Are", "\u2581yo", "U", "?", "\u2581"],
+ ids: [0, 28863, 23708, 31, 38, 47251, 15901, 3005, 1062, 32, 6, 2],
+ decoded: " H\u00e4LLo!how Are yoU? ",
+ },
+ },
+ "Xenova/bge-small-zh-v1.5": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["[UNK]", "are", "you", "doi", "##ng", "?"],
+ ids: [101, 100, 8995, 8357, 9962, 8291, 136, 102],
+ decoded: "[CLS] [UNK] are you doing? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["[UNK]", "sh", "##ould", "'", "ve", "don", "##e", "this"],
+ ids: [101, 100, 11167, 11734, 112, 12810, 9524, 8154, 8554, 102],
+ decoded: "[CLS] [UNK] should've done this [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["[UNK]", "company", "was", "f", "##ound", "##ed", "in", "2016", "."],
+ ids: [101, 100, 10007, 9947, 148, 11477, 8303, 8217, 8112, 119, 102],
+ decoded: "[CLS] [UNK] company was founded in 2016. [SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["[UNK]", "'", "ll", "!", "!", "to", "?", "'", "d", "'", "'", "d", "of", ",", "can", "'", "t", "."],
+ ids: [101, 100, 112, 10856, 106, 106, 8228, 136, 112, 146, 112, 112, 146, 8205, 117, 9109, 112, 162, 119, 102],
+ decoded: "[CLS] [UNK]'ll!! to?'d'' d of, can't. [SEP]",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["de", "##f", "main", "(", ")", ":", "pass"],
+ ids: [101, 8363, 8189, 9139, 113, 114, 131, 9703, 102],
+ decoded: "[CLS] def main ( ) : pass [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "[UNK]", "(", ")", ";", "[UNK]", "(", ")", ";"],
+ ids: [101, 9946, 143, 134, 12639, 8334, 119, 100, 113, 114, 132, 100, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. [UNK] ( ) ; [UNK] ( ) ; [SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["[UNK]", "is", "a", "test", "."],
+ ids: [101, 100, 8310, 143, 10060, 119, 102],
+ decoded: "[CLS] [UNK] is a test. [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["[UNK]", ",", "running"],
+ ids: [101, 100, 117, 11620, 102],
+ decoded: "[CLS] [UNK], running [SEP]",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["[UNK]", "[UNK]"],
+ ids: [101, 100, 100, 102],
+ decoded: "[CLS] [UNK] [UNK] [SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["le", "##ad", "##ing", "space"],
+ ids: [101, 8983, 8695, 8221, 9634, 102],
+ decoded: "[CLS] leading space [SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["t", "##rail", "##ing", "space"],
+ ids: [101, 162, 12783, 8221, 9634, 102],
+ decoded: "[CLS] trailing space [SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["[UNK]", "[UNK]"],
+ ids: [101, 100, 100, 102],
+ decoded: "[CLS] [UNK] [UNK] [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "[UNK]", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "[UNK]", "[UNK]", "test"],
+ ids: [101, 10060, 109, 122, 100, 108, 124, 359, 8159, 173, 8157, 175, 8158, 100, 100, 100, 10060, 102],
+ decoded: "[CLS] test $ 1 [UNK] # 3 \u20ac4 \u00a35 \u00a56 [UNK] [UNK] [UNK] test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["[UNK]", "bo", "##ugh", "##t", "an", "apple", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 100, 11059, 12667, 8165, 9064, 8350, 8330, 109, 122, 119, 8136, 8243, 8174, 8719, 119, 102],
+ decoded: "[CLS] [UNK] bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\ud83d\ude02", "\ud83d\udc4d", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "\ud83d\udd25", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "\ud83d\ude0e", "[UNK]", "[UNK]", "[UNK]", "\u2728", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [101, 8104, 8102, 100, 100, 100, 100, 100, 100, 8103, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 8105, 100, 100, 100, 501, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102],
+ decoded: "[CLS] \ud83d\ude02 \ud83d\udc4d [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] \ud83d\udd25 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] \ud83d\ude0e [UNK] [UNK] [UNK] \u2728 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u2728", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [101, 501, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102],
+ decoded: "[CLS] \u2728 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["[UNK]"],
+ ids: [101, 100, 102],
+ decoded: "[CLS] [UNK] [SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["[UNK]", "!", "how", "[UNK]", "[UNK]", "?"],
+ ids: [101, 100, 106, 9510, 100, 100, 136, 102],
+ decoded: "[CLS] [UNK]! how [UNK] [UNK]? [SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["[UNK]", "!", "how", "[UNK]", "[UNK]", "?"],
+ ids: [101, 100, 106, 9510, 100, 100, 136, 102],
+ decoded: "[CLS] [UNK]! how [UNK] [UNK]? [SEP]",
+ },
+ },
+ "Xenova/bge-base-zh-v1.5": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["how", "are", "you", "doi", "##ng", "?"],
+ ids: [101, 9510, 8995, 8357, 9962, 8291, 136, 102],
+ decoded: "[CLS] how are you doing? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["you", "sh", "##ould", "'", "ve", "don", "##e", "this"],
+ ids: [101, 8357, 11167, 11734, 112, 12810, 9524, 8154, 8554, 102],
+ decoded: "[CLS] you should've done this [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["the", "company", "was", "f", "##ound", "##ed", "in", "2016", "."],
+ ids: [101, 8174, 10007, 9947, 148, 11477, 8303, 8217, 8112, 119, 102],
+ decoded: "[CLS] the company was founded in 2016. [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["u", "##n", "##wan", "##ted", ",", "running"],
+ ids: [101, 163, 8171, 9951, 9255, 117, 11620, 102],
+ decoded: "[CLS] unwanted, running [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "[UNK]", "[UNK]", "test"],
+ ids: [101, 10060, 109, 122, 11345, 108, 124, 359, 8159, 173, 8157, 175, 8158, 100, 100, 100, 10060, 102],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] [UNK] [UNK] test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["i", "bo", "##ugh", "##t", "an", "apple", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 151, 11059, 12667, 8165, 9064, 8350, 8330, 109, 122, 119, 8136, 8243, 8174, 8719, 119, 102],
+ decoded: "[CLS] i bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\ud83d\ude02", "\ud83d\udc4d", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "\ud83d\udd25", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "\u2764", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "\ud83d\ude0e", "[UNK]", "[UNK]", "[UNK]", "\u2728", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [101, 8104, 8102, 100, 100, 100, 100, 100, 100, 8103, 100, 100, 100, 100, 100, 506, 100, 100, 100, 100, 100, 8105, 100, 100, 100, 501, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102],
+ decoded: "[CLS] \ud83d\ude02 \ud83d\udc4d [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] \ud83d\udd25 [UNK] [UNK] [UNK] [UNK] [UNK] \u2764 [UNK] [UNK] [UNK] [UNK] [UNK] \ud83d\ude0e [UNK] [UNK] [UNK] \u2728 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ },
+ "Xenova/indobert-base-p1": {
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["you", "sho", "##uld", "'", "ve", "don", "##e", "this"],
+ ids: [2, 3299, 9596, 15370, 30463, 28239, 4081, 30357, 5379, 3],
+ decoded: "[CLS] you should've done this [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["the", "company", "was", "found", "##ed", "in", "2016", "."],
+ ids: [2, 1002, 9105, 2738, 11009, 133, 48, 1538, 30470, 3],
+ decoded: "[CLS] the company was founded in 2016. [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "tos", "##trin", "##g", "(", ")", ";", "tos", "##trin", "##g", "(", ")", ";"],
+ ids: [2, 4734, 253, 30475, 559, 30372, 30470, 20498, 12448, 30365, 30464, 30465, 30473, 20498, 12448, 30365, 30464, 30465, 30473, 3],
+ decoded: "[CLS] let a = obj. tostring ( ) ; tostring ( ) ; [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["un", "##wan", "##te", "##d", ",", "running"],
+ ids: [2, 78, 1322, 3298, 30364, 30468, 22715, 3],
+ decoded: "[CLS] unwanted, running [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [2, 1, 1, 1, 1, 1, 1, 3],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["lead", "##ing", "space"],
+ ids: [2, 9196, 55, 14561, 3],
+ decoded: "[CLS] leading space [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r", "##2", "#", "3", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "test"],
+ ids: [2, 4243, 30460, 111, 56, 30378, 30459, 283, 1, 1, 1, 1, 1, 1, 4243, 3],
+ decoded: "[CLS] test $ 1 r2 # 3 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["i", "bo", "##ught", "an", "apple", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [2, 89, 1880, 25009, 223, 7761, 1548, 30460, 111, 30470, 4230, 117, 1002, 8052, 30470, 3],
+ decoded: "[CLS] i bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["wei", "##rd", "[UNK]", "edge", "[UNK]", "case"],
+ ids: [2, 27753, 12548, 1, 21418, 1, 13687, 3],
+ decoded: "[CLS] weird [UNK] edge [UNK] case [SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["hallo", "!", "how", "are", "you", "?"],
+ ids: [2, 19598, 30457, 11088, 5811, 3299, 30477, 3],
+ decoded: "[CLS] hallo! how are you? [SEP]",
+ },
+ },
+ "Xenova/spanbert-large-cased": {
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "o", "##b", "##j", ".", "to", "##st", "##ring", "(", ")", ";", "to", "##st", "##ring", "(", ")", ";"],
+ ids: [101, 1519, 170, 134, 184, 1830, 3361, 119, 1106, 2050, 3384, 113, 114, 132, 1106, 2050, 3384, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. tostring ( ) ; tostring ( ) ; [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r", "##2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [101, 2774, 109, 122, 187, 1477, 108, 124, 836, 1527, 202, 1571, 203, 1545, 100, 838, 1604, 837, 1580, 2774, 102],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 \u20b19 test [SEP]",
+ },
+ },
+ "Xenova/UMLSBert_ENG": {
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "obj", ".", "tos", "##tring", "(", ")", ";", "tos", "##tring", "(", ")", ";"],
+ ids: [2, 8894, 42, 32, 2473, 17, 22660, 23640, 11, 12, 30, 22660, 23640, 11, 12, 30, 3],
+ decoded: "[CLS] let a = obj. tostring ( ) ; tostring ( ) ; [SEP]",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["hel", "##lo", "world"],
+ ids: [2, 3018, 5368, 4517, 3],
+ decoded: "[CLS] hello world [SEP]",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hel", "##lo", "world"],
+ ids: [2, 3018, 5368, 4517, 3],
+ decoded: "[CLS] hello world [SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["hi", "hel", "##lo"],
+ ids: [2, 11245, 3018, 5368, 3],
+ decoded: "[CLS] hi hello [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "\u20a3", "##7", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [2, 2313, 7, 20, 9663, 6, 22, 528, 1017, 74, 1009, 76, 1018, 524, 1019, 531, 1011, 529, 1038, 2313, 3],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 \u20a37 \u20b98 \u20b19 test [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["we", "##ir", "##d", "\uff5e", "edge", "\uff5e", "case"],
+ ids: [2, 1802, 1753, 1022, 943, 9676, 943, 2632, 3],
+ decoded: "[CLS] weird \uff5e edge \uff5e case [SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["hel", "##lo"],
+ ids: [2, 3018, 5368, 3],
+ decoded: "[CLS] hello [SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["hel", "##lo", "!", "how", "are", "you", "?"],
+ ids: [2, 3018, 5368, 5, 2135, 1810, 17915, 34, 3],
+ decoded: "[CLS] hello! how are you? [SEP]",
+ },
+ },
+ "Xenova/SapBERT-from-PubMedBERT-fulltext": {
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "\u6d3b", "\u7684", "[UNK]", "[UNK]", "\u662f"],
+ ids: [2, 799, 776, 811, 1, 1, 731, 3],
+ decoded: "[CLS] \u751f \u6d3b \u7684 [UNK] [UNK] \u662f [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "[UNK]", "test"],
+ ids: [2, 2648, 8, 21, 7261, 7, 23, 281, 1006, 76, 1015, 78, 1016, 1, 282, 1025, 1, 2648, 3],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 [UNK] test [SEP]",
+ },
+ },
+ "Xenova/rubert-base-cased": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["How", "are", "you", "do", "##ing", "?"],
+ ids: [101, 15474, 10813, 13540, 10661, 7729, 166, 102],
+ decoded: "[CLS] How are you doing? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["You", "sh", "##oul", "##d", "'", "ve", "don", "##e", "this"],
+ ids: [101, 11577, 45942, 76143, 239, 118, 10835, 17450, 241, 11043, 102],
+ decoded: "[CLS] You should've done this [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["The", "comp", "##any", "was", "f", "##ound", "##ed", "in", "2016", "."],
+ ids: [101, 6821, 71382, 17927, 10646, 242, 71129, 7491, 10618, 8273, 132, 102],
+ decoded: "[CLS] The company was founded in 2016. [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "to", "##St", "##ring", "(", ")", ";", "to", "##St", "##ring", "(", ")", ";"],
+ ids: [101, 14107, 232, 162, 17851, 251, 132, 10626, 21568, 13647, 120, 122, 158, 10626, 21568, 13647, 120, 122, 158, 102],
+ decoded: "[CLS] let a = obj. toString ( ) ; toString ( ) ; [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["UN", "##wan", "##t", "##\u00e9d", ",", "run", "##ning"],
+ ids: [101, 27090, 14906, 271, 84705, 128, 14607, 11781, 102],
+ decoded: "[CLS] UNwant\u00e9d, running [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "\u6d3b", "\u7684", "\u771f", "[UNK]", "\u662f"],
+ ids: [101, 6104, 5480, 6222, 6273, 100, 4877, 102],
+ decoded: "[CLS] \u751f \u6d3b \u7684 \u771f [UNK] \u662f [SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["le", "##ading", "sp", "##ace"],
+ ids: [101, 10653, 73130, 33162, 13967, 102],
+ decoded: "[CLS] leading space [SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["tra", "##ili", "##ng", "sp", "##ace"],
+ ids: [101, 11776, 14296, 10888, 33162, 13967, 102],
+ decoded: "[CLS] trailing space [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "bo", "##ught", "an", "app", "##le", "for", "$", "1", ".", "00", "at", "the", "st", "##ore", "."],
+ ids: [101, 186, 21018, 53718, 10663, 73406, 7159, 10654, 112, 138, 132, 11537, 10672, 10617, 28668, 13536, 132, 102],
+ decoded: "[CLS] I bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["we", "##ird", "\uff5e", "ed", "##ge", "\uff5e", "cas", "##e"],
+ ids: [101, 12463, 36865, 10608, 11051, 11037, 10608, 15501, 241, 102],
+ decoded: "[CLS] weird \uff5e edge \uff5e case [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["a", "##h", "\u535a", "\u63a8", "z", "##z"],
+ ids: [101, 232, 247, 3166, 4657, 282, 283, 102],
+ decoded: "[CLS] ah \u535a \u63a8 zz [SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["He", "##LL", "##o", "!", "ho", "##w", "Are", "yo", "##U", "?"],
+ ids: [101, 10869, 83346, 261, 106, 13685, 277, 14003, 14184, 211, 166, 102],
+ decoded: "[CLS] HeLLo! how Are yoU? [SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["H", "##\u00e4", "##LL", "##o", "!", "ho", "##w", "Are", "yo", "##U", "?"],
+ ids: [101, 184, 384, 83346, 261, 106, 13685, 277, 14003, 14184, 211, 166, 102],
+ decoded: "[CLS] H\u00e4LLo! how Are yoU? [SEP]",
+ },
+ },
+ "Xenova/kobert": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "?"],
+ ids: [2, 0, 0, 0, 0, 258, 3],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK]? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["[UNK]", "[UNK]", "'", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [2, 0, 0, 15, 0, 0, 0, 3],
+ decoded: "[CLS] [UNK] [UNK]'[UNK] [UNK] [UNK] [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["The", "[UNK]", "[UNK]", "[UNK]", "in", "[UNK]", "."],
+ ids: [2, 355, 0, 0, 0, 409, 0, 54, 3],
+ decoded: "[CLS] The [UNK] [UNK] [UNK] in [UNK]. [SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["A", "'", "[UNK]", "!", "!", "[UNK]", "?", "'", "d", "'", "'", "d", "[UNK]", ",", "[UNK]", "'", "t", "."],
+ ids: [2, 264, 15, 0, 5, 5, 0, 258, 15, 388, 15, 15, 388, 0, 46, 0, 15, 442, 54, 3],
+ decoded: "[CLS] A'[UNK]!! [UNK]?'d'' d [UNK], [UNK]'t. [SEP]",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["[UNK]", "[UNK]", "(", ")", ":", "[UNK]"],
+ ids: [2, 0, 0, 18, 40, 249, 0, 3],
+ decoded: "[CLS] [UNK] [UNK] ( ) : [UNK] [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["[UNK]", "a", "=", "[UNK]", ".", "[UNK]", "(", ")", ";", "[UNK]", "(", ")", ";"],
+ ids: [2, 0, 367, 254, 0, 54, 0, 18, 40, 252, 0, 18, 40, 252, 3],
+ decoded: "[CLS] [UNK] a = [UNK]. [UNK] ( ) ; [UNK] ( ) ; [SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["[UNK]", "is", "a", "[UNK]", "."],
+ ids: [2, 0, 412, 367, 0, 54, 3],
+ decoded: "[CLS] [UNK] is a [UNK]. [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["[UNK]", ",", "[UNK]"],
+ ids: [2, 0, 46, 0, 3],
+ decoded: "[CLS] [UNK], [UNK] [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [2, 5298, 0, 0, 0, 0, 0, 3],
+ decoded: "[CLS] \u751f [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["[UNK]", "$", "1", "[UNK]", "#", "3", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [2, 0, 10, 93, 0, 9, 142, 0, 0, 0, 0, 0, 0, 0, 3],
+ decoded: "[CLS] [UNK] $ 1 [UNK] # 3 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "[UNK]", "an", "[UNK]", "[UNK]", "$", "1", ".", "00", "at", "[UNK]", "[UNK]", "."],
+ ids: [2, 296, 0, 374, 0, 0, 10, 93, 54, 79, 377, 0, 0, 54, 3],
+ decoded: "[CLS] I [UNK] an [UNK] [UNK] $ 1. 00 at [UNK] [UNK]. [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [2, 0, 0, 0, 0, 3],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [2, 0, 0, 0, 0, 0, 3],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "\u2581", "."],
+ ids: [2, 0, 0, 0, 0, 517, 54, 3],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] \u2581. [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [2, 0, 0, 0, 0, 3],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["[UNK]", "!", "[UNK]", "[UNK]", "[UNK]", "?"],
+ ids: [2, 0, 5, 0, 0, 0, 258, 3],
+ decoded: "[CLS] [UNK]! [UNK] [UNK] [UNK]? [SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["[UNK]", "!", "[UNK]", "[UNK]", "[UNK]", "?"],
+ ids: [2, 0, 5, 0, 0, 0, 258, 3],
+ decoded: "[CLS] [UNK]! [UNK] [UNK] [UNK]? [SEP]",
+ },
+ },
+ "Xenova/scibert_scivocab_uncased": {
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "obj", ".", "to", "##string", "(", ")", ";", "to", "##string", "(", ")", ";"],
+ ids: [102, 1293, 106, 275, 2324, 205, 147, 20301, 145, 546, 1814, 147, 20301, 145, 546, 1814, 103],
+ decoded: "[CLS] let a = obj. tostring ( ) ; tostring ( ) ; [SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["hi", "hell", "##o"],
+ ids: [102, 5305, 29423, 30112, 103],
+ decoded: "[CLS] hi hello [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r", "##2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "[UNK]", "[UNK]", "test"],
+ ids: [102, 856, 3250, 158, 182, 30132, 3000, 239, 20801, 30140, 11221, 30139, 20704, 30142, 101, 101, 101, 856, 103],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] [UNK] [UNK] test [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["ah", "[UNK]", "[UNK]", "zz"],
+ ids: [102, 7839, 101, 101, 23591, 103],
+ decoded: "[CLS] ah [UNK] [UNK] zz [SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["hell", "##o"],
+ ids: [102, 29423, 30112, 103],
+ decoded: "[CLS] hello [SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["hell", "##o", "!", "how", "are", "you", "?"],
+ ids: [102, 29423, 30112, 3190, 539, 220, 3034, 3912, 103],
+ decoded: "[CLS] hello! how are you? [SEP]",
+ },
+ },
+ "Xenova/LaBSE": {
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "obj", ".", "to", "##String", "(", ")", ";", "to", "##String", "(", ")", ";"],
+ ids: [101, 17214, 170, 134, 228877, 119, 14986, 368304, 113, 114, 132, 14986, 368304, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. toString ( ) ; toString ( ) ; [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "R2", "#", "3", "\u20ac", "##4", "\u00a35", "\u00a5", "##6", "\u20a3", "##7", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [101, 17678, 109, 122, 51222, 108, 124, 3030, 16006, 279082, 205, 16151, 3023, 16187, 3037, 16175, 3033, 16236, 17678, 102],
+ decoded: "[CLS] test $ 1 R2 # 3 \u20ac4 \u00a35 \u00a56 \u20a37 \u20b98 \u20b19 test [SEP]",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["\u2581", "##This", "\u2581", "##is", "\u2581", "##a", "\u2581", "##test", "\u2581", "."],
+ ids: [101, 3283, 342068, 3283, 15319, 3283, 14983, 3283, 50149, 3283, 119, 102],
+ decoded: "[CLS] \u2581This \u2581is \u2581a \u2581test \u2581. [SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\ud83d\ude02", "\ud83d\udc4d", "\ud83e\udd23", "\ud83d\ude0d", "\ud83d\ude2d", "\ud83c\udf89", "\ud83d\ude4f", "\ud83d\ude0a", "\ud83d\udd25", "\ud83d\ude01", "\ud83d\ude05", "\ud83e\udd17", "\ud83d\ude06", "\ud83d\udc4f", "\u2764\ufe0f", "\ud83d\udc9c", "\ud83d\udc9a", "\ud83d\udc97", "\ud83d\udc99", "\ud83d\udda4", "\ud83d\ude0e", "\ud83d\udc4c", "\ud83e\udd73", "\ud83d\udcaa", "\u2728", "\ud83d\udc49", "\ud83d\udc40", "\ud83d\udcaf", "\ud83c\udf88", "\ud83d\ude48", "\ud83d\ude4c", "\ud83d\udc80", "\ud83d\udc47", "\ud83d\udc4b", "\u2705", "\ud83c\udf81", "\ud83c\udf1e", "\ud83c\udf38", "\ud83d\udcb0"],
+ ids: [101, 14820, 14617, 14933, 14831, 14863, 14496, 14893, 14828, 14775, 14819, 14823, 14926, 14824, 14619, 91822, 14687, 14685, 14682, 14684, 14810, 14832, 14616, 14956, 14701, 3496, 14613, 14606, 14706, 14495, 14887, 14891, 14660, 14611, 14615, 3465, 14488, 14416, 14430, 14707, 102],
+ decoded: "[CLS] \ud83d\ude02 \ud83d\udc4d \ud83e\udd23 \ud83d\ude0d \ud83d\ude2d \ud83c\udf89 \ud83d\ude4f \ud83d\ude0a \ud83d\udd25 \ud83d\ude01 \ud83d\ude05 \ud83e\udd17 \ud83d\ude06 \ud83d\udc4f \u2764\ufe0f \ud83d\udc9c \ud83d\udc9a \ud83d\udc97 \ud83d\udc99 \ud83d\udda4 \ud83d\ude0e \ud83d\udc4c \ud83e\udd73 \ud83d\udcaa \u2728 \ud83d\udc49 \ud83d\udc40 \ud83d\udcaf \ud83c\udf88 \ud83d\ude48 \ud83d\ude4c \ud83d\udc80 \ud83d\udc47 \ud83d\udc4b \u2705 \ud83c\udf81 \ud83c\udf1e \ud83c\udf38 \ud83d\udcb0 [SEP]",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u2728", "\ud83e\udd17", "\ud83d\udc41\ufe0f", "\ud83d\udc71", "##\ud83c\udffb", "[UNK]", "[UNK]", "\ud83d\udc68", "##\ud83c\udffb", "##\ud83c\udf3e", "[UNK]", "\ud83d\udc69", "##\u2764", "##\ud83d\udc8b", "##\ud83d\udc68", "\ud83d\udc69", "##\ud83d\udc69", "##\ud83d\udc67", "##\ud83d\udc66", "[UNK]", "\ud83c\udff4", "\ud83d\udc68", "##\ud83c\udffb", "##\u2764", "##\ufe0f", "##\ud83d\udc8b", "##\ud83d\udc68", "##\ud83c\udffc"],
+ ids: [101, 3496, 14926, 350545, 14648, 130826, 100, 100, 14639, 130826, 498832, 100, 14640, 488649, 499065, 499034, 14640, 499035, 499033, 499032, 100, 14555, 14639, 130826, 488649, 44450, 499065, 499034, 421916, 102],
+ decoded: "[CLS] \u2728 \ud83e\udd17 \ud83d\udc41\ufe0f \ud83d\udc71\ud83c\udffb [UNK] [UNK] \ud83d\udc68\ud83c\udffb\ud83c\udf3e [UNK] \ud83d\udc69\u2764\ud83d\udc8b\ud83d\udc68 \ud83d\udc69\ud83d\udc69\ud83d\udc67\ud83d\udc66 [UNK] \ud83c\udff4 \ud83d\udc68\ud83c\udffb\u2764\ufe0f\ud83d\udc8b\ud83d\udc68\ud83c\udffc [SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["ah", "\u535a", "\u63a8", "zz"],
+ ids: [101, 15524, 4573, 6405, 441764, 102],
+ decoded: "[CLS] ah \u535a \u63a8 zz [SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["H\u00e9", "##llo"],
+ ids: [101, 220855, 23025, 102],
+ decoded: "[CLS] H\u00e9llo [SEP]",
+ },
+ },
+ "Xenova/herbert-large-cased": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["Ho", "w", "are", "you", "do", "ing", "?"],
+ ids: [0, 5213, 1019, 25720, 20254, 2065, 5129, 1550, 2],
+ decoded: "How are you doing? ",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["You", "sho", "uld", "'", "ve", "d", "one", "this"],
+ ids: [0, 32795, 14924, 48273, 1571, 6647, 72, 2290, 48846, 2],
+ decoded: "You should've done this ",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["The", "co", "mpany", "was", "fo", "un", "de", "d", "in", "20", "16", "."],
+ ids: [0, 7117, 2406, 41449, 9873, 3435, 2195, 2101, 1038, 2651, 5646, 2555, 1899, 2],
+ decoded: "The company was founded in 2016. ",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["A", "'", "ll", "!", "!", "to", "?", "'", "d", "'", "'", "d", "of", ",", "can", "'", "t", "."],
+ ids: [0, 1012, 1571, 9396, 1725, 1725, 2063, 1550, 1571, 1038, 1571, 1571, 1038, 6595, 1947, 26794, 1571, 1026, 1899, 2],
+ decoded: "A'll!! to?'d'' d of, can't. ",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["de", "f", "main", "(", ")", ":", "pa", "ss"],
+ ids: [0, 2101, 1050, 41851, 1341, 1940, 1335, 2083, 5357, 2],
+ decoded: "def main ( ) : pass ",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "j", ".", "to", "S", "tr", "ing", "(", ")", ";", "to", "S", "tr", "ing", "(", ")", ";"],
+ ids: [0, 11324, 1011, 1789, 2033, 1013, 1899, 2146, 55, 2518, 5129, 1341, 1940, 1195, 2146, 55, 2518, 5129, 1341, 1940, 1195, 2],
+ decoded: "let a = obj. toString ( ) ; toString ( ) ; ",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["T", "his", "is", "a", "test", "."],
+ ids: [0, 56, 22855, 6869, 1011, 14825, 1899, 2],
+ decoded: "This is a test. ",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["UN", "wan", "t", "\u00e9", "d", ",", "run", "ning"],
+ ids: [0, 23029, 2688, 88, 163, 1038, 1947, 4980, 17843, 2],
+ decoded: "UNwant\u00e9d, running ",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["123"],
+ ids: [0, 19049, 2],
+ decoded: "123 ",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["Hel", "lo", "World"],
+ ids: [0, 12156, 6170, 21207, 2],
+ decoded: "Hello World ",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hel", "lo", "world"],
+ ids: [0, 11526, 6170, 38188, 2],
+ decoded: "hello world ",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["", "", "", "", "", "\u662f"],
+ ids: [0, 3, 3, 3, 3, 3, 1776, 2],
+ decoded: "\u662f ",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["le", "ad", "ing", "space"],
+ ids: [0, 2018, 2035, 5129, 46489, 2],
+ decoded: "leading space ",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["tra", "i", "ling", "space"],
+ ids: [0, 2201, 77, 16342, 46489, 2],
+ decoded: "trailing space ",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["H", "i", "Hel", "lo"],
+ ids: [0, 44, 1009, 12156, 6170, 2],
+ decoded: "Hi Hello ",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "R", "2", "#", "3", "\u20ac", "4", "\u00a3", "5", "", "6", "", "7", "", "8", "", "9", "test"],
+ ids: [0, 14825, 1927, 1029, 54, 1025, 1393, 1034, 706, 1018, 100, 1008, 3, 1036, 3, 1030, 3, 1064, 3, 1017, 14825, 2],
+ decoded: "test $ 1 R2 # 3 \u20ac4 \u00a35 6 7 8 9 test ",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "bou", "ght", "an", "ap", "ple", "for", "$", "1", ".", "00", "at", "the", "st", "ore", "."],
+ ids: [0, 1056, 13016, 15272, 2879, 10309, 20861, 15181, 1927, 1029, 1899, 2291, 4772, 6854, 1989, 24005, 1899, 2],
+ decoded: "I bought an apple for $ 1. 00 at the store. ",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["you", "\u2026"],
+ ids: [0, 20254, 1826, 2],
+ decoded: "you \u2026 ",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["you", "\u2026"],
+ ids: [0, 20254, 1826, 2],
+ decoded: "you \u2026 ",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["you", "\u2026", "you", "\u2026"],
+ ids: [0, 20254, 1826, 20254, 1826, 2],
+ decoded: "you \u2026 you \u2026 ",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["we", "ir", "d", "", "e", "dge", "", "ca", "se"],
+ ids: [0, 2149, 17435, 1038, 3, 73, 25801, 3, 3833, 4417, 2],
+ decoded: "weird edge case ",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["", "T", "his", "", "is", "", "a", "", "test", "", "."],
+ ids: [0, 3, 56, 22855, 3, 6869, 3, 1011, 3, 14825, 3, 1899, 2],
+ decoded: "This is a test . ",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["a", "h", "", "", "zz"],
+ ids: [0, 69, 1021, 3, 3, 49185, 2],
+ decoded: "ah zz ",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["H", "\u00e9", "l", "lo"],
+ ids: [0, 44, 163, 80, 6170, 2],
+ decoded: "H\u00e9llo ",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["He", "L", "L", "o", "!", "ho", "w", "Ar", "e", "yo", "U", "?"],
+ ids: [0, 4596, 48, 48, 1007, 1725, 3145, 1019, 2921, 1015, 13908, 1041, 1550, 2],
+ decoded: "HeLLo! how Are yoU? ",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["H", "\u00e4", "L", "L", "o", "!", "ho", "w", "Ar", "e", "yo", "U", "?"],
+ ids: [0, 44, 158, 48, 48, 1007, 1725, 3145, 1019, 2921, 1015, 13908, 1041, 1550, 2],
+ decoded: "H\u00e4LLo! how Are yoU? ",
+ },
+ },
+ "Xenova/ernie-gram-zh": {
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r2", "#", "3", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "test"],
+ ids: [1, 6943, 18005, 208, 6847, 9474, 284, 18017, 18017, 18017, 18017, 18017, 18017, 6943, 2],
+ decoded: "[CLS] test $ 1 r2 # 3 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] test [SEP]",
+ },
+ },
+};
diff --git a/tests/models/blenderbot_small/tokenization.js b/tests/models/blenderbot_small/tokenization.js
new file mode 100644
index 000000000..6bf4bbb93
--- /dev/null
+++ b/tests/models/blenderbot_small/tokenization.js
@@ -0,0 +1,166 @@
+import { BlenderbotSmallTokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS, BLENDERBOT_SMALL_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = BlenderbotSmallTokenizer;
+
+// NOTE: `.tokenize()` is disabled for BlenderbotSmallTokenizer
+export const TEST_CONFIG = {
+ "Xenova/blenderbot_small-90M": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ // "tokens": ["how", "are", "you", "doing", "?"],
+ ids: [102, 46, 15, 267, 20],
+ decoded: "how are you doing?",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ // "tokens": ["you", "should", "'", "ve", "done", "this"],
+ ids: [15, 197, 8, 117, 369, 36],
+ decoded: "you should've done this",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ // "tokens": ["0@@", "1@@", "2@@", "3@@", "4@@", "5@@", "6@@", "7@@", "89", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", "1000"],
+ ids: [1988, 2388, 735, 801, 827, 948, 981, 1110, 4814, 520, 143, 176, 216, 260, 253, 345, 374, 420, 475, 316, 773, 6217],
+ decoded: "0123456789 0 1 2 3 4 5 6 7 8 9 10 100 1000",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ // "tokens": ["the", "company", "was", "founded", "in", "2016", "."],
+ ids: [7, 293, 18, 912, 13, 845, 5],
+ decoded: "the company was founded in 2016.",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ // "tokens": ["a", "__newln__", "'", "ll", "!", "!@@", "to", "?", "'", "d", "'", "'", "d", "of", ",", "can", "'", "t", "."],
+ ids: [12, 4, 8, 97, 37, 3, 11, 20, 8, 85, 8, 8, 85, 10, 6, 62, 8, 30, 5],
+ decoded: "a __newln__'ll! __unk__ to?'d'' d of, can't.",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ // "tokens": ["def", "main", "(", ")@@", ":", "__newln__", "pass"],
+ ids: [21996, 550, 40, 3, 106, 4, 1314],
+ decoded: "def main ( __unk__ : __newln__ pass",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ // "tokens": ["let", "a", "=", "ob@@", "j", ".@@", "to@@", "string", "(", ")@@", ";", "__newln__", "to@@", "string", "(", ")@@", ";"],
+ ids: [131, 12, 1381, 2808, 755, 3, 752, 4529, 40, 3, 118, 4, 752, 4529, 40, 3, 118],
+ decoded: "let a = obj __unk__ tostring ( __unk__ ; __newln__ tostring ( __unk__ ;",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ // "tokens": ["this", "__newln__", "is", "__newln__", "a", "__newln__", "test", "."],
+ ids: [36, 4, 24, 4, 12, 4, 1248, 5],
+ decoded: "this __newln__ is __newln__ a __newln__ test.",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ // "tokens": ["un@@", "wan@@", "t@@", "\u00e9@@", "d", ",@@", "running"],
+ ids: [204, 4151, 291, 1677, 85, 3, 785],
+ decoded: "unwant\u00e9d __unk__ running",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ // "tokens": ["1@@", "\u0000@@", "2@@", "\ufffd@@", "3"],
+ ids: [2388, 3, 735, 3, 216],
+ decoded: "1__unk__ 2__unk__ 3",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ // "tokens": ["hello", "world"],
+ ids: [880, 159],
+ decoded: "hello world",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ // "tokens": ["hello", "world"],
+ ids: [880, 159],
+ decoded: "hello world",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ // "tokens": ["\u751f@@", "\u6d3b@@", "\u7684@@", "\u771f@@", "\u8c1b@@", "\u662f"],
+ ids: [30488, 32756, 29891, 30813, 3, 34037],
+ decoded: "\u751f\u6d3b\u7684\u771f__unk__ \u662f",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ // "tokens": ["leading", "space"],
+ ids: [1164, 833],
+ decoded: "leading space",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ // "tokens": ["trailing", "space"],
+ ids: [12499, 833],
+ decoded: "trailing space",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ // "tokens": ["hi", "hello"],
+ ids: [792, 880],
+ decoded: "hi hello",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ // "tokens": ["test", "$@@", "1", "r@@", "2", "#@@", "3", "\u20ac@@", "4", "\u00a3@@", "5", "\u00a5@@", "6", "\u20a3@@", "7", "\u20b9@@", "8", "\u20b1@@", "9", "test"],
+ ids: [1248, 3, 143, 510, 176, 3, 216, 3, 260, 3, 253, 3, 345, 3, 374, 3, 420, 3, 475, 1248],
+ decoded: "test __unk__ 1 r2 __unk__ 3 __unk__ 4 __unk__ 5 __unk__ 6 __unk__ 7 __unk__ 8 __unk__ 9 test",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ // "tokens": ["i", "bought", "an", "apple", "for", "$@@", "1", ".@@", "00", "at", "the", "store", "."],
+ ids: [14, 1890, 50, 4758, 26, 3, 143, 3, 1966, 32, 7, 1640, 5],
+ decoded: "i bought an apple for __unk__ 1 __unk__ 00 at the store.",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ // "tokens": ["you@@", "\u2026"],
+ ids: [7984, 1244],
+ decoded: "you\u2026",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ // "tokens": ["you@@", "\u2026"],
+ ids: [7984, 1244],
+ decoded: "you\u2026",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ // "tokens": ["you@@", "\u2026", "you@@", "\u2026"],
+ ids: [7984, 1244, 7984, 1244],
+ decoded: "you\u2026 you\u2026",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ // "tokens": ["weird", "\uff5e", "edge", "\uff5e", "case"],
+ ids: [2614, 30831, 1649, 30831, 543],
+ decoded: "weird \uff5e edge \uff5e case",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ // "tokens": ["\u2581@@", "this", "\u2581@@", "is", "\u2581@@", "a", "\u2581@@", "test", "\u2581", "."],
+ ids: [3, 36, 3, 24, 3, 12, 3, 1248, 50106, 5],
+ decoded: "__unk__ this __unk__ is __unk__ a __unk__ test \u2581.",
+ },
+ SPECIAL_TOKENS: {
+ text: BLENDERBOT_SMALL_TEST_STRINGS.SPECIAL_TOKENS,
+ // "tokens": ["__start__", "hello", "world", "__end__"],
+ ids: [1, 880, 159, 2],
+ decoded: "__start__ hello world __end__",
+ },
+ WHITESPACE_1: {
+ text: BLENDERBOT_SMALL_TEST_STRINGS.WHITESPACE_1,
+ // "tokens": ["__start__", "hey", "__end__"],
+ ids: [1, 226, 2],
+ decoded: "__start__ hey __end__",
+ },
+ WHITESPACE_2: {
+ text: BLENDERBOT_SMALL_TEST_STRINGS.WHITESPACE_2,
+ // "tokens": ["__start__", "hey", "__end__"],
+ ids: [1, 226, 2],
+ decoded: "__start__ hey __end__",
+ },
+ },
+};
diff --git a/tests/models/bloom/tokenization.js b/tests/models/bloom/tokenization.js
new file mode 100644
index 000000000..03b95d63a
--- /dev/null
+++ b/tests/models/bloom/tokenization.js
@@ -0,0 +1,194 @@
+import { BloomTokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS, BLOOM_TEST_STRINGS, SENTENCEPIECE_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = BloomTokenizer;
+export const TEST_CONFIG = {
+ "Xenova/bloom-560m": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["How", "\u0120are", "\u0120you", "\u0120doing", "?"],
+ ids: [7572, 1306, 1152, 12491, 34],
+ decoded: "How are you doing?",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["You", "\u0120should", "'ve", "\u0120done", "\u0120this"],
+ ids: [5448, 3403, 7300, 11541, 1119],
+ decoded: "You should've done this",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ tokens: ["0123", "456789", "\u01200", "\u01201", "\u01202", "\u01203", "\u01204", "\u01205", "\u01206", "\u01207", "\u01208", "\u01209", "\u012010", "\u0120100", "\u01201000"],
+ ids: [166660, 145647, 931, 404, 415, 735, 934, 973, 1231, 1392, 1445, 1575, 1581, 4334, 19526],
+ decoded: "0123456789 0 1 2 3 4 5 6 7 8 9 10 100 1000",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["The", "\u0120company", "\u0120was", "\u0120founded", "\u0120in", "\u01202016", "."],
+ ids: [2175, 16333, 1620, 88289, 361, 5854, 17],
+ decoded: "The company was founded in 2016.",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["A", "\u010a", "'ll", "\u0120!!", "to", "?", "'d", "''", "d", "\u0120of", ",", "\u0120can't", "."],
+ ids: [36, 189, 8722, 49825, 1025, 34, 10628, 2328, 71, 461, 15, 11229, 17],
+ decoded: "A\n'll !!to?'d''d of, can't.",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["def", "\u0120main", "()", ":", "\u010a\u0109", "pass"],
+ ids: [7564, 4291, 883, 29, 1582, 12608],
+ decoded: "def main():\n\tpass",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "\u0120a", "\u0120=", "\u0120obj", ".", "toString", "()", ";", "\u010a", "toString", "()", ";"],
+ ids: [2963, 267, 564, 17949, 17, 27392, 883, 30, 189, 27392, 883, 30],
+ decoded: "let a = obj.toString();\ntoString();",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["This", "\u010a\u010a", "is", "\u010a", "a", "\u010a", "test", "."],
+ ids: [6168, 603, 290, 189, 68, 189, 9234, 17],
+ decoded: "This\n\nis\na\ntest.",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["UN", "want", "\u00c3\u00a9d", ",", "running"],
+ ids: [5777, 75642, 2454, 15, 101897],
+ decoded: "UNwant\u00e9d,running",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["1", "\u0100", "2", "\u00ef\u00bf\u00bd", "3"],
+ ids: [20, 179, 21, 23181, 22],
+ decoded: "1\u00002\ufffd3",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["Hello", "\u0120World"],
+ ids: [59414, 12155],
+ decoded: "Hello World",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hello", "\u0120world"],
+ ids: [101579, 8876],
+ decoded: "hello world",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u00e7\u0136\u0141\u00e6\u00b4\u00bb\u00e7\u013c\u0126", "\u00e7\u013e\u0141", "\u00e8\u00b0", "\u013d", "\u00e6\u013a\u00af"],
+ ids: [71167, 4137, 1927, 239, 644],
+ decoded: "\u751f\u6d3b\u7684\u771f\u8c1b\u662f",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["\u0120\u0120", "\u0120leading", "\u0120space"],
+ ids: [250, 36128, 12978],
+ decoded: " leading space",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["tra", "iling", "\u0120space", "\u0120\u0120\u0120"],
+ ids: [1900, 17022, 12978, 416],
+ decoded: "trailing space ",
+ },
+ SURROUNDING_SPACE: {
+ text: BASE_TEST_STRINGS.SURROUNDING_SPACE,
+ tokens: ["\u0120\u0120", "\u0120surrounding", "\u0120space", "\u0120\u0120\u0120"],
+ ids: [250, 66599, 12978, 416],
+ decoded: " surrounding space ",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["Hi", "\u0120", "\u0120Hello"],
+ ids: [30050, 210, 86153],
+ decoded: "Hi Hello",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "\u0120$1", "\u0120R2", "\u0120#3", "\u0120\u00e2\u0124\u00ac", "4", "\u0120\u00c2\u00a3", "5", "\u0120\u00c2\u00a5", "6", "\u0120\u00e2\u0124", "\u00a3", "7", "\u0120\u00e2\u0124\u00b9", "8", "\u0120\u00e2\u0124", "\u00b1", "9", "\u0120test"],
+ ids: [9234, 41448, 80774, 201642, 20117, 23, 40300, 24, 62153, 25, 72279, 100, 26, 120434, 27, 72279, 113, 28, 4006],
+ decoded: "test $1 R2 #3 \u20ac4 \u00a35 \u00a56 \u20a37 \u20b98 \u20b19 test",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "\u0120bought", "\u0120an", "\u0120apple", "\u0120for", "\u0120$1", ".", "00", "\u0120at", "\u0120the", "\u0120store", "."],
+ ids: [44, 87926, 660, 101091, 613, 41448, 17, 462, 919, 368, 18706, 17],
+ decoded: "I bought an apple for $1.00 at the store.",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["you", "\u00e2\u0122\u00a6", "\u0120\u0120"],
+ ids: [23438, 4346, 250],
+ decoded: "you\u2026 ",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["you", "\u00e2\u0122\u00a6", "\u00c2\u0142\u00c2\u0142"],
+ ids: [23438, 4346, 12361],
+ decoded: "you\u2026\u00a0\u00a0",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["you", "\u00e2\u0122\u00a6", "\u00c2\u0142\u00c2\u0142", "you", "\u00e2\u0122\u00a6", "\u00c2\u0142\u00c2\u0142"],
+ ids: [23438, 4346, 12361, 23438, 4346, 12361],
+ decoded: "you\u2026\u00a0\u00a0you\u2026\u00a0\u00a0",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["we", "ird", "\u0120\u00ef\u00bd", "\u0140", "\u0120edge", "\u0120\u00ef\u00bd", "\u0140", "\u0120case"],
+ ids: [2136, 7589, 122354, 242, 29655, 122354, 242, 4462],
+ decoded: "weird \uff5e edge \uff5e case",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["\u00e2\u0138", "\u0123", "This", "\u0120\u00e2\u0138", "\u0123", "is", "\u0120\u00e2\u0138", "\u0123", "a", "\u0120\u00e2\u0138", "\u0123", "test", "\u0120\u00e2\u0138", "\u0123", "."],
+ ids: [26127, 213, 6168, 15299, 213, 290, 15299, 213, 68, 15299, 213, 9234, 15299, 213, 17],
+ decoded: "\u2581This \u2581is \u2581a \u2581test \u2581.",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\u00f0\u0141\u013a", "\u0124", "\u0120\u00f0\u0141", "\u0133", "\u012f", "\u0120\u00f0\u0141", "\u00a4", "\u00a3", "\u0120\u00f0\u0141\u013a", "\u012f", "\u0120\u00f0\u0141\u013a", "\u0143", "\u0120\u00f0\u0141", "\u0130", "\u012b", "\u0120\u00f0\u0141", "\u013b", "\u0131", "\u0120\u00f0\u0141\u013a", "\u012c", "\u0120\u00f0\u0141", "\u0136", "\u00a5", "\u0120\u00f0\u0141\u013a", "\u0123", "\u0120\u00f0\u0141\u013a", "\u0127", "\u0120\u00f0\u0141", "\u00a4", "\u0139", "\u0120\u00f0\u0141\u013a", "\u0128", "\u0120\u00f0\u0141", "\u0133", "\u0131", "\u0120\u00e2\u013f", "\u00a4", "\u00ef\u00b8\u0131", "\u0120\u00f0\u0141", "\u0134", "\u013e", "\u0120\u00f0\u0141", "\u0134", "\u013c", "\u0120\u00f0\u0141", "\u0134", "\u0139", "\u0120\u00f0\u0141", "\u0134", "\u013b", "\u0120\u00f0\u0141", "\u0138", "\u00a4", "\u0120\u00f0\u0141\u013a", "\u0130", "\u0120\u00f0\u0141", "\u0133", "\u012e", "\u0120\u00f0\u0141", "\u00a5", "\u00b3", "\u0120\u00f0\u0141", "\u0134", "\u00aa", "\u0120\u00e2\u013e", "\u00a8", "\u0120\u00f0\u0141", "\u0133", "\u012b", "\u0120\u00f0\u0141", "\u0133", "\u0122", "\u0120\u00f0\u0141", "\u0134", "\u00af", "\u0120\u00f0\u0141", "\u0130", "\u012a", "\u0120\u00f0\u0141", "\u013b", "\u012a", "\u0120\u00f0\u0141", "\u013b", "\u012e", "\u0120\u00f0\u0141", "\u0134", "\u0122", "\u0120\u00f0\u0141", "\u0133", "\u0129", "\u0120\u00f0\u0141", "\u0133", "\u012d", "\u0120\u00e2\u013e", "\u0127", "\u0120\u00f0\u0141", "\u0130", "\u0123", "\u0120\u00f0\u0141", "\u012e", "\u0140", "\u0120\u00f0\u0141", "\u012e", "\u00b8", "\u0120\u00f0\u0141", "\u0134", "\u00b0"],
+ ids: [127322, 214, 41234, 229, 225, 41234, 101, 100, 126342, 225, 126342, 245, 41234, 226, 221, 41234, 237, 227, 126342, 222, 41234, 232, 102, 126342, 213, 126342, 217, 41234, 101, 235, 126342, 218, 41234, 229, 227, 189367, 101, 116057, 41234, 230, 240, 41234, 230, 238, 41234, 230, 235, 41234, 230, 237, 41234, 234, 101, 126342, 226, 41234, 229, 224, 41234, 102, 115, 41234, 230, 107, 76758, 105, 41234, 229, 221, 41234, 229, 212, 41234, 230, 111, 41234, 226, 220, 41234, 237, 220, 41234, 237, 224, 41234, 230, 212, 41234, 229, 219, 41234, 229, 223, 76758, 217, 41234, 226, 213, 41234, 224, 242, 41234, 224, 120, 41234, 230, 112],
+ decoded: "\ud83d\ude02 \ud83d\udc4d \ud83e\udd23 \ud83d\ude0d \ud83d\ude2d \ud83c\udf89 \ud83d\ude4f \ud83d\ude0a \ud83d\udd25 \ud83d\ude01 \ud83d\ude05 \ud83e\udd17 \ud83d\ude06 \ud83d\udc4f \u2764\ufe0f \ud83d\udc9c \ud83d\udc9a \ud83d\udc97 \ud83d\udc99 \ud83d\udda4 \ud83d\ude0e \ud83d\udc4c \ud83e\udd73 \ud83d\udcaa \u2728 \ud83d\udc49 \ud83d\udc40 \ud83d\udcaf \ud83c\udf88 \ud83d\ude48 \ud83d\ude4c \ud83d\udc80 \ud83d\udc47 \ud83d\udc4b \u2705 \ud83c\udf81 \ud83c\udf1e \ud83c\udf38 \ud83d\udcb0",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u00e2\u013e", "\u00a8", "\u0120\u00f0\u0141", "\u00a4", "\u0139", "\u0120\u00f0\u0141", "\u0133", "\u0123", "\u00ef\u00b8\u0131", "\u0120\u00f0\u0141", "\u0133", "\u00b1", "\u00f0\u0141\u0131", "\u00bb", "\u0120\u00f0\u0141", "\u0137", "\u00b5", "\u00e2\u0122\u012f", "\u00e2\u013b", "\u0124", "\u00ef\u00b8\u0131", "\u0120\u00f0\u0141", "\u00a7", "\u013b", "\u00f0\u0141\u0131", "\u00bb", "\u00e2\u0122\u012f", "\u00e2\u013b", "\u0124", "\u0120\u00f0\u0141", "\u0133", "\u00a8", "\u00f0\u0141\u0131", "\u00bb", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u012e", "\u00be", "\u0120\u00f0\u0141", "\u00a7", "\u0133", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u00a4", "\u013f", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u00a7", "\u0133", "\u0120\u00f0\u0141", "\u0133", "\u00a9", "\u00e2\u0122\u012f", "\u00e2\u013f", "\u00a4", "\u00e2\u0122\u012f", "\u00f0\u0141\u0134", "\u012d", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u0133", "\u00a8", "\u0120\u00f0\u0141", "\u0133", "\u00a9", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u0133", "\u00a9", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u0133", "\u00a7", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u0133", "\u00a6", "\u0120\u00f0\u0141", "\u00a7", "\u0133", "\u00f0\u0141\u0131", "\u00bb", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u00a4", "\u013f", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u00a7", "\u0133", "\u00f0\u0141\u0131", "\u00bb", "\u0120\u00f0\u0141", "\u0131", "\u00b4", "\u00f3", "\u0142", "\u0123", "\u00a7", "\u00f3", "\u0142", "\u0123", "\u00a2", "\u00f3", "\u0142", "\u0123", "\u00a5", "\u00f3", "\u0142", "\u0123", "\u00ae", "\u00f3", "\u0142", "\u0123", "\u00a7", "\u00f3", "\u0142", "\u0123", "\u00bf", "\u0120\u00f0\u0141", "\u0133", "\u00a8", "\u00f0\u0141\u0131", "\u00bb", "\u00e2\u0122\u012f", "\u00e2\u013f", "\u00a4", "\u00ef\u00b8\u0131", "\u00e2\u0122\u012f", "\u00f0\u0141\u0134", "\u012d", "\u00e2\u0122\u012f", "\u00f0\u0141", "\u0133", "\u00a8", "\u00f0\u0141\u0131", "\u00bc"],
+ ids: [120709, 105, 41234, 101, 235, 41234, 229, 213, 116057, 41234, 229, 113, 244635, 123, 41234, 233, 117, 1553, 15596, 214, 116057, 41234, 104, 237, 244635, 123, 1553, 15596, 214, 41234, 229, 105, 244635, 123, 1553, 22618, 224, 126, 41234, 104, 229, 1553, 22618, 101, 241, 1553, 22618, 104, 229, 41234, 229, 106, 1553, 157147, 101, 1553, 139500, 223, 1553, 22618, 229, 105, 41234, 229, 106, 1553, 22618, 229, 106, 1553, 22618, 229, 104, 1553, 22618, 229, 103, 41234, 104, 229, 244635, 123, 1553, 22618, 101, 241, 1553, 22618, 104, 229, 244635, 123, 41234, 227, 116, 177, 244, 213, 104, 177, 244, 213, 99, 177, 244, 213, 102, 177, 244, 213, 110, 177, 244, 213, 104, 177, 244, 213, 127, 41234, 229, 105, 244635, 123, 1553, 157147, 101, 116057, 1553, 139500, 223, 1553, 22618, 229, 105, 244635, 124],
+ decoded: "\u2728 \ud83e\udd17 \ud83d\udc41\ufe0f \ud83d\udc71\ud83c\udffb \ud83d\udd75\u200d\u2642\ufe0f \ud83e\uddd9\ud83c\udffb\u200d\u2642 \ud83d\udc68\ud83c\udffb\u200d\ud83c\udf3e \ud83e\uddd1\u200d\ud83e\udd1d\u200d\ud83e\uddd1 \ud83d\udc69\u200d\u2764\u200d\ud83d\udc8b\u200d\ud83d\udc68 \ud83d\udc69\u200d\ud83d\udc69\u200d\ud83d\udc67\u200d\ud83d\udc66 \ud83e\uddd1\ud83c\udffb\u200d\ud83e\udd1d\u200d\ud83e\uddd1\ud83c\udffb \ud83c\udff4\udb40\udc67\udb40\udc62\udb40\udc65\udb40\udc6e\udb40\udc67\udb40\udc7f \ud83d\udc68\ud83c\udffb\u200d\u2764\ufe0f\u200d\ud83d\udc8b\u200d\ud83d\udc68\ud83c\udffc",
+ },
+ ONLY_WHITESPACE: {
+ text: BASE_TEST_STRINGS.ONLY_WHITESPACE,
+ tokens: ["\u0120\u0109", "\u010a"],
+ ids: [33651, 189],
+ decoded: " \t\n",
+ },
+ END_OF_SENTENCE_PUNCTUATION: {
+ text: BLOOM_TEST_STRINGS.END_OF_SENTENCE_PUNCTUATION,
+ tokens: ["test", ".", "\u0120test", ",", "\u0120test", "!", "\u0120test", "?", "\u0120test", "\u00e2\u0122\u00a6", "\u0120test", "\u00e3\u0122\u0124", "\u0120test", "\u00ef\u00bc\u012e", "\u0120test", "\u00e3\u0122\u0123", "\u0120test", "\u00e0\u00a5\u00a4", "\u0120test", "\u00db\u0136", "\u0120test", "\u00d8\u012e", "\u0120test"],
+ ids: [9234, 17, 4006, 15, 4006, 4, 4006, 34, 4006, 4346, 4006, 420, 4006, 355, 4006, 594, 4006, 527, 4006, 1174, 4006, 687, 4006],
+ decoded: "test. test, test! test? test\u2026 test\u3002 test\uff0c test\u3001 test\u0964 test\u06d4 test\u060c test",
+ },
+ SPECIAL_WITH_TRAILING_WHITESPACE: {
+ text: SENTENCEPIECE_TEST_STRINGS.SPECIAL_WITH_TRAILING_WHITESPACE,
+ tokens: ["", "\u010a"],
+ ids: [1, 189],
+ decoded: "\n",
+ },
+ SPECIAL_SURROUNDED_BY_WHITESPACE: {
+ text: SENTENCEPIECE_TEST_STRINGS.SPECIAL_SURROUNDED_BY_WHITESPACE,
+ tokens: ["\u0120", " ", "\u0120test", "\u0120", " ", "\u0120"],
+ ids: [210, 2, 4006, 210, 2, 210],
+ decoded: " test ",
+ },
+ SPECIAL_NO_WHITESPACE: {
+ text: SENTENCEPIECE_TEST_STRINGS.SPECIAL_NO_WHITESPACE,
+ tokens: ["", "test", ""],
+ ids: [2, 9234, 2],
+ decoded: "test",
+ },
+ },
+};
diff --git a/tests/models/clip/tokenization.js b/tests/models/clip/tokenization.js
new file mode 100644
index 000000000..73cacda3c
--- /dev/null
+++ b/tests/models/clip/tokenization.js
@@ -0,0 +1,166 @@
+import { CLIPTokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = CLIPTokenizer;
+export const TEST_CONFIG = {
+ "Xenova/clip-vit-base-patch16": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["how", "are", "you", "doing", "?"],
+ ids: [49406, 829, 631, 592, 1960, 286, 49407],
+ decoded: "<|startoftext|>how are you doing? <|endoftext|>",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["you", "should", "'ve", "done", "this"],
+ ids: [49406, 592, 1535, 1200, 1700, 589, 49407],
+ decoded: "<|startoftext|>you should've done this <|endoftext|>",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ tokens: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "1", "0", "1", "0", "0", "1", "0", "0", "0"],
+ ids: [49406, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 272, 271, 272, 271, 271, 272, 271, 271, 271, 49407],
+ decoded: "<|startoftext|>0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 1 0 1 0 0 1 0 0 0 <|endoftext|>",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["the", "company", "was", "founded", "in", "2", "0", "1", "6", "."],
+ ids: [49406, 518, 2634, 739, 12240, 530, 273, 271, 272, 277, 269, 49407],
+ decoded: "<|startoftext|>the company was founded in 2 0 1 6. <|endoftext|>",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["a", "'ll", "!!", "to", "?'", "d", "''", "d", "of", ",", "can", "'t", "."],
+ ids: [49406, 320, 1342, 748, 531, 13610, 323, 8445, 323, 539, 267, 753, 713, 269, 49407],
+ decoded: "<|startoftext|>a 'll!! to?' d '' d of, can 't. <|endoftext|>",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["def", "main", "(", "):", "pass"],
+ ids: [49406, 11649, 2623, 7, 4143, 3511, 49407],
+ decoded: "<|startoftext|>def main (): pass <|endoftext|>",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "j", ".", "to", "string", "(", ");", "to", "string", "(", ");"],
+ ids: [49406, 1094, 320, 284, 1411, 329, 269, 580, 9696, 7, 19686, 580, 9696, 7, 19686, 49407],
+ decoded: "<|startoftext|>let a = obj. tostring (); tostring (); <|endoftext|>",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["this", "is", "a", "test", "."],
+ ids: [49406, 589, 533, 320, 1628, 269, 49407],
+ decoded: "<|startoftext|>this is a test. <|endoftext|>",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["un", "want", "\u00c3\u00a9", "d", ",", "running"],
+ ids: [49406, 569, 18356, 3459, 323, 267, 2761, 49407],
+ decoded: "<|startoftext|>unwant\u00e9d, running <|endoftext|>",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["1", "\u0100", "2", "\u00ef\u00bf\u00bd", "3"],
+ ids: [49406, 272, 444, 273, 39802, 274, 49407],
+ decoded: "<|startoftext|>1 \u0000 2 \ufffd 3 <|endoftext|>",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["hello", "world"],
+ ids: [49406, 3306, 1002, 49407],
+ decoded: "<|startoftext|>hello world <|endoftext|>",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hello", "world"],
+ ids: [49406, 3306, 1002, 49407],
+ decoded: "<|startoftext|>hello world <|endoftext|>",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u00e7\u0136\u0141", "\u00e6", "\u00b4", "\u00bb", "\u00e7", "\u013c", "\u0126", "\u00e7\u013e\u0141", "\u00e8", "\u00b0", "\u013d", "\u00e6\u013a", "\u00af"],
+ ids: [49406, 33375, 162, 112, 119, 163, 248, 226, 41570, 164, 108, 249, 42891, 363, 49407],
+ decoded: "<|startoftext|>\u751f\u6d3b\u7684\u771f\u8c1b\u662f <|endoftext|>",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["leading", "space"],
+ ids: [49406, 3833, 2138, 49407],
+ decoded: "<|startoftext|>leading space <|endoftext|>",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["trailing", "space"],
+ ids: [49406, 37427, 2138, 49407],
+ decoded: "<|startoftext|>trailing space <|endoftext|>",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["hi", "hello"],
+ ids: [49406, 1883, 3306, 49407],
+ decoded: "<|startoftext|>hi hello <|endoftext|>",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r", "2", "#", "3", "\u00e2\u0124\u00ac", "4", "\u00c2\u00a3", "5", "\u00c2\u00a5", "6", "\u00e2\u0124", "\u00a3", "7", "\u00e2\u0124\u00b9", "8", "\u00e2\u0124", "\u00b1", "9", "test"],
+ ids: [49406, 1628, 259, 272, 337, 273, 258, 274, 6309, 275, 1950, 276, 20199, 277, 5227, 352, 278, 21777, 279, 5227, 365, 280, 1628, 49407],
+ decoded: "<|startoftext|>test $ 1 r 2 # 3 \u20ac 4 \u00a3 5 \u00a5 6 \u20a3 7 \u20b9 8 \u20b1 9 test <|endoftext|>",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["i", "bought", "an", "apple", "for", "$", "1", ".", "0", "0", "at", "the", "store", "."],
+ ids: [49406, 328, 4142, 550, 3055, 556, 259, 272, 269, 271, 271, 536, 518, 2183, 269, 49407],
+ decoded: "<|startoftext|>i bought an apple for $ 1. 0 0 at the store. <|endoftext|>",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["you", "\u00e2\u0122\u00a6"],
+ ids: [49406, 592, 959, 49407],
+ decoded: "<|startoftext|>you \u2026 <|endoftext|>",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["you", "\u00e2\u0122\u00a6"],
+ ids: [49406, 592, 959, 49407],
+ decoded: "<|startoftext|>you \u2026 <|endoftext|>",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["you", "\u00e2\u0122\u00a6", "you", "\u00e2\u0122\u00a6"],
+ ids: [49406, 592, 959, 592, 959, 49407],
+ decoded: "<|startoftext|>you \u2026 you \u2026 <|endoftext|>",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["weird", "\u00ef", "\u00bd", "\u0140", "edge", "\u00ef", "\u00bd", "\u0140", "case"],
+ ids: [49406, 5613, 171, 121, 508, 5461, 171, 121, 508, 2068, 49407],
+ decoded: "<|startoftext|>weird \uff5e edge \uff5e case <|endoftext|>",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["\u00e2\u0138", "\u0123", "this", "\u00e2\u0138", "\u0123", "is", "\u00e2\u0138", "\u0123", "a", "\u00e2\u0138", "\u0123", "test", "\u00e2\u0138", "\u0123", "."],
+ ids: [49406, 4168, 479, 589, 4168, 479, 533, 4168, 479, 320, 4168, 479, 1628, 4168, 223, 269, 49407],
+ decoded: "<|startoftext|>\u2581 this \u2581 is \u2581 a \u2581 test \u2581. <|endoftext|>",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\u00f0\u0141\u013a\u0124", "\u00f0\u0141\u0133\u012f", "\u00f0\u0141\u00a4\u00a3", "\u00f0\u0141\u013a\u012f", "\u00f0\u0141\u013a\u0143", "\u00f0\u0141\u0130\u012b", "\u00f0\u0141\u013b\u0131", "\u00f0\u0141\u013a\u012c", "\u00f0\u0141\u0136\u00a5", "\u00f0\u0141\u013a\u0123", "\u00f0\u0141\u013a\u0127", "\u00f0\u0141\u00a4\u0139", "\u00f0\u0141\u013a\u0128", "\u00f0\u0141\u0133\u0131", "\u00e2\u013f\u00a4\u00ef\u00b8\u0131", "\u00f0\u0141\u0134\u013e", "\u00f0\u0141\u0134\u013c", "\u00f0\u0141\u0134\u0139", "\u00f0\u0141\u0134\u013b", "\u00f0\u0141\u0138\u00a4", "\u00f0\u0141\u013a\u0130", "\u00f0\u0141\u0133\u012e", "\u00f0\u0141\u00a5\u00b3", "\u00f0\u0141\u0134\u00aa", "\u00e2\u013e\u00a8", "\u00f0\u0141\u0133\u012b", "\u00f0\u0141\u0133\u0122", "\u00f0\u0141\u0134\u00af", "\u00f0\u0141\u0130\u012a", "\u00f0\u0141\u013b\u012a", "\u00f0\u0141\u013b\u012e", "\u00f0\u0141\u0134\u0122", "\u00f0\u0141\u0133\u0129", "\u00f0\u0141\u0133\u012d", "\u00e2\u013e\u0127", "\u00f0\u0141\u0130\u0123", "\u00f0\u0141\u012e\u0140", "\u00f0\u0141\u012e\u00b8", "\u00f0\u0141\u0134\u00b0"],
+ ids: [49406, 1558, 4201, 9909, 1754, 3915, 3986, 5503, 3020, 3016, 4821, 9188, 10465, 10943, 4829, 1752, 4882, 6521, 6690, 4074, 10860, 4345, 4494, 28055, 6440, 3531, 3988, 5908, 7018, 14448, 9516, 4855, 12158, 7475, 17686, 5564, 13462, 12980, 10980, 14078, 49407],
+ decoded: "<|startoftext|>\ud83d\ude02 \ud83d\udc4d \ud83e\udd23 \ud83d\ude0d \ud83d\ude2d \ud83c\udf89 \ud83d\ude4f \ud83d\ude0a \ud83d\udd25 \ud83d\ude01 \ud83d\ude05 \ud83e\udd17 \ud83d\ude06 \ud83d\udc4f \u2764\ufe0f \ud83d\udc9c \ud83d\udc9a \ud83d\udc97 \ud83d\udc99 \ud83d\udda4 \ud83d\ude0e \ud83d\udc4c \ud83e\udd73 \ud83d\udcaa \u2728 \ud83d\udc49 \ud83d\udc40 \ud83d\udcaf \ud83c\udf88 \ud83d\ude48 \ud83d\ude4c \ud83d\udc80 \ud83d\udc47 \ud83d\udc4b \u2705 \ud83c\udf81 \ud83c\udf1e \ud83c\udf38 \ud83d\udcb0 <|endoftext|>",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u00e2\u013e\u00a8", "\u00f0\u0141\u00a4\u0139", "\u00f0\u0141\u0133\u0123", "\u00ef\u00b8\u0131", "\u00f0\u0141\u0133", "\u00b1", "\u00f0\u0141\u0131\u00bb", "\u00f0\u0141\u0137", "\u00b5", "\u00e2\u0122\u012f\u00e2\u013b\u0124\u00ef\u00b8\u0131", "\u00f0\u0141\u00a7", "\u013b", "\u00f0\u0141\u0131\u00bb", "\u00e2\u0122\u012f\u00e2\u013b", "\u0124", "\u00f0\u0141\u0133\u00a8", "\u00f0\u0141\u0131\u00bb\u00e2\u0122\u012f", "\u00f0\u0141\u012e\u00be", "\u00f0\u0141\u00a7", "\u0133", "\u00e2\u0122\u012f", "\u00f0\u0141\u00a4", "\u013f", "\u00e2\u0122\u012f", "\u00f0\u0141\u00a7", "\u0133", "\u00f0\u0141\u0133\u00a9\u00e2\u0122\u012f", "\u00e2\u013f\u00a4", "\u00e2\u0122\u012f", "\u00f0\u0141\u0134\u012d", "\u00e2\u0122\u012f", "\u00f0\u0141\u0133", "\u00a8", "\u00f0\u0141\u0133\u00a9\u00e2\u0122\u012f", "\u00f0\u0141\u0133\u00a9\u00e2\u0122\u012f", "\u00f0\u0141\u0133\u00a7", "\u00e2\u0122\u012f", "\u00f0\u0141\u0133", "\u00a6", "\u00f0\u0141\u00a7", "\u0133", "\u00f0\u0141\u0131\u00bb\u00e2\u0122\u012f", "\u00f0\u0141\u00a4", "\u013f", "\u00e2\u0122\u012f", "\u00f0\u0141\u00a7", "\u0133", "\u00f0\u0141\u0131\u00bb", "\u00f0\u0141\u0131\u00b4", "\u00f3", "\u0142", "\u0123", "\u00a7", "\u00f3", "\u0142", "\u0123", "\u00a2", "\u00f3", "\u0142", "\u0123", "\u00a5", "\u00f3", "\u0142", "\u0123", "\u00ae", "\u00f3", "\u0142", "\u0123", "\u00a7", "\u00f3", "\u0142", "\u0123", "\u00bf", "\u00f0\u0141\u0133\u00a8", "\u00f0\u0141\u0131\u00bb\u00e2\u0122\u012f", "\u00e2\u013f\u00a4\u00ef\u00b8\u0131", "\u00e2\u0122\u012f", "\u00f0\u0141\u0134\u012d", "\u00e2\u0122\u012f", "\u00f0\u0141\u0133\u00a8", "\u00f0\u0141\u0131\u00bc"],
+ ids: [49406, 3531, 10465, 47796, 1001, 964, 109, 3702, 7692, 113, 10613, 8792, 247, 5042, 5177, 480, 18966, 46250, 39796, 8792, 239, 4244, 1793, 251, 4244, 8792, 495, 26304, 1266, 4244, 12217, 4244, 964, 357, 26304, 26304, 48938, 4244, 964, 355, 8792, 239, 46250, 1793, 251, 4244, 8792, 239, 3702, 39690, 175, 254, 223, 100, 175, 254, 223, 95, 175, 254, 223, 98, 175, 254, 223, 106, 175, 254, 223, 100, 175, 254, 223, 379, 18966, 46250, 2626, 4244, 12217, 4244, 18966, 4027, 49407],
+ decoded: "<|startoftext|>\u2728 \ud83e\udd17 \ud83d\udc41\ufe0f \ud83d\udc71\ud83c\udffb \ud83d\udd75\u200d\u2642\ufe0f \ud83e\uddd9\ud83c\udffb\u200d\u2642 \ud83d\udc68\ud83c\udffb\u200d\ud83c\udf3e \ud83e\uddd1\u200d\ud83e\udd1d\u200d\ud83e\uddd1 \ud83d\udc69\u200d\u2764\u200d\ud83d\udc8b\u200d\ud83d\udc68 \ud83d\udc69\u200d\ud83d\udc69\u200d\ud83d\udc67\u200d\ud83d\udc66 \ud83e\uddd1\ud83c\udffb\u200d\ud83e\udd1d\u200d\ud83e\uddd1\ud83c\udffb \ud83c\udff4\udb40\udc67\udb40\udc62\udb40\udc65\udb40\udc6e\udb40\udc67\udb40\udc7f \ud83d\udc68\ud83c\udffb\u200d\u2764\ufe0f\u200d\ud83d\udc8b\u200d\ud83d\udc68\ud83c\udffc <|endoftext|>",
+ },
+ },
+ "Xenova/owlvit-base-patch32": {
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["a", "'ll", "!", "!", "to", "?'", "d", "''", "d", "of", ",", "can", "'t", "."],
+ ids: [49406, 320, 1342, 0, 0, 531, 13610, 323, 8445, 323, 539, 267, 753, 713, 269, 49407],
+ decoded: "<|startoftext|>a 'll!!to?' d '' d of, can 't. <|endoftext|>",
+ },
+ },
+};
diff --git a/tests/models/deberta-v2/tokenization.js b/tests/models/deberta-v2/tokenization.js
new file mode 100644
index 000000000..177502340
--- /dev/null
+++ b/tests/models/deberta-v2/tokenization.js
@@ -0,0 +1,304 @@
+import { DebertaV2Tokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS, BERT_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = DebertaV2Tokenizer;
+export const TEST_CONFIG = {
+ "Xenova/nli-deberta-v3-small": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["\u2581How", "\u2581are", "\u2581you", "\u2581doing", "?"],
+ ids: [1, 577, 281, 274, 653, 302, 2],
+ decoded: "[CLS] How are you doing?[SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["\u2581You", "\u2581should", "'", "ve", "\u2581done", "\u2581this"],
+ ids: [1, 367, 403, 280, 415, 619, 291, 2],
+ decoded: "[CLS] You should've done this[SEP]",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ tokens: ["\u25810", "123456", "789", "\u25810", "\u25811", "\u25812", "\u25813", "\u25814", "\u25815", "\u25816", "\u25817", "\u25818", "\u25819", "\u258110", "\u2581100", "\u25811000"],
+ ids: [1, 767, 120304, 51535, 767, 376, 392, 404, 453, 456, 525, 574, 578, 712, 466, 803, 4985, 2],
+ decoded: "[CLS] 0123456789 0 1 2 3 4 5 6 7 8 9 10 100 1000[SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["\u2581The", "\u2581company", "\u2581was", "\u2581founded", "\u2581in", "\u25812016", "."],
+ ids: [1, 279, 483, 284, 3679, 267, 892, 260, 2],
+ decoded: "[CLS] The company was founded in 2016.[SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["\u2581A", "\u2581'", "ll", "\u2581!", "!", "to", "?", "'", "d", "'", "'", "d", "\u2581of", ",", "\u2581can", "'", "t", "."],
+ ids: [1, 336, 382, 436, 1084, 300, 725, 302, 280, 407, 280, 280, 407, 265, 261, 295, 280, 297, 260, 2],
+ decoded: "[CLS] A 'll!!to?'d''d of, can't.[SEP]",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["\u2581def", "\u2581main", "(", ")", ":", "\u2581pass"],
+ ids: [1, 23097, 872, 555, 285, 294, 1633, 2],
+ decoded: "[CLS] def main(): pass[SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["\u2581let", "\u2581a", "\u2581=", "\u2581obj", ".", "to", "String", "(", ")", ";", "\u2581to", "String", "(", ")", ";"],
+ ids: [1, 678, 266, 1842, 68215, 260, 725, 29867, 555, 285, 346, 264, 29867, 555, 285, 346, 2],
+ decoded: "[CLS] let a = obj.toString(); toString();[SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["\u2581This", "\u2581is", "\u2581a", "\u2581test", "."],
+ ids: [1, 329, 269, 266, 1010, 260, 2],
+ decoded: "[CLS] This is a test.[SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["\u2581UN", "want", "\u00e9", "d", ",", "running"],
+ ids: [1, 4647, 27364, 5858, 407, 261, 15243, 2],
+ decoded: "[CLS] UNwant\u00e9d,running[SEP]",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["\u25811", "\u0000", "2", "\u25813"],
+ ids: [1, 376, 3, 445, 404, 2],
+ decoded: "[CLS] 1[UNK]2 3[SEP]",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["\u2581Hello", "\u2581World"],
+ ids: [1, 5365, 964, 2],
+ decoded: "[CLS] Hello World[SEP]",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["\u2581hello", "\u2581world"],
+ ids: [1, 12018, 447, 2],
+ decoded: "[CLS] hello world[SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u2581", "\u751f", "\u6d3b", "\u7684", "\u771f", "\u8c1b", "\u662f"],
+ ids: [1, 507, 41065, 101952, 9301, 98186, 3, 30060, 2],
+ decoded: "[CLS] \u751f\u6d3b\u7684\u771f[UNK]\u662f[SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["\u2581leading", "\u2581space"],
+ ids: [1, 1249, 754, 2],
+ decoded: "[CLS] leading space[SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["\u2581trailing", "\u2581space"],
+ ids: [1, 18347, 754, 2],
+ decoded: "[CLS] trailing space[SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["\u2581Hi", "\u2581Hello"],
+ ids: [1, 2684, 5365, 2],
+ decoded: "[CLS] Hi Hello[SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["\u2581test", "\u2581$", "1", "\u2581R", "2", "\u2581#", "3", "\u2581\u20ac4", "\u2581\u00a35", "\u2581\u00a5", "6", "\u2581", "\u20a3", "7", "\u2581\u20b9", "8", "\u2581\u20b1", "9", "\u2581test"],
+ ids: [1, 1010, 419, 435, 909, 445, 953, 508, 56238, 14636, 56478, 765, 507, 3, 819, 34880, 804, 121499, 1088, 1010, 2],
+ decoded: "[CLS] test $1 R2 #3 \u20ac4 \u00a35 \u00a56 [UNK]7 \u20b98 \u20b19 test[SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["\u2581I", "\u2581bought", "\u2581an", "\u2581apple", "\u2581for", "\u2581$", "1", ".", "00", "\u2581at", "\u2581the", "\u2581store", "."],
+ ids: [1, 273, 2031, 299, 6038, 270, 419, 435, 260, 962, 288, 262, 1106, 260, 2],
+ decoded: "[CLS] I bought an apple for $1.00 at the store.[SEP]",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["\u2581you", ".", ".", "."],
+ ids: [1, 274, 260, 260, 260, 2],
+ decoded: "[CLS] you...[SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["\u2581you", ".", ".", "."],
+ ids: [1, 274, 260, 260, 260, 2],
+ decoded: "[CLS] you...[SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["\u2581you", ".", ".", ".", "\u2581you", ".", ".", "."],
+ ids: [1, 274, 260, 260, 260, 274, 260, 260, 260, 2],
+ decoded: "[CLS] you... you...[SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["\u2581weird", "\u2581", "\uff5e", "\u2581edge", "\u2581", "\uff5e", "\u2581case"],
+ ids: [1, 4926, 507, 96622, 2363, 507, 96622, 571, 2],
+ decoded: "[CLS] weird \uff5e edge \uff5e case[SEP]",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["\u2581This", "\u2581is", "\u2581a", "\u2581test", "\u2581."],
+ ids: [1, 329, 269, 266, 1010, 323, 2],
+ decoded: "[CLS] This is a test.[SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\u2581\ud83d\ude02", "\u2581", "\ud83d\udc4d", "\u2581", "\ud83e\udd23", "\u2581", "\ud83d\ude0d", "\u2581", "\ud83d\ude2d", "\u2581", "\ud83c\udf89", "\u2581", "\ud83d\ude4f", "\u2581\ud83d\ude0a", "\u2581\ud83d\udd25", "\u2581", "\ud83d\ude01", "\u2581", "\ud83d\ude05", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\ude06", "\u2581", "\ud83d\udc4f", "\u2581\u2764", "\ufe0f", "\u2581", "\ud83d\udc9c", "\u2581", "\ud83d\udc9a", "\u2581", "\ud83d\udc97", "\u2581", "\ud83d\udc99", "\u2581", "\ud83d\udda4", "\u2581", "\ud83d\ude0e", "\u2581", "\ud83d\udc4c", "\u2581", "\ud83e\udd73", "\u2581", "\ud83d\udcaa", "\u2581", "\u2728", "\u2581", "\ud83d\udc49", "\u2581", "\ud83d\udc40", "\u2581", "\ud83d\udcaf", "\u2581", "\ud83c\udf88", "\u2581", "\ud83d\ude48", "\u2581", "\ud83d\ude4c", "\u2581", "\ud83d\udc80", "\u2581", "\ud83d\udc47", "\u2581", "\ud83d\udc4b", "\u2581\u2705", "\u2581", "\ud83c\udf81", "\u2581", "\ud83c\udf1e", "\u2581", "\ud83c\udf38", "\u2581", "\ud83d\udcb0"],
+ ids: [1, 97504, 507, 117545, 507, 123057, 507, 96353, 507, 123058, 507, 123169, 507, 121772, 109976, 115475, 507, 122874, 507, 124017, 507, 123983, 507, 123571, 507, 122632, 49509, 25377, 507, 123614, 507, 124105, 507, 124077, 507, 123384, 507, 124382, 507, 123340, 507, 123492, 507, 3, 507, 123306, 507, 110119, 507, 122633, 507, 123659, 507, 123765, 507, 125799, 507, 124322, 507, 122878, 507, 125843, 507, 124011, 507, 125021, 88523, 507, 124698, 507, 125612, 507, 123887, 507, 123979, 2],
+ decoded: "[CLS] \ud83d\ude02 \ud83d\udc4d \ud83e\udd23 \ud83d\ude0d \ud83d\ude2d \ud83c\udf89 \ud83d\ude4f \ud83d\ude0a \ud83d\udd25 \ud83d\ude01 \ud83d\ude05 \ud83e\udd17 \ud83d\ude06 \ud83d\udc4f \u2764\ufe0f \ud83d\udc9c \ud83d\udc9a \ud83d\udc97 \ud83d\udc99 \ud83d\udda4 \ud83d\ude0e \ud83d\udc4c [UNK] \ud83d\udcaa \u2728 \ud83d\udc49 \ud83d\udc40 \ud83d\udcaf \ud83c\udf88 \ud83d\ude48 \ud83d\ude4c \ud83d\udc80 \ud83d\udc47 \ud83d\udc4b \u2705 \ud83c\udf81 \ud83c\udf1e \ud83c\udf38 \ud83d\udcb0[SEP]",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u2581", "\u2728", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\udc41", "\ufe0f", "\u2581", "\ud83d\udc71", "\ud83c\udffb", "\u2581", "\ud83d\udd75", "\u2581", "\u2642", "\ufe0f", "\u2581", "\ud83e\uddd9", "\ud83c\udffb", "\u2581", "\u2642", "\u2581", "\ud83d\udc68", "\ud83c\udffb", "\u2581", "\ud83c\udf3e", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83d\udc69", "\u2581\u2764", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc67", "\u2581", "\ud83d\udc66", "\u2581", "\ud83e\uddd1", "\ud83c\udffb", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1", "\ud83c\udffb", "\u2581", "\ud83c\udff4", "\udb40\udc67\udb40\udc62\udb40\udc65\udb40\udc6e\udb40\udc67\udb40\udc7f", "\u2581", "\ud83d\udc68", "\ud83c\udffb", "\u2581\u2764", "\ufe0f", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68", "\ud83c\udffc"],
+ ids: [1, 507, 110119, 507, 123983, 507, 127294, 25377, 507, 3, 108391, 507, 3, 507, 117868, 25377, 507, 3, 108391, 507, 117868, 507, 125199, 108391, 507, 3, 507, 3, 507, 3, 507, 3, 507, 124709, 49509, 507, 124327, 507, 125199, 507, 124709, 507, 124709, 507, 126640, 507, 126853, 507, 3, 108391, 507, 3, 507, 3, 108391, 507, 126132, 3, 507, 125199, 108391, 49509, 25377, 507, 124327, 507, 125199, 118155, 2],
+ decoded: "[CLS] \u2728 \ud83e\udd17 \ud83d\udc41\ufe0f [UNK]\ud83c\udffb [UNK] \u2642\ufe0f [UNK]\ud83c\udffb \u2642 \ud83d\udc68\ud83c\udffb [UNK] [UNK] [UNK] [UNK] \ud83d\udc69 \u2764 \ud83d\udc8b \ud83d\udc68 \ud83d\udc69 \ud83d\udc69 \ud83d\udc67 \ud83d\udc66 [UNK]\ud83c\udffb [UNK] [UNK]\ud83c\udffb \ud83c\udff4[UNK] \ud83d\udc68\ud83c\udffb \u2764\ufe0f \ud83d\udc8b \ud83d\udc68\ud83c\udffc[SEP]",
+ },
+ CHINESE_LATIN_MIXED: {
+ text: BERT_TEST_STRINGS.CHINESE_LATIN_MIXED,
+ tokens: ["\u2581a", "h", "\u535a", "\u63a8", "zz"],
+ ids: [1, 266, 1537, 122598, 111743, 23260, 2],
+ decoded: "[CLS] ah\u535a\u63a8zz[SEP]",
+ },
+ SIMPLE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.SIMPLE_WITH_ACCENTS,
+ tokens: ["\u2581H\u00e9", "llo"],
+ ids: [1, 93519, 25341, 2],
+ decoded: "[CLS] H\u00e9llo[SEP]",
+ },
+ MIXED_CASE_WITHOUT_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITHOUT_ACCENTS,
+ tokens: ["\u2581He", "LL", "o", "!", "how", "\u2581Are", "\u2581yo", "U", "?"],
+ ids: [1, 383, 17145, 795, 300, 5608, 1396, 14469, 2628, 302, 2],
+ decoded: "[CLS] HeLLo!how Are yoU?[SEP]",
+ },
+ MIXED_CASE_WITH_ACCENTS: {
+ text: BERT_TEST_STRINGS.MIXED_CASE_WITH_ACCENTS,
+ tokens: ["\u2581H\u00e4", "LL", "o", "!", "how", "\u2581Are", "\u2581yo", "U", "?"],
+ ids: [1, 62693, 17145, 795, 300, 5608, 1396, 14469, 2628, 302, 2],
+ decoded: "[CLS] H\u00e4LLo!how Are yoU?[SEP]",
+ },
+ },
+ "Xenova/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["\u2581How", "\u2581are", "\u2581you", "\u2581do", "ing", "?"],
+ ids: [1, 5101, 419, 522, 343, 348, 292, 2],
+ decoded: "[CLS] How are you doing?[SEP]",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ tokens: ["\u2581", "0123456789", "\u25810", "\u25811", "\u25812", "\u25813", "\u25814", "\u25815", "\u25816", "\u25817", "\u25818", "\u25819", "\u258110", "\u2581100", "\u25811000"],
+ ids: [1, 260, 170160, 498, 334, 357, 382, 420, 431, 571, 618, 631, 775, 476, 967, 3884, 2],
+ decoded: "[CLS] 0123456789 0 1 2 3 4 5 6 7 8 9 10 100 1000[SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["\u2581The", "\u2581company", "\u2581was", "\u2581found", "ed", "\u2581in", "\u25812016."],
+ ids: [1, 487, 5836, 640, 5898, 346, 282, 13792, 2],
+ decoded: "[CLS] The company was founded in 2016.[SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["\u2581A", "\u2581", "'", "ll", "\u2581", "!!", "to", "?", "'", "d", "''", "d", "\u2581of", ",", "\u2581can", "'", "t", "."],
+ ids: [1, 299, 260, 278, 1579, 260, 1524, 477, 292, 278, 286, 4461, 286, 305, 262, 739, 278, 271, 261, 2],
+ decoded: "[CLS] A 'll!!to?'d''d of, can't.[SEP]",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["\u2581de", "f", "\u2581main", "():", "\u2581pass"],
+ ids: [1, 270, 368, 4398, 78612, 4748, 2],
+ decoded: "[CLS] def main(): pass[SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["\u2581let", "\u2581", "a", "\u2581", "=", "\u2581obj", ".", "toString", "();", "\u2581", "toString", "();"],
+ ids: [1, 3257, 260, 263, 260, 350, 50670, 261, 64577, 1994, 260, 64577, 1994, 2],
+ decoded: "[CLS] let a = obj.toString(); toString();[SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["\u2581This", "\u2581is", "\u2581", "a", "\u2581test", "."],
+ ids: [1, 1495, 340, 260, 263, 2979, 261, 2],
+ decoded: "[CLS] This is a test.[SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["\u2581UN", "wan", "t\u00e9", "d", ",", "running"],
+ ids: [1, 10970, 3016, 3986, 286, 262, 170565, 2],
+ decoded: "[CLS] UNwant\u00e9d,running[SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u2581", "\u751f\u6d3b\u7684", "\u771f", "\u8c1b", "\u662f"],
+ ids: [1, 260, 197263, 7275, 241962, 1544, 2],
+ decoded: "[CLS] \u751f\u6d3b\u7684\u771f\u8c1b\u662f[SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["\u2581", "leading", "\u2581space"],
+ ids: [1, 260, 22120, 11496, 2],
+ decoded: "[CLS] leading space[SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["\u2581trail", "ing", "\u2581space"],
+ ids: [1, 66699, 348, 11496, 2],
+ decoded: "[CLS] trailing space[SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["\u2581test", "\u2581$1", "\u2581R", "2", "\u2581#3", "\u2581\u20ac4", "\u2581\u00a35", "\u2581\u00a5", "6", "\u2581", "\u20a3", "7", "\u2581\u20b9", "8", "\u2581", "\u20b1", "9", "\u2581test"],
+ ids: [1, 2979, 21793, 532, 339, 19403, 157186, 156260, 33481, 452, 260, 242687, 488, 39568, 450, 260, 211232, 496, 2979, 2],
+ decoded: "[CLS] test $1 R2 #3 \u20ac4 \u00a35 \u00a56 \u20a37 \u20b98 \u20b19 test[SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["\u2581I", "\u2581b", "ought", "\u2581an", "\u2581apple", "\u2581for", "\u2581$", "1.00", "\u2581at", "\u2581the", "\u2581store", "."],
+ ids: [1, 337, 331, 22280, 462, 44791, 333, 1161, 42645, 345, 288, 5318, 261, 2],
+ decoded: "[CLS] I bought an apple for $1.00 at the store.[SEP]",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["\u2581you", "..."],
+ ids: [1, 522, 303, 2],
+ decoded: "[CLS] you...[SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["\u2581you", "..."],
+ ids: [1, 522, 303, 2],
+ decoded: "[CLS] you...[SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["\u2581you", "...", "\u2581you", "..."],
+ ids: [1, 522, 303, 522, 303, 2],
+ decoded: "[CLS] you... you...[SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["\u2581w", "eird", "\u2581", "\uff5e", "\u2581edge", "\u2581", "\uff5e", "\u2581case"],
+ ids: [1, 415, 116640, 260, 2790, 53876, 260, 2790, 4073, 2],
+ decoded: "[CLS] weird \uff5e edge \uff5e case[SEP]",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["\u2581This", "\u2581is", "\u2581", "a", "\u2581test", "\u2581", "."],
+ ids: [1, 1495, 340, 260, 263, 2979, 260, 261, 2],
+ decoded: "[CLS] This is a test.[SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["\u2581", "\ud83d\ude02", "\u2581", "\ud83d\udc4d", "\u2581", "\ud83e\udd23", "\u2581", "\ud83d\ude0d", "\u2581", "\ud83d\ude2d", "\u2581", "\ud83c\udf89", "\u2581", "\ud83d\ude4f", "\u2581", "\ud83d\ude0a", "\u2581", "\ud83d\udd25", "\u2581", "\ud83d\ude01", "\u2581", "\ud83d\ude05", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\ude06", "\u2581", "\ud83d\udc4f", "\u2581\u2764", "\ufe0f", "\u2581", "\ud83d\udc9c", "\u2581", "\ud83d\udc9a", "\u2581", "\ud83d\udc97", "\u2581", "\ud83d\udc99", "\u2581", "\ud83d\udda4", "\u2581", "\ud83d\ude0e", "\u2581", "\ud83d\udc4c", "\u2581", "\ud83e\udd73", "\u2581", "\ud83d\udcaa", "\u2581", "\u2728", "\u2581\ud83d\udc49", "\u2581", "\ud83d\udc40", "\u2581", "\ud83d\udcaf", "\u2581", "\ud83c\udf88", "\u2581", "\ud83d\ude48", "\u2581", "\ud83d\ude4c", "\u2581", "\ud83d\udc80", "\u2581", "\ud83d\udc47", "\u2581", "\ud83d\udc4b", "\u2581\u2705", "\u2581", "\ud83c\udf81", "\u2581", "\ud83c\udf1e", "\u2581", "\ud83c\udf38", "\u2581", "\ud83d\udcb0"],
+ ids: [1, 260, 116844, 260, 72330, 260, 160951, 260, 78796, 260, 180546, 260, 212774, 260, 102930, 260, 71509, 260, 96089, 260, 137652, 260, 194608, 260, 182033, 260, 164467, 260, 149267, 56787, 4668, 260, 210251, 260, 195202, 260, 178523, 260, 167604, 260, 236081, 260, 157800, 260, 162843, 260, 242580, 260, 174590, 260, 65271, 113700, 260, 239652, 260, 237474, 260, 240937, 260, 239131, 260, 216701, 260, 242618, 260, 133395, 260, 240645, 82147, 260, 49599, 260, 239888, 260, 152102, 260, 239168, 2],
+ decoded: "[CLS] \ud83d\ude02 \ud83d\udc4d \ud83e\udd23 \ud83d\ude0d \ud83d\ude2d \ud83c\udf89 \ud83d\ude4f \ud83d\ude0a \ud83d\udd25 \ud83d\ude01 \ud83d\ude05 \ud83e\udd17 \ud83d\ude06 \ud83d\udc4f \u2764\ufe0f \ud83d\udc9c \ud83d\udc9a \ud83d\udc97 \ud83d\udc99 \ud83d\udda4 \ud83d\ude0e \ud83d\udc4c \ud83e\udd73 \ud83d\udcaa \u2728 \ud83d\udc49 \ud83d\udc40 \ud83d\udcaf \ud83c\udf88 \ud83d\ude48 \ud83d\ude4c \ud83d\udc80 \ud83d\udc47 \ud83d\udc4b \u2705 \ud83c\udf81 \ud83c\udf1e \ud83c\udf38 \ud83d\udcb0[SEP]",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["\u2581", "\u2728", "\u2581", "\ud83e\udd17", "\u2581", "\ud83d\udc41", "\ufe0f", "\u2581", "\ud83d\udc71", "\ud83c\udffb", "\u2581", "\ud83d\udd75", "\u2581", "\u2642", "\ufe0f", "\u2581", "\ud83e\uddd9", "\ud83c\udffb", "\u2581", "\u2642", "\u2581", "\ud83d\udc68", "\ud83c\udffb", "\u2581", "\ud83c\udf3e", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1", "\u2581", "\ud83d\udc69", "\u2581\u2764", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc69", "\u2581", "\ud83d\udc67", "\u2581", "\ud83d\udc66", "\u2581", "\ud83e\uddd1", "\ud83c\udffb", "\u2581", "\ud83e\udd1d", "\u2581", "\ud83e\uddd1", "\ud83c\udffb", "\u2581", "\ud83c\udff4", "\udb40\udc67", "\udb40\udc62", "\udb40\udc65", "\udb40\udc6e", "\udb40\udc67", "\udb40\udc7f", "\u2581", "\ud83d\udc68", "\ud83c\udffb", "\u2581\u2764", "\ufe0f", "\u2581", "\ud83d\udc8b", "\u2581", "\ud83d\udc68", "\ud83c\udffc"],
+ ids: [1, 260, 65271, 260, 182033, 260, 16307, 4668, 260, 244774, 75846, 260, 247133, 260, 50622, 4668, 260, 3, 75846, 260, 50622, 260, 239432, 75846, 260, 243052, 260, 244250, 260, 243394, 260, 244250, 260, 239098, 56787, 260, 223802, 260, 239432, 260, 239098, 260, 239098, 260, 241727, 260, 242446, 260, 244250, 75846, 260, 243394, 260, 244250, 75846, 260, 244177, 245994, 247023, 248837, 248531, 245994, 245953, 260, 239432, 75846, 56787, 4668, 260, 223802, 260, 239432, 159667, 2],
+ decoded: "[CLS] \u2728 \ud83e\udd17 \ud83d\udc41\ufe0f \ud83d\udc71\ud83c\udffb \ud83d\udd75 \u2642\ufe0f [UNK]\ud83c\udffb \u2642 \ud83d\udc68\ud83c\udffb \ud83c\udf3e \ud83e\uddd1 \ud83e\udd1d \ud83e\uddd1 \ud83d\udc69 \u2764 \ud83d\udc8b \ud83d\udc68 \ud83d\udc69 \ud83d\udc69 \ud83d\udc67 \ud83d\udc66 \ud83e\uddd1\ud83c\udffb \ud83e\udd1d \ud83e\uddd1\ud83c\udffb \ud83c\udff4\udb40\udc67\udb40\udc62\udb40\udc65\udb40\udc6e\udb40\udc67\udb40\udc7f \ud83d\udc68\ud83c\udffb \u2764\ufe0f \ud83d\udc8b \ud83d\udc68\ud83c\udffc[SEP]",
+ },
+ },
+};
diff --git a/tests/models/distilbert/tokenization.js b/tests/models/distilbert/tokenization.js
new file mode 100644
index 000000000..5fc1f3b93
--- /dev/null
+++ b/tests/models/distilbert/tokenization.js
@@ -0,0 +1,306 @@
+import { DistilBertTokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = DistilBertTokenizer;
+export const TEST_CONFIG = {
+ "Xenova/distilbert-base-cased-distilled-squad": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["How", "are", "you", "doing", "?"],
+ ids: [101, 1731, 1132, 1128, 1833, 136, 102],
+ decoded: "[CLS] How are you doing? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["You", "should", "'", "ve", "done", "this"],
+ ids: [101, 1192, 1431, 112, 1396, 1694, 1142, 102],
+ decoded: "[CLS] You should've done this [SEP]",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ tokens: ["01", "##23", "##45", "##6", "##7", "##8", "##9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", "1000"],
+ ids: [101, 5187, 22737, 21336, 1545, 1559, 1604, 1580, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 1275, 1620, 6087, 102],
+ decoded: "[CLS] 0123456789 0 1 2 3 4 5 6 7 8 9 10 100 1000 [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["The", "company", "was", "founded", "in", "2016", "."],
+ ids: [101, 1109, 1419, 1108, 1771, 1107, 1446, 119, 102],
+ decoded: "[CLS] The company was founded in 2016. [SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["A", "'", "ll", "!", "!", "to", "?", "'", "d", "'", "'", "d", "of", ",", "can", "'", "t", "."],
+ ids: [101, 138, 112, 1325, 106, 106, 1106, 136, 112, 173, 112, 112, 173, 1104, 117, 1169, 112, 189, 119, 102],
+ decoded: "[CLS] A'll!! to?'d'' d of, can't. [SEP]",
+ },
+ PYTHON_CODE: {
+ text: BASE_TEST_STRINGS.PYTHON_CODE,
+ tokens: ["def", "main", "(", ")", ":", "pass"],
+ ids: [101, 19353, 1514, 113, 114, 131, 2789, 102],
+ decoded: "[CLS] def main ( ) : pass [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "o", "##b", "##j", ".", "to", "##S", "##tring", "(", ")", ";", "to", "##S", "##tring", "(", ")", ";"],
+ ids: [101, 1519, 170, 134, 184, 1830, 3361, 119, 1106, 1708, 28108, 113, 114, 132, 1106, 1708, 28108, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. toString ( ) ; toString ( ) ; [SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["This", "is", "a", "test", "."],
+ ids: [101, 1188, 1110, 170, 2774, 119, 102],
+ decoded: "[CLS] This is a test. [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["UN", "##wan", "##t\u00e9", "##d", ",", "running"],
+ ids: [101, 7414, 5491, 14608, 1181, 117, 1919, 102],
+ decoded: "[CLS] UNwant\u00e9d, running [SEP]",
+ },
+ CONTROL_TOKENS: {
+ text: BASE_TEST_STRINGS.CONTROL_TOKENS,
+ tokens: ["123"],
+ ids: [101, 13414, 102],
+ decoded: "[CLS] 123 [SEP]",
+ },
+ HELLO_WORLD_TITLECASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
+ tokens: ["Hello", "World"],
+ ids: [101, 8667, 1291, 102],
+ decoded: "[CLS] Hello World [SEP]",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hello", "world"],
+ ids: [101, 19082, 1362, 102],
+ decoded: "[CLS] hello world [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "[UNK]", "[UNK]", "\u771f", "[UNK]", "[UNK]"],
+ ids: [101, 1056, 100, 100, 1061, 100, 100, 102],
+ decoded: "[CLS] \u751f [UNK] [UNK] \u771f [UNK] [UNK] [SEP]",
+ },
+ LEADING_SPACE: {
+ text: BASE_TEST_STRINGS.LEADING_SPACE,
+ tokens: ["leading", "space"],
+ ids: [101, 2020, 2000, 102],
+ decoded: "[CLS] leading space [SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["trailing", "space"],
+ ids: [101, 13161, 2000, 102],
+ decoded: "[CLS] trailing space [SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["Hi", "Hello"],
+ ids: [101, 8790, 8667, 102],
+ decoded: "[CLS] Hi Hello [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "R", "##2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [101, 2774, 109, 122, 155, 1477, 108, 124, 836, 1527, 202, 1571, 203, 1545, 100, 838, 1604, 837, 1580, 2774, 102],
+ decoded: "[CLS] test $ 1 R2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 \u20b19 test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "bought", "an", "apple", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 146, 3306, 1126, 12075, 1111, 109, 122, 119, 3135, 1120, 1103, 2984, 119, 102],
+ decoded: "[CLS] I bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["you", "\u2026"],
+ ids: [101, 1128, 795, 102],
+ decoded: "[CLS] you \u2026 [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["you", "\u2026"],
+ ids: [101, 1128, 795, 102],
+ decoded: "[CLS] you \u2026 [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["you", "\u2026", "you", "\u2026"],
+ ids: [101, 1128, 795, 1128, 795, 102],
+ decoded: "[CLS] you \u2026 you \u2026 [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["weird", "[UNK]", "edge", "[UNK]", "case"],
+ ids: [101, 6994, 100, 2652, 100, 1692, 102],
+ decoded: "[CLS] weird [UNK] edge [UNK] case [SEP]",
+ },
+ SPIECE_UNDERSCORE: {
+ text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "."],
+ ids: [101, 100, 100, 100, 100, 100, 119, 102],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK]. [SEP]",
+ },
+ POPULAR_EMOJIS: {
+ text: BASE_TEST_STRINGS.POPULAR_EMOJIS,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ MULTIBYTE_EMOJIS: {
+ text: BASE_TEST_STRINGS.MULTIBYTE_EMOJIS,
+ tokens: ["[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[UNK]"],
+ ids: [101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 102],
+ decoded: "[CLS] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [SEP]",
+ },
+ },
+ "Xenova/distilbert-base-uncased-finetuned-sst-2-english": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ tokens: ["how", "are", "you", "doing", "?"],
+ ids: [101, 2129, 2024, 2017, 2725, 1029, 102],
+ decoded: "[CLS] how are you doing? [SEP]",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ tokens: ["you", "should", "'", "ve", "done", "this"],
+ ids: [101, 2017, 2323, 1005, 2310, 2589, 2023, 102],
+ decoded: "[CLS] you should've done this [SEP]",
+ },
+ TEXT_WITH_NUMBERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
+ tokens: ["the", "company", "was", "founded", "in", "2016", "."],
+ ids: [101, 1996, 2194, 2001, 2631, 1999, 2355, 1012, 102],
+ decoded: "[CLS] the company was founded in 2016. [SEP]",
+ },
+ PUNCTUATION: {
+ text: BASE_TEST_STRINGS.PUNCTUATION,
+ tokens: ["a", "'", "ll", "!", "!", "to", "?", "'", "d", "'", "'", "d", "of", ",", "can", "'", "t", "."],
+ ids: [101, 1037, 1005, 2222, 999, 999, 2000, 1029, 1005, 1040, 1005, 1005, 1040, 1997, 1010, 2064, 1005, 1056, 1012, 102],
+ decoded: "[CLS] a'll!! to?'d'' d of, can't. [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "to", "##st", "##ring", "(", ")", ";", "to", "##st", "##ring", "(", ")", ";"],
+ ids: [101, 2292, 1037, 1027, 27885, 3501, 1012, 2000, 3367, 4892, 1006, 1007, 1025, 2000, 3367, 4892, 1006, 1007, 1025, 102],
+ decoded: "[CLS] let a = obj. tostring ( ) ; tostring ( ) ; [SEP]",
+ },
+ NEWLINES: {
+ text: BASE_TEST_STRINGS.NEWLINES,
+ tokens: ["this", "is", "a", "test", "."],
+ ids: [101, 2023, 2003, 1037, 3231, 1012, 102],
+ decoded: "[CLS] this is a test. [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["unwanted", ",", "running"],
+ ids: [101, 18162, 1010, 2770, 102],
+ decoded: "[CLS] unwanted, running [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "[UNK]", "\u7684", "\u771f", "[UNK]", "[UNK]"],
+ ids: [101, 1910, 100, 1916, 1921, 100, 100, 102],
+ decoded: "[CLS] \u751f [UNK] \u7684 \u771f [UNK] [UNK] [SEP]",
+ },
+ DOUBLE_SPACE: {
+ text: BASE_TEST_STRINGS.DOUBLE_SPACE,
+ tokens: ["hi", "hello"],
+ ids: [101, 7632, 7592, 102],
+ decoded: "[CLS] hi hello [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "r", "##2", "#", "3", "\u20ac", "##4", "\u00a35", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "\u20b1", "##9", "test"],
+ ids: [101, 3231, 1002, 1015, 1054, 2475, 1001, 1017, 1574, 2549, 27813, 1071, 2575, 100, 1576, 2620, 1575, 2683, 3231, 102],
+ decoded: "[CLS] test $ 1 r2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 \u20b19 test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["i", "bought", "an", "apple", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 1045, 4149, 2019, 6207, 2005, 1002, 1015, 1012, 4002, 2012, 1996, 3573, 1012, 102],
+ decoded: "[CLS] i bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["weird", "\uff5e", "edge", "\uff5e", "case"],
+ ids: [101, 6881, 1995, 3341, 1995, 2553, 102],
+ decoded: "[CLS] weird \uff5e edge \uff5e case [SEP]",
+ },
+ },
+ "Xenova/distiluse-base-multilingual-cased-v2": {
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ tokens: ["012", "##34", "##5", "##6", "##7", "##8", "##9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", "1000"],
+ ids: [101, 69878, 78301, 11166, 11211, 11305, 11396, 11373, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 10150, 10407, 12186, 102],
+ decoded: "[CLS] 0123456789 0 1 2 3 4 5 6 7 8 9 10 100 1000 [SEP]",
+ },
+ JAVASCRIPT_CODE: {
+ text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
+ tokens: ["let", "a", "=", "ob", "##j", ".", "to", "##S", "##trin", "##g", "(", ")", ";", "to", "##S", "##trin", "##g", "(", ")", ";"],
+ ids: [101, 13595, 169, 134, 17339, 10418, 119, 10114, 10731, 109163, 10240, 113, 114, 132, 10114, 10731, 109163, 10240, 113, 114, 132, 102],
+ decoded: "[CLS] let a = obj. toString ( ) ; toString ( ) ; [SEP]",
+ },
+ BASIC: {
+ text: BASE_TEST_STRINGS.BASIC,
+ tokens: ["UN", "##want", "##\u00e9d", ",", "running"],
+ ids: [101, 26578, 104216, 84193, 117, 18020, 102],
+ decoded: "[CLS] UNwant\u00e9d, running [SEP]",
+ },
+ HELLO_WORLD_LOWERCASE: {
+ text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
+ tokens: ["hell", "##o", "world"],
+ ids: [101, 61694, 10133, 11356, 102],
+ decoded: "[CLS] hello world [SEP]",
+ },
+ CHINESE_ONLY: {
+ text: BASE_TEST_STRINGS.CHINESE_ONLY,
+ tokens: ["\u751f", "\u6d3b", "\u7684", "\u771f", "\u8c1b", "\u662f"],
+ ids: [101, 5600, 4978, 5718, 5769, 7378, 4380, 102],
+ decoded: "[CLS] \u751f \u6d3b \u7684 \u771f \u8c1b \u662f [SEP]",
+ },
+ TRAILING_SPACE: {
+ text: BASE_TEST_STRINGS.TRAILING_SPACE,
+ tokens: ["trail", "##ing", "space"],
+ ids: [101, 56559, 10230, 16199, 102],
+ decoded: "[CLS] trailing space [SEP]",
+ },
+ CURRENCY: {
+ text: BASE_TEST_STRINGS.CURRENCY,
+ tokens: ["test", "$", "1", "R2", "#", "3", "\u20ac", "##4", "\u00a3", "##5", "\u00a5", "##6", "[UNK]", "\u20b9", "##8", "[UNK]", "test"],
+ ids: [101, 15839, 109, 122, 94000, 108, 124, 1775, 11011, 201, 11166, 202, 11211, 100, 1776, 11396, 100, 15839, 102],
+ decoded: "[CLS] test $ 1 R2 # 3 \u20ac4 \u00a35 \u00a56 [UNK] \u20b98 [UNK] test [SEP]",
+ },
+ CURRENCY_WITH_DECIMALS: {
+ text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
+ tokens: ["I", "bought", "an", "app", "##le", "for", "$", "1", ".", "00", "at", "the", "store", "."],
+ ids: [101, 146, 28870, 10151, 72894, 10284, 10142, 109, 122, 119, 11025, 10160, 10105, 13708, 119, 102],
+ decoded: "[CLS] I bought an apple for $ 1. 00 at the store. [SEP]",
+ },
+ ELLIPSIS: {
+ text: BASE_TEST_STRINGS.ELLIPSIS,
+ tokens: ["you", "[UNK]"],
+ ids: [101, 13028, 100, 102],
+ decoded: "[CLS] you [UNK] [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
+ tokens: ["you", "[UNK]"],
+ ids: [101, 13028, 100, 102],
+ decoded: "[CLS] you [UNK] [SEP]",
+ },
+ TEXT_WITH_ESCAPE_CHARACTERS_2: {
+ text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
+ tokens: ["you", "[UNK]", "you", "[UNK]"],
+ ids: [101, 13028, 100, 13028, 100, 102],
+ decoded: "[CLS] you [UNK] you [UNK] [SEP]",
+ },
+ TILDE_NORMALIZATION: {
+ text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
+ tokens: ["wei", "##rd", "\uff5e", "edge", "\uff5e", "case"],
+ ids: [101, 86981, 12023, 10096, 30599, 10096, 13474, 102],
+ decoded: "[CLS] weird \uff5e edge \uff5e case [SEP]",
+ },
+ },
+};
diff --git a/tests/models/esm/tokenization.js b/tests/models/esm/tokenization.js
new file mode 100644
index 000000000..c072d5251
--- /dev/null
+++ b/tests/models/esm/tokenization.js
@@ -0,0 +1,322 @@
+import { EsmTokenizer } from "../../../src/tokenizers.js";
+import { BASE_TEST_STRINGS, ESM_TEST_STRINGS } from "../test_strings.js";
+
+export const TOKENIZER_CLASS = EsmTokenizer;
+export const TEST_CONFIG = {
+ "Xenova/nucleotide-transformer-500m-human-ref": {
+ SIMPLE: {
+ text: BASE_TEST_STRINGS.SIMPLE,
+ // "tokens": ["How", "are", "you", "doing?"],
+ ids: [3, 0, 0, 0, 0],
+ decoded: " ",
+ },
+ SIMPLE_WITH_PUNCTUATION: {
+ text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
+ // "tokens": ["You", "should've", "done", "this"],
+ ids: [3, 0, 0, 0, 0],
+ decoded: " ",
+ },
+ NUMBERS: {
+ text: BASE_TEST_STRINGS.NUMBERS,
+ // "tokens": ["0123456789", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", "1000"],
+ ids: [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ decoded: "