Releases: huggingface/transformers.js
3.1.0
🚀 Transformers.js v3.1 — any-to-any, text-to-image, image-to-text, pose estimation, time series forecasting, and more!
Table of contents:
- 🤖 New models: Janus, Qwen2-VL, JinaCLIP, LLaVA-OneVision, ViTPose, MGP-STR, PatchTST, PatchTSMixer.
- 🐛 Bug fixes
- 📝 Documentation improvements
- 🛠️ Other improvements
- 🤗 New contributors
🤖 New models: Janus, Qwen2-VL, JinaCLIP, LLaVA-OneVision, ViTPose, MGP-STR, PatchTST, PatchTSMixer.
Janus for Any-to-Any generation (e.g., image-to-text and text-to-image)
First of all, this release adds support for Janus, a novel autoregressive framework that unifies multimodal understanding and generation. The most popular model, deepseek-ai/Janus-1.3B, is tagged as an "any-to-any" model, and has specifically been trained for the following tasks:
Example: Image-Text-to-Text
import { AutoProcessor, MultiModalityCausalLM } from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Janus-1.3B-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await MultiModalityCausalLM.from_pretrained(model_id);
// Prepare inputs
const conversation = [
{
role: "User",
content: "<image_placeholder>\nConvert the formula into latex code.",
images: ["https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/quadratic_formula.png"],
},
];
const inputs = await processor(conversation);
// Generate response
const outputs = await model.generate({
...inputs,
max_new_tokens: 150,
do_sample: false,
});
// Decode output
const new_tokens = outputs.slice(null, [inputs.input_ids.dims.at(-1), null]);
const decoded = processor.batch_decode(new_tokens, { skip_special_tokens: true });
console.log(decoded[0]);
Sample output:
Sure, here is the LaTeX code for the given formula:
```
x = \frac{-b \pm \sqrt{b^2 - 4a c}}{2a}
```
This code represents the mathematical expression for the variable \( x \).
Example: Text-to-Image
import { AutoProcessor, MultiModalityCausalLM } from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Janus-1.3B-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await MultiModalityCausalLM.from_pretrained(model_id);
// Prepare inputs
const conversation = [
{
role: "User",
content: "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
},
];
const inputs = await processor(conversation, { chat_template: "text_to_image" });
// Generate response
const num_image_tokens = processor.num_image_tokens;
const outputs = await model.generate_images({
...inputs,
min_new_tokens: num_image_tokens,
max_new_tokens: num_image_tokens,
do_sample: true,
});
// Save the generated image
await outputs[0].save("test.png");
Sample outputs:
What to play around with the model? Check out our online WebGPU demo! 👇
Janus-WebGPU.mp4
Qwen2-VL for Image-Text-to-Text
Example: Image-Text-to-Text
Next, we added support for Qwen2-VL, the multimodal large language model series developed by Qwen team, Alibaba Cloud. It introduces the Naive Dynamic Resolution mechanism, allowing the model to process images of varying resolutions and leading to more efficient and accurate visual representations.
import { AutoProcessor, Qwen2VLForConditionalGeneration, RawImage } from "@huggingface/transformers";
// Load processor and model
const model_id = "onnx-community/Qwen2-VL-2B-Instruct";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await Qwen2VLForConditionalGeneration.from_pretrained(model_id);
// Prepare inputs
const url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg";
const image = await (await RawImage.read(url)).resize(448, 448);
const conversation = [
{
role: "user",
content: [
{ type: "image" },
{ type: "text", text: "Describe this image." },
],
},
];
const text = processor.apply_chat_template(conversation, { add_generation_prompt: true });
const inputs = await processor(text, image);
// Perform inference
const outputs = await model.generate({
...inputs,
max_new_tokens: 128,
});
// Decode output
const decoded = processor.batch_decode(
outputs.slice(null, [inputs.input_ids.dims.at(-1), null]),
{ skip_special_tokens: true },
);
console.log(decoded[0]);
// The image depicts a serene beach scene with a woman and a dog. The woman is sitting on the sand, wearing a plaid shirt, and appears to be engaged in a playful interaction with the dog. The dog, which is a large breed, is sitting on its hind legs and appears to be reaching out to the woman, possibly to give her a high-five or a paw. The background shows the ocean with gentle waves, and the sky is clear, suggesting it might be either sunrise or sunset. The overall atmosphere is calm and relaxed, capturing a moment of connection between the woman and the dog.
JinaCLIP for multimodal embeddings
JinaCLIP is a series of general-purpose multilingual multimodal embedding models for text & images, created by Jina AI.
Example: Compute text and/or image embeddings with jinaai/jina-clip-v2
:
import { AutoModel, AutoProcessor, RawImage, matmul } from "@huggingface/transformers";
// Load processor and model
const model_id = "jinaai/jina-clip-v2";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await AutoModel.from_pretrained(model_id, { dtype: "q4" /* e.g., "fp16", "q8", or "q4" */ });
// Prepare inputs
const urls = ["https://i.ibb.co/nQNGqL0/beach1.jpg", "https://i.ibb.co/r5w8hG8/beach2.jpg"];
const images = await Promise.all(urls.map(url => RawImage.read(url)));
const sentences = [
"غروب جميل على الشاطئ", // Arabic
"海滩上美丽的日落", // Chinese
"Un beau coucher de soleil sur la plage", // French
"Ein wunderschöner Sonnenuntergang am Strand", // German
"Ένα όμορφο ηλιοβασίλεμα πάνω από την παραλία", // Greek
"समुद्र तट पर एक खूबसूरत सूर्यास्त", // Hindi
"Un bellissimo tramonto sulla spiaggia", // Italian
"浜辺に沈む美しい夕日", // Japanese
"해변 위로 아름다운 일몰", // Korean
];
// Encode text and images
const inputs = await processor(sentences, images, { padding: true, truncation: true });
const { l2norm_text_embeddings, l2norm_image_embeddings } = await model(inputs);
// Encode query (text-only)
const query_prefix = "Represent the query for retrieving evidence documents: ";
const query_inputs = await processor(query_prefix + "beautiful sunset over the beach");
const { l2norm_text_embeddings: query_embeddings } = await model(query_inputs);
// Compute text-image similarity scores
const text_to_image_scores = await matmul(query_embeddings, l2norm_image_embeddings.transpose(1, 0));
console.log("text-image similarity scores", text_to_image_scores.tolist()[0]); // [0.29530206322669983, 0.3183615803718567]
// Compute image-image similarity scores
const image_to_image_score = await matmul(l2norm_image_embeddings[0], l2norm_image_embeddings[1]);
console.log("image-image similarity score", image_to_image_score.item()); // 0.9344457387924194
// Compute text-text similarity scores
const text_to_text_scores = await matmul(query_embeddings, l2norm_text_embeddings.transpose(1, 0));
console.log("text-text similarity scores", text_to_text_scores.tolist()[0]); // [0.5566609501838684, 0.7028406858444214, 0.582255482673645, 0.6648036241531372, 0.5462006330490112, 0.6791588068008423, 0.6192430257797241, 0.6258729100227356, 0.6453716158866882]
LLaVA-OneVision for Image-Text-to-Text
LLaVA-OneVision is a Vision-Language Model that can generate text conditioned on one or several images/videos. The model consists of SigLIP vision encoder and a Qwen2 language backbone.
Example: Multi-round conversations w/ PKV caching
import { AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration, RawImage } from '@huggingface/transformers';
// Load tokenizer, processor and model
const model_id = 'llava-hf/llava-onevision-qwen2-0.5b-ov-hf';
...
3.0.2
What's new?
-
Add support for MobileLLM in #1003
Example: Text generation with
onnx-community/MobileLLM-125M
.import { pipeline } from "@huggingface/transformers"; // Create a text generation pipeline const generator = await pipeline( "text-generation", "onnx-community/MobileLLM-125M", { dtype: "fp32" }, ); // Define the list of messages const text = "Q: What is the capital of France?\nA: Paris\nQ: What is the capital of England?\nA:"; // Generate a response const output = await generator(text, { max_new_tokens: 30 }); console.log(output[0].generated_text);
Example output
Q: What is the capital of France? A: Paris Q: What is the capital of England? A: London Q: What is the capital of Scotland? A: Edinburgh Q: What is the capital of Wales? A: Cardiff
-
Add support for OLMo in #1011
Example: Text generation with
onnx-community/AMD-OLMo-1B-SFT-DPO"
.import { pipeline } from "@huggingface/transformers"; // Create a text generation pipeline const generator = await pipeline( "text-generation", "onnx-community/AMD-OLMo-1B-SFT-DPO", { dtype: "q4" }, ); // Define the list of messages const messages = [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "Tell me a joke." }, ]; // Generate a response const output = await generator(messages, { max_new_tokens: 128 }); console.log(output[0].generated_text.at(-1).content);
Example output
Why don't scientists trust atoms? Because they make up everything!
-
Fix CommonJS bundling in #1012. Thanks @jens-ghc for reporting!
-
Remove duplicate
gemma
value fromNO_PER_CHANNEL_REDUCE_RANGE_MODEL
by @bekzod in #1005
🤗 New contributors
Full Changelog: 3.0.1...3.0.2
3.0.1
3.0.0
Transformers.js v3: WebGPU Support, New Models & Tasks, New Quantizations, Deno & Bun Compatibility, and More…
After more than a year of development, we're excited to announce the release of 🤗 Transformers.js v3!
You can get started by installing Transformers.js v3 from NPM using:
npm i @huggingface/transformers
Then, importing the library with
import { pipeline } from "@huggingface/transformers";
or, via a CDN
import { pipeline } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]";
For more information, check out the documentation.
⚡ WebGPU support (up to 100x faster than WASM!)
WebGPU is a new web standard for accelerated graphics and compute. The API enables web developers to use the underlying system's GPU to carry out high-performance computations directly in the browser. WebGPU is the successor to WebGL and provides significantly better performance, because it allows for more direct interaction with modern GPUs. Lastly, it supports general-purpose GPU computations, which makes it just perfect for machine learning!
Warning
As of October 2024, global WebGPU support is around 70% (according to caniuse.com), meaning some users may not be able to use the API.
If the following demos do not work in your browser, you may need to enable it using a feature flag:
Usage in Transformers.js v3
Thanks to our collaboration with ONNX Runtime Web, enabling WebGPU acceleration is as simple as setting device: 'webgpu'
when loading a model. Let's see some examples!
Example: Compute text embeddings on WebGPU (demo)
import { pipeline } from "@huggingface/transformers";
// Create a feature-extraction pipeline
const extractor = await pipeline(
"feature-extraction",
"mixedbread-ai/mxbai-embed-xsmall-v1",
{ device: "webgpu" },
});
// Compute embeddings
const texts = ["Hello world!", "This is an example sentence."];
const embeddings = await extractor(texts, { pooling: "mean", normalize: true });
console.log(embeddings.tolist());
// [
// [-0.016986183822155, 0.03228696808218956, -0.0013630966423079371, ... ],
// [0.09050482511520386, 0.07207386940717697, 0.05762749910354614, ... ],
// ]
Example: Perform automatic speech recognition with OpenAI whisper on WebGPU (demo)
import { pipeline } from "@huggingface/transformers";
// Create automatic speech recognition pipeline
const transcriber = await pipeline(
"automatic-speech-recognition",
"onnx-community/whisper-tiny.en",
{ device: "webgpu" },
);
// Transcribe audio from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav";
const output = await transcriber(url);
console.log(output);
// { text: ' And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.' }
Example: Perform image classification with MobileNetV4 on WebGPU (demo)
import { pipeline } from "@huggingface/transformers";
// Create image classification pipeline
const classifier = await pipeline(
"image-classification",
"onnx-community/mobilenetv4_conv_small.e2400_r224_in1k",
{ device: "webgpu" },
);
// Classify an image from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg";
const output = await classifier(url);
console.log(output);
// [
// { label: 'tiger, Panthera tigris', score: 0.6149784922599792 },
// { label: 'tiger cat', score: 0.30281734466552734 },
// { label: 'tabby, tabby cat', score: 0.0019135422771796584 },
// { label: 'lynx, catamount', score: 0.0012161266058683395 },
// { label: 'Egyptian cat', score: 0.0011465961579233408 }
// ]
🔢 New quantization formats (dtypes)
Before Transformers.js v3, we used the quantized
option to specify whether to use a quantized (q8) or full-precision (fp32) variant of the model by setting quantized
to true
or false
, respectively. Now, we've added the ability to select from a much larger list with the dtype
parameter.
The list of available quantizations depends on the model, but some common ones are: full-precision ("fp32"
), half-precision ("fp16"
), 8-bit ("q8"
, "int8"
, "uint8"
), and 4-bit ("q4"
, "bnb4"
, "q4f16"
).
Basic usage
Example: Run Qwen2.5-0.5B-Instruct in 4-bit quantization (demo)
import { pipeline } from "@huggingface/transformers";
// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/Qwen2.5-0.5B-Instruct",
{ dtype: "q4", device: "webgpu" },
);
// Define the list of messages
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Tell me a funny joke." },
];
// Generate a response
const output = await generator(messages, { max_new_tokens: 128 });
console.log(output[0].generated_text.at(-1).content);
Per-module dtypes
Some encoder-decoder models, like Whisper or Florence-2, are extremely sensitive to quantization settings: especially of the encoder. For this reason, we added the ability to select per-module dtypes, which can be done by providing a mapping from module name to dtype.
Example: Run Florence-2 on WebGPU (demo)
import { Florence2ForConditionalGeneration } from "@huggingface/transformers";
const model = await Florence2ForConditionalGeneration.from_pretrained(
"onnx-community/Florence-2-base-ft",
{
dtype: {
embed_tokens: "fp16",
vision_encoder: "fp16",
encoder_model: "q4",
decoder_model_merged: "q4",
},
device: "webgpu",
},
);
See full code example
import {
Florence2ForConditionalGeneration,
AutoProcessor,
AutoTokenizer,
RawImage,
} from "@huggingface/transformers";
// Load model, processor, and tokenizer
const model_id = "onnx-community/Florence-2-base-ft";
const model = await Florence2ForConditionalGeneration.from_pretrained(
model_id,
{
dtype: {
embed_tokens: "fp16",
vision_encoder: "fp16",
encoder_model: "q4",
decoder_model_merged: "q4",
},
device: "webgpu",
},
);
const processor = await AutoProcessor.from_pretrained(model_id);
const tokenizer = await AutoTokenizer.from_pretrained(model_id);
// Load image and prepare vision inputs
const url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const image = await RawImage.fromURL(url);
const vision_inputs = await processor(image);
// Specify task and prepare text inputs
const task = "<MORE_DETAILED_CAPTION>";
const prompts = processor.construct_prompts(task);
const text_inputs = tokenizer(prompts);
// Generate text
const generated_ids = await model.generate({
...text_inputs,
...vision_inputs,
max_new_tokens: 100,
});
// Decode generated text
const generated_text = tokenizer.batch_decode(generated_ids, {
skip_special_tokens: false,
})[0];
// Post-process the generated text
const result = processor.post_process_generation(
generated_text,
task,
image.size,
);
console.log(result);
// { '<MORE_DETAILED_CAPTION>': 'A green car is parked in front of a tan building. The building has a brown door and two brown windows. The car is a two door and the door is closed. The green car has black tires.' }
🏛 A total of 120 supported architectures
This release increases the total number of supported architectures to 120 (see full list), spanning a wide range of input modalities and tasks. Notable ...
2.17.2
🚀 What's new?
-
Add support for MobileViTv2 in #721
import { pipeline } from '@xenova/transformers'; // Create an image classification pipeline const classifier = await pipeline('image-classification', 'Xenova/mobilevitv2-1.0-imagenet1k-256', { quantized: false, }); // Classify an image const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg'; const output = await classifier(url); // [{ label: 'tiger, Panthera tigris', score: 0.6491137742996216 }]
See here for the full list of supported models.
-
Add support for FastViT in #749
import { pipeline } from '@xenova/transformers'; // Create an image classification pipeline const classifier = await pipeline('image-classification', 'Xenova/fastvit_t12.apple_in1k', { quantized: false }); // Classify an image const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg'; const output = await classifier(url, { topk: 5 }); // [ // { label: 'tiger, Panthera tigris', score: 0.6649345755577087 }, // { label: 'tiger cat', score: 0.12454754114151001 }, // { label: 'lynx, catamount', score: 0.0010689536575227976 }, // { label: 'dhole, Cuon alpinus', score: 0.0010422508930787444 }, // { label: 'silky terrier, Sydney silky', score: 0.0009548701345920563 } // ]
See here for the full list of supported models.
-
Optimize FFT in #766
-
Add sequence post processor in #771
-
Update pipelines.js to allow for
token_embeddings
as well by @NikhilVerma in #770 -
Remove old import from
stream/web
forReadableStream
in #752 -
docs: update vanilla-js.md by @eltociear in #738
-
Fix CI by in #768
-
Update Next.js demos to 14.2.3 in #772
🤗 New contributors
- @eltociear made their first contribution in #738
- @KTibow made their first contribution in #737
- @NawarA made their first contribution in #594
- @NikhilVerma made their first contribution in #770
Full Changelog: 2.17.1...2.17.2
2.17.1
2.17.0
What's new?
💬 Improved text-generation
pipeline for conversational models
This version adds support for passing an array of chat messages (with "role" and "content" properties) to the text-generation
pipeline (PR). Check out the list of supported models here.
Example: Chat with Xenova/Qwen1.5-0.5B-Chat
.
import { pipeline } from '@xenova/transformers';
// Create text-generation pipeline
const generator = await pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');
// Define the list of messages
const messages = [
{ role: 'system', content: 'You are a helpful assistant.' },
{ role: 'user', content: 'Tell me a funny joke.' }
]
// Generate text
const output = await generator(messages, {
max_new_tokens: 128,
do_sample: false,
})
console.log(output[0].generated_text);
// [
// { role: 'system', content: 'You are a helpful assistant.' },
// { role: 'user', content: 'Tell me a funny joke.' },
// { role: 'assistant', content: "Sure, here's one:\n\nWhy was the math book sad?\n\nBecause it had too many problems.\n\nI hope you found that joke amusing! Do you have any other questions or topics you'd like to discuss?" },
// ]
We also added the return_full_text
parameter, which means if you set return_full_text=false
, only the newly-generated tokens will be returned (only applicable if passing the raw text prompt to the pipeline).
🔢 Binary embedding quantization support
Transformers.js v2.17 adds two new parameters to the feature-extraction
pipeline ("quantize" and "precision"), enabling you to generate binary embeddings. These can be used with certain embedding models to shrink the size of the document embeddings for retrieval. This results in reductions in index size/memory usage (for storage) and improvements in retrieval speed. Surprisingly, you can still achieve up to ~95% of the original performance, but at 32x storage savings and up to 32x retrieval speeds! 🤯 Thanks to @jonathanpv for this addition in #691!
import { pipeline } from '@xenova/transformers';
// Create feature-extraction pipeline
const extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
// Compute binary embeddings
const output = await extractor('This is a simple test.', { pooling: 'mean', quantize: true, precision: 'binary' });
// Tensor {
// type: 'int8',
// data: Int8Array [49, 108, 24, ...],
// dims: [1, 48]
// }
As you can see, this produces a 32x smaller output tensor (a 4x reduction in data type with Float32Array → Int8Array, as well as an 8x reduction in dimensionality from 384 → 48). For more information, check out this PR in sentence-transformers, which inspired this update!
🛠️ Misc. improvements
🤗 New contributors
- @pulsejet made their first contribution in #667
- @jonathanpv made their first contribution in #691
Full Changelog: 2.16.1...2.17.0
2.16.1
What's new?
-
Add support for the
image-feature-extraction
pipeline in #650.Example: Perform image feature extraction with
Xenova/vit-base-patch16-224-in21k
.const image_feature_extractor = await pipeline('image-feature-extraction', 'Xenova/vit-base-patch16-224-in21k'); const url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png'; const features = await image_feature_extractor(url); // Tensor { // dims: [ 1, 197, 768 ], // type: 'float32', // data: Float32Array(151296) [ ... ], // size: 151296 // }
Example: Compute image embeddings with
Xenova/clip-vit-base-patch32
.const image_feature_extractor = await pipeline('image-feature-extraction', 'Xenova/clip-vit-base-patch32'); const url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png'; const features = await image_feature_extractor(url); // Tensor { // dims: [ 1, 512 ], // type: 'float32', // data: Float32Array(512) [ ... ], // size: 512 // }
-
Fix channel format when padding non-square images for certain models in #655. This means you can now perform super-resolution for non-square images with APISR models:
Example: Upscale an image with
Xenova/4x_APISR_GRL_GAN_generator-onnx
.import { pipeline } from '@xenova/transformers'; // Create image-to-image pipeline const upscaler = await pipeline('image-to-image', 'Xenova/4x_APISR_GRL_GAN_generator-onnx', { quantized: false, }); // Upscale an image const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/anime.png'; const output = await upscaler(url); // RawImage { // data: Uint8Array(16588800) [ ... ], // width: 2560, // height: 1920, // channels: 3 // } // (Optional) Save the upscaled image output.save('upscaled.png');
-
Update tokenizer
apply_chat_template
functionality in #647. This PR added functionality to support the new C4AI Command-R tokenizer.See example tool usage
import { AutoTokenizer } from "@xenova/transformers"; const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer") // define conversation input: const conversation = [ { role: "user", content: "Whats the biggest penguin in the world?" } ] // Define tools available for the model to use: const tools = [ { name: "internet_search", description: "Returns a list of relevant document snippets for a textual query retrieved from the internet", parameter_definitions: { query: { description: "Query to search the internet with", type: "str", required: true } } }, { name: "directly_answer", description: "Calls a standard (un-augmented) AI chatbot to generate a response given the conversation history", parameter_definitions: {} } ] // render the tool use prompt as a string: const tool_use_prompt = tokenizer.apply_chat_template( conversation, { chat_template: "tool_use", tokenize: false, add_generation_prompt: true, tools, } ) console.log(tool_use_prompt)
See example RAG usage
import { AutoTokenizer } from "@xenova/transformers"; const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer") // define conversation input: const conversation = [ { role: "user", content: "Whats the biggest penguin in the world?" } ] // define documents to ground on: const documents = [ { title: "Tall penguins", text: "Emperor penguins are the tallest growing up to 122 cm in height." }, { title: "Penguin habitats", text: "Emperor penguins only live in Antarctica." } ] // render the RAG prompt as a string: const grounded_generation_prompt = tokenizer.apply_chat_template( conversation, { chat_template: "rag", tokenize: false, add_generation_prompt: true, documents, citation_mode: "accurate", // or "fast" } ) console.log(grounded_generation_prompt);
-
Add support for EfficientNet in #639.
Example: Classify images with
chriamue/bird-species-classifier
import { pipeline } from '@xenova/transformers'; // Create image classification pipeline const classifier = await pipeline('image-classification', 'chriamue/bird-species-classifier', { quantized: false, // Quantized model doesn't work revision: 'refs/pr/1', // Needed until the model author merges the PR }); // Classify an image const url = 'https://upload.wikimedia.org/wikipedia/commons/7/73/Short_tailed_Albatross1.jpg'; const output = await classifier(url); console.log(output) // [{ label: 'ALBATROSS', score: 0.9999023079872131 }]
Full Changelog: 2.16.0...2.16.1
2.16.0
What's new?
💬 StableLM text-generation models
This version adds support for the StableLM family of text-generation models (up to 1.6B params), developed by Stability AI. Huge thanks to @D4ve-R for this contribution in #616! See here for the full list of supported models.
Example: Text generation with Xenova/stablelm-2-zephyr-1_6b
.
import { pipeline } from '@xenova/transformers';
// Create text generation pipeline
const generator = await pipeline('text-generation', 'Xenova/stablelm-2-zephyr-1_6b');
// Define the prompt and list of messages
const prompt = "Tell me a funny joke."
const messages = [
{ "role": "system", "content": "You are a helpful assistant." },
{ "role": "user", "content": prompt },
]
// Apply chat template
const inputs = generator.tokenizer.apply_chat_template(messages, {
tokenize: false,
add_generation_prompt: true,
});
// Generate text
const output = await generator(inputs, { max_new_tokens: 20 });
console.log(output[0].generated_text);
// "<|system|>\nYou are a helpful assistant.\n<|user|>\nTell me a funny joke.\n<|assistant|>\nHere's a joke for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"
Note: these models may be too large to run in your browser at the moment, so for now, we recommend using them in Node.js. Stay tuned for updates on this!
🔉 Speaker verification and diarization models
Example: Speaker verification w/ Xenova/wavlm-base-plus-sv
.
import { AutoProcessor, AutoModel, read_audio, cos_sim } from '@xenova/transformers';
// Load processor and model
const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv');
const model = await AutoModel.from_pretrained('Xenova/wavlm-base-plus-sv');
// Helper function to compute speaker embedding from audio URL
async function compute_embedding(url) {
const audio = await read_audio(url, 16000);
const inputs = await processor(audio);
const { embeddings } = await model(inputs);
return embeddings.data;
}
// Generate speaker embeddings
const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/sv_speaker';
const speaker_1_1 = await compute_embedding(`${BASE_URL}-1_1.wav`);
const speaker_1_2 = await compute_embedding(`${BASE_URL}-1_2.wav`);
const speaker_2_1 = await compute_embedding(`${BASE_URL}-2_1.wav`);
const speaker_2_2 = await compute_embedding(`${BASE_URL}-2_2.wav`);
// Compute similarity scores
console.log(cos_sim(speaker_1_1, speaker_1_2)); // 0.959439158881247 (Both are speaker 1)
console.log(cos_sim(speaker_1_2, speaker_2_1)); // 0.618130172602329 (Different speakers)
console.log(cos_sim(speaker_2_1, speaker_2_2)); // 0.962999814169370 (Both are speaker 2)
Example: Perform speaker diarization with Xenova/wavlm-base-plus-sd
.
import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers';
// Read and preprocess audio
const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd');
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const audio = await read_audio(url, 16000);
const inputs = await processor(audio);
// Run model with inputs
const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd');
const { logits } = await model(inputs);
// {
// logits: Tensor {
// dims: [ 1, 549, 2 ], // [batch_size, num_frames, num_speakers]
// type: 'float32',
// data: Float32Array(1098) [-3.5301010608673096, ...],
// size: 1098
// }
// }
const labels = logits[0].sigmoid().tolist().map(
frames => frames.map(speaker => speaker > 0.5 ? 1 : 0)
);
console.log(labels); // labels is a one-hot array of shape (num_frames, num_speakers)
// [
// [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
// [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
// [0, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1],
// ...
// ]
These additions were made possible thanks to the following PRs:
- Add support for
WavLMForXVector
by @D4ve-R in #603 - Add support for
WavLMForAudioFrameClassification
andWav2Vec2ForAudioFrameClassification
by @D4ve-R in #611 - Add support for
UniSpeech
andUniSpeechSat
models in #624
📝 Improved chat templating operation coverage
With this release, we're pleased to announce that Transformers.js is now able to parse every single valid chat template that is currently on the Hugging Face Hub! 🤯 As of 2024/03/05, this is around ~12k conversational models (of which there were ~250 unique templates). Of course, future models may introduce more complex chat templates, and we'll continue to add support for them!
For example, transformers.js can now generate the prompt for highly complex function-calling models (e.g., fireworks-ai/firefunction-v1):
See code
import { AutoTokenizer } from '@xenova/transformers';
const tokenizer = await AutoTokenizer.from_pretrained('fireworks-ai/firefunction-v1')
const function_spec = [
{
name: 'get_stock_price',
description: 'Get the current stock price',
parameters: {
type: 'object',
properties: {
symbol: {
type: 'string',
description: 'The stock symbol, e.g. AAPL, GOOG'
}
},
required: ['symbol']
}
},
{
name: 'check_word_anagram',
description: 'Check if two words are anagrams of each other',
parameters: {
type: 'object',
properties: {
word1: {
type: 'string',
description: 'The first word'
},
word2: {
type: 'string',
description: 'The second word'
}
},
required: ['word1', 'word2']
}
}
]
const messages = [
{ role: 'functions', content: JSON.stringify(function_spec, null, 4) },
{ role: 'system', content: 'You are a helpful assistant with access to functions. Use them if required.' },
{ role: 'user', content: 'Hi, can you tell me the current stock price of AAPL?' }
]
const inputs = tokenizer.apply_chat_template(messages, { tokenize: false });
console.log(inputs);
// <s>SYSTEM: You are a helpful assistant ...
🎨 New example applications and demos
-
Create video object detection demo in #607 (try it out).
-
Create cross-encoder demo in #617 (try it out).
-
Add Claude 3 and Mistral to the tokenizer playground in #625 (try it out).
🛠️ Misc. improvements
- Add support for the starcoder2 architecture in #622. Note: we haven't yet added transformers.js-compatible versions of the 3B and 7B models.
- Check for existence of
onnx_env.wasm
before updatingwasmPaths
in #621
🤗 New contributors
Full Changelog: 2.15.1...2.16.0
2.15.1
What's new?
-
Add Background Removal demo in #576 (online demo).
-
Add support for owlv2 models in #579
Example: Zero-shot object detection w/
Xenova/owlv2-base-patch16-ensemble
.import { pipeline } from '@xenova/transformers'; const detector = await pipeline('zero-shot-object-detection', 'Xenova/owlv2-base-patch16-ensemble'); const url = 'http://images.cocodataset.org/val2017/000000039769.jpg'; const candidate_labels = ['a photo of a cat', 'a photo of a dog']; const output = await detector(url, candidate_labels); console.log(output); // [ // { score: 0.7400985360145569, label: 'a photo of a cat', box: { xmin: 0, ymin: 50, xmax: 323, ymax: 485 } }, // { score: 0.6315087080001831, label: 'a photo of a cat', box: { xmin: 333, ymin: 23, xmax: 658, ymax: 378 } } // ]
-
Add support for Adaptive Retrieval w/ Matryoshka Embeddings (nomic-ai/nomic-embed-text-v1.5) in #587 and #588 (online demo).
Full Changelog: 2.15.0...2.15.1