Skip to content

Releases: huggingface/transformers.js

3.1.0

26 Nov 23:19
2c92943
Compare
Choose a tag to compare

🚀 Transformers.js v3.1 — any-to-any, text-to-image, image-to-text, pose estimation, time series forecasting, and more!

Table of contents:

🤖 New models: Janus, Qwen2-VL, JinaCLIP, LLaVA-OneVision, ViTPose, MGP-STR, PatchTST, PatchTSMixer.

Janus for Any-to-Any generation (e.g., image-to-text and text-to-image)

First of all, this release adds support for Janus, a novel autoregressive framework that unifies multimodal understanding and generation. The most popular model, deepseek-ai/Janus-1.3B, is tagged as an "any-to-any" model, and has specifically been trained for the following tasks:

Example: Image-Text-to-Text

import { AutoProcessor, MultiModalityCausalLM } from "@huggingface/transformers";

// Load processor and model
const model_id = "onnx-community/Janus-1.3B-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await MultiModalityCausalLM.from_pretrained(model_id);

// Prepare inputs
const conversation = [
  {
    role: "User",
    content: "<image_placeholder>\nConvert the formula into latex code.",
    images: ["https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/quadratic_formula.png"],
  },
];
const inputs = await processor(conversation);

// Generate response
const outputs = await model.generate({
  ...inputs,
  max_new_tokens: 150,
  do_sample: false,
});

// Decode output
const new_tokens = outputs.slice(null, [inputs.input_ids.dims.at(-1), null]);
const decoded = processor.batch_decode(new_tokens, { skip_special_tokens: true });
console.log(decoded[0]);

Sample output:

Sure, here is the LaTeX code for the given formula:

```
x = \frac{-b \pm \sqrt{b^2 - 4a c}}{2a}
```

This code represents the mathematical expression for the variable \( x \).

Example: Text-to-Image

import { AutoProcessor, MultiModalityCausalLM } from "@huggingface/transformers";

// Load processor and model
const model_id = "onnx-community/Janus-1.3B-ONNX";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await MultiModalityCausalLM.from_pretrained(model_id);

// Prepare inputs
const conversation = [
  {
    role: "User",
    content: "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
  },
];
const inputs = await processor(conversation, { chat_template: "text_to_image" });

// Generate response
const num_image_tokens = processor.num_image_tokens;
const outputs = await model.generate_images({
  ...inputs,
  min_new_tokens: num_image_tokens,
  max_new_tokens: num_image_tokens,
  do_sample: true,
});

// Save the generated image
await outputs[0].save("test.png");

Sample outputs:

fox_1 fox_2 fox_3 fox_4
fox_5 fox_6 fox_7 fox_8

What to play around with the model? Check out our online WebGPU demo! 👇

Janus-WebGPU.mp4

Qwen2-VL for Image-Text-to-Text

Example: Image-Text-to-Text

Next, we added support for Qwen2-VL, the multimodal large language model series developed by Qwen team, Alibaba Cloud. It introduces the Naive Dynamic Resolution mechanism, allowing the model to process images of varying resolutions and leading to more efficient and accurate visual representations.

import { AutoProcessor, Qwen2VLForConditionalGeneration, RawImage } from "@huggingface/transformers";

// Load processor and model
const model_id = "onnx-community/Qwen2-VL-2B-Instruct";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await Qwen2VLForConditionalGeneration.from_pretrained(model_id);

// Prepare inputs
const url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg";
const image = await (await RawImage.read(url)).resize(448, 448);
const conversation = [
  {
    role: "user",
    content: [
      { type: "image" },
      { type: "text", text: "Describe this image." },
    ],
  },
];
const text = processor.apply_chat_template(conversation, { add_generation_prompt: true });
const inputs = await processor(text, image);

// Perform inference
const outputs = await model.generate({
  ...inputs,
  max_new_tokens: 128,
});

// Decode output
const decoded = processor.batch_decode(
  outputs.slice(null, [inputs.input_ids.dims.at(-1), null]),
  { skip_special_tokens: true },
);
console.log(decoded[0]);
// The image depicts a serene beach scene with a woman and a dog. The woman is sitting on the sand, wearing a plaid shirt, and appears to be engaged in a playful interaction with the dog. The dog, which is a large breed, is sitting on its hind legs and appears to be reaching out to the woman, possibly to give her a high-five or a paw. The background shows the ocean with gentle waves, and the sky is clear, suggesting it might be either sunrise or sunset. The overall atmosphere is calm and relaxed, capturing a moment of connection between the woman and the dog.

JinaCLIP for multimodal embeddings

JinaCLIP is a series of general-purpose multilingual multimodal embedding models for text & images, created by Jina AI.

Example: Compute text and/or image embeddings with jinaai/jina-clip-v2:

import { AutoModel, AutoProcessor, RawImage, matmul } from "@huggingface/transformers";

// Load processor and model
const model_id = "jinaai/jina-clip-v2";
const processor = await AutoProcessor.from_pretrained(model_id);
const model = await AutoModel.from_pretrained(model_id, { dtype: "q4" /* e.g., "fp16", "q8", or "q4" */ });

// Prepare inputs
const urls = ["https://i.ibb.co/nQNGqL0/beach1.jpg", "https://i.ibb.co/r5w8hG8/beach2.jpg"];
const images = await Promise.all(urls.map(url => RawImage.read(url)));
const sentences = [
    "غروب جميل على الشاطئ", // Arabic
    "海滩上美丽的日落", // Chinese
    "Un beau coucher de soleil sur la plage", // French
    "Ein wunderschöner Sonnenuntergang am Strand", // German
    "Ένα όμορφο ηλιοβασίλεμα πάνω από την παραλία", // Greek
    "समुद्र तट पर एक खूबसूरत सूर्यास्त", // Hindi
    "Un bellissimo tramonto sulla spiaggia", // Italian
    "浜辺に沈む美しい夕日", // Japanese
    "해변 위로 아름다운 일몰", // Korean
];

// Encode text and images
const inputs = await processor(sentences, images, { padding: true, truncation: true });
const { l2norm_text_embeddings, l2norm_image_embeddings } = await model(inputs);

// Encode query (text-only)
const query_prefix = "Represent the query for retrieving evidence documents: ";
const query_inputs = await processor(query_prefix + "beautiful sunset over the beach");
const { l2norm_text_embeddings: query_embeddings } = await model(query_inputs);

// Compute text-image similarity scores
const text_to_image_scores = await matmul(query_embeddings, l2norm_image_embeddings.transpose(1, 0));
console.log("text-image similarity scores", text_to_image_scores.tolist()[0]); // [0.29530206322669983, 0.3183615803718567]

// Compute image-image similarity scores
const image_to_image_score = await matmul(l2norm_image_embeddings[0], l2norm_image_embeddings[1]);
console.log("image-image similarity score", image_to_image_score.item()); // 0.9344457387924194

// Compute text-text similarity scores
const text_to_text_scores = await matmul(query_embeddings, l2norm_text_embeddings.transpose(1, 0));
console.log("text-text similarity scores", text_to_text_scores.tolist()[0]); // [0.5566609501838684, 0.7028406858444214, 0.582255482673645, 0.6648036241531372, 0.5462006330490112, 0.6791588068008423, 0.6192430257797241, 0.6258729100227356, 0.6453716158866882]

LLaVA-OneVision for Image-Text-to-Text

LLaVA-OneVision is a Vision-Language Model that can generate text conditioned on one or several images/videos. The model consists of SigLIP vision encoder and a Qwen2 language backbone.

Example: Multi-round conversations w/ PKV caching

import { AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration, RawImage } from '@huggingface/transformers';

// Load tokenizer, processor and model
const model_id = 'llava-hf/llava-onevision-qwen2-0.5b-ov-hf';

...
Read more

3.0.2

04 Nov 07:59
Compare
Choose a tag to compare

What's new?

  • Add support for MobileLLM in #1003

    Example: Text generation with onnx-community/MobileLLM-125M.

    import { pipeline } from "@huggingface/transformers";
    
    // Create a text generation pipeline
    const generator = await pipeline(
      "text-generation",
      "onnx-community/MobileLLM-125M",
      { dtype: "fp32" },
    );
    
    // Define the list of messages
    const text = "Q: What is the capital of France?\nA: Paris\nQ: What is the capital of England?\nA:";
    
    // Generate a response
    const output = await generator(text, { max_new_tokens: 30 });
    console.log(output[0].generated_text);
    Example output
    Q: What is the capital of France?
    A: Paris
    Q: What is the capital of England?
    A: London
    Q: What is the capital of Scotland?
    A: Edinburgh
    Q: What is the capital of Wales?
    A: Cardiff
    
  • Add support for OLMo in #1011

    Example: Text generation with onnx-community/AMD-OLMo-1B-SFT-DPO".

    import { pipeline } from "@huggingface/transformers";
    
    // Create a text generation pipeline
    const generator = await pipeline(
      "text-generation",
      "onnx-community/AMD-OLMo-1B-SFT-DPO",
      { dtype: "q4" },
    );
    
    // Define the list of messages
    const messages = [
      { role: "system", content: "You are a helpful assistant." },
      { role: "user", content: "Tell me a joke." },
    ];
    
    // Generate a response
    const output = await generator(messages, { max_new_tokens: 128 });
    console.log(output[0].generated_text.at(-1).content);
    Example output
    Why don't scientists trust atoms?
    
    Because they make up everything!
    
  • Fix CommonJS bundling in #1012. Thanks @jens-ghc for reporting!

  • Doc fixes by @roschler in #1002

  • Remove duplicate gemma value from NO_PER_CHANNEL_REDUCE_RANGE_MODEL by @bekzod in #1005

🤗 New contributors

Full Changelog: 3.0.1...3.0.2

3.0.1

25 Oct 17:27
e129c47
Compare
Choose a tag to compare

What's new?

  • Fix Document QA pipeline in #987. Thanks @martinsomm for reporting!
  • Next.js 15 (code; demo) and SvelteKit 5 (code; demo) server-side templates
  • Minor documentation fixes

Full Changelog: 3.0.0...3.0.1

3.0.0

22 Oct 14:52
e8c0f77
Compare
Choose a tag to compare

Transformers.js v3: WebGPU Support, New Models & Tasks, New Quantizations, Deno & Bun Compatibility, and More…

thumbnail

After more than a year of development, we're excited to announce the release of 🤗 Transformers.js v3!

You can get started by installing Transformers.js v3 from NPM using:

npm i @huggingface/transformers

Then, importing the library with

import { pipeline } from "@huggingface/transformers";

or, via a CDN

import { pipeline } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]";

For more information, check out the documentation.

⚡ WebGPU support (up to 100x faster than WASM!)

WebGPU is a new web standard for accelerated graphics and compute. The API enables web developers to use the underlying system's GPU to carry out high-performance computations directly in the browser. WebGPU is the successor to WebGL and provides significantly better performance, because it allows for more direct interaction with modern GPUs. Lastly, it supports general-purpose GPU computations, which makes it just perfect for machine learning!

Warning

As of October 2024, global WebGPU support is around 70% (according to caniuse.com), meaning some users may not be able to use the API.

If the following demos do not work in your browser, you may need to enable it using a feature flag:

  • Firefox: with the dom.webgpu.enabled flag (see here).
  • Safari: with the WebGPU feature flag (see here).
  • Older Chromium browsers (on Windows, macOS, Linux): with the enable-unsafe-webgpu flag (see here).

Usage in Transformers.js v3

Thanks to our collaboration with ONNX Runtime Web, enabling WebGPU acceleration is as simple as setting device: 'webgpu' when loading a model. Let's see some examples!

Example: Compute text embeddings on WebGPU (demo)

import { pipeline } from "@huggingface/transformers";

// Create a feature-extraction pipeline
const extractor = await pipeline(
  "feature-extraction",
  "mixedbread-ai/mxbai-embed-xsmall-v1",
  { device: "webgpu" },
});

// Compute embeddings
const texts = ["Hello world!", "This is an example sentence."];
const embeddings = await extractor(texts, { pooling: "mean", normalize: true });
console.log(embeddings.tolist());
// [
//   [-0.016986183822155, 0.03228696808218956, -0.0013630966423079371, ... ],
//   [0.09050482511520386, 0.07207386940717697, 0.05762749910354614, ... ],
// ]

Example: Perform automatic speech recognition with OpenAI whisper on WebGPU (demo)

import { pipeline } from "@huggingface/transformers";

// Create automatic speech recognition pipeline
const transcriber = await pipeline(
  "automatic-speech-recognition",
  "onnx-community/whisper-tiny.en",
  { device: "webgpu" },
);

// Transcribe audio from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav";
const output = await transcriber(url);
console.log(output);
// { text: ' And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.' }

Example: Perform image classification with MobileNetV4 on WebGPU (demo)

import { pipeline } from "@huggingface/transformers";

// Create image classification pipeline
const classifier = await pipeline(
  "image-classification",
  "onnx-community/mobilenetv4_conv_small.e2400_r224_in1k",
  { device: "webgpu" },
);

// Classify an image from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg";
const output = await classifier(url);
console.log(output);
// [
//   { label: 'tiger, Panthera tigris', score: 0.6149784922599792 },
//   { label: 'tiger cat', score: 0.30281734466552734 },
//   { label: 'tabby, tabby cat', score: 0.0019135422771796584 },
//   { label: 'lynx, catamount', score: 0.0012161266058683395 },
//   { label: 'Egyptian cat', score: 0.0011465961579233408 }
// ]

🔢 New quantization formats (dtypes)

Before Transformers.js v3, we used the quantized option to specify whether to use a quantized (q8) or full-precision (fp32) variant of the model by setting quantized to true or false, respectively. Now, we've added the ability to select from a much larger list with the dtype parameter.

The list of available quantizations depends on the model, but some common ones are: full-precision ("fp32"), half-precision ("fp16"), 8-bit ("q8", "int8", "uint8"), and 4-bit ("q4", "bnb4", "q4f16").

Available dtypes for mixedbread-ai/mxbai-embed-xsmall-v1 (e.g., mixedbread-ai/mxbai-embed-xsmall-v1)

Basic usage

Example: Run Qwen2.5-0.5B-Instruct in 4-bit quantization (demo)

import { pipeline } from "@huggingface/transformers";

// Create a text generation pipeline
const generator = await pipeline(
  "text-generation",
  "onnx-community/Qwen2.5-0.5B-Instruct",
  { dtype: "q4", device: "webgpu" },
);

// Define the list of messages
const messages = [
  { role: "system", content: "You are a helpful assistant." },
  { role: "user", content: "Tell me a funny joke." },
];

// Generate a response
const output = await generator(messages, { max_new_tokens: 128 });
console.log(output[0].generated_text.at(-1).content);

Per-module dtypes

Some encoder-decoder models, like Whisper or Florence-2, are extremely sensitive to quantization settings: especially of the encoder. For this reason, we added the ability to select per-module dtypes, which can be done by providing a mapping from module name to dtype.

Example: Run Florence-2 on WebGPU (demo)

import { Florence2ForConditionalGeneration } from "@huggingface/transformers";

const model = await Florence2ForConditionalGeneration.from_pretrained(
  "onnx-community/Florence-2-base-ft",
  {
    dtype: {
      embed_tokens: "fp16",
      vision_encoder: "fp16",
      encoder_model: "q4",
      decoder_model_merged: "q4",
    },
    device: "webgpu",
  },
);

Florence-2 running on WebGPU

See full code example
import {
  Florence2ForConditionalGeneration,
  AutoProcessor,
  AutoTokenizer,
  RawImage,
} from "@huggingface/transformers";

// Load model, processor, and tokenizer
const model_id = "onnx-community/Florence-2-base-ft";
const model = await Florence2ForConditionalGeneration.from_pretrained(
  model_id,
  {
    dtype: {
      embed_tokens: "fp16",
      vision_encoder: "fp16",
      encoder_model: "q4",
      decoder_model_merged: "q4",
    },
    device: "webgpu",
  },
);
const processor = await AutoProcessor.from_pretrained(model_id);
const tokenizer = await AutoTokenizer.from_pretrained(model_id);

// Load image and prepare vision inputs
const url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const image = await RawImage.fromURL(url);
const vision_inputs = await processor(image);

// Specify task and prepare text inputs
const task = "<MORE_DETAILED_CAPTION>";
const prompts = processor.construct_prompts(task);
const text_inputs = tokenizer(prompts);

// Generate text
const generated_ids = await model.generate({
  ...text_inputs,
  ...vision_inputs,
  max_new_tokens: 100,
});

// Decode generated text
const generated_text = tokenizer.batch_decode(generated_ids, {
  skip_special_tokens: false,
})[0];

// Post-process the generated text
const result = processor.post_process_generation(
  generated_text,
  task,
  image.size,
);
console.log(result);
// { '<MORE_DETAILED_CAPTION>': 'A green car is parked in front of a tan building. The building has a brown door and two brown windows. The car is a two door and the door is closed. The green car has black tires.' }

🏛 A total of 120 supported architectures

This release increases the total number of supported architectures to 120 (see full list), spanning a wide range of input modalities and tasks. Notable ...

Read more

2.17.2

29 May 14:36
Compare
Choose a tag to compare

🚀 What's new?

  • Add support for MobileViTv2 in #721

    import { pipeline } from '@xenova/transformers';
    
    // Create an image classification pipeline
    const classifier = await pipeline('image-classification', 'Xenova/mobilevitv2-1.0-imagenet1k-256', {
        quantized: false,
    });
    
    // Classify an image
    const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg';
    const output = await classifier(url);
    // [{ label: 'tiger, Panthera tigris', score: 0.6491137742996216 }]

    See here for the full list of supported models.

  • Add support for FastViT in #749

    import { pipeline } from '@xenova/transformers';
    
    // Create an image classification pipeline
    const classifier = await pipeline('image-classification', 'Xenova/fastvit_t12.apple_in1k', {
      quantized: false
    });
    
    // Classify an image
    const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg';
    const output = await classifier(url, { topk: 5 });
    // [
    //   { label: 'tiger, Panthera tigris', score: 0.6649345755577087 },
    //   { label: 'tiger cat', score: 0.12454754114151001 },
    //   { label: 'lynx, catamount', score: 0.0010689536575227976 },
    //   { label: 'dhole, Cuon alpinus', score: 0.0010422508930787444 },
    //   { label: 'silky terrier, Sydney silky', score: 0.0009548701345920563 }
    // ]

    See here for the full list of supported models.

  • Optimize FFT in #766

  • Auto rotate image by @KTibow in #737

  • Support reading data from blob URI by @hans00 in #645

  • Add sequence post processor in #771

  • Add model file name by @NawarA in #594

  • Update pipelines.js to allow for token_embeddings as well by @NikhilVerma in #770

  • Remove old import from stream/web for ReadableStream in #752

  • Update tokenizer playground by @xenova in #717

  • Use ungated version of mistral tokenizer by @xenova in #718

  • docs: update vanilla-js.md by @eltociear in #738

  • Fix CI by in #768

  • Update Next.js demos to 14.2.3 in #772

🤗 New contributors

Full Changelog: 2.17.1...2.17.2

2.17.1

18 Apr 15:46
Compare
Choose a tag to compare

What's new?

  • Add ignore_merges option to BPE tokenizers in #716

Full Changelog: 2.17.0...2.17.1

2.17.0

11 Apr 00:08
Compare
Choose a tag to compare

What's new?

💬 Improved text-generation pipeline for conversational models

This version adds support for passing an array of chat messages (with "role" and "content" properties) to the text-generation pipeline (PR). Check out the list of supported models here.

Example: Chat with Xenova/Qwen1.5-0.5B-Chat.

import { pipeline } from '@xenova/transformers';

// Create text-generation pipeline
const generator = await pipeline('text-generation', 'Xenova/Qwen1.5-0.5B-Chat');

// Define the list of messages
const messages = [
    { role: 'system', content: 'You are a helpful assistant.' },
    { role: 'user', content: 'Tell me a funny joke.' }
]

// Generate text
const output = await generator(messages, {
    max_new_tokens: 128,
    do_sample: false,
})
console.log(output[0].generated_text);
// [
//   { role: 'system', content: 'You are a helpful assistant.' },
//   { role: 'user', content: 'Tell me a funny joke.' },
//   { role: 'assistant', content: "Sure, here's one:\n\nWhy was the math book sad?\n\nBecause it had too many problems.\n\nI hope you found that joke amusing! Do you have any other questions or topics you'd like to discuss?" },
// ]

We also added the return_full_text parameter, which means if you set return_full_text=false, only the newly-generated tokens will be returned (only applicable if passing the raw text prompt to the pipeline).

🔢 Binary embedding quantization support

Transformers.js v2.17 adds two new parameters to the feature-extraction pipeline ("quantize" and "precision"), enabling you to generate binary embeddings. These can be used with certain embedding models to shrink the size of the document embeddings for retrieval. This results in reductions in index size/memory usage (for storage) and improvements in retrieval speed. Surprisingly, you can still achieve up to ~95% of the original performance, but at 32x storage savings and up to 32x retrieval speeds! 🤯 Thanks to @jonathanpv for this addition in #691!

import { pipeline } from '@xenova/transformers';

// Create feature-extraction pipeline
const extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');

// Compute binary embeddings
const output = await extractor('This is a simple test.', { pooling: 'mean', quantize: true, precision: 'binary' });
// Tensor {
//   type: 'int8',
//   data: Int8Array [49, 108, 24, ...],
//   dims: [1, 48]
// }

As you can see, this produces a 32x smaller output tensor (a 4x reduction in data type with Float32Array → Int8Array, as well as an 8x reduction in dimensionality from 384 → 48). For more information, check out this PR in sentence-transformers, which inspired this update!

🛠️ Misc. improvements

🤗 New contributors

Full Changelog: 2.16.1...2.17.0

2.16.1

20 Mar 15:11
Compare
Choose a tag to compare

What's new?

  • Add support for the image-feature-extraction pipeline in #650.

    Example: Perform image feature extraction with Xenova/vit-base-patch16-224-in21k.

    const image_feature_extractor = await pipeline('image-feature-extraction', 'Xenova/vit-base-patch16-224-in21k');
    const url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png';
    const features = await image_feature_extractor(url);
    // Tensor {
    //   dims: [ 1, 197, 768 ],
    //   type: 'float32',
    //   data: Float32Array(151296) [ ... ],
    //   size: 151296
    // }

    Example: Compute image embeddings with Xenova/clip-vit-base-patch32.

    const image_feature_extractor = await pipeline('image-feature-extraction', 'Xenova/clip-vit-base-patch32');
    const url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png';
    const features = await image_feature_extractor(url);
    // Tensor {
    //   dims: [ 1, 512 ],
    //   type: 'float32',
    //   data: Float32Array(512) [ ... ],
    //   size: 512
    // }
  • Fix channel format when padding non-square images for certain models in #655. This means you can now perform super-resolution for non-square images with APISR models:

    Example: Upscale an image with Xenova/4x_APISR_GRL_GAN_generator-onnx.

    import { pipeline } from '@xenova/transformers';
    
    // Create image-to-image pipeline
    const upscaler = await pipeline('image-to-image', 'Xenova/4x_APISR_GRL_GAN_generator-onnx', {
        quantized: false,
    });
    
    // Upscale an image
    const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/anime.png';
    const output = await upscaler(url);
    // RawImage {
    //   data: Uint8Array(16588800) [ ... ],
    //   width: 2560,
    //   height: 1920,
    //   channels: 3
    // }
    
    // (Optional) Save the upscaled image
    output.save('upscaled.png');
    See example output

    Input image:
    image

    Output image:
    image

  • Update tokenizer apply_chat_template functionality in #647. This PR added functionality to support the new C4AI Command-R tokenizer.

    See example tool usage
    import { AutoTokenizer } from "@xenova/transformers";
    
    const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer")
    
    // define conversation input:
    const conversation = [
      { role: "user", content: "Whats the biggest penguin in the world?" }
    ]
    // Define tools available for the model to use:
    const tools = [
      {
        name: "internet_search",
        description: "Returns a list of relevant document snippets for a textual query retrieved from the internet",
        parameter_definitions: {
          query: {
            description: "Query to search the internet with",
            type: "str",
            required: true
          }
        }
      },
      {
        name: "directly_answer",
        description: "Calls a standard (un-augmented) AI chatbot to generate a response given the conversation history",
        parameter_definitions: {}
      }
    ]
    
    
    // render the tool use prompt as a string:
    const tool_use_prompt = tokenizer.apply_chat_template(
      conversation,
      {
        chat_template: "tool_use",
        tokenize: false,
        add_generation_prompt: true,
        tools,
      }
    )
    console.log(tool_use_prompt)
    See example RAG usage
    import { AutoTokenizer } from "@xenova/transformers";
    
    const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer")
    
    // define conversation input:
    const conversation = [
      { role: "user", content: "Whats the biggest penguin in the world?" }
    ]
    // define documents to ground on:
    const documents = [
      { title: "Tall penguins", text: "Emperor penguins are the tallest growing up to 122 cm in height." },
      { title: "Penguin habitats", text: "Emperor penguins only live in Antarctica." }
    ]
    
    // render the RAG prompt as a string:
    const grounded_generation_prompt = tokenizer.apply_chat_template(
      conversation,
      {
        chat_template: "rag",
        tokenize: false,
        add_generation_prompt: true,
    
        documents,
        citation_mode: "accurate", // or "fast"
      }
    )
    console.log(grounded_generation_prompt);
  • Add support for EfficientNet in #639.

    Example: Classify images with chriamue/bird-species-classifier

    import { pipeline } from '@xenova/transformers';
    
    // Create image classification pipeline
    const classifier = await pipeline('image-classification', 'chriamue/bird-species-classifier', {
        quantized: false,      // Quantized model doesn't work
        revision: 'refs/pr/1', // Needed until the model author merges the PR
    });
    
    // Classify an image
    const url = 'https://upload.wikimedia.org/wikipedia/commons/7/73/Short_tailed_Albatross1.jpg';
    const output = await classifier(url);
    console.log(output)
    // [{ label: 'ALBATROSS', score: 0.9999023079872131 }]

Full Changelog: 2.16.0...2.16.1

2.16.0

07 Mar 14:52
Compare
Choose a tag to compare

What's new?

💬 StableLM text-generation models

This version adds support for the StableLM family of text-generation models (up to 1.6B params), developed by Stability AI. Huge thanks to @D4ve-R for this contribution in #616! See here for the full list of supported models.

Example: Text generation with Xenova/stablelm-2-zephyr-1_6b.

import { pipeline } from '@xenova/transformers';

// Create text generation pipeline
const generator = await pipeline('text-generation', 'Xenova/stablelm-2-zephyr-1_6b');

// Define the prompt and list of messages
const prompt = "Tell me a funny joke."
const messages = [
    { "role": "system", "content": "You are a helpful assistant." },
    { "role": "user", "content": prompt },
]

// Apply chat template
const inputs = generator.tokenizer.apply_chat_template(messages, {
    tokenize: false,
    add_generation_prompt: true,
});

// Generate text
const output = await generator(inputs, { max_new_tokens: 20 });
console.log(output[0].generated_text);
// "<|system|>\nYou are a helpful assistant.\n<|user|>\nTell me a funny joke.\n<|assistant|>\nHere's a joke for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"

Note: these models may be too large to run in your browser at the moment, so for now, we recommend using them in Node.js. Stay tuned for updates on this!

🔉 Speaker verification and diarization models

Example: Speaker verification w/ Xenova/wavlm-base-plus-sv.

import { AutoProcessor, AutoModel, read_audio, cos_sim } from '@xenova/transformers';

// Load processor and model
const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv');
const model = await AutoModel.from_pretrained('Xenova/wavlm-base-plus-sv');

// Helper function to compute speaker embedding from audio URL
async function compute_embedding(url) {
    const audio = await read_audio(url, 16000);
    const inputs = await processor(audio);
    const { embeddings } = await model(inputs);
    return embeddings.data;
}

// Generate speaker embeddings
const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/sv_speaker';
const speaker_1_1 = await compute_embedding(`${BASE_URL}-1_1.wav`);
const speaker_1_2 = await compute_embedding(`${BASE_URL}-1_2.wav`);
const speaker_2_1 = await compute_embedding(`${BASE_URL}-2_1.wav`);
const speaker_2_2 = await compute_embedding(`${BASE_URL}-2_2.wav`);

// Compute similarity scores
console.log(cos_sim(speaker_1_1, speaker_1_2)); // 0.959439158881247 (Both are speaker 1)
console.log(cos_sim(speaker_1_2, speaker_2_1)); // 0.618130172602329 (Different speakers)
console.log(cos_sim(speaker_2_1, speaker_2_2)); // 0.962999814169370 (Both are speaker 2)

Example: Perform speaker diarization with Xenova/wavlm-base-plus-sd.

import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers';

// Read and preprocess audio
const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd');
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const audio = await read_audio(url, 16000);
const inputs = await processor(audio);

// Run model with inputs
const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd');
const { logits } = await model(inputs);
// {
//   logits: Tensor {
//     dims: [ 1, 549, 2 ],  // [batch_size, num_frames, num_speakers]
//     type: 'float32',
//     data: Float32Array(1098) [-3.5301010608673096, ...],
//     size: 1098
//   }
// }

const labels = logits[0].sigmoid().tolist().map(
    frames => frames.map(speaker => speaker > 0.5 ? 1 : 0)
);
console.log(labels); // labels is a one-hot array of shape (num_frames, num_speakers)
// [
//     [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
//     [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
//     [0, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1],
//     ...
// ]

These additions were made possible thanks to the following PRs:

  • Add support for WavLMForXVector by @D4ve-R in #603
  • Add support for WavLMForAudioFrameClassification and Wav2Vec2ForAudioFrameClassification by @D4ve-R in #611
  • Add support for UniSpeech and UniSpeechSat models in #624

📝 Improved chat templating operation coverage

With this release, we're pleased to announce that Transformers.js is now able to parse every single valid chat template that is currently on the Hugging Face Hub! 🤯 As of 2024/03/05, this is around ~12k conversational models (of which there were ~250 unique templates). Of course, future models may introduce more complex chat templates, and we'll continue to add support for them!

For example, transformers.js can now generate the prompt for highly complex function-calling models (e.g., fireworks-ai/firefunction-v1):

See code
import { AutoTokenizer } from '@xenova/transformers';

const tokenizer = await AutoTokenizer.from_pretrained('fireworks-ai/firefunction-v1')

const function_spec = [
    {
        name: 'get_stock_price',
        description: 'Get the current stock price',
        parameters: {
            type: 'object',
            properties: {
                symbol: {
                    type: 'string',
                    description: 'The stock symbol, e.g. AAPL, GOOG'
                }
            },
            required: ['symbol']
        }
    },
    {
        name: 'check_word_anagram',
        description: 'Check if two words are anagrams of each other',
        parameters: {
            type: 'object',
            properties: {
                word1: {
                    type: 'string',
                    description: 'The first word'
                },
                word2: {
                    type: 'string',
                    description: 'The second word'
                }
            },
            required: ['word1', 'word2']
        }
    }
]

const messages = [
    { role: 'functions', content: JSON.stringify(function_spec, null, 4) },
    { role: 'system', content: 'You are a helpful assistant with access to functions. Use them if required.' },
    { role: 'user', content: 'Hi, can you tell me the current stock price of AAPL?' }
]

const inputs = tokenizer.apply_chat_template(messages, { tokenize: false });
console.log(inputs);
// <s>SYSTEM: You are a helpful assistant ...

🎨 New example applications and demos

🛠️ Misc. improvements

  • Add support for the starcoder2 architecture in #622. Note: we haven't yet added transformers.js-compatible versions of the 3B and 7B models.
  • Check for existence of onnx_env.wasm before updating wasmPaths in #621

🤗 New contributors

Full Changelog: 2.15.1...2.16.0

2.15.1

21 Feb 14:47
Compare
Choose a tag to compare

What's new?

  • Add Background Removal demo in #576 (online demo).

    background-removal

  • Add support for owlv2 models in #579

    Example: Zero-shot object detection w/ Xenova/owlv2-base-patch16-ensemble.

    import { pipeline } from '@xenova/transformers';
    
    const detector = await pipeline('zero-shot-object-detection', 'Xenova/owlv2-base-patch16-ensemble');
    
    const url = 'http://images.cocodataset.org/val2017/000000039769.jpg';
    const candidate_labels = ['a photo of a cat', 'a photo of a dog'];
    const output = await detector(url, candidate_labels);
    console.log(output);
    // [
    //   { score: 0.7400985360145569, label: 'a photo of a cat', box: { xmin: 0, ymin: 50, xmax: 323, ymax: 485 } },
    //   { score: 0.6315087080001831, label: 'a photo of a cat', box: { xmin: 333, ymin: 23, xmax: 658, ymax: 378 } }
    // ]

    image

  • Add support for Adaptive Retrieval w/ Matryoshka Embeddings (nomic-ai/nomic-embed-text-v1.5) in #587 and #588 (online demo).

    adaptive-retrieval

  • Add support for Gemma Tokenizer in #597 and #598

Full Changelog: 2.15.0...2.15.1