From f5226a5fc7d3cbaaee356fb4e89bab8dc265507e Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 18 Jul 2024 14:14:35 -0400 Subject: [PATCH 01/16] Add q_* ops to match float ops --- crates/burn-autodiff/src/ops/qtensor.rs | 84 +- crates/burn-candle/src/ops/qtensor.rs | 77 +- crates/burn-core/src/module/quantize.rs | 13 + crates/burn-fusion/src/ops/qtensor.rs | 84 +- crates/burn-jit/src/ops/base.rs | 1 + crates/burn-jit/src/ops/qtensor.rs | 77 +- crates/burn-ndarray/src/ops/base.rs | 9 + crates/burn-ndarray/src/ops/qtensor.rs | 107 +- crates/burn-ndarray/src/ops/tensor.rs | 5 +- crates/burn-tch/src/ops/base.rs | 9 + crates/burn-tch/src/ops/qtensor.rs | 199 ++- crates/burn-tch/src/ops/tensor.rs | 8 + crates/burn-tensor/src/tensor/api/float.rs | 16 + crates/burn-tensor/src/tensor/ops/qtensor.rs | 1550 ++++++++++++++++- .../src/tensor/quantization/calibration.rs | 22 + .../src/tensor/quantization/scheme.rs | 17 +- 16 files changed, 2263 insertions(+), 15 deletions(-) diff --git a/crates/burn-autodiff/src/ops/qtensor.rs b/crates/burn-autodiff/src/ops/qtensor.rs index 156efe08d0..061a3d5db1 100644 --- a/crates/burn-autodiff/src/ops/qtensor.rs +++ b/crates/burn-autodiff/src/ops/qtensor.rs @@ -1,6 +1,8 @@ +use std::ops::Range; + use burn_tensor::{ backend::Backend, - ops::{FloatTensor, QTensorOps, QuantizedTensor}, + ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, Device, Shape, TensorData, }; @@ -23,6 +25,13 @@ impl QTensorOps for Autodiff { todo!() // required for QAT } + fn quantize_dynamic( + _tensor: FloatTensor, + _scheme: &QuantizationScheme, + ) -> QuantizedTensor { + todo!() + } + fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { todo!() } @@ -35,6 +44,13 @@ impl QTensorOps for Autodiff { B::q_device(tensor) } + fn q_to_device( + _tensor: QuantizedTensor, + _device: &Device, + ) -> QuantizedTensor { + unimplemented!() + } + fn q_reshape( tensor: QuantizedTensor, shape: Shape, @@ -45,4 +61,70 @@ impl QTensorOps for Autodiff { async fn q_into_data(tensor: QuantizedTensor) -> TensorData { B::q_into_data(tensor).await } + + fn q_swap_dims( + _tensor: QuantizedTensor, + _dim1: usize, + _dim2: usize, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_permute( + _tensor: QuantizedTensor, + _axes: [usize; D], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_flip( + _tensor: QuantizedTensor, + _axes: &[usize], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_gather( + _dim: usize, + _tensor: QuantizedTensor, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_select( + _tensor: QuantizedTensor, + _dim: usize, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_slice( + _tensor: QuantizedTensor, + _ranges: [Range; D2], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_argmax( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_argmin( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_expand( + _tensor: QuantizedTensor, + _shape: Shape, + ) -> QuantizedTensor { + unimplemented!() + } } diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs index 0070aec097..038da97521 100644 --- a/crates/burn-candle/src/ops/qtensor.rs +++ b/crates/burn-candle/src/ops/qtensor.rs @@ -1,6 +1,8 @@ +use std::ops::Range; + use burn_tensor::{ backend::Backend, - ops::{FloatTensor, QTensorOps, QuantizedTensor}, + ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy}, DType, Device, Shape, TensorData, }; @@ -38,6 +40,13 @@ impl QTensorOps for Candle( + _tensor: QuantizedTensor, + _device: &Device, + ) -> QuantizedTensor { + unimplemented!() + } + fn q_reshape( tensor: QuantizedTensor, shape: Shape, @@ -51,4 +60,70 @@ impl QTensorOps for Candle(tensor: QuantizedTensor) -> TensorData { unimplemented!() } + + fn q_swap_dims( + _tensor: QuantizedTensor, + _dim1: usize, + _dim2: usize, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_permute( + _tensor: QuantizedTensor, + _axes: [usize; D], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_flip( + _tensor: QuantizedTensor, + _axes: &[usize], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_gather( + _dim: usize, + _tensor: QuantizedTensor, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_select( + _tensor: QuantizedTensor, + _dim: usize, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_slice( + _tensor: QuantizedTensor, + _ranges: [Range; D2], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_argmax( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_argmin( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_expand( + _tensor: QuantizedTensor, + _shape: Shape, + ) -> QuantizedTensor { + unimplemented!() + } } diff --git a/crates/burn-core/src/module/quantize.rs b/crates/burn-core/src/module/quantize.rs index fbbad115a2..b6c34ee029 100644 --- a/crates/burn-core/src/module/quantize.rs +++ b/crates/burn-core/src/module/quantize.rs @@ -12,6 +12,10 @@ pub struct Quantizer { pub calibration: C, /// The quantization scheme. pub scheme: QuantizationScheme, + // TODO: dynamic quant? I think we won't support fully static (with observers to record the values on data samples) + // just yet so this is not required. + // /// Dynamic quantization computes the quantized parameters at runtime. + // pub dynamic: bool, } impl ModuleMapper for Quantizer { @@ -21,3 +25,12 @@ impl ModuleMapper for Quantizer { tensor.quantize(&self.scheme, qparams) } } + +// /// Describes how to quantize a module by providing quantizer settings for activations and weights respectively. +// pub struct QuantizationConfig { +// // TODO: quantization config +// /// The quantizer used to quantize the activations (i.e., a layer's output). +// // pub activations: Quantizer, +// /// The quantizer used to quantize the weights. +// pub weights: Quantizer, +// } diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 7d07054c10..34b06a3487 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -1,6 +1,8 @@ +use std::ops::Range; + use burn_tensor::{ backend::Backend, - ops::{QTensorOps, QuantizedTensor}, + ops::{IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, Device, Shape, TensorData, }; @@ -23,6 +25,13 @@ impl QTensorOps for Fusion { unimplemented!() } + fn quantize_dynamic( + _tensor: ::FloatTensorPrimitive, + _scheme: &QuantizationScheme, + ) -> QuantizedTensor { + unimplemented!() + } + fn dequantize( _tensor: ::QuantizedTensorPrimitive, ) -> ::FloatTensorPrimitive { @@ -37,6 +46,13 @@ impl QTensorOps for Fusion { tensor.qtensor.client.device().clone() } + fn q_to_device( + _tensor: QuantizedTensor, + _device: &Device, + ) -> QuantizedTensor { + unimplemented!() + } + fn q_reshape( _tensor: QuantizedTensor, _shape: Shape, @@ -47,4 +63,70 @@ impl QTensorOps for Fusion { async fn q_into_data(_tensor: QuantizedTensor) -> TensorData { unimplemented!() } + + fn q_swap_dims( + _tensor: QuantizedTensor, + _dim1: usize, + _dim2: usize, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_permute( + _tensor: QuantizedTensor, + _axes: [usize; D], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_flip( + _tensor: QuantizedTensor, + _axes: &[usize], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_gather( + _dim: usize, + _tensor: QuantizedTensor, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_select( + _tensor: QuantizedTensor, + _dim: usize, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_slice( + _tensor: QuantizedTensor, + _ranges: [Range; D2], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_argmax( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_argmin( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_expand( + _tensor: QuantizedTensor, + _shape: Shape, + ) -> QuantizedTensor { + unimplemented!() + } } diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index ae4aca6f0d..f5249b2cdc 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -7,6 +7,7 @@ pub(crate) fn from_data( data: TensorData, device: &R::Device, ) -> JitTensor { + // TODO: from_data QFloat should not convert let shape: Shape = (&data.shape).into(); let client = R::client(device); let buffer = client.create(data.convert::().as_bytes()); diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index ada52bbb07..cfd68637c6 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -1,5 +1,7 @@ +use std::ops::Range; + use burn_tensor::{ - ops::{FloatTensor, QTensorOps, QuantizedTensor}, + ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, Device, Shape, TensorData, }; @@ -39,6 +41,13 @@ where tensor.qtensor.device.clone() } + fn q_to_device( + _tensor: QuantizedTensor, + _device: &Device, + ) -> QuantizedTensor { + unimplemented!() + } + fn q_reshape( tensor: QuantizedTensor, shape: Shape, @@ -52,4 +61,70 @@ where async fn q_into_data(_tensor: QuantizedTensor) -> TensorData { unimplemented!() } + + fn q_swap_dims( + _tensor: QuantizedTensor, + _dim1: usize, + _dim2: usize, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_permute( + _tensor: QuantizedTensor, + _axes: [usize; D], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_flip( + _tensor: QuantizedTensor, + _axes: &[usize], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_gather( + _dim: usize, + _tensor: QuantizedTensor, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_select( + _tensor: QuantizedTensor, + _dim: usize, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_slice( + _tensor: QuantizedTensor, + _ranges: [Range; D2], + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_argmax( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_argmin( + _tensor: QuantizedTensor, + _dim: usize, + ) -> IntTensor { + unimplemented!() + } + + fn q_expand( + _tensor: QuantizedTensor, + _shape: Shape, + ) -> QuantizedTensor { + unimplemented!() + } } diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index 60d472937e..eed8915e7a 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -115,6 +115,15 @@ where NdArrayTensor::new(array) } + pub fn permute( + tensor: NdArrayTensor, + axes: [usize; D], + ) -> NdArrayTensor { + let array = tensor.array.permuted_axes(axes.into_dimension()); + + NdArrayTensor::new(array) + } + /// Broadcasts the tensor to the given shape pub(crate) fn expand( tensor: NdArrayTensor, diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index d5f855586e..300bb4a7fb 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -1,5 +1,7 @@ +use core::ops::Range; + use burn_tensor::{ - ops::{FloatTensor, QTensorOps, QuantizedTensor}, + ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ AffineQuantization, Quantization, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, @@ -12,7 +14,7 @@ use crate::{ FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, }; -use super::NdArrayOps; +use super::{NdArrayMathOps, NdArrayOps}; fn into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); @@ -99,6 +101,13 @@ impl QTensorOps for NdArray NdArrayDevice::Cpu } + fn q_to_device( + tensor: QuantizedTensor, + _device: &NdArrayDevice, + ) -> QuantizedTensor { + tensor + } + fn q_reshape( tensor: QuantizedTensor, shape: Shape, @@ -115,4 +124,98 @@ impl QTensorOps for NdArray let values = tensor.qtensor.array.into_iter().collect(); TensorData::quantized(values, shape, tensor.strategy) } + + fn q_swap_dims( + tensor: QuantizedTensor, + dim1: usize, + dim2: usize, + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: NdArrayOps::swap_dims(tensor.qtensor, dim1, dim2), + scheme: tensor.scheme, + strategy: tensor.strategy, + } + } + + fn q_permute( + tensor: QuantizedTensor, + axes: [usize; D], + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: NdArrayOps::permute(tensor.qtensor, axes), + scheme: tensor.scheme, + strategy: tensor.strategy, + } + } + + fn q_flip( + tensor: QuantizedTensor, + axes: &[usize], + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: NdArrayOps::flip(tensor.qtensor, axes), + scheme: tensor.scheme, + strategy: tensor.strategy, + } + } + + fn q_gather( + dim: usize, + tensor: QuantizedTensor, + indices: IntTensor, + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: NdArrayMathOps::gather(dim, tensor.qtensor, indices), + scheme: tensor.scheme, + strategy: tensor.strategy, + } + } + + fn q_select( + tensor: QuantizedTensor, + dim: usize, + indices: IntTensor, + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: NdArrayMathOps::select(tensor.qtensor, dim, indices), + scheme: tensor.scheme, + strategy: tensor.strategy, + } + } + + fn q_slice( + tensor: QuantizedTensor, + ranges: [Range; D2], + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: NdArrayOps::slice(tensor.qtensor, ranges), + scheme: tensor.scheme, + strategy: tensor.strategy, + } + } + + fn q_argmax( + tensor: QuantizedTensor, + dim: usize, + ) -> IntTensor { + NdArrayMathOps::argmax(tensor.qtensor, dim) + } + + fn q_argmin( + tensor: QuantizedTensor, + dim: usize, + ) -> IntTensor { + NdArrayMathOps::argmin(tensor.qtensor, dim) + } + + fn q_expand( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: NdArrayOps::expand(tensor.qtensor, shape), + scheme: tensor.scheme, + strategy: tensor.strategy, + } + } } diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index 640c3bc0fa..56fe096441 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -1,7 +1,7 @@ // Language use alloc::vec::Vec; use core::ops::Range; -use ndarray::{IntoDimension, Zip}; +use ndarray::Zip; // Current crate use super::{matmul::matmul, NdArrayMathOps, NdArrayOps}; @@ -481,8 +481,7 @@ impl FloatTensorOps for NdArray, axes: [usize; D], ) -> burn_tensor::ops::FloatTensor { - let array = tensor.array.permuted_axes(axes.into_dimension()); - NdArrayTensor { array } + NdArrayOps::permute(tensor, axes) } fn float_flip( diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index cc65774041..3271cac714 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -509,6 +509,15 @@ impl TchOps { TchTensor::new(tensor.tensor.sort(dim as i64, descending).0) } + pub fn sort_with_indices( + tensor: TchTensor, + dim: usize, + descending: bool, + ) -> (TchTensor, TchTensor) { + let sorted = tensor.tensor.sort(dim as i64, descending); + (TchTensor::new(sorted.0), TchTensor::new(sorted.1)) + } + pub fn argsort( tensor: TchTensor, dim: usize, diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 16d7641b1f..b42f93ede1 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -1,5 +1,7 @@ +use std::ops::Range; + use burn_tensor::{ - ops::{FloatTensor, QTensorOps, QuantizedTensor}, + ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ QTensorPrimitive, Quantization, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy, QuantizationType, @@ -91,6 +93,30 @@ impl QTensorOps for LibTorch { } } + fn quantize_dynamic( + tensor: FloatTensor, + scheme: &QuantizationScheme, + ) -> QuantizedTensor { + let qtensor = match &scheme { + QuantizationScheme::PerTensorAffine(dtype) => match dtype { + // Notes on `reduce_range`: + // https://github.com/pytorch/pytorch/issues/93140 + // https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + QuantizationType::QInt8 => tensor + .tensor + .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false), + }, + QuantizationScheme::PerTensorSymmetric(_) => { + panic!("LibTorch backend does not support symmetric quantize_dynamic") + } + }; + + TchQTensor { + qtensor: TchTensor::new(qtensor), + scheme: scheme.clone(), + } + } + fn dequantize(tensor: QuantizedTensor) -> FloatTensor { TchTensor::new(tensor.qtensor.tensor.dequantize().to_kind(E::KIND)) } @@ -103,6 +129,15 @@ impl QTensorOps for LibTorch { tensor.qtensor.tensor.device().into() } + fn q_to_device( + tensor: QuantizedTensor, + device: &burn_tensor::Device, + ) -> QuantizedTensor { + let mut tensor = tensor; + tensor.qtensor = TchOps::to_device(tensor.qtensor, device); + tensor + } + fn q_reshape( tensor: QuantizedTensor, shape: Shape, @@ -123,4 +158,166 @@ impl QTensorOps for LibTorch { TensorData::quantized(values.unwrap(), shape, strategy) } + + fn q_swap_dims( + tensor: QuantizedTensor, + dim1: usize, + dim2: usize, + ) -> QuantizedTensor { + // NOTE: with per-channel quantization (future), the channel axis could be impacted by this op + let mut tensor = tensor; + tensor.qtensor = TchOps::swap_dims(tensor.qtensor, dim1, dim2); + tensor + } + + fn q_permute( + tensor: QuantizedTensor, + axes: [usize; D], + ) -> QuantizedTensor { + // NOTE: with per-channel quantization (future), the channel axis could be impacted by this op + let mut tensor = tensor; + tensor.qtensor = TchOps::permute(tensor.qtensor, axes); + tensor + } + + fn q_flip( + tensor: QuantizedTensor, + axes: &[usize], + ) -> QuantizedTensor { + let mut tensor = tensor; + tensor.qtensor = TchOps::flip(tensor.qtensor, axes); + tensor + } + + fn q_gather( + dim: usize, + tensor: QuantizedTensor, + indices: IntTensor, + ) -> QuantizedTensor { + let mut tensor = tensor; + tensor.qtensor = TchOps::gather(dim, tensor.qtensor, indices); + tensor + } + + fn q_select( + tensor: QuantizedTensor, + dim: usize, + indices: IntTensor, + ) -> QuantizedTensor { + let mut tensor = tensor; + tensor.qtensor = TchOps::index_select_dim(tensor.qtensor, dim, indices); + tensor + } + + fn q_slice( + tensor: QuantizedTensor, + ranges: [Range; D2], + ) -> QuantizedTensor { + let mut tensor = tensor; + tensor.qtensor = TchOps::slice(tensor.qtensor, ranges); + tensor + } + + fn q_argmax( + tensor: QuantizedTensor, + dim: usize, + ) -> IntTensor { + TchOps::argmax(tensor.qtensor, dim) + } + + fn q_argmin( + tensor: QuantizedTensor, + dim: usize, + ) -> IntTensor { + TchOps::argmin(tensor.qtensor, dim) + } + + fn q_max_dim( + tensor: QuantizedTensor, + dim: usize, + ) -> QuantizedTensor { + TchQTensor { + qtensor: TchOps::max_dim(tensor.qtensor, dim), + scheme: tensor.scheme, + } + } + + fn q_min_dim( + tensor: QuantizedTensor, + dim: usize, + ) -> QuantizedTensor { + TchQTensor { + qtensor: TchOps::min_dim(tensor.qtensor, dim), + scheme: tensor.scheme, + } + } + + fn q_narrow( + tensor: QuantizedTensor, + dim: usize, + start: usize, + length: usize, + ) -> QuantizedTensor { + TchQTensor { + qtensor: TchOps::narrow(tensor.qtensor, dim, start, length), + scheme: tensor.scheme, + } + } + + fn q_chunk( + tensor: QuantizedTensor, + chunks: usize, + dim: usize, + ) -> Vec> { + TchOps::chunk(tensor.qtensor, chunks, dim) + .into_iter() + .map(|x| TchQTensor { + qtensor: x, + scheme: tensor.scheme.clone(), + }) + .collect() + } + + fn q_expand( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor { + // NOTE: with per-channel quantization (future), the channel axis could be impacted by this op + TchQTensor { + qtensor: TchOps::expand(tensor.qtensor, shape), + scheme: tensor.scheme, + } + } + + fn q_sort( + tensor: QuantizedTensor, + dim: usize, + descending: bool, + ) -> QuantizedTensor { + TchQTensor { + qtensor: TchOps::sort(tensor.qtensor, dim, descending), + scheme: tensor.scheme, + } + } + + fn q_sort_with_indices( + tensor: QuantizedTensor, + dim: usize, + descending: bool, + ) -> (QuantizedTensor, IntTensor) { + let (qtensor, indices) = TchOps::sort_with_indices(tensor.qtensor, dim, descending); + let tensor = TchQTensor { + qtensor, + scheme: tensor.scheme, + }; + (tensor, indices) + } + + fn q_argsort( + tensor: QuantizedTensor, + dim: usize, + descending: bool, + ) -> IntTensor { + TchOps::argsort(tensor.qtensor, dim, descending) + } } diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index e2b979ca5c..88008f3ff4 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -506,6 +506,14 @@ impl FloatTensorOps for LibTorch { TchOps::sort(tensor, dim, descending) } + fn float_sort_with_indices( + tensor: TchTensor, + dim: usize, + descending: bool, + ) -> (TchTensor, TchTensor) { + TchOps::sort_with_indices(tensor, dim, descending) + } + fn float_argsort( tensor: as Backend>::FloatTensorPrimitive, dim: usize, diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index 4b1eef3cba..2990bf5b38 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -334,6 +334,22 @@ where ))) } + // /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. + // /// + // /// # Arguments + // /// + // /// * `scheme` - The quantization scheme. + // /// + // /// # Returns + // /// + // /// The quantized tensor. + // pub fn quantize_dynamic(self, scheme: QuantizationScheme) -> Tensor { + // Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic( + // self.primitive.tensor(), + // scheme, + // ))) + // } + /// Convert the tensor back to a higher precision data type. /// /// If the tensor is not quantized, its value is simply returned. diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 1722bda7d4..0a637c7b5c 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -1,12 +1,13 @@ -use core::future::Future; +use alloc::vec::Vec; +use core::{future::Future, ops::Range}; use crate::{ backend::Backend, - quantization::{QuantizationParametersPrimitive, QuantizationScheme}, + quantization::{QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme}, Device, Shape, TensorData, }; -use super::{FloatTensor, QuantizedTensor}; +use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor}; /// Quantized Tensor API for basic operations, see [tensor](crate::Tensor) /// for documentation on each function. @@ -30,6 +31,18 @@ pub trait QTensorOps { qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor; + /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. + fn quantize_dynamic( + tensor: FloatTensor, + scheme: &QuantizationScheme, + ) -> QuantizedTensor { + // Dynamically compute min/max tensor range and qparams before quantizing + let min = B::float_min(tensor.clone()); + let max = B::float_max(tensor.clone()); + let qparams = scheme.compute_q_params_primitive(min, max); + Self::quantize(tensor, scheme, qparams) + } + /// Convert the tensor back to a higher precision data type. fn dequantize(tensor: QuantizedTensor) -> FloatTensor; @@ -55,6 +68,21 @@ pub trait QTensorOps { /// The device of the tensor. fn q_device(tensor: &QuantizedTensor) -> Device; + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device to move the tensor to. + /// + /// # Returns + /// + /// The tensor on the given device. + fn q_to_device( + tensor: QuantizedTensor, + device: &Device, + ) -> QuantizedTensor; + /// Reshapes a tensor. /// /// # Arguments @@ -83,6 +111,12 @@ pub trait QTensorOps { tensor: QuantizedTensor, ) -> impl Future + Send; + /// Detaches a tensor from the computation graph. + fn q_detach(tensor: QuantizedTensor) -> QuantizedTensor { + // Should only be overridden by autodiff backends. + tensor + } + /// Sets the `require_grad` flag of a tensor. fn q_set_require_grad( tensor: QuantizedTensor, @@ -97,4 +131,1514 @@ pub trait QTensorOps { // Should only be overridden by autodiff backends. false } + + /// Adds two tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn q_add( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> QuantizedTensor { + // Heuristic: prioritize lhs scheme + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + let out_f = B::float_add(lhs_f, rhs_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Adds a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of adding the scalar to the tensor. + fn q_add_scalar( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_add_scalar(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Clamps a tensor under a minimum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn q_clamp_min( + tensor: QuantizedTensor, + min: FloatElem, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_clamp_min(tensor_f, min); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Clamps a tensor over a maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn q_clamp_max( + tensor: QuantizedTensor, + max: FloatElem, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_clamp_max(tensor_f, max); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Clamps a tensor between a minimum and maximum value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to clamp. + /// * `min` - The minimum value. + /// * `max` - The maximum value. + /// + /// # Returns + /// + /// The clamped tensor. + fn q_clamp( + tensor: QuantizedTensor, + min: FloatElem, + max: FloatElem, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_clamp(tensor_f, min, max); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn q_sub( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> QuantizedTensor { + // Heuristic: prioritize lhs scheme + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + let out_f = B::float_sub(lhs_f, rhs_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Subtracts a scalar from a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of subtracting the scalar from the tensor. + fn q_sub_scalar( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_sub_scalar(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Multiplies two tensors together element-wise. + fn q_mul( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> QuantizedTensor { + // Heuristic: prioritize lhs scheme + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + let out_f = B::float_mul(lhs_f, rhs_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Multiplies a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of multiplying the tensor by the scalar. + fn q_mul_scalar( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_mul_scalar(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Divides two tensors element-wise. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn q_div( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> QuantizedTensor { + // Heuristic: prioritize lhs scheme + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + let out_f = B::float_div(lhs_f, rhs_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Divides a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of dividing the tensor by the scalar. + fn q_div_scalar( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_div_scalar(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Computes the modulus of a tensor given a scalar. + /// + /// # Arguments + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of applying the modulus of the scalar to the tensor. + fn q_remainder_scalar( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_remainder_scalar(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Multiplies two tensors together using matrix multiplication. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together using matrix multiplication. + fn q_matmul( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> QuantizedTensor { + // Heuristic: prioritize lhs scheme + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + let out_f = B::float_matmul(lhs_f, rhs_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Negates a tensor element-wise. + fn q_neg(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_neg(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Calculates the reciprocals element-wise + fn q_recip(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_recip(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Transposes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn q_transpose(tensor: QuantizedTensor) -> QuantizedTensor { + Self::q_swap_dims(tensor, D - 2, D - 1) + } + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn q_swap_dims( + tensor: QuantizedTensor, + dim1: usize, + dim2: usize, + ) -> QuantizedTensor; + + /// Permutes the dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to permute the dimensions of. + /// * `axes` - The new order of the dimensions. + /// # Returns + /// + /// The tensor with the dimensions permuted. + fn q_permute( + tensor: QuantizedTensor, + axes: [usize; D], + ) -> QuantizedTensor; + + /// Reverse the order of elements in a tensor along the given axes. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reverse. + /// * `axes` - The axes to reverse. + /// + /// The tensor with the elements reversed. + fn q_flip( + tensor: QuantizedTensor, + axes: &[usize], + ) -> QuantizedTensor; + + /// Gather elements from a tensor. + /// + /// # Arguments + /// + /// * `dim` - The dimension to gather from. + /// * `tensor` - The tensor to gather from. + /// * `indices` - The indices to gather. + /// + /// # Returns + /// + /// The gathered elements. + fn q_gather( + dim: usize, + tensor: QuantizedTensor, + indices: IntTensor, + ) -> QuantizedTensor; + + /// Scatter elements into a tensor. + /// + /// # Arguments + /// + /// * `dim` - The dimension to scatter into. + /// * `tensor` - The tensor to scatter into. + /// * `indices` - The indices to scatter into. + /// * `value` - The value to scatter. + /// + /// # Returns + /// + /// The tensor with the scattered elements. + fn q_scatter( + dim: usize, + tensor: QuantizedTensor, + indices: IntTensor, + value: QuantizedTensor, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let value_f = Self::dequantize(value); + let out_f = B::float_scatter(dim, tensor_f, indices, value_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Select tensor elements along the given dimension corresponding for the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices to select. + /// + /// # Returns + /// + /// The selected elements. + fn q_select( + tensor: QuantizedTensor, + dim: usize, + indices: IntTensor, + ) -> QuantizedTensor; + + /// Assign the selected elements along the given dimension corresponding for the given indices + /// to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn q_select_assign( + tensor: QuantizedTensor, + dim: usize, + indices: IntTensor, + value: QuantizedTensor, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let value_f = Self::dequantize(value); + let out_f = B::float_select_assign(tensor_f, dim, indices, value_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Select tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `ranges` - The ranges to select. + /// + /// # Returns + /// + /// The selected elements in a new tensor. + fn q_slice( + tensor: QuantizedTensor, + ranges: [Range; D2], + ) -> QuantizedTensor; + + /// Assign the selected elements corresponding for the given ranges to the given value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `ranges` - The ranges to select. + /// * `value` - The value to assign. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn q_slice_assign( + tensor: QuantizedTensor, + ranges: [Range; D2], + value: QuantizedTensor, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let value_f = Self::dequantize(value); + let out_f = B::float_slice_assign(tensor_f, ranges, value_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Update the given tensor with the value tensor where the mask is true. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `mask` - The boolean mask to select with. + /// * `value` - The value to assign to the selected elements from the value tensor. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn q_mask_where( + tensor: QuantizedTensor, + mask: BoolTensor, + value: QuantizedTensor, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let value_f = Self::dequantize(value); + let out_f = B::float_mask_where(tensor_f, mask, value_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Update the given tensor with the value where the mask is true. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `mask` - The boolean mask to select with. + /// * `value` - The value to assign to the selected elements. + /// + /// # Returns + /// + /// The tensor with the selected elements assigned to the given value. + fn q_mask_fill( + tensor: QuantizedTensor, + mask: BoolTensor, + value: FloatElem, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_mask_fill(tensor_f, mask, value); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_equal( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + + B::float_equal(lhs_f, rhs_f) + } + + /// Element-wise non-equality comparison. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_not_equal( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + + B::float_not_equal(lhs_f, rhs_f) + } + + /// Equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_equal_elem( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + + B::float_equal_elem(lhs_f, rhs) + } + + /// Element-wise non-equality comparison with a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_not_equal_elem( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + + B::float_not_equal_elem(lhs_f, rhs) + } + + /// Greater than comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_greater( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + + B::float_greater(lhs_f, rhs_f) + } + + /// Greater than comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_greater_elem( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + + B::float_greater_elem(lhs_f, rhs) + } + + /// Greater than or equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_greater_equal( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + + B::float_greater_equal(lhs_f, rhs_f) + } + + /// Greater than or equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_greater_equal_elem( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + + B::float_greater_equal_elem(lhs_f, rhs) + } + + /// Less than comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_lower( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + + B::float_lower(lhs_f, rhs_f) + } + + /// Less than comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_lower_elem( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + + B::float_lower_elem(lhs_f, rhs) + } + + /// Less than or equal comparison of two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_lower_equal( + lhs: QuantizedTensor, + rhs: QuantizedTensor, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + let rhs_f = Self::dequantize(rhs); + + B::float_lower_equal(lhs_f, rhs_f) + } + + /// Less than or equal comparison of a tensor and a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// A boolean tensor with the result of the comparison. + fn q_lower_equal_elem( + lhs: QuantizedTensor, + rhs: FloatElem, + ) -> BoolTensor { + let lhs_f = Self::dequantize(lhs); + + B::float_lower_equal_elem(lhs_f, rhs) + } + + /// Sum of all elements in a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// + /// # Returns + /// + /// A scalar tensor with the sum of all elements in `tensor`. + fn q_sum(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_sum(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Sum of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// A tensor with the sum of all elements in `tensor` along `dim`. + fn q_sum_dim( + tensor: QuantizedTensor, + dim: usize, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_sum_dim(tensor_f, dim); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Product of all elements in a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to product. + /// + /// # Returns + /// + /// A scalar tensor with the product of all elements in `tensor`. + fn q_prod(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_prod(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Product of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to product. + /// + /// # Returns + /// + /// A tensor with the product of all elements in `tensor` along `dim`. + fn q_prod_dim( + tensor: QuantizedTensor, + dim: usize, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_prod_dim(tensor_f, dim); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Mean of all elements in a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to mean. + /// + /// # Returns + /// + /// A scalar tensor with the mean of all elements in `tensor`. + fn q_mean(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_mean(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Mean of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to mean. + /// * `dim` - The dimension along which to mean. + /// + /// # Returns + /// + /// A tensor with the mean of all elements in `tensor` along `dim`. + fn q_mean_dim( + tensor: QuantizedTensor, + dim: usize, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_mean_dim(tensor_f, dim); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with exponential values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to exponentiate. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with exponential values. + fn q_exp(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_exp(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with natural logarithm values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the logarithm of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with natural logarithm values. + fn q_log(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_log(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with logarithm values of (1 + Xi). + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the logarithm of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). + fn q_log1p(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_log1p(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Element-wise power with a FloatTensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The elements of `lhs` raised to the power of the elements of `rhs`. + fn q_powf( + lhs: QuantizedTensor, + rhs: FloatTensor, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_powf(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Element-wise power with an IntTensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side floatTensor. + /// + /// # Returns + /// + /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. + fn q_powi( + lhs: QuantizedTensor, + rhs: IntTensor, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_powi(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Element-wise power with an int scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The elements of `lhs` raised to the value of `rhs`. + fn q_powi_scalar( + lhs: QuantizedTensor, + rhs: IntElem, + ) -> QuantizedTensor { + let scheme = lhs.scheme().clone(); + + let lhs_f = Self::dequantize(lhs); + let out_f = B::float_powi_scalar(lhs_f, rhs); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Element-wise power with a float scalar. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to exponentiate. + /// * `value` - The exponent. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with values raised to the power of `value`. + fn q_powf_scalar( + tensor: QuantizedTensor, + value: f32, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_powf_scalar(tensor_f, value); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with square root values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the square root of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with square root values. + fn q_sqrt(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_sqrt(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with absolute values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take absolute value of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with absolute values. + fn q_abs(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_abs(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with cosine values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the cosine of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with cosine values. + fn q_cos(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_cos(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with sine values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the sine of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with sine values. + fn q_sin(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_sin(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with tangent values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the tangent of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with tangent values. + fn q_tanh(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_tanh(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Returns a new tensor with the error function values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to take the error function of. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with error function values. + fn q_erf(tensor: QuantizedTensor) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_erf(tensor_f); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Concatenates tensors along a dimension. + /// + /// # Arguments + /// + /// * `tensors` - The tensors to concatenate. + /// * `dim` - The dimension along which to concatenate. + /// + /// # Returns + /// + /// A tensor with the concatenated tensors along `dim`. + fn q_cat( + tensors: Vec>, + dim: usize, + ) -> QuantizedTensor { + // Heuristic: prioritize first tensor scheme + let scheme = tensors.first().unwrap().scheme().clone(); + + let tensor_f = tensors + .into_iter() + .map(|tensor| Self::dequantize(tensor)) + .collect(); + + let out_f = B::float_cat(tensor_f, dim); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Gets the indices of the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the indices of the maximum elements of `tensor` along `dim`. + fn q_argmax(tensor: QuantizedTensor, dim: usize) -> IntTensor; + + /// Gets the indices of the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the indices of the minimum elements of `tensor` along `dim`. + fn q_argmin(tensor: QuantizedTensor, dim: usize) -> IntTensor; + + /// Gets the maximum element of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// + /// # Returns + /// + /// A tensor with the maximum element of `tensor`. + fn q_max(tensor: QuantizedTensor) -> QuantizedTensor { + let shape = B::q_shape(&tensor); + let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); + + B::q_max_dim(tensor, 0) + } + + /// Gets the maximum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tensor with the maximum elements of `tensor` along `dim`. + fn q_max_dim( + tensor: QuantizedTensor, + dim: usize, + ) -> QuantizedTensor { + let index = B::q_argmax(tensor.clone(), dim); + + B::q_gather(dim, tensor, index) + } + + /// Gets the maximum elements of a tensor along an axis and their indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the maximum elements of. + /// * `dim` - The dimension along which to get the maximum elements. + /// + /// # Returns + /// + /// A tuple with the maximum elements of `tensor` along `dim` and their indices. + fn q_max_dim_with_indices( + tensor: QuantizedTensor, + dim: usize, + ) -> (QuantizedTensor, IntTensor) { + let index = B::q_argmax(tensor.clone(), dim); + let values = B::q_gather(dim, tensor, index.clone()); + + (values, index) + } + + /// Gets the minimum element of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// + /// # Returns + /// + /// A tensor with the minimum element of `tensor`. + fn q_min(tensor: QuantizedTensor) -> QuantizedTensor { + let shape = B::q_shape(&tensor); + let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); + + B::q_min_dim(tensor, 0) + } + + /// Gets the minimum elements of a tensor along an axis. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tensor with the minimum elements of `tensor` along `dim`. + fn q_min_dim( + tensor: QuantizedTensor, + dim: usize, + ) -> QuantizedTensor { + let index = B::q_argmin(tensor.clone(), dim); + + B::q_gather(dim, tensor, index) + } + + /// Gets the minimum elements of a tensor along an axis and their indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to get the minimum elements of. + /// * `dim` - The dimension along which to get the minimum elements. + /// + /// # Returns + /// + /// A tuple with the minimum elements of `tensor` along `dim` and their indices. + fn q_min_dim_with_indices( + tensor: QuantizedTensor, + dim: usize, + ) -> (QuantizedTensor, IntTensor) { + let index = B::q_argmin(tensor.clone(), dim); + let values = B::q_gather(dim, tensor, index.clone()); + + (values, index) + } + + /// Returns a new tensor with the given dimension narrowed to the given range. + /// + /// # Arguments + /// + /// * `dim` - The dimension along which the tensor will be narrowed. + /// * `start` - The starting point of the given range. + /// * `length` - The ending point of the given range. + /// # Panics + /// + /// - If the dimension is greater than the number of dimensions of the tensor. + /// - If the given range exceeds the number of elements on the given dimension. + /// + /// # Returns + /// + /// A new tensor with the given dimension narrowed to the given range. + fn q_narrow( + tensor: QuantizedTensor, + dim: usize, + start: usize, + length: usize, + ) -> QuantizedTensor { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_narrow(tensor_f, dim, start, length); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Split the tensor along the given dimension into chunks. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `chunks` - The number of chunks to be produced + /// * `times` - The dimension along which the tensor will be split. + /// + /// # Returns + /// + /// A vector of tensors + fn q_chunk( + tensor: QuantizedTensor, + chunks: usize, + dim: usize, + ) -> Vec> { + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_chunk(tensor_f, chunks, dim); + + out_f + .into_iter() + .map(|tensor| Self::quantize_dynamic(tensor, &scheme)) + .collect() + } + + /// Tests if any element in the `tensor` evaluates to True. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to test. + /// + /// # Returns + /// + /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. + fn q_any(tensor: QuantizedTensor) -> BoolTensor { + let tensor_f = Self::dequantize(tensor); + B::float_any(tensor_f) + } + + /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to test. + /// * `dim` - The axis along which to test. + /// + /// # Returns + /// + /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis + /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the + /// input evaluates to True, False otherwise. + fn q_any_dim(tensor: QuantizedTensor, dim: usize) -> BoolTensor { + let tensor_f = Self::dequantize(tensor); + B::float_any_dim(tensor_f, dim) + } + + /// Tests if all elements in the `tensor` evaluate to True. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to test. + /// + /// # Returns + /// + /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor + /// evaluate to True, False otherwise. + fn q_all(tensor: QuantizedTensor) -> BoolTensor { + let tensor_f = Self::dequantize(tensor); + B::float_all(tensor_f) + } + + /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to test. + /// * `dim` - The axis along which to test. + /// + /// # Returns + /// + /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis + /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input + /// evaluates to True, False otherwise. + fn q_all_dim(tensor: QuantizedTensor, dim: usize) -> BoolTensor { + let tensor_f = Self::dequantize(tensor); + B::float_all_dim(tensor_f, dim) + } + + /// Broadcasts the `tensor` to the given `shape`. + fn q_expand( + tensor: QuantizedTensor, + shape: Shape, + ) -> QuantizedTensor; + + /// Sort the elements of the input `tensor` by value in along a given dimension. + /// + /// This sort is unstable (i.e., may reorder equal elements). + /// + /// # Arguments + /// + /// * `tensor` - The input tensor. + /// * `dim` - The axis along which to sort. + /// * `descending` - The sorting order. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where the elements are sorted by value. + fn q_sort( + tensor: QuantizedTensor, + dim: usize, + descending: bool, + ) -> QuantizedTensor { + // Default implementation. Backends can sort on the int values since qparams remain the same. + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let out_f = B::float_sort(tensor_f, dim, descending); + + Self::quantize_dynamic(out_f, &scheme) + } + + /// Sort the elements of the input `tensor` by value in along a given dimension. + /// + /// This sort is unstable (i.e., may reorder equal elements). + /// + /// # Arguments + /// + /// * `tensor` - The input tensor. + /// * `dim` - The axis along which to sort. + /// * `descending` - The sorting order. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor and corresponding indices, where + /// the elements are sorted by value and the indices map back to the original input tensor. + fn q_sort_with_indices( + tensor: QuantizedTensor, + dim: usize, + descending: bool, + ) -> (QuantizedTensor, IntTensor) { + // Default implementation. Backends can sort on the int values since qparams remain the same. + let scheme = tensor.scheme().clone(); + + let tensor_f = Self::dequantize(tensor); + let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending); + + (Self::quantize_dynamic(out_f, &scheme), indices) + } + + /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension. + /// + /// This sort is unstable (i.e., may reorder equal elements). + /// + /// # Arguments + /// + /// * `tensor` - The input tensor. + /// * `dim` - The axis along which to sort. + /// * `descending` - The sorting order. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. + fn q_argsort( + tensor: QuantizedTensor, + dim: usize, + descending: bool, + ) -> IntTensor { + // Default implementation. Backends can sort on the int values since qparams remain the same. + let tensor_f = Self::dequantize(tensor); + B::float_argsort(tensor_f, dim, descending) + } } diff --git a/crates/burn-tensor/src/tensor/quantization/calibration.rs b/crates/burn-tensor/src/tensor/quantization/calibration.rs index c8060f6547..b7aac15d34 100644 --- a/crates/burn-tensor/src/tensor/quantization/calibration.rs +++ b/crates/burn-tensor/src/tensor/quantization/calibration.rs @@ -32,3 +32,25 @@ impl Calibration for MinMaxCalibration { CalibrationRange { min, max } } } + +// Observers keep a running min/max, so for static quantization this can be computed multiple times w/ representative data to get the "global" min/max + +// pub struct PerChannelCalibrationSettings { +// pub dtype: QuantizationType, +// pub symmetric: bool, +// } + +// For now, we only support static quantization. Since the tensor is dequantized to a float at the first operation, the remaining operations will all be performed on floats anyways. +// But to test dynamic quantization, just make the first layer use dynamic quantization. + +/* +let q_activation = Quantizer { + calibration: MinMaxCalibration {scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)}, + dynamic: true, +}; +let q_weights = Quantizer { + calibration: MinMaxCalibration {scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)}, + dynamic: false, +} + +*/ diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index 03a5da58f5..f93c79a5e5 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -1,6 +1,6 @@ -use crate::{backend::Backend, Tensor}; +use crate::{backend::Backend, Tensor, TensorPrimitive}; -use super::{CalibrationRange, QuantizationParameters}; +use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive}; /// Quantization data type. #[derive(Clone, Debug, PartialEq)] @@ -65,4 +65,17 @@ impl QuantizationScheme { }, } } + + /// Compute the quantization parameters. + pub(crate) fn compute_q_params_primitive( + &self, + min: B::FloatTensorPrimitive<1>, + max: B::FloatTensorPrimitive<1>, + ) -> QuantizationParametersPrimitive { + let range = CalibrationRange { + min: Tensor::from_primitive(TensorPrimitive::Float(min)), + max: Tensor::from_primitive(TensorPrimitive::Float(max)), + }; + self.compute_q_params(range).into() + } } From 65e83824a5a22b2c3c0bee426f74348c58796ee6 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 19 Jul 2024 09:53:20 -0400 Subject: [PATCH 02/16] Refactor q_* ops w/ dequant_op_quant macro --- crates/burn-tensor/src/tensor/ops/qtensor.rs | 400 +++++++++---------- 1 file changed, 200 insertions(+), 200 deletions(-) diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 0a637c7b5c..37502dafe2 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -9,6 +9,37 @@ use crate::{ use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor}; +/// Automatically applies dequantization -> float operation -> quantization. +#[macro_export] +macro_rules! dequant_op_quant { + // Binary tensor float op w/ lhs & rhs + ( + ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr + ) => {{ + // Heuristic: prioritize lhs scheme + let scheme = $t1.scheme().clone(); + + let t1_f = <$ty>::dequantize($t1); + let t2_f = <$ty>::dequantize($t2); + #[allow(clippy::redundant_closure_call)] + let out_f = $float_op(t1_f, t2_f); + + <$ty>::quantize_dynamic(out_f, &scheme) + }}; + // Unary tensor float op + ( + ty $ty:ty, float_op $float_op:expr, $tensor:expr + ) => {{ + let scheme = $tensor.scheme().clone(); + + let tensor_f = <$ty>::dequantize($tensor); + #[allow(clippy::redundant_closure_call)] + let out_f = $float_op(tensor_f); + + <$ty>::quantize_dynamic(out_f, &scheme) + }}; +} + /// Quantized Tensor API for basic operations, see [tensor](crate::Tensor) /// for documentation on each function. pub trait QTensorOps { @@ -146,14 +177,13 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: QuantizedTensor, ) -> QuantizedTensor { - // Heuristic: prioritize lhs scheme - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - let out_f = B::float_add(lhs_f, rhs_f); - - Self::quantize_dynamic(out_f, &scheme) + // dequant_op_quant!(Self, B::float_add, lhs, rhs) + dequant_op_quant!( + ty Self, + float_op |lhs, rhs| B::float_add(lhs, rhs), + lhs, + rhs + ) } /// Adds a scalar to a tensor. @@ -260,14 +290,13 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: QuantizedTensor, ) -> QuantizedTensor { - // Heuristic: prioritize lhs scheme - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - let out_f = B::float_sub(lhs_f, rhs_f); - - Self::quantize_dynamic(out_f, &scheme) + // dequant_op_quant!(Self, B::float_sub, lhs, rhs) + dequant_op_quant!( + ty Self, + float_op |lhs, rhs| B::float_sub(lhs, rhs), + lhs, + rhs + ) } /// Subtracts a scalar from a tensor. @@ -297,14 +326,13 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: QuantizedTensor, ) -> QuantizedTensor { - // Heuristic: prioritize lhs scheme - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - let out_f = B::float_mul(lhs_f, rhs_f); - - Self::quantize_dynamic(out_f, &scheme) + // dequant_op_quant!(Self, B::float_mul, lhs, rhs) + dequant_op_quant!( + ty Self, + float_op |lhs, rhs| B::float_mul(lhs, rhs), + lhs, + rhs + ) } /// Multiplies a tensor by a scalar. @@ -343,14 +371,13 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: QuantizedTensor, ) -> QuantizedTensor { - // Heuristic: prioritize lhs scheme - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - let out_f = B::float_div(lhs_f, rhs_f); - - Self::quantize_dynamic(out_f, &scheme) + // dequant_op_quant!(Self, B::float_div, lhs, rhs) + dequant_op_quant!( + ty Self, + float_op |lhs, rhs| B::float_div(lhs, rhs), + lhs, + rhs + ) } /// Divides a tensor by a scalar. @@ -410,14 +437,13 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: QuantizedTensor, ) -> QuantizedTensor { - // Heuristic: prioritize lhs scheme - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - let out_f = B::float_matmul(lhs_f, rhs_f); - - Self::quantize_dynamic(out_f, &scheme) + // dequant_op_quant!(Self, B::float_matmul, lhs, rhs) + dequant_op_quant!( + ty Self, + float_op |lhs, rhs| B::float_matmul(lhs, rhs), + lhs, + rhs + ) } /// Negates a tensor element-wise. @@ -532,13 +558,12 @@ pub trait QTensorOps { indices: IntTensor, value: QuantizedTensor, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let value_f = Self::dequantize(value); - let out_f = B::float_scatter(dim, tensor_f, indices, value_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor, value| B::float_scatter(dim, tensor, indices, value), + tensor, + value + ) } /// Select tensor elements along the given dimension corresponding for the given indices. @@ -577,13 +602,12 @@ pub trait QTensorOps { indices: IntTensor, value: QuantizedTensor, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let value_f = Self::dequantize(value); - let out_f = B::float_select_assign(tensor_f, dim, indices, value_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor, value| B::float_select_assign(tensor, dim, indices, value), + tensor, + value + ) } /// Select tensor elements corresponding for the given ranges. @@ -617,13 +641,12 @@ pub trait QTensorOps { ranges: [Range; D2], value: QuantizedTensor, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let value_f = Self::dequantize(value); - let out_f = B::float_slice_assign(tensor_f, ranges, value_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor, value| B::float_slice_assign(tensor, ranges, value), + tensor, + value + ) } /// Update the given tensor with the value tensor where the mask is true. @@ -642,13 +665,12 @@ pub trait QTensorOps { mask: BoolTensor, value: QuantizedTensor, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let value_f = Self::dequantize(value); - let out_f = B::float_mask_where(tensor_f, mask, value_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor, value| B::float_mask_where(tensor, mask, value), + tensor, + value + ) } /// Update the given tensor with the value where the mask is true. @@ -667,12 +689,11 @@ pub trait QTensorOps { mask: BoolTensor, value: FloatElem, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_mask_fill(tensor_f, mask, value); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_mask_fill(tensor, mask, value), + tensor + ) } /// Equal comparison of two tensors. @@ -919,12 +940,11 @@ pub trait QTensorOps { /// /// A scalar tensor with the sum of all elements in `tensor`. fn q_sum(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_sum(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_sum(tensor), + tensor + ) } /// Sum of all elements in a tensor along a dimension. @@ -941,12 +961,11 @@ pub trait QTensorOps { tensor: QuantizedTensor, dim: usize, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_sum_dim(tensor_f, dim); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_sum_dim(tensor, dim), + tensor + ) } /// Product of all elements in a tensor. @@ -959,12 +978,11 @@ pub trait QTensorOps { /// /// A scalar tensor with the product of all elements in `tensor`. fn q_prod(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_prod(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_prod(tensor), + tensor + ) } /// Product of all elements in a tensor along a dimension. @@ -980,12 +998,11 @@ pub trait QTensorOps { tensor: QuantizedTensor, dim: usize, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_prod_dim(tensor_f, dim); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_prod_dim(tensor, dim), + tensor + ) } /// Mean of all elements in a tensor. @@ -998,12 +1015,11 @@ pub trait QTensorOps { /// /// A scalar tensor with the mean of all elements in `tensor`. fn q_mean(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_mean(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_mean(tensor), + tensor + ) } /// Mean of all elements in a tensor along a dimension. @@ -1020,12 +1036,11 @@ pub trait QTensorOps { tensor: QuantizedTensor, dim: usize, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_mean_dim(tensor_f, dim); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_mean_dim(tensor, dim), + tensor + ) } /// Returns a new tensor with exponential values. @@ -1038,12 +1053,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with exponential values. fn q_exp(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_exp(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_exp(tensor), + tensor + ) } /// Returns a new tensor with natural logarithm values. @@ -1056,12 +1070,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with natural logarithm values. fn q_log(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_log(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_log(tensor), + tensor + ) } /// Returns a new tensor with logarithm values of (1 + Xi). @@ -1074,12 +1087,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). fn q_log1p(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_log1p(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_log1p(tensor), + tensor + ) } /// Element-wise power with a FloatTensor. @@ -1096,12 +1108,11 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: FloatTensor, ) -> QuantizedTensor { - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_powf(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_powf(tensor, rhs), + lhs + ) } /// Element-wise power with an IntTensor. @@ -1118,12 +1129,11 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: IntTensor, ) -> QuantizedTensor { - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_powi(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_powi(tensor, rhs), + lhs + ) } /// Element-wise power with an int scalar. @@ -1140,12 +1150,11 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: IntElem, ) -> QuantizedTensor { - let scheme = lhs.scheme().clone(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_powi_scalar(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_powi_scalar(tensor, rhs), + lhs + ) } /// Element-wise power with a float scalar. @@ -1162,12 +1171,11 @@ pub trait QTensorOps { tensor: QuantizedTensor, value: f32, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_powf_scalar(tensor_f, value); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_powf_scalar(tensor, value), + tensor + ) } /// Returns a new tensor with square root values. @@ -1180,12 +1188,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with square root values. fn q_sqrt(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_sqrt(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_sqrt(tensor), + tensor + ) } /// Returns a new tensor with absolute values. @@ -1198,12 +1205,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with absolute values. fn q_abs(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_abs(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_abs(tensor), + tensor + ) } /// Returns a new tensor with cosine values. @@ -1216,12 +1222,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with cosine values. fn q_cos(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_cos(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_cos(tensor), + tensor + ) } /// Returns a new tensor with sine values. @@ -1234,12 +1239,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with sine values. fn q_sin(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_sin(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_sin(tensor), + tensor + ) } /// Returns a new tensor with tangent values. @@ -1252,12 +1256,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with tangent values. fn q_tanh(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_tanh(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_tanh(tensor), + tensor + ) } /// Returns a new tensor with the error function values. @@ -1270,12 +1273,11 @@ pub trait QTensorOps { /// /// A tensor with the same shape as `tensor` with error function values. fn q_erf(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_erf(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_erf(tensor), + tensor + ) } /// Concatenates tensors along a dimension. @@ -1460,12 +1462,11 @@ pub trait QTensorOps { start: usize, length: usize, ) -> QuantizedTensor { - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_narrow(tensor_f, dim, start, length); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_narrow(tensor, dim, start, length), + tensor + ) } /// Split the tensor along the given dimension into chunks. @@ -1583,12 +1584,11 @@ pub trait QTensorOps { descending: bool, ) -> QuantizedTensor { // Default implementation. Backends can sort on the int values since qparams remain the same. - let scheme = tensor.scheme().clone(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_sort(tensor_f, dim, descending); - - Self::quantize_dynamic(out_f, &scheme) + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_sort(tensor, dim, descending), + tensor + ) } /// Sort the elements of the input `tensor` by value in along a given dimension. From a87a54c90e1d987361caac648633a9e501d3ad37 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 22 Jul 2024 14:24:31 -0400 Subject: [PATCH 03/16] Comparison ops are already implemented by default to compare dequantized values --- crates/burn-tensor/src/tensor/ops/qtensor.rs | 235 ------------------- 1 file changed, 235 deletions(-) diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 37502dafe2..4db2614f1f 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -177,7 +177,6 @@ pub trait QTensorOps { lhs: QuantizedTensor, rhs: QuantizedTensor, ) -> QuantizedTensor { - // dequant_op_quant!(Self, B::float_add, lhs, rhs) dequant_op_quant!( ty Self, float_op |lhs, rhs| B::float_add(lhs, rhs), @@ -696,240 +695,6 @@ pub trait QTensorOps { ) } - /// Equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_equal( - lhs: QuantizedTensor, - rhs: QuantizedTensor, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - - B::float_equal(lhs_f, rhs_f) - } - - /// Element-wise non-equality comparison. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_not_equal( - lhs: QuantizedTensor, - rhs: QuantizedTensor, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - - B::float_not_equal(lhs_f, rhs_f) - } - - /// Equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_equal_elem( - lhs: QuantizedTensor, - rhs: FloatElem, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - - B::float_equal_elem(lhs_f, rhs) - } - - /// Element-wise non-equality comparison with a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_not_equal_elem( - lhs: QuantizedTensor, - rhs: FloatElem, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - - B::float_not_equal_elem(lhs_f, rhs) - } - - /// Greater than comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_greater( - lhs: QuantizedTensor, - rhs: QuantizedTensor, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - - B::float_greater(lhs_f, rhs_f) - } - - /// Greater than comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_greater_elem( - lhs: QuantizedTensor, - rhs: FloatElem, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - - B::float_greater_elem(lhs_f, rhs) - } - - /// Greater than or equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_greater_equal( - lhs: QuantizedTensor, - rhs: QuantizedTensor, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - - B::float_greater_equal(lhs_f, rhs_f) - } - - /// Greater than or equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_greater_equal_elem( - lhs: QuantizedTensor, - rhs: FloatElem, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - - B::float_greater_equal_elem(lhs_f, rhs) - } - - /// Less than comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_lower( - lhs: QuantizedTensor, - rhs: QuantizedTensor, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - - B::float_lower(lhs_f, rhs_f) - } - - /// Less than comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_lower_elem( - lhs: QuantizedTensor, - rhs: FloatElem, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - - B::float_lower_elem(lhs_f, rhs) - } - - /// Less than or equal comparison of two tensors. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_lower_equal( - lhs: QuantizedTensor, - rhs: QuantizedTensor, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - let rhs_f = Self::dequantize(rhs); - - B::float_lower_equal(lhs_f, rhs_f) - } - - /// Less than or equal comparison of a tensor and a scalar. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// A boolean tensor with the result of the comparison. - fn q_lower_equal_elem( - lhs: QuantizedTensor, - rhs: FloatElem, - ) -> BoolTensor { - let lhs_f = Self::dequantize(lhs); - - B::float_lower_equal_elem(lhs_f, rhs) - } - /// Sum of all elements in a tensor. /// /// # Arguments From e7166352fead3e6501bc6fb62c9e779834ea03df Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 23 Jul 2024 14:33:05 -0400 Subject: [PATCH 04/16] Add default arg min/max implementation and fix tch implementation --- crates/burn-autodiff/src/ops/qtensor.rs | 12 ++++++------ crates/burn-candle/src/ops/qtensor.rs | 14 -------------- crates/burn-fusion/src/ops/qtensor.rs | 14 -------------- crates/burn-jit/src/ops/qtensor.rs | 14 -------------- crates/burn-tch/src/ops/qtensor.rs | 10 ++++++++-- crates/burn-tensor/src/tensor/ops/qtensor.rs | 12 ++++++++++-- 6 files changed, 24 insertions(+), 52 deletions(-) diff --git a/crates/burn-autodiff/src/ops/qtensor.rs b/crates/burn-autodiff/src/ops/qtensor.rs index 061a3d5db1..46956d3497 100644 --- a/crates/burn-autodiff/src/ops/qtensor.rs +++ b/crates/burn-autodiff/src/ops/qtensor.rs @@ -108,17 +108,17 @@ impl QTensorOps for Autodiff { } fn q_argmax( - _tensor: QuantizedTensor, - _dim: usize, + tensor: QuantizedTensor, + dim: usize, ) -> IntTensor { - unimplemented!() + B::q_argmax(tensor, dim) } fn q_argmin( - _tensor: QuantizedTensor, - _dim: usize, + tensor: QuantizedTensor, + dim: usize, ) -> IntTensor { - unimplemented!() + B::q_argmin(tensor, dim) } fn q_expand( diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs index 038da97521..157ebae305 100644 --- a/crates/burn-candle/src/ops/qtensor.rs +++ b/crates/burn-candle/src/ops/qtensor.rs @@ -106,20 +106,6 @@ impl QTensorOps for Candle( - _tensor: QuantizedTensor, - _dim: usize, - ) -> IntTensor { - unimplemented!() - } - - fn q_argmin( - _tensor: QuantizedTensor, - _dim: usize, - ) -> IntTensor { - unimplemented!() - } - fn q_expand( _tensor: QuantizedTensor, _shape: Shape, diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 34b06a3487..df30c43b10 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -109,20 +109,6 @@ impl QTensorOps for Fusion { unimplemented!() } - fn q_argmax( - _tensor: QuantizedTensor, - _dim: usize, - ) -> IntTensor { - unimplemented!() - } - - fn q_argmin( - _tensor: QuantizedTensor, - _dim: usize, - ) -> IntTensor { - unimplemented!() - } - fn q_expand( _tensor: QuantizedTensor, _shape: Shape, diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index cfd68637c6..d6ccb79add 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -107,20 +107,6 @@ where unimplemented!() } - fn q_argmax( - _tensor: QuantizedTensor, - _dim: usize, - ) -> IntTensor { - unimplemented!() - } - - fn q_argmin( - _tensor: QuantizedTensor, - _dim: usize, - ) -> IntTensor { - unimplemented!() - } - fn q_expand( _tensor: QuantizedTensor, _shape: Shape, diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index b42f93ede1..8453e31ee1 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -222,14 +222,20 @@ impl QTensorOps for LibTorch { tensor: QuantizedTensor, dim: usize, ) -> IntTensor { - TchOps::argmax(tensor.qtensor, dim) + TchOps::argmax( + TchTensor::::new(tensor.qtensor.tensor.int_repr()), + dim, + ) } fn q_argmin( tensor: QuantizedTensor, dim: usize, ) -> IntTensor { - TchOps::argmin(tensor.qtensor, dim) + TchOps::argmin( + TchTensor::::new(tensor.qtensor.tensor.int_repr()), + dim, + ) } fn q_max_dim( diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 4db2614f1f..fd3bf2704b 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -1082,7 +1082,11 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the indices of the maximum elements of `tensor` along `dim`. - fn q_argmax(tensor: QuantizedTensor, dim: usize) -> IntTensor; + fn q_argmax(tensor: QuantizedTensor, dim: usize) -> IntTensor { + // Default implementation. Backends can sort on the int values since qparams remain the same. + let tensor_f = Self::dequantize(tensor); + B::float_argmax(tensor_f, dim) + } /// Gets the indices of the minimum elements of a tensor along an axis. /// @@ -1094,7 +1098,11 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the indices of the minimum elements of `tensor` along `dim`. - fn q_argmin(tensor: QuantizedTensor, dim: usize) -> IntTensor; + fn q_argmin(tensor: QuantizedTensor, dim: usize) -> IntTensor { + // Default implementation. Backends can sort on the int values since qparams remain the same. + let tensor_f = Self::dequantize(tensor); + B::float_argmin(tensor_f, dim) + } /// Gets the maximum element of a tensor. /// From 1554eb374e0479aaa0dd0b9addd7ac1bbc22ba68 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 24 Jul 2024 10:02:19 -0400 Subject: [PATCH 05/16] Avoid division by zero scale --- .../burn-tensor/src/tensor/quantization/strategy.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/burn-tensor/src/tensor/quantization/strategy.rs b/crates/burn-tensor/src/tensor/quantization/strategy.rs index fd3f77c326..1e13935e33 100644 --- a/crates/burn-tensor/src/tensor/quantization/strategy.rs +++ b/crates/burn-tensor/src/tensor/quantization/strategy.rs @@ -62,6 +62,12 @@ pub struct AffineQuantization { impl AffineQuantization { /// Initialize an affine quantization scheme with the given parameters. pub fn init(scale: E, offset: Q) -> Self { + let mut scale = scale; + // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the + // scale to 0.1 to avoid division by zero. + if scale.eq(&E::zero()) { + scale = E::from(0.1).unwrap(); + } Self { scale, offset, @@ -132,6 +138,12 @@ pub struct SymmetricQuantization { impl SymmetricQuantization { /// Initialize a symmetric quantization scheme with the given parameters. pub fn init(scale: E) -> Self { + let mut scale = scale; + // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the + // scale to 0.1 to avoid division by zero. + if scale.eq(&E::zero()) { + scale = E::from(0.1).unwrap(); + } Self { scale, _q: PhantomData, From 768a1feddbe04db7ae5af9ca1e6b63e9b5cc379b Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 24 Jul 2024 13:40:25 -0400 Subject: [PATCH 06/16] Add default q_gather implementation (tch does not support on quantized tensor) --- crates/burn-tch/src/ops/qtensor.rs | 10 ---------- crates/burn-tensor/src/tensor/ops/qtensor.rs | 9 ++++++++- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 8453e31ee1..d8dba7498f 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -189,16 +189,6 @@ impl QTensorOps for LibTorch { tensor } - fn q_gather( - dim: usize, - tensor: QuantizedTensor, - indices: IntTensor, - ) -> QuantizedTensor { - let mut tensor = tensor; - tensor.qtensor = TchOps::gather(dim, tensor.qtensor, indices); - tensor - } - fn q_select( tensor: QuantizedTensor, dim: usize, diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index fd3bf2704b..66fa7126ab 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -537,7 +537,14 @@ pub trait QTensorOps { dim: usize, tensor: QuantizedTensor, indices: IntTensor, - ) -> QuantizedTensor; + ) -> QuantizedTensor { + // Default implementation. Backends can gather on the quantized values when supported. + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_gather(dim, tensor, indices), + tensor + ) + } /// Scatter elements into a tensor. /// From 0dd18aef72c594977c5b77fb412ae776ebc2f3c3 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 24 Jul 2024 15:16:52 -0400 Subject: [PATCH 07/16] Add warning instead for tch quantize_dynamic --- Cargo.lock | 1 + crates/burn-tch/Cargo.toml | 1 + crates/burn-tch/src/lib.rs | 7 +++++-- crates/burn-tch/src/ops/qtensor.rs | 9 +++++++-- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7e7cad83d0..7197bf8265 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -704,6 +704,7 @@ dependencies = [ "burn-tensor", "half", "libc", + "log", "rand", "tch", ] diff --git a/crates/burn-tch/Cargo.toml b/crates/burn-tch/Cargo.toml index b96e165465..c4d0b550e0 100644 --- a/crates/burn-tch/Cargo.toml +++ b/crates/burn-tch/Cargo.toml @@ -21,6 +21,7 @@ half = { workspace = true, features = ["std"] } libc = { workspace = true } rand = { workspace = true, features = ["std"] } tch = { workspace = true, features = ["download-libtorch"] } +log = { workspace = true } [dev-dependencies] burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, features = [ diff --git a/crates/burn-tch/src/lib.rs b/crates/burn-tch/src/lib.rs index c858edd2bc..293e2743d4 100644 --- a/crates/burn-tch/src/lib.rs +++ b/crates/burn-tch/src/lib.rs @@ -22,7 +22,10 @@ mod tests { type TestTensor = burn_tensor::Tensor; type TestTensorInt = burn_tensor::Tensor; type TestTensorBool = burn_tensor::Tensor; + // type TestBackendQInt8 = crate::LibTorch; + // type TestTensorQInt8 = burn_tensor::Tensor; - burn_tensor::testgen_all!(); - burn_autodiff::testgen_all!(); + // burn_tensor::testgen_all!(); + // burn_autodiff::testgen_all!(); + burn_tensor::testgen_quantization!(); } diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index d8dba7498f..935451e59b 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -106,8 +106,13 @@ impl QTensorOps for LibTorch { .tensor .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false), }, - QuantizationScheme::PerTensorSymmetric(_) => { - panic!("LibTorch backend does not support symmetric quantize_dynamic") + QuantizationScheme::PerTensorSymmetric(dtype) => { + log::warn!("LibTorch backend does not support symmetric per-tensor scheme for dynamic quantization, reverting to the default per-tensor affine quantization"); + match dtype { + QuantizationType::QInt8 => tensor + .tensor + .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false), + } } }; From 16cedd9bdcd38f47e5fd12d55971dda9caf1a15e Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 24 Jul 2024 15:20:26 -0400 Subject: [PATCH 08/16] Call chunk backend implementation --- crates/burn-tensor/src/tensor/api/base.rs | 153 +++++++++++++++++++--- 1 file changed, 134 insertions(+), 19 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 9280291daf..d9d13c7b03 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -15,7 +15,6 @@ use serde::{Deserialize, Deserializer}; use serde::{Serialize, Serializer}; use crate::check::TensorCheck; -use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; use crate::{DType, Element, TensorPrimitive}; @@ -834,9 +833,9 @@ where /// A vector of tensors. pub fn chunk(self, chunks: usize, dim: usize) -> Vec { check!(TensorCheck::dim_ops::("chunk", dim)); - chunk::(self.primitive, chunks, dim) + K::chunk(self.primitive, chunks, dim) .into_iter() - .map(|v| Self::new(v)) + .map(Self::new) .collect() } @@ -1588,6 +1587,33 @@ pub trait BasicOps: TensorKind { /// which is more high-level and designed for public use. fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; + /// Attempts to split the tensor along the given dimension into chunks. + /// May return less chunks than requested if the tensor size is not divisible by the number of chunks. + /// + /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size. + /// Otherwise all chunks will be of equal size except for the last one. + /// + /// # Panics + /// + /// If the dimension is greater than the number of dimensions of the tensor. + /// + /// # Returns + /// A vector of tensors. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// To split a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function, + /// which is more high-level and designed for public use. + fn chunk( + tensor: Self::Primitive, + chunks: usize, + dim: usize, + ) -> Vec>; + /// Equates the given tensors. /// /// # Arguments @@ -1759,7 +1785,10 @@ impl BasicOps for Float { } fn transpose(tensor: Self::Primitive) -> Self::Primitive { - TensorPrimitive::Float(B::float_transpose(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)), + } } fn swap_dims( @@ -1768,14 +1797,26 @@ impl BasicOps for Float { dim2: usize, ) -> Self::Primitive { check!(TensorCheck::swap_dims::(dim1, dim2)); - TensorPrimitive::Float(B::float_swap_dims(tensor.tensor(), dim1, dim2)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2)) + } + } } fn slice( tensor: Self::Primitive, ranges: [Range; D2], ) -> Self::Primitive { - TensorPrimitive::Float(B::float_slice(tensor.tensor(), ranges)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_slice(tensor, ranges)) + } + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, ranges)), + } } fn slice_assign( @@ -1783,11 +1824,15 @@ impl BasicOps for Float { ranges: [Range; D2], value: Self::Primitive, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_slice_assign( - tensor.tensor(), - ranges, - value.tensor(), - )) + match (tensor, value) { + (TensorPrimitive::Float(tensor), TensorPrimitive::Float(value)) => { + TensorPrimitive::Float(B::float_slice_assign(tensor, ranges, value)) + } + (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(value)) => { + TensorPrimitive::QFloat(B::q_slice_assign(tensor, ranges, value)) + } + _ => panic!("Primitive type mismatch for tensor and value"), + } } fn device(tensor: &Self::Primitive) -> ::Device { @@ -1801,7 +1846,14 @@ impl BasicOps for Float { tensor: Self::Primitive, device: &::Device, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_to_device(tensor.tensor(), device)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_to_device(tensor, device)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_to_device(tensor, device)) + } + } } async fn into_data_async(tensor: Self::Primitive) -> TensorData { @@ -1823,14 +1875,36 @@ impl BasicOps for Float { dim: usize, times: usize, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_repeat_dim(tensor.tensor(), dim, times)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times)) + } + } } fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - TensorPrimitive::Float(B::float_cat( - vectors.into_iter().map(|tensor| tensor.tensor()).collect(), - dim, - )) + match vectors.get(0).unwrap() { + TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat( + vectors.into_iter().map(|tensor| tensor.tensor()).collect(), + dim, + )), + TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat( + vectors + .into_iter() + .map(|tensor| { + if let TensorPrimitive::QFloat(t) = tensor { + t + } else { + panic!("Concatenation only works with vector of QFloat") + } + }) + .collect(), + dim, + )), + } } fn equal( @@ -1864,7 +1938,12 @@ impl BasicOps for Float { } fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { - TensorPrimitive::Float(B::float_permute(tensor.tensor(), axes)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_permute(tensor, axes)) + } + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)), + } } fn expand( @@ -1875,7 +1954,27 @@ impl BasicOps for Float { } fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { - TensorPrimitive::Float(B::float_flip(tensor.tensor(), axes)) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)), + } + } + + fn chunk( + tensor: Self::Primitive, + chunks: usize, + dim: usize, + ) -> Vec> { + match tensor { + TensorPrimitive::Float(tensor) => B::float_chunk(tensor, chunks, dim) + .into_iter() + .map(TensorPrimitive::Float) + .collect(), + TensorPrimitive::QFloat(tensor) => B::q_chunk(tensor, chunks, dim) + .into_iter() + .map(TensorPrimitive::QFloat) + .collect(), + } } } @@ -1999,6 +2098,14 @@ impl BasicOps for Int { fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { B::int_flip(tensor, axes) } + + fn chunk( + tensor: Self::Primitive, + chunks: usize, + dim: usize, + ) -> Vec> { + B::int_chunk(tensor, chunks, dim) + } } impl BasicOps for Bool { @@ -2121,6 +2228,14 @@ impl BasicOps for Bool { fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { B::bool_flip(tensor, axes) } + + fn chunk( + tensor: Self::Primitive, + chunks: usize, + dim: usize, + ) -> Vec> { + B::bool_chunk(tensor, chunks, dim) + } } /// Trait used for movedim arguments From 928d6c6ecd59cdef72e4cc6905e610393ab21fd6 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 24 Jul 2024 15:21:45 -0400 Subject: [PATCH 09/16] Add QFloat check for q_ ops --- crates/burn-tensor/src/tensor/api/float.rs | 30 +- crates/burn-tensor/src/tensor/api/numeric.rs | 346 +++++++++++++++---- crates/burn-tensor/src/tensor/ops/qtensor.rs | 32 +- 3 files changed, 331 insertions(+), 77 deletions(-) diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index 2990bf5b38..7d6d4c20db 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -334,21 +334,21 @@ where ))) } - // /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. - // /// - // /// # Arguments - // /// - // /// * `scheme` - The quantization scheme. - // /// - // /// # Returns - // /// - // /// The quantized tensor. - // pub fn quantize_dynamic(self, scheme: QuantizationScheme) -> Tensor { - // Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic( - // self.primitive.tensor(), - // scheme, - // ))) - // } + /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. + /// + /// # Arguments + /// + /// * `scheme` - The quantization scheme. + /// + /// # Returns + /// + /// The quantized tensor. + pub fn quantize_dynamic(self, scheme: &QuantizationScheme) -> Tensor { + Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic( + self.primitive.tensor(), + scheme, + ))) + } /// Convert the tensor back to a higher precision data type. /// diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 71b4e4d7ba..b67293715b 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2460,58 +2460,128 @@ impl Numeric for Float { lhs: Self::Primitive, rhs: Self::Primitive, ) -> >::Primitive { - TensorPrimitive::Float(B::float_add(lhs.tensor(), rhs.tensor())) + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_add(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { + TensorPrimitive::QFloat(B::q_add(lhs, rhs)) + } + _ => panic!("Primitive type mismatch for lhs and rhs"), + } } fn add_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_add_scalar(lhs.tensor(), rhs.elem())) + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => { + TensorPrimitive::QFloat(B::q_add_scalar(lhs, rhs.elem())) + } + } } fn sub( lhs: Self::Primitive, rhs: Self::Primitive, ) -> >::Primitive { - TensorPrimitive::Float(B::float_sub(lhs.tensor(), rhs.tensor())) + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_sub(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { + TensorPrimitive::QFloat(B::q_sub(lhs, rhs)) + } + _ => panic!("Primitive type mismatch for lhs and rhs"), + } } fn sub_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_sub_scalar(lhs.tensor(), rhs.elem())) + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => { + TensorPrimitive::QFloat(B::q_sub_scalar(lhs, rhs.elem())) + } + } } fn div( lhs: Self::Primitive, rhs: Self::Primitive, ) -> >::Primitive { - TensorPrimitive::Float(B::float_div(lhs.tensor(), rhs.tensor())) + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_div(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { + TensorPrimitive::QFloat(B::q_div(lhs, rhs)) + } + _ => panic!("Primitive type mismatch for lhs and rhs"), + } } fn div_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_div_scalar(lhs.tensor(), rhs.elem())) + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => { + TensorPrimitive::QFloat(B::q_div_scalar(lhs, rhs.elem())) + } + } } fn remainder_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs.elem())) + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_remainder_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => { + TensorPrimitive::QFloat(B::q_remainder_scalar(lhs, rhs.elem())) + } + } } fn mul( lhs: Self::Primitive, rhs: Self::Primitive, ) -> >::Primitive { - TensorPrimitive::Float(B::float_mul(lhs.tensor(), rhs.tensor())) + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_mul(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { + TensorPrimitive::QFloat(B::q_mul(lhs, rhs)) + } + _ => panic!("Primitive type mismatch for lhs and rhs"), + } } fn mul_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_mul_scalar(lhs.tensor(), rhs.elem())) + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => { + TensorPrimitive::QFloat(B::q_mul_scalar(lhs, rhs.elem())) + } + } } fn neg(tensor: Self::Primitive) -> Self::Primitive { - TensorPrimitive::Float(B::float_neg(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_neg(tensor)), + } } fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { TensorPrimitive::Float(B::float_zeros(shape, device)) @@ -2529,27 +2599,49 @@ impl Numeric for Float { } fn sum(tensor: Self::Primitive) -> Self::Primitive<1> { - TensorPrimitive::Float(B::float_sum(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum(tensor)), + } } fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - TensorPrimitive::Float(B::float_sum_dim(tensor.tensor(), dim)) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum_dim(tensor, dim)), + } } fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { - TensorPrimitive::Float(B::float_prod(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod(tensor)), + } } fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - TensorPrimitive::Float(B::float_prod_dim(tensor.tensor(), dim)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_prod_dim(tensor, dim)) + } + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod_dim(tensor, dim)), + } } fn mean(tensor: Self::Primitive) -> Self::Primitive<1> { - TensorPrimitive::Float(B::float_mean(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean(tensor)), + } } fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - TensorPrimitive::Float(B::float_mean_dim(tensor.tensor(), dim)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_mean_dim(tensor, dim)) + } + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean_dim(tensor, dim)), + } } fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { @@ -2619,11 +2711,15 @@ impl Numeric for Float { mask: Tensor, source: Self::Primitive, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_mask_where( - tensor.tensor(), - mask.primitive, - source.tensor(), - )) + match (tensor, source) { + (TensorPrimitive::Float(tensor), TensorPrimitive::Float(source)) => { + TensorPrimitive::Float(B::float_mask_where(tensor, mask.primitive, source)) + } + (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(source)) => { + TensorPrimitive::QFloat(B::q_mask_where(tensor, mask.primitive, source)) + } + _ => panic!("Primitive type mismatch for tensor and source"), + } } fn mask_fill( @@ -2631,7 +2727,14 @@ impl Numeric for Float { mask: Tensor, value: Self::Elem, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask.primitive, value)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_mask_fill(tensor, mask.primitive, value)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_mask_fill(tensor, mask.primitive, value)) + } + } } fn select( @@ -2639,7 +2742,14 @@ impl Numeric for Float { dim: usize, indices: Tensor, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_select(tensor.tensor(), dim, indices.primitive)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_select(tensor, dim, indices.primitive)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_select(tensor, dim, indices.primitive)) + } + } } fn select_assign( @@ -2648,12 +2758,20 @@ impl Numeric for Float { indices: Tensor, values: Self::Primitive, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_select_assign( - tensor.tensor(), - dim, - indices.primitive, - values.tensor(), - )) + match (tensor, values) { + (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => { + TensorPrimitive::Float(B::float_select_assign( + tensor, + dim, + indices.primitive, + values, + )) + } + (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => { + TensorPrimitive::QFloat(B::q_select_assign(tensor, dim, indices.primitive, values)) + } + _ => panic!("Primitive type mismatch for tensor and values"), + } } fn gather( @@ -2661,7 +2779,14 @@ impl Numeric for Float { tensor: Self::Primitive, indices: Tensor, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_gather(dim, tensor.tensor(), indices.primitive)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_gather(dim, tensor, indices.primitive)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices.primitive)) + } + } } fn scatter( @@ -2670,58 +2795,95 @@ impl Numeric for Float { indices: Tensor, values: Self::Primitive, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_scatter( - dim, - tensor.tensor(), - indices.primitive, - values.tensor(), - )) + match (tensor, values) { + (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => { + TensorPrimitive::Float(B::float_scatter(dim, tensor, indices.primitive, values)) + } + (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => { + TensorPrimitive::QFloat(B::q_scatter(dim, tensor, indices.primitive, values)) + } + _ => panic!("Primitive type mismatch for tensor and values"), + } } fn argmax( tensor: Self::Primitive, dim: usize, ) -> ::IntTensorPrimitive { - B::float_argmax(tensor.tensor(), dim) + match tensor { + TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim), + TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim), + } } fn argmin( tensor: Self::Primitive, dim: usize, ) -> ::IntTensorPrimitive { - B::float_argmin(tensor.tensor(), dim) + match tensor { + TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim), + TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim), + } } fn max(tensor: Self::Primitive) -> Self::Primitive<1> { - TensorPrimitive::Float(B::float_max(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)), + } } fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - TensorPrimitive::Float(B::float_max_dim(tensor.tensor(), dim)) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)), + } } fn max_dim_with_indices( tensor: Self::Primitive, dim: usize, ) -> (Self::Primitive, ::IntTensorPrimitive) { - let (tensor, indices) = B::float_max_dim_with_indices(tensor.tensor(), dim); - (TensorPrimitive::Float(tensor), indices) + match tensor { + TensorPrimitive::Float(tensor) => { + let (values, indices) = B::float_max_dim_with_indices(tensor, dim); + (TensorPrimitive::Float(values), indices) + } + TensorPrimitive::QFloat(tensor) => { + let (values, indices) = B::q_max_dim_with_indices(tensor, dim); + (TensorPrimitive::QFloat(values), indices) + } + } } fn min(tensor: Self::Primitive) -> Self::Primitive<1> { - TensorPrimitive::Float(B::float_min(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)), + } } fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { - TensorPrimitive::Float(B::float_min_dim(tensor.tensor(), dim)) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)), + } } fn min_dim_with_indices( tensor: Self::Primitive, dim: usize, ) -> (Self::Primitive, ::IntTensorPrimitive) { - let (tensor, indices) = B::float_min_dim_with_indices(tensor.tensor(), dim); - (TensorPrimitive::Float(tensor), indices) + match tensor { + TensorPrimitive::Float(tensor) => { + let (values, indices) = B::float_min_dim_with_indices(tensor, dim); + (TensorPrimitive::Float(values), indices) + } + TensorPrimitive::QFloat(tensor) => { + let (values, indices) = B::q_min_dim_with_indices(tensor, dim); + (TensorPrimitive::QFloat(values), indices) + } + } } fn clamp( @@ -2729,53 +2891,103 @@ impl Numeric for Float { min: B::FloatElem, max: B::FloatElem, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_clamp(tensor.tensor(), min, max)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_clamp(tensor, min, max)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_clamp(tensor, min, max)) + } + } } fn clamp_min( tensor: Self::Primitive, min: B::FloatElem, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_clamp_min(tensor.tensor(), min)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_clamp_min(tensor, min)) + } + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_min(tensor, min)), + } } fn clamp_max( tensor: Self::Primitive, max: B::FloatElem, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_clamp_max(tensor.tensor(), max)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_clamp_max(tensor, max)) + } + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_max(tensor, max)), + } } fn abs(tensor: Self::Primitive) -> Self::Primitive { - TensorPrimitive::Float(B::float_abs(tensor.tensor())) + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)), + TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)), + } } fn powf( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_powf(lhs.tensor(), rhs.tensor())) + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_powf(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { + TensorPrimitive::QFloat(B::q_powf(lhs, rhs)) + } + _ => panic!("Primitive type mismatch for lhs and rhs"), + } } fn powf_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_powf_scalar(lhs.tensor(), rhs.elem())) + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => { + TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem())) + } + } } fn powi( lhs: Self::Primitive, rhs: Self::Primitive, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_powf(lhs.tensor(), rhs.tensor())) + match (lhs, rhs) { + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + TensorPrimitive::Float(B::float_powf(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { + TensorPrimitive::QFloat(B::q_powf(lhs, rhs)) + } + _ => panic!("Primitive type mismatch for lhs and rhs"), + } } fn powi_scalar( lhs: Self::Primitive, rhs: E, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_powf_scalar(lhs.tensor(), rhs.elem())) + match lhs { + TensorPrimitive::Float(lhs) => { + TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem())) + } + TensorPrimitive::QFloat(lhs) => { + TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem())) + } + } } fn random( @@ -2795,7 +3007,14 @@ impl Numeric for Float { dim: usize, descending: bool, ) -> Self::Primitive { - TensorPrimitive::Float(B::float_sort(tensor.tensor(), dim, descending)) + match tensor { + TensorPrimitive::Float(tensor) => { + TensorPrimitive::Float(B::float_sort(tensor, dim, descending)) + } + TensorPrimitive::QFloat(tensor) => { + TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending)) + } + } } fn sort_with_indices( @@ -2803,8 +3022,16 @@ impl Numeric for Float { dim: usize, descending: bool, ) -> (Self::Primitive, >::Primitive) { - let (tensor, indices) = B::float_sort_with_indices(tensor.tensor(), dim, descending); - (TensorPrimitive::Float(tensor), indices) + match tensor { + TensorPrimitive::Float(tensor) => { + let (values, indices) = B::float_sort_with_indices(tensor, dim, descending); + (TensorPrimitive::Float(values), indices) + } + TensorPrimitive::QFloat(tensor) => { + let (values, indices) = B::q_sort_with_indices(tensor, dim, descending); + (TensorPrimitive::QFloat(values), indices) + } + } } fn argsort( @@ -2812,7 +3039,10 @@ impl Numeric for Float { dim: usize, descending: bool, ) -> >::Primitive { - B::float_argsort(tensor.tensor(), dim, descending) + match tensor { + TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending), + TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending), + } } } diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 66fa7126ab..c94668738a 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -163,6 +163,29 @@ pub trait QTensorOps { false } + /// Repeat the tensor along the given dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `dim` - The dimension to repeat. + /// * `times` - The number of times to repeat the dimension. + /// + /// # Returns + /// + /// The tensor with the given dimension repeated. + fn q_repeat_dim( + tensor: QuantizedTensor, + dim: usize, + times: usize, + ) -> QuantizedTensor { + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_repeat_dim(tensor, dim, times), + tensor + ) + } + /// Adds two tensors together. /// /// # Arguments @@ -866,7 +889,7 @@ pub trait QTensorOps { ) } - /// Element-wise power with a FloatTensor. + /// Element-wise power with another tensor. /// /// # Arguments /// @@ -878,12 +901,13 @@ pub trait QTensorOps { /// The elements of `lhs` raised to the power of the elements of `rhs`. fn q_powf( lhs: QuantizedTensor, - rhs: FloatTensor, + rhs: QuantizedTensor, ) -> QuantizedTensor { dequant_op_quant!( ty Self, - float_op |tensor| B::float_powf(tensor, rhs), - lhs + float_op |lhs, rhs| B::float_powf(lhs, rhs), + lhs, + rhs ) } From 07ba2608d676995956fac82cb858b5866598a139 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 6 Aug 2024 15:10:32 -0400 Subject: [PATCH 10/16] Add tch q_min/max_dim_with_indices --- crates/burn-tch/src/ops/qtensor.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 935451e59b..2157bcd8fa 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -233,6 +233,18 @@ impl QTensorOps for LibTorch { ) } + fn q_max_dim_with_indices( + tensor: QuantizedTensor, + dim: usize, + ) -> (QuantizedTensor, IntTensor) { + let (qtensor, indices) = TchOps::max_dim_with_indices(tensor.qtensor, dim); + let values = TchQTensor { + qtensor, + scheme: tensor.scheme, + }; + (values, indices) + } + fn q_max_dim( tensor: QuantizedTensor, dim: usize, @@ -253,6 +265,18 @@ impl QTensorOps for LibTorch { } } + fn q_min_dim_with_indices( + tensor: QuantizedTensor, + dim: usize, + ) -> (QuantizedTensor, IntTensor) { + let (qtensor, indices) = TchOps::min_dim_with_indices(tensor.qtensor, dim); + let values = TchQTensor { + qtensor, + scheme: tensor.scheme, + }; + (values, indices) + } + fn q_narrow( tensor: QuantizedTensor, dim: usize, From e5c23301d0323e656662d8c7177844662abf7470 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 24 Jul 2024 15:22:12 -0400 Subject: [PATCH 11/16] Add q_ ops tests --- crates/burn-ndarray/src/lib.rs | 1 + crates/burn-tch/src/lib.rs | 6 +- crates/burn-tensor/src/tests/mod.rs | 51 +++ .../burn-tensor/src/tests/quantization/mod.rs | 1 + .../src/tests/quantization/ops/abs.rs | 25 ++ .../src/tests/quantization/ops/add.rs | 163 +++++++++ .../src/tests/quantization/ops/aggregation.rs | 233 +++++++++++++ .../src/tests/quantization/ops/all.rs | 45 +++ .../src/tests/quantization/ops/any.rs | 48 +++ .../src/tests/quantization/ops/arg.rs | 74 ++++ .../src/tests/quantization/ops/cat.rs | 137 ++++++++ .../src/tests/quantization/ops/chunk.rs | 133 ++++++++ .../src/tests/quantization/ops/clamp.rs | 63 ++++ .../src/tests/quantization/ops/cos.rs | 26 ++ .../src/tests/quantization/ops/div.rs | 82 +++++ .../src/tests/quantization/ops/erf.rs | 49 +++ .../src/tests/quantization/ops/exp.rs | 27 ++ .../src/tests/quantization/ops/expand.rs | 153 +++++++++ .../src/tests/quantization/ops/flip.rs | 60 ++++ .../tests/quantization/ops/gather_scatter.rs | 293 ++++++++++++++++ .../src/tests/quantization/ops/log.rs | 29 ++ .../src/tests/quantization/ops/log1p.rs | 29 ++ .../tests/quantization/ops/map_comparison.rs | 271 +++++++++++++++ .../src/tests/quantization/ops/mask.rs | 63 ++++ .../src/tests/quantization/ops/matmul.rs | 314 +++++++++++++++++ .../src/tests/quantization/ops/maxmin.rs | 211 ++++++++++++ .../src/tests/quantization/ops/mod.rs | 43 +++ .../src/tests/quantization/ops/mul.rs | 105 ++++++ .../src/tests/quantization/ops/narrow.rs | 90 +++++ .../src/tests/quantization/ops/neg.rs | 31 ++ .../src/tests/quantization/ops/permute.rs | 82 +++++ .../src/tests/quantization/ops/powf.rs | 120 +++++++ .../src/tests/quantization/ops/powf_scalar.rs | 89 +++++ .../src/tests/quantization/ops/quantize.rs | 100 ++++++ .../src/tests/quantization/ops/recip.rs | 26 ++ .../src/tests/quantization/ops/remainder.rs | 162 +++++++++ .../src/tests/quantization/ops/repeat_dim.rs | 50 +++ .../src/tests/quantization/ops/reshape.rs | 111 ++++++ .../src/tests/quantization/ops/select.rs | 174 ++++++++++ .../src/tests/quantization/ops/sin.rs | 25 ++ .../src/tests/quantization/ops/slice.rs | 323 ++++++++++++++++++ .../tests/quantization/ops/sort_argsort.rs | 267 +++++++++++++++ .../src/tests/quantization/ops/sqrt.rs | 26 ++ .../src/tests/quantization/ops/stack.rs | 138 ++++++++ .../src/tests/quantization/ops/sub.rs | 81 +++++ .../src/tests/quantization/ops/tanh.rs | 25 ++ .../src/tests/quantization/ops/topk.rs | 91 +++++ .../src/tests/quantization/ops/transpose.rs | 53 +++ 48 files changed, 4795 insertions(+), 4 deletions(-) create mode 100644 crates/burn-tensor/src/tests/quantization/ops/abs.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/add.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/aggregation.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/all.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/any.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/arg.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/cat.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/chunk.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/clamp.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/cos.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/div.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/erf.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/exp.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/expand.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/flip.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/gather_scatter.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/log.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/log1p.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/map_comparison.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/mask.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/matmul.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/maxmin.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/mod.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/mul.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/narrow.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/neg.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/permute.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/powf.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/quantize.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/recip.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/remainder.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/repeat_dim.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/reshape.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/select.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/sin.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/slice.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/sort_argsort.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/sqrt.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/stack.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/sub.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/tanh.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/topk.rs create mode 100644 crates/burn-tensor/src/tests/quantization/ops/transpose.rs diff --git a/crates/burn-ndarray/src/lib.rs b/crates/burn-ndarray/src/lib.rs index bafc4ba502..3784caeb5e 100644 --- a/crates/burn-ndarray/src/lib.rs +++ b/crates/burn-ndarray/src/lib.rs @@ -39,6 +39,7 @@ mod tests { use alloc::vec; burn_tensor::testgen_all!(); + burn_tensor::testgen_quantization!(); #[cfg(feature = "std")] burn_autodiff::testgen_all!(); diff --git a/crates/burn-tch/src/lib.rs b/crates/burn-tch/src/lib.rs index 293e2743d4..c15e98b5c3 100644 --- a/crates/burn-tch/src/lib.rs +++ b/crates/burn-tch/src/lib.rs @@ -22,10 +22,8 @@ mod tests { type TestTensor = burn_tensor::Tensor; type TestTensorInt = burn_tensor::Tensor; type TestTensorBool = burn_tensor::Tensor; - // type TestBackendQInt8 = crate::LibTorch; - // type TestTensorQInt8 = burn_tensor::Tensor; - // burn_tensor::testgen_all!(); - // burn_autodiff::testgen_all!(); + burn_tensor::testgen_all!(); + burn_autodiff::testgen_all!(); burn_tensor::testgen_quantization!(); } diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 9512eae838..86966b3c04 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -116,9 +116,60 @@ macro_rules! testgen_all { // test padding burn_tensor::testgen_padding!(); + }; +} +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_quantization { + () => { // test quantization burn_tensor::testgen_calibration!(); burn_tensor::testgen_scheme!(); + burn_tensor::testgen_quantize!(); + + // test ops + burn_tensor::testgen_q_abs!(); + burn_tensor::testgen_q_add!(); + burn_tensor::testgen_q_aggregation!(); + burn_tensor::testgen_q_all!(); + burn_tensor::testgen_q_any!(); + burn_tensor::testgen_q_arg!(); + burn_tensor::testgen_q_cat!(); + burn_tensor::testgen_q_chunk!(); + burn_tensor::testgen_q_clamp!(); + burn_tensor::testgen_q_cos!(); + burn_tensor::testgen_q_div!(); + burn_tensor::testgen_q_erf!(); + burn_tensor::testgen_q_exp!(); + burn_tensor::testgen_q_expand!(); + burn_tensor::testgen_q_flip!(); + burn_tensor::testgen_q_gather_scatter!(); + burn_tensor::testgen_q_log!(); + burn_tensor::testgen_q_log1p!(); + burn_tensor::testgen_q_map_comparison!(); + burn_tensor::testgen_q_mask!(); + burn_tensor::testgen_q_matmul!(); + burn_tensor::testgen_q_maxmin!(); + burn_tensor::testgen_q_mul!(); + burn_tensor::testgen_q_narrow!(); + burn_tensor::testgen_q_neg!(); + burn_tensor::testgen_q_permute!(); + burn_tensor::testgen_q_powf_scalar!(); + burn_tensor::testgen_q_powf!(); + burn_tensor::testgen_q_recip!(); + burn_tensor::testgen_q_remainder!(); + burn_tensor::testgen_q_repeat_dim!(); + burn_tensor::testgen_q_reshape!(); + burn_tensor::testgen_q_select!(); + burn_tensor::testgen_q_sin!(); + burn_tensor::testgen_q_slice!(); + burn_tensor::testgen_q_sort_argsort!(); + burn_tensor::testgen_q_sqrt!(); + burn_tensor::testgen_q_stack!(); + burn_tensor::testgen_q_sub!(); + burn_tensor::testgen_q_tanh!(); + burn_tensor::testgen_q_topk!(); + burn_tensor::testgen_q_transpose!(); }; } diff --git a/crates/burn-tensor/src/tests/quantization/mod.rs b/crates/burn-tensor/src/tests/quantization/mod.rs index 36539db297..bc9bec8673 100644 --- a/crates/burn-tensor/src/tests/quantization/mod.rs +++ b/crates/burn-tensor/src/tests/quantization/mod.rs @@ -1,2 +1,3 @@ mod calibration; +mod ops; mod scheme; diff --git a/crates/burn-tensor/src/tests/quantization/ops/abs.rs b/crates/burn-tensor/src/tests/quantization/ops/abs.rs new file mode 100644 index 0000000000..e02823f130 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/abs.rs @@ -0,0 +1,25 @@ +#[burn_tensor_testgen::testgen(q_abs)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_abs_ops() { + // Quantized [[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]] + let data = TensorData::quantized( + vec![0i8, -25, 51, 76, 102, -127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.abs(); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/add.rs b/crates/burn-tensor/src/tests/quantization/ops/add.rs new file mode 100644 index 0000000000..188ad8b005 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/add.rs @@ -0,0 +1,163 @@ +#[burn_tensor_testgen::testgen(q_add)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn test_add_d2() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + // Quantized [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]] + let data = TensorData::quantized( + vec![69i8, 81, 92, 104, 115, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.08661418)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor_1 + tensor_2; + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]), 1); + } + + #[test] + fn test_add_broadcast() { + // Quantized [[0.0, 1.0, 2.0]] + let data = TensorData::quantized( + vec![0i8, 64, 127], + [1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.015748031)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + // Quantized [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]] + let data = TensorData::quantized( + vec![48i8, 64, 79, 95, 111, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.062992126)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor_1 + tensor_2; + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]), 1); + } + + // TODO: tests + #[test] + fn test_add_different_strides_rhs() { + // Quantized [[0.0, 1.0], [2.0, 3.0]] + let data = TensorData::quantized( + vec![0i8, 42, 85, 127], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + // We need to execute an operation after `from data` to trigger inplace in some backends. + // Which is the operation that might be problematic in this case. + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()) * 1; + // Quantized [[4.0, 5.0], [6.0, 7.0]] + let data = TensorData::quantized( + vec![73i8, 91, 109, 127], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()) * 1; + + let output = tensor_1 + tensor_2.transpose(); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), 1); + } + + #[test] + fn test_add_different_strides_lhs() { + // Quantized [[0.0, 1.0], [2.0, 3.0]] + let data = TensorData::quantized( + vec![0i8, 42, 85, 127], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + // We need to execute an operation after `from data` to trigger inplace in some backends. + // Which is the operation that might be problematic in this case. + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()) * 1; + // Quantized [[4.0, 5.0], [6.0, 7.0]] + let data = TensorData::quantized( + vec![73i8, 91, 109, 127], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()) * 1; + + let output = tensor_1.transpose() + tensor_2; + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), 1); + } + + #[test] + fn test_add_different_strides_broadcast() { + // Quantized [[0.0, 1.0], [2.0, 3.0]] + let data = TensorData::quantized( + vec![0i8, 42, 85, 127], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + // We need to execute an operation after `from data` to trigger inplace in some backends. + // Which is the operation that might be problematic in this case. + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()) * 1; + // Quantized [[4.0, 5.0]] + let data = TensorData::quantized( + vec![102i8, 127], + [1, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()) * 1; + + let output = tensor_1.transpose() + tensor_2; + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[4.0, 7.0], [5.0, 8.0]]), 1); + } + + #[test] + fn should_support_add_scalar_ops() { + let scalar = 2.0; + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor + scalar; + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]), 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs new file mode 100644 index 0000000000..448d0a78a4 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs @@ -0,0 +1,233 @@ +#[burn_tensor_testgen::testgen(q_aggregation)] +mod tests { + use super::*; + use burn_tensor::quantization::{ + AffineQuantization, QuantizationStrategy, SymmetricQuantization, + }; + use burn_tensor::{Shape, Tensor, TensorData}; + + #[test] + fn test_should_mean() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.mean(); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([15.0 / 6.0]), 1); + } + + #[test] + fn test_should_sum() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.sum(); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([15.0]), 1); + } + + #[test] + fn test_should_mean_last_dim() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.mean_dim(1); + let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn test_should_sum_last_dim() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.sum_dim(1); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[3.0], [12.0]]), 1); + } + + #[test] + fn test_should_sum_first_dim() { + // Quantized [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![95i8, 32, 64, 127, 64, 95], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.sum_dim(0); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[7.0, 3.0, 5.0]]), 1); + } + + #[test] + fn test_should_mean_first_dim() { + // Quantized [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![95i8, 32, 64, 127, 64, 95], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.mean_dim(0); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]]), 1); + } + + #[test] + fn test_should_sum_mid_dim_3d_non_contiguous_1() { + // Quantized [ + // [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], + // [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], + // ] + let data = TensorData::quantized( + vec![36i8, 73, 18, 127, -91, 54, 54, 18, 36, 73, 36, 54], + [2, 2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let output = tensor.swap_dims(0, 2).sum_dim(1); + + // Precision 1 to approximate de/quantization errors + output.dequantize().into_data().assert_approx_eq( + &TensorData::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], [3, 1, 2]), + 1, + ); + } + + #[test] + fn test_should_sum_mid_dim_3d_non_contiguous_2() { + // Quantized [ + // [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], + // [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], + // ] + let data = TensorData::quantized( + vec![36i8, 73, 18, 127, -91, 54, 54, 18, 36, 73, 36, 54], + [2, 2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let output = tensor.swap_dims(0, 1).sum_dim(1); + + // Precision 1 to approximate de/quantization errors + output.dequantize().into_data().assert_approx_eq( + &TensorData::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], [2, 1, 3]), + 1, + ); + } + + #[test] + fn test_prod_float() { + // Quantized [[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + // NOTE: we use affine quantization to reduce quantization errors since `prod()` amplifies the error + let data = TensorData::quantized( + vec![-26i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let output = tensor.prod(); + + output + .dequantize() + .into_data() + .assert_eq(&TensorData::from([240.0]), false); + + // Quantized [[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![51i8, 0, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_with_zero = TestTensor::<2>::from_data(data, &Default::default()); + let output = tensor_with_zero.prod(); + + output + .dequantize() + .into_data() + .assert_eq(&TensorData::from([0.0]), false); + } + + #[test] + fn test_prod_dim_float() { + // Quantized [[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![51i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let output = tensor.prod_dim(1); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[4.0], [60.0]]), 1); + + // Quantized [[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![51i8, 0, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_with_zero = TestTensor::<2>::from_data(data, &Default::default()); + let output = tensor_with_zero.prod_dim(1); + let expected = TensorData::from([[0.0], [60.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/all.rs b/crates/burn-tensor/src/tests/quantization/ops/all.rs new file mode 100644 index 0000000000..0789b28207 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/all.rs @@ -0,0 +1,45 @@ +#[burn_tensor_testgen::testgen(q_all)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn test_all() { + // Quantized [[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]] + let data = TensorData::quantized( + vec![0i8, 127, 0, 127, -127, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let data_actual = tensor.all().into_data(); + let data_expected = TensorData::from([false]); + assert_eq!(data_expected, data_actual); + + // Quantized [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] + let data = TensorData::quantized( + vec![127i8, 127, 127, 127, 127, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let data_actual = tensor.all().into_data(); + let data_expected = TensorData::from([true]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_all_dim() { + // Quantized [[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]] + let data = TensorData::quantized( + vec![0i8, 127, 0, 127, -127, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let data_actual = tensor.all_dim(1).into_data(); + let data_expected = TensorData::from([[false], [true]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/any.rs b/crates/burn-tensor/src/tests/quantization/ops/any.rs new file mode 100644 index 0000000000..3146488ed0 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/any.rs @@ -0,0 +1,48 @@ +#[burn_tensor_testgen::testgen(q_any)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn test_any() { + // Quantized [[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]] + let data = TensorData::quantized( + vec![0i8, 0, 0, 127, -127, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual = tensor.any().into_data(); + let data_expected = TensorData::from([true]); + assert_eq!(data_expected, data_actual); + + // Quantized [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + let data = TensorData::quantized( + vec![0i8, 0, 0, 0, 0, 0], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual = tensor.any().into_data(); + let data_expected = TensorData::from([false]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_any_dim() { + // Quantized [[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]] + let data = TensorData::quantized( + vec![0i8, 0, 0, 127, -127, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual = tensor.any_dim(1).into_data(); + let data_expected = TensorData::from([[false], [true]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/arg.rs b/crates/burn-tensor/src/tests/quantization/ops/arg.rs new file mode 100644 index 0000000000..925b91ab2d --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/arg.rs @@ -0,0 +1,74 @@ +#[burn_tensor_testgen::testgen(q_arg)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn test_argmax_2d_dim0() { + // Quantized [[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![115i8, 127, 23, 35, 46, 58], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.08661418)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.argmax(0); + + output + .into_data() + .assert_eq(&TensorData::from([[0, 0, 1]]), false); + } + + #[test] + fn test_argmin_2d_dim0() { + // Quantized [[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![42i8, 47, 8, 127, 17, 21], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.23622048)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.argmin(0); + + output + .into_data() + .assert_eq(&TensorData::from([[0, 1, 0]]), false); + } + + #[test] + fn test_argmax_2d_dim1() { + // Quantized [[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![115i8, 127, 23, 35, 46, 58], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.08661418)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.argmax(1); + + output + .into_data() + .assert_eq(&TensorData::from([[1], [2]]), false); + } + + #[test] + fn test_argmin_2d_dim1() { + // Quantized [[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![42i8, 47, 8, 127, 17, 21], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.23622048)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.argmin(1); + + output + .into_data() + .assert_eq(&TensorData::from([[2], [1]]), false); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/cat.rs b/crates/burn-tensor/src/tests/quantization/ops/cat.rs new file mode 100644 index 0000000000..3a7c04cb2a --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/cat.rs @@ -0,0 +1,137 @@ +#[burn_tensor_testgen::testgen(q_cat)] +mod tests { + use super::*; + use alloc::vec::Vec; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Bool, Int, Tensor, TensorData}; + + #[test] + fn should_support_cat_ops_2d_dim0() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![85i8, 106, 127], + [1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); + let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_cat_ops_2d_dim1() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![85i8, 106, 127], + [1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = TestTensor::cat(vec![tensor_1, tensor_2], 1); + let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_cat_ops_3d() { + let device = Default::default(); + // Quantized [[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]] + let data = TensorData::quantized( + vec![41i8, 82, 123, 45, 86, 127], + [2, 1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.024409449)), + ); + let tensor_1 = TestTensor::<3>::from_data(data, &device); + // Quantized [[4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![85i8, 106, 127], + [1, 1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor_2 = TestTensor::<3>::from_data(data, &device); + + let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); + let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + #[should_panic] + fn should_panic_when_dimensions_are_not_the_same() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![42i8, 85, 127, 42, 85, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0]] + let data = TensorData::quantized( + vec![102i8, 127], + [2, 1], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], &device); + let tensor_2 = TestTensor::from_data([[4.0, 5.0]], &device); + + let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); + } + + #[test] + #[should_panic] + fn should_panic_when_cat_exceeds_dimension() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![85i8, 106, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = TestTensor::cat(vec![tensor_1, tensor_2], 3); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/chunk.rs b/crates/burn-tensor/src/tests/quantization/ops/chunk.rs new file mode 100644 index 0000000000..71bec0395d --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/chunk.rs @@ -0,0 +1,133 @@ +#[burn_tensor_testgen::testgen(q_chunk)] +mod tests { + use super::*; + use alloc::vec::Vec; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + fn test_chunk_evenly_divisible() { + // Quantized [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [6], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let tensors: Vec> = tensor.chunk(3, 0); + assert_eq!(tensors.len(), 3); + + let expected = vec![ + TensorData::from([0., 1.]), + TensorData::from([2., 3.]), + TensorData::from([4., 5.]), + ]; + + // Precision 1 to approximate de/quantization errors + for (index, tensor) in tensors.into_iter().enumerate() { + tensor + .dequantize() + .to_data() + .assert_approx_eq(&expected[index], 1); + } + } + + #[test] + fn test_chunk_not_evenly_divisible() { + // Quantized [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + let data = TensorData::quantized( + vec![0i8, 21, 42, 64, 85, 106, 127], + [7], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let tensors: Vec> = tensor.chunk(4, 0); + assert_eq!(tensors.len(), 4); + + let expected = vec![ + TensorData::from([0., 1.]), + TensorData::from([2., 3.]), + TensorData::from([4., 5.]), + TensorData::from([6.]), + ]; + + // Precision 1 to approximate de/quantization errors + for (index, tensor) in tensors.into_iter().enumerate() { + tensor + .dequantize() + .to_data() + .assert_approx_eq(&expected[index], 1); + } + } + + #[test] + fn test_chunk_not_divisible() { + // Quantized [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [6], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let tensors: Vec> = tensor.chunk(7, 0); + assert_eq!(tensors.len(), 6); + + let expected = vec![ + TensorData::from([0.]), + TensorData::from([1.]), + TensorData::from([2.]), + TensorData::from([3.]), + TensorData::from([4.]), + TensorData::from([5.]), + ]; + + // Precision 1 to approximate de/quantization errors + for (index, tensor) in tensors.into_iter().enumerate() { + tensor + .dequantize() + .to_data() + .assert_approx_eq(&expected[index], 1); + } + } + + #[test] + fn test_chunk_multi_dimension() { + // Quantized [[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [1, 6], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let tensors: Vec> = tensor.chunk(2, 1); + assert_eq!(tensors.len(), 2); + + let expected = vec![ + TensorData::from([[0., 1., 2.]]), + TensorData::from([[3., 4., 5.]]), + ]; + + // Precision 1 to approximate de/quantization errors + for (index, tensor) in tensors.into_iter().enumerate() { + tensor + .dequantize() + .to_data() + .assert_approx_eq(&expected[index], 1); + } + } + + #[test] + #[should_panic] + fn test_invalid_dim() { + // Quantized [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [6], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensors = TestTensor::<1>::from_data(data, &Default::default()).chunk(6, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/clamp.rs b/crates/burn-tensor/src/tests/quantization/ops/clamp.rs new file mode 100644 index 0000000000..d239eb6c2e --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/clamp.rs @@ -0,0 +1,63 @@ +#[burn_tensor_testgen::testgen(q_clamp)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn clamp_min() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.clamp_min(2.0); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]), 1); + } + + #[test] + fn clamp_max() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.clamp_max(2.0); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]), 1); + } + + #[test] + fn clamp_min_max() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.clamp(1.0, 4.0); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]), 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/cos.rs b/crates/burn-tensor/src/tests/quantization/ops/cos.rs new file mode 100644 index 0000000000..2af539a41f --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/cos.rs @@ -0,0 +1,26 @@ +#[burn_tensor_testgen::testgen(q_cos)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_cos_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.cos(); + let expected = TensorData::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/div.rs b/crates/burn-tensor/src/tests/quantization/ops/div.rs new file mode 100644 index 0000000000..b485673557 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/div.rs @@ -0,0 +1,82 @@ +#[burn_tensor_testgen::testgen(q_div)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_div_ops() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![25i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = tensor_1 / tensor_2; + let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn test_div_broadcast() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0]] + let data = TensorData::quantized( + vec![0i8, 64, 127], + [1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.015748031)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![25i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = tensor_1 / tensor_2; + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]), 1); + } + + #[test] + fn should_support_div_scalar_ops() { + let scalar = 2.0; + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + + let output = tensor / scalar; + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]), 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/erf.rs b/crates/burn-tensor/src/tests/quantization/ops/erf.rs new file mode 100644 index 0000000000..03e20197ff --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/erf.rs @@ -0,0 +1,49 @@ +#[burn_tensor_testgen::testgen(q_erf)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_erf_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.erf(); + let expected = TensorData::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_erf_ops_with_negative_number() { + // Quantized [[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-1i8, -1, -2, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.erf(); + let expected = TensorData::from([ + [-0.06312324, -0.048490416, -0.10016122], + [1.0000, 1.0000, 1.0000], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/exp.rs b/crates/burn-tensor/src/tests/quantization/ops/exp.rs new file mode 100644 index 0000000000..54aea097c3 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/exp.rs @@ -0,0 +1,27 @@ +#[burn_tensor_testgen::testgen(q_exp)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_exp_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + // NOTE: we use affine quantization to reduce quantization errors since `exp()` amplifies the error + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.exp(); + let expected = TensorData::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/expand.rs b/crates/burn-tensor/src/tests/quantization/ops/expand.rs new file mode 100644 index 0000000000..fb1c530a47 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/expand.rs @@ -0,0 +1,153 @@ +#[burn_tensor_testgen::testgen(q_expand)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Shape, Tensor, TensorData}; + + #[test] + fn expand_2d() { + // Quantized [1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let output = tensor.expand([3, 3]); + + // Precision 1 to approximate de/quantization errors + output.dequantize().into_data().assert_approx_eq( + &TensorData::from([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), + 1, + ); + + // Quantized [4.0, 7.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![73i8, 127, 36, 54], + [4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let output = tensor.expand([2, 4]); + + // Precision 1 to approximate de/quantization errors + output.dequantize().into_data().assert_approx_eq( + &TensorData::from([[4.0, 7.0, 2.0, 3.0], [4.0, 7.0, 2.0, 3.0]]), + 1, + ); + } + + #[test] + fn expand_3d() { + // Quantized [[1.0, 2.0], [3.0, 4.0]] + let data = TensorData::quantized( + vec![32i8, 64, 95, 127], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let output = tensor.expand([3, 2, 2]); + let expected = TensorData::from([ + [[1.0, 2.0], [3.0, 4.0]], + [[1.0, 2.0], [3.0, 4.0]], + [[1.0, 2.0], [3.0, 4.0]], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn expand_higher_dimensions() { + // Quantized [[1.0, 2.0, 3.0, 4.0]] + let data = TensorData::quantized( + vec![32i8, 64, 95, 127], + [1, 4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let output = tensor.expand([2, 3, 4]); + let expected = TensorData::from([ + [ + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + ], + [ + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + ], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn broadcast_single() { + // Quantized [1.0] + let data = TensorData::quantized( + vec![127i8], + [1], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let output = tensor.expand([2, 3]); + + output + .dequantize() + .into_data() + .assert_eq(&TensorData::from([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), false); + } + + #[test] + #[should_panic] + fn should_fail_expand_incompatible_shapes() { + // Quantized [1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let _expanded_tensor = tensor.expand([2, 2]); + } + + #[test] + fn should_all_negative_one() { + // Quantized [1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let output = tensor.expand([2, -1]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[1., 2., 3.], [1., 2., 3.]]), 1); + } + + #[test] + #[should_panic] + fn should_panic_negative_one_on_non_existing_dim() { + // Quantized [1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let _expanded_tensor = tensor.expand([-1, 3]); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/flip.rs b/crates/burn-tensor/src/tests/quantization/ops/flip.rs new file mode 100644 index 0000000000..de22ef124a --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/flip.rs @@ -0,0 +1,60 @@ +#[burn_tensor_testgen::testgen(q_flip)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn flip_float() { + // Quantized [[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let flipped = tensor.clone().flip([0, 2]); + let expected = TensorData::from([[[5., 4., 3.]], [[2., 1., 0.]]]); + + // Precision 1 to approximate de/quantization errors + flipped + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + + // Test with no flip + let flipped = tensor.clone().flip([]); + tensor.into_data().assert_eq(&flipped.into_data(), true); + } + + #[test] + #[should_panic] + fn flip_duplicated_axes() { + // Quantized [[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + // Test with a duplicated axis + let _ = tensor.flip([0, 0, 1]); + } + + #[test] + #[should_panic] + fn flip_out_of_bound_axis() { + // Quantized [[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 1, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + // Test with an out of bound axis + let _ = tensor.clone().flip([3, 0, 1]); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/gather_scatter.rs b/crates/burn-tensor/src/tests/quantization/ops/gather_scatter.rs new file mode 100644 index 0000000000..58cee7d4fa --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/gather_scatter.rs @@ -0,0 +1,293 @@ +#[burn_tensor_testgen::testgen(q_gather_scatter)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_gather_1d_dim0() { + let device = Default::default(); + // Quantized [0.0, 1.0, 2.0] + let data = TensorData::quantized( + vec![0i8, 64, 127], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.015748031)), + ); + let tensor = TestTensor::<1>::from_data(data, &device); + let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device); + + let output = tensor.gather(0, indices); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]), 1); + } + + #[test] + fn should_gather_2d_dim0() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]], &device); + + let output = tensor.gather(0, indices); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]), 1); + } + + #[test] + fn should_gather_2d_dim1() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]], &device); + + let output = tensor.gather(1, indices); + + // Precision 1 to approximate de/quantization errors + output.dequantize().into_data().assert_approx_eq( + &TensorData::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]), + 1, + ); + } + + #[test] + fn should_gather_3d_dim1() { + let device = Default::default(); + // Quantized [ + // [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + // [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + // ] + let data = TensorData::quantized( + vec![0i8, 12, 23, 35, 46, 58, 69, 81, 92, 104, 115, 127], + [2, 2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.08661418)), + ); + let tensor = TestTensor::<3>::from_data(data, &device); + let indices = + TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &device); + + let output = tensor.gather(1, indices); + let expected = TensorData::from([ + [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], + [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_gather_2d_only_1dim() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::<2>::from_ints([[1, 2]], &device).reshape([2, 1]); + + let output = tensor.gather(1, indices); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[1.0], [5.0]]), 1); + } + + #[test] + fn should_scatter_1d() { + let device = Default::default(); + // Quantized [0.0, 0.0, 0.0] + let data = TensorData::quantized( + vec![0i8, 0, 0], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)), + ); + let tensor = TestTensor::<1>::from_data(data, &device); + // Quantized [5.0, 4.0, 3.0] + let data = TensorData::quantized( + vec![127i8, 102, 76], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let values = TestTensor::<1>::from_data(data, &device); + let indices = TestTensorInt::from_ints([1, 0, 2], &device); + + let output = tensor.scatter(0, indices, values); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([4.0, 5.0, 3.0]), 1); + } + + #[test] + fn should_scatter_2d_dim0() { + let device = Default::default(); + // Quantized [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + let data = TensorData::quantized( + vec![0i8, 0, 0, 0, 0, 0], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + // Quantized [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![21i8, 42, 64, 85, 106, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let values = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]], &device); + + let output = tensor.scatter(0, indices, values); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]), 1); + } + + #[test] + fn should_scatter_2d_dim1() { + let device = Default::default(); + // Quantized [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + let data = TensorData::quantized( + vec![0i8, 0, 0, 0, 0, 0], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + // Quantized [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![21i8, 42, 64, 85, 106, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let values = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]], &device); + + let output = tensor.scatter(1, indices, values); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]), 1); + } + + #[test] + fn should_scatter_3d_dim1() { + let device = Default::default(); + // Quantized [ + // [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + // [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + // ] + let data = TensorData::quantized( + vec![0i8, 12, 23, 35, 46, 58, 69, 81, 92, 104, 115, 127], + [2, 2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.08661418)), + ); + let tensor = TestTensor::<3>::from_data(data, &device); + // Quantized [ + // [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], + // [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + // ] + let data = TensorData::quantized( + vec![66i8, 72, 77, 83, 88, 94, 99, 105, 110, 116, 121, 127], + [2, 2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.18110237)), + ); + let values = TestTensor::<3>::from_data(data, &device); + let indices = + TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &device); + + let output = tensor.scatter(1, indices, values); + let expected = TensorData::from([ + [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], + [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]], + ]); + + // Set higher tolerance (0.2) due to larger de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.2); + } + + #[test] + fn should_scatter_2d_dim1_diff_shape() { + let device = Default::default(); + // Quantized [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + let data = TensorData::quantized( + vec![0i8, 0, 0, 0, 0, 0], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + // Quantized [[1.0], [4.0]] + let data = TensorData::quantized( + vec![32i8, 127], + [2, 1], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let values = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_ints([[1], [2]], &device); + + let output = tensor.scatter(1, indices, values); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&TensorData::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]), 1); + } + + #[test] + #[should_panic] + fn scatter_should_panic_on_mismatch_of_shapes() { + let device = Default::default(); + // Quantized [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + let data = TensorData::quantized( + vec![0i8, 0, 0, 0, 0, 0], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + // Quantized [1.0, 4.0] + let data = TensorData::quantized( + vec![32i8, 127], + [2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let values = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_ints([1, 0, 2], &device); + + tensor.scatter(0, indices, values); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/log.rs b/crates/burn-tensor/src/tests/quantization/ops/log.rs new file mode 100644 index 0000000000..f9b9b37a2c --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/log.rs @@ -0,0 +1,29 @@ +#[burn_tensor_testgen::testgen(q_log)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_log_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.log(); + let expected = TensorData::from([ + [-f32::INFINITY, 0.0, core::f32::consts::LN_2], + [1.0986, 1.3862, 1.6094], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/log1p.rs b/crates/burn-tensor/src/tests/quantization/ops/log1p.rs new file mode 100644 index 0000000000..83682f4a71 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/log1p.rs @@ -0,0 +1,29 @@ +#[burn_tensor_testgen::testgen(q_log1p)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_exp_log1p() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.log1p(); + let expected = TensorData::from([ + [0.0, core::f32::consts::LN_2, 1.0986], + [1.3862, 1.6094, 1.7917], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/map_comparison.rs b/crates/burn-tensor/src/tests/quantization/ops/map_comparison.rs new file mode 100644 index 0000000000..8ffd83aaf9 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/map_comparison.rs @@ -0,0 +1,271 @@ +#[burn_tensor_testgen::testgen(q_map_comparison)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + // NOTE: we use affine quantization to reduce quantization errors since equality tests are precise + #[test] + fn test_equal() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -77, 25, 127, 76], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.equal(tensor_2); + + let data_expected = TensorData::from([[true, true, false], [true, false, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_not_equal() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -77, 25, 127, 76], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.not_equal(tensor_2); + + let data_expected = TensorData::from([[false, false, true], [false, true, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_equal_elem() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, -26, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual_cloned = tensor_1.clone().equal_elem(2); + let data_actual_inplace = tensor_1.equal_elem(2); + + let data_expected = TensorData::from([[false, false, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_not_equal_elem() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, -26, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual_cloned = tensor_1.clone().not_equal_elem(2); + let data_actual_inplace = tensor_1.not_equal_elem(2); + + let data_expected = TensorData::from([[true, true, false], [true, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_greater_elem() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual_cloned = tensor_1.clone().greater_elem(4); + let data_actual_inplace = tensor_1.greater_elem(4); + + let data_expected = TensorData::from([[false, false, false], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_greater_equal_elem() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0); + let data_actual_inplace = tensor_1.greater_equal_elem(4.0); + + let data_expected = TensorData::from([[false, false, false], [false, true, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_greater() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -77, 25, 127, 76], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); + let data_actual_inplace = tensor_1.greater(tensor_2); + + let data_expected = TensorData::from([[false, false, true], [false, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_greater_equal() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 1.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -77, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[0.0, 1.0, 2.0], [3.0, 5.0, 4.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 127, 76], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.greater_equal(tensor_2); + + let data_expected = TensorData::from([[true, true, false], [true, false, true]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_lower_elem() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual_cloned = tensor_1.clone().lower_elem(4.0); + let data_actual_inplace = tensor_1.lower_elem(4.0); + + let data_expected = TensorData::from([[true, true, true], [true, false, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_lower_equal_elem() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()); + + let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0); + let data_actual_inplace = tensor_1.lower_equal_elem(4.0); + + let data_expected = TensorData::from([[true, true, true], [true, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_lower() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 1.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -77, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[0.0, 1.0, 2.0], [3.0, 5.0, 4.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 127, 76], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); + let data_actual_inplace = tensor_1.lower(tensor_2); + + let data_expected = TensorData::from([[false, false, true], [false, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } + + #[test] + fn test_lower_equal() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -77, 25, 127, 76], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); + let data_actual_inplace = tensor_1.lower_equal(tensor_2); + + let data_expected = TensorData::from([[true, true, false], [true, true, false]]); + assert_eq!(data_expected, data_actual_cloned.into_data()); + assert_eq!(data_expected, data_actual_inplace.into_data()); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/mask.rs b/crates/burn-tensor/src/tests/quantization/ops/mask.rs new file mode 100644 index 0000000000..1ec47dc30d --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/mask.rs @@ -0,0 +1,63 @@ +#[burn_tensor_testgen::testgen(q_mask)] +mod tests { + use super::*; + use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; + use burn_tensor::{Bool, Int, Tensor, TensorData}; + + #[test] + fn should_support_mask_where_ops() { + let device = Default::default(); + // Quantized [[1.0, 7.0], [2.0, 3.0]] + let data = TensorData::quantized( + vec![18i8, 127, 36, 54], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let mask = Tensor::::from_bool( + TensorData::from([[true, false], [false, true]]), + &device, + ); + // Quantized [[1.0, 7.0], [2.0, 3.0]] + let data = TensorData::quantized( + vec![48i8, 74, 101, 127], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.037795275)), + ); + let value = TestTensor::<2>::from_data(data, &device); + + let output = tensor.mask_where(mask, value); + let expected = TensorData::from([[1.8, 7.0], [2.0, 4.8]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_mask_fill_ops() { + let device = Default::default(); + // Quantized [[1.0, 7.0], [2.0, 3.0]] + let data = TensorData::quantized( + vec![18i8, 127, 36, 54], + [2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let mask = Tensor::::from_bool( + TensorData::from([[true, false], [false, true]]), + &device, + ); + + let output = tensor.mask_fill(mask, 2.0); + let expected = TensorData::from([[2.0, 7.0], [2.0, 2.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/matmul.rs b/crates/burn-tensor/src/tests/quantization/ops/matmul.rs new file mode 100644 index 0000000000..7230fd58f0 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/matmul.rs @@ -0,0 +1,314 @@ +#[burn_tensor_testgen::testgen(q_matmul)] +mod tests { + use super::*; + use burn_tensor::quantization::{ + AffineQuantization, QuantizationStrategy, SymmetricQuantization, + }; + use burn_tensor::{Int, Tensor, TensorData}; + + // NOTE: we set higher tolerance (0.3) due to larger de/quantization errors accumulation + #[test] + fn test_matmul_d2() { + let device = Default::default(); + // Quantized [[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]] + let data = TensorData::quantized( + vec![18i8, 127, 36, 54, 18, 91], + [3, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]] + let data = TensorData::quantized( + vec![73i8, 127, 91, 36, 54, 91], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = + TensorData::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + fn test_matmul_d3() { + let device = Default::default(); + // Quantized [[[1.0, 7.0], [2.0, 3.0]]] + let data = TensorData::quantized( + vec![18i8, 127, 36, 54], + [1, 2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_1 = TestTensor::<3>::from_data(data, &device); + // Quantized [[[4.0, 7.0], [2.0, 3.0]]] + let data = TensorData::quantized( + vec![73i8, 127, 36, 54], + [1, 2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_2 = TestTensor::<3>::from_data(data, &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[[18.0, 28.0], [14.0, 23.0]]]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + fn test_matmul_broadcast_1() { + let device = Default::default(); + // Quantized [[[1.0, 7.0], [2.0, 3.0]]] + let data = TensorData::quantized( + vec![18i8, 127, 36, 54], + [1, 2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_1 = TestTensor::<3>::from_data(data, &device); + // Quantized [[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]] + let data = TensorData::quantized( + vec![73i8, 127, 36, 54, 36, 91, 109, 54], + [2, 2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_2 = TestTensor::<3>::from_data(data, &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = + TensorData::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + fn test_matmul_broadcast_4d() { + let device = Default::default(); + // Quantized [[[[1.0, 7.0], [2.0, 3.0]]], [[[2.0, 5.0], [6.0, 3.0]]]] + let data = TensorData::quantized( + vec![18i8, 127, 36, 54, 36, 91, 109, 54], + [2, 1, 2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)), + ); + let tensor_1 = TestTensor::<4>::from_data(data, &device); + // Quantized [[[[9.0, 8.0], [1.0, 4.0]], [[2.0, 7.0], [3.0, 5.0]]]] + let data = TensorData::quantized( + vec![127i8, 113, 14, 56, 28, 99, 42, 71], + [1, 2, 2, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.070866145)), + ); + let tensor_2 = TestTensor::<4>::from_data(data, &device); + + // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2] + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([ + [[[16.0, 36.0], [21.0, 28.0]], [[23.0, 42.0], [13.0, 29.0]]], + [[[23.0, 36.0], [57.0, 60.0]], [[19.0, 39.0], [21.0, 57.0]]], + ]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + fn test_matmul_simple_1() { + let device = Default::default(); + // NOTE: we use affine quantization to lower de/quantization errors + // Quantized [[5.0, 14.0], [14.0, 25.0]] + let data = TensorData::quantized( + vec![-77i8, 15, 15, 127], + [2, 2], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.09803922, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]] + let data = TensorData::quantized( + vec![25i8, 76, 127, -128, -77, -26], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[15.0, 34.0, 53.0], [42.0, 81.0, 120.0]]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + fn test_matmul_4_3() { + let device = Default::default(); + // NOTE: we use affine quantization to lower de/quantization errors + // Quantized [[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]] + let data = TensorData::quantized( + vec![-128i8, -105, -82, -58, -35, -12, 11, 34, 57, 81, 104, 127], + [3, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.043137256, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.], [13., 14., 15.]] + let data = TensorData::quantized( + vec![-128i8, -111, -94, -60, -43, -26, 8, 25, 42, 93, 110, 127], + [4, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[59., 65., 71.], [159., 181., 203.], [259., 297., 335.]]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + fn test_matmul_trivial() { + let device = Default::default(); + // NOTE: we use affine quantization to lower de/quantization errors + // Quantized [[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.], [12., 13., 14., 15.]] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [4, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + + let tensor_3 = tensor_1.clone().matmul(tensor_1); + + tensor_3.dequantize().into_data().assert_approx_eq( + &TensorData::from([ + [56., 62., 68., 74.], + [152., 174., 196., 218.], + [248., 286., 324., 362.], + [344., 398., 452., 506.], + ]), + 3, + ); + } + + #[test] + fn test_matmul_trivial_transposed() { + let device = Default::default(); + // NOTE: we use affine quantization to lower de/quantization errors + // Quantized [[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.], [12., 13., 14., 15.]] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [4, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + + let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); + + tensor_3.dequantize().into_data().assert_approx_eq( + &TensorData::from([ + [14., 38., 62., 86.], + [38., 126., 214., 302.], + [62., 214., 366., 518.], + [86., 302., 518., 734.], + ]), + 1, + ); + } + + #[test] + fn test_matmul_simple_2() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0, 4.0]] + let data = TensorData::quantized( + vec![32i8, 64, 95, 127], + [1, 4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[1.0, 2.0, 3.0, 4.0]] + let data = TensorData::quantized( + vec![64i8, 85, 106, 127], + [4, 1], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([[50.0]]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + fn test_matmul_simple_3() { + let device = Default::default(); + // Quantized [[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]] + let data = TensorData::quantized( + vec![64i8, 64, 64, 85, 85, 85, 106, 106, 106, 127, 127, 127], + [4, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]] + let data = TensorData::quantized( + vec![32i8, 64, 95, 127, 32, 64, 95, 127, 32, 64, 95, 127], + [3, 4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let tensor_3 = tensor_1.matmul(tensor_2); + let expected = TensorData::from([ + [9., 18., 27., 36.], + [12., 24., 36., 48.], + [15., 30., 45., 60.], + [18., 36., 54., 72.], + ]); + + tensor_3 + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.3); + } + + #[test] + #[should_panic] + fn should_panic_when_inner_dimensions_are_not_equal() { + let device = Default::default(); + // Quantized [[3., 3.], [4., 4.], [5., 5.], [6., 6.]] + let data = TensorData::quantized( + vec![64i8, 64, 85, 85, 106, 106, 127, 127], + [4, 2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.047244094)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]] + let data = TensorData::quantized( + vec![32i8, 64, 95, 127, 32, 64, 95, 127, 32, 64, 95, 127], + [4, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let _ = tensor_1.matmul(tensor_2); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/maxmin.rs b/crates/burn-tensor/src/tests/quantization/ops/maxmin.rs new file mode 100644 index 0000000000..03937846de --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/maxmin.rs @@ -0,0 +1,211 @@ +#[burn_tensor_testgen::testgen(q_maxmin)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn test_max_dim_2d() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.max_dim(1); + let expected = TensorData::from([[2.], [5.]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn test_max_dim_with_indices_2d_with_dim_0th() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let (output, index) = tensor.max_dim_with_indices(0); + + let output_expected = TensorData::from([[3., 4., 5.]]); + let index_expected = TensorData::from([[1, 1, 1]]); + + output + .dequantize() + .into_data() + .assert_eq(&output_expected, false); + index.into_data().assert_eq(&index_expected, false); + } + + #[test] + fn test_max_dim_with_indices_2d() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let (output, index) = tensor.max_dim_with_indices(1); + + let output_expected = TensorData::from([[2.], [5.]]); + let index_expected = TensorData::from([[2], [2]]); + + output + .dequantize() + .into_data() + .assert_eq(&output_expected, false); + index.into_data().assert_eq(&index_expected, false); + } + + #[test] + fn test_min_dim_2d() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.min_dim(1); + + let expected = TensorData::from([[0.], [3.]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn test_min_dim_with_indices_2d() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let (output, index) = tensor.min_dim_with_indices(1); + + let output_expected = TensorData::from([[0.], [3.]]); + let index_expected = TensorData::from([[0], [0]]); + + output + .dequantize() + .into_data() + .assert_eq(&output_expected, false); + index.into_data().assert_eq(&index_expected, false); + } + + #[test] + fn test_min_dim_2d_with_0th_dim() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.min_dim(0); + let expected = TensorData::from([[0., 1., 2.]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn test_max_dim_2d_with_0th_dim() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.max_dim(0); + let expected = TensorData::from([[3., 4., 5.]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn test_min_dim_with_indices_2d_with_0th_dim() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let (output, index) = tensor.min_dim_with_indices(0); + + let output_expected = TensorData::from([[0., 1., 2.]]); + let index_expected = TensorData::from([[0, 0, 0]]); + + output + .dequantize() + .into_data() + .assert_eq(&output_expected, false); + index.into_data().assert_eq(&index_expected, false); + } + + #[test] + fn test_maximum_pair() { + // Quantized [1.0, 2.0, 3.0, 4.0] (with range [0., 5.]) + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let a = TestTensor::<1>::from_data(data, &Default::default()); + // Quantized [2.0, 1.0, 4.0, 5.0] (with range [0., 5.]) + let data = TensorData::quantized( + vec![-26i8, -77, 76, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let b = TestTensor::<1>::from_data(data, &Default::default()); + + let output = a.max_pair(b); + let expected = TensorData::from([2.0, 2.0, 4.0, 5.0]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn test_minimum_pair() { + // Quantized [1.0, 2.0, 3.0, 4.0] (with range [0., 5.]) + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let a = TestTensor::<1>::from_data(data, &Default::default()); + // Quantized [2.0, 1.0, 4.0, 5.0] (with range [0., 5.]) + let data = TensorData::quantized( + vec![-26i8, -77, 76, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let b = TestTensor::<1>::from_data(data, &Default::default()); + + let output = a.min_pair(b); + let expected = TensorData::from([1.0, 1.0, 3.0, 4.0]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/mod.rs b/crates/burn-tensor/src/tests/quantization/ops/mod.rs new file mode 100644 index 0000000000..e64d072617 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/mod.rs @@ -0,0 +1,43 @@ +mod abs; +mod add; +mod aggregation; +mod all; +mod any; +mod arg; +mod cat; +mod chunk; +mod clamp; +mod cos; +mod div; +mod erf; +mod exp; +mod expand; +mod flip; +mod gather_scatter; +mod log; +mod log1p; +mod map_comparison; +mod mask; +mod matmul; +mod maxmin; +mod mul; +mod narrow; +mod neg; +mod permute; +mod powf; +mod powf_scalar; +mod quantize; +mod recip; +mod remainder; +mod repeat_dim; +mod reshape; +mod select; +mod sin; +mod slice; +mod sort_argsort; +mod sqrt; +mod stack; +mod sub; +mod tanh; +mod topk; +mod transpose; diff --git a/crates/burn-tensor/src/tests/quantization/ops/mul.rs b/crates/burn-tensor/src/tests/quantization/ops/mul.rs new file mode 100644 index 0000000000..b69c880cca --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/mul.rs @@ -0,0 +1,105 @@ +#[burn_tensor_testgen::testgen(q_mul)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_mul_ops() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data.clone(), &device); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = tensor_1 * tensor_2; + let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn test_mul_broadcast() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0]] + let data = TensorData::quantized( + vec![-128i8, -1, 127], + [1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.007843138, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data.clone(), &device); + // Quantized [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]] + let data = TensorData::quantized( + vec![-32i8, -1, 31, 63, 95, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.03137255, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data.clone(), &device); + + let output = tensor_1 * tensor_2; + let expected = TensorData::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn test_mul_broadcast_2_dims() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0]] + let data = TensorData::quantized( + vec![-128i8, -1, 127], + [3, 1], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.007843138, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data.clone(), &device); + // Quantized [[0.0, 1.0, 2.0]] + let data = TensorData::quantized( + vec![25i8, 76, 127], + [1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data.clone(), &device); + + let output = tensor_1 * tensor_2; + let expected = TensorData::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_mul_scalar_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + let scalar = 2.0; + + let output = tensor * scalar; + let expected = TensorData::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/narrow.rs b/crates/burn-tensor/src/tests/quantization/ops/narrow.rs new file mode 100644 index 0000000000..2c290a2234 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/narrow.rs @@ -0,0 +1,90 @@ +#[burn_tensor_testgen::testgen(q_narrow)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Shape, Tensor, TensorData}; + + #[test] + fn test_narrow() { + // Quantized [[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]] + let data = TensorData::quantized( + vec![-111i8, -94, -77, -9, 8, 25, 93, 110, 127], + [3, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + let output = tensor.clone().narrow(0, 0, 2); + let expected = TensorData::from([[1., 2., 3.], [7., 8., 9.]]); + + assert_eq!(output.shape(), Shape::from([2, 3])); + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + + let output = tensor.narrow(1, 1, 2); + let expected = TensorData::from([[2., 3.], [8., 9.], [14., 15.]]); + assert_eq!(output.shape(), Shape::from([3, 2])); + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + } + + #[test] + #[should_panic] + fn test_narrow_invalid_dim() { + // Quantized [[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]] + let data = TensorData::quantized( + vec![-111i8, -94, -77, -9, 8, 25, 93, 110, 127], + [3, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + let output = tensor.narrow(2, 0, 2); + } + + #[test] + #[should_panic] + fn test_narrow_invalid_start() { + // Quantized [[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]] + let data = TensorData::quantized( + vec![-111i8, -94, -77, -9, 8, 25, 93, 110, 127], + [3, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + let output = tensor.narrow(0, 3, 2); + } + + #[test] + #[should_panic] + fn test_narrow_invalid_zero_length() { + // Quantized [[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]] + let data = TensorData::quantized( + vec![-111i8, -94, -77, -9, 8, 25, 93, 110, 127], + [3, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + let output = tensor.narrow(0, 1, 0); + } + + #[test] + #[should_panic] + fn test_narrow_invalid_length() { + // Quantized [[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [3, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + let output = tensor.narrow(0, 0, 4); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/neg.rs b/crates/burn-tensor/src/tests/quantization/ops/neg.rs new file mode 100644 index 0000000000..8a27c3834f --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/neg.rs @@ -0,0 +1,31 @@ +#[burn_tensor_testgen::testgen(q_neg)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_neg_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + let output = tensor.neg(); + let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::(); + + // -0.0 is represented differently than 0.0 so we make sure the values are the same in f32 + assert_eq!( + output + .dequantize() + .into_data() + .convert::() + .as_slice::() + .unwrap(), + expected.as_slice::().unwrap() + ); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/permute.rs b/crates/burn-tensor/src/tests/quantization/ops/permute.rs new file mode 100644 index 0000000000..668852c4b8 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/permute.rs @@ -0,0 +1,82 @@ +#[burn_tensor_testgen::testgen(q_permute)] +mod tests { + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Device, Tensor, TensorData}; + + #[test] + fn permute_float() { + let device = Default::default(); + // Quantized [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [2, 2, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<3>::from_data(data.clone(), &device); + + let permuted = tensor.clone().permute([2, 1, 0]); + + let expected = TensorData::from([ + [[0., 8.], [4., 12.]], + [[1., 9.], [5., 13.]], + [[2., 10.], [6., 14.]], + [[3., 11.], [7., 15.]], + ]); + + permuted + .dequantize() + .into_data() + .assert_eq(&expected, false); + + // Test with negative axis + let permuted = tensor.clone().permute([-1, 1, 0]); + permuted + .dequantize() + .into_data() + .assert_eq(&expected, false); + + // Test with the same axis + let permuted = tensor.clone().permute([0, 1, 2]); + permuted.into_data().assert_eq(&tensor.into_data(), true); + } + + #[test] + #[should_panic] + fn edge_repeated_axes() { + let device = Default::default(); + // Quantized [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [2, 2, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<3>::from_data(data.clone(), &device); + + // Test with a repeated axis + let _ = tensor.clone().permute([0, 0, 1]); + } + + #[test] + #[should_panic] + fn edge_out_of_bound_axis() { + let device = Default::default(); + // Quantized [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [2, 2, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<3>::from_data(data.clone(), &device); + + // Test with an invalid axis + let _ = tensor.clone().permute([3, 0, 1]); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/powf.rs b/crates/burn-tensor/src/tests/quantization/ops/powf.rs new file mode 100644 index 0000000000..d944d39517 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/powf.rs @@ -0,0 +1,120 @@ +#[burn_tensor_testgen::testgen(q_powf)] +mod tests { + use super::*; + use burn_tensor::quantization::{ + AffineQuantization, QuantizationStrategy, SymmetricQuantization, + }; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_powf_ops() { + let device = Default::default(); + // Quantized [[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-77i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + // Quantized [[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]] (with range [1., 5.] to reduce quantization errors) + let data = TensorData::quantized( + vec![-77i8, -77, -26, 25, 76, -26], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_pow = TestTensor::<2>::from_data(data, &device); + + let output = tensor.powf(tensor_pow); + let expected = TensorData::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]); + + // NOTE: we set higher tolerance (0.2) due to larger de/quantization errors accumulation w/ powers + output + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.2); + } + + #[test] + fn should_support_neg_power() { + let device = Default::default(); + // Quantized [[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-77i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + // Quantized [[-0.95, -0.67, -0.45], [-0.24, -0.5, -0.6]] + let data = TensorData::quantized( + vec![-128i8, -53, 6, 63, -7, -34], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.00372549, 127)), + ); + let tensor_pow = TestTensor::<2>::from_data(data, &device); + + let output = tensor.powf(tensor_pow); + let expected = TensorData::from([[1., 1., 0.73204285], [0.76822936, 0.5, 0.38073079]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_neg_values_with_even_power() { + let device = Default::default(); + // Quantized [[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]] + let data = TensorData::quantized( + vec![126i8, 75, 24, -27, -78, -128], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), + ); + let tensor = Tensor::::from_data(data, &device); + // Quantized [[4.0, 2.0, 4.0], [2.0, 4.0, 2.0]] (with range [2., 5.] to reduce quantization errors) + let data = TensorData::quantized( + vec![76i8, -26, 76, -26, 76, -26], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_pow = TestTensor::<2>::from_data(data, &device); + + let output = tensor.powf(tensor_pow); + let expected = TensorData::from([[0.0, 1.0, 16.0], [9.0, 256.0, 25.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_neg_values_with_odd_power() { + let device = Default::default(); + // Quantized [[0.0, -1.0, -2.0], [-3.0, -4.0, -4.0]] (with range [-5., 0.] to reduce quantization errors) + let data = TensorData::quantized( + vec![126i8, 75, 24, -27, -78, -78], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), + ); + let tensor = Tensor::::from_data(data, &Default::default()); + // Quantized [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]] + let data = TensorData::quantized( + vec![127i8, 127, 127, 127, 127, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor_pow = TestTensor::<2>::from_data(data, &device); + + let output = tensor.powf(tensor_pow); + let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]); + + // NOTE: we set higher tolerance (0.2) due to larger de/quantization errors accumulation w/ powers + output + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.2); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs b/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs new file mode 100644 index 0000000000..bdf65b5b92 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs @@ -0,0 +1,89 @@ +#[burn_tensor_testgen::testgen(q_powf_scalar)] +mod tests { + use super::*; + use burn_tensor::quantization::{ + AffineQuantization, QuantizationStrategy, SymmetricQuantization, + }; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_powf_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![0i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.powf_scalar(0.71); + let expected = TensorData::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_neg_power() { + // Quantized [[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![25i8, 25, 51, 76, 102, 127], + [2, 3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.powf_scalar(-0.33); + let expected = + TensorData::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_neg_values_with_even_power() { + // Quantized [[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]] + let data = TensorData::quantized( + vec![126i8, 75, 24, -27, -78, -128], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), + ); + let tensor = Tensor::::from_data(data, &Default::default()); + + let output = tensor.powf_scalar(2.0); + let expected = TensorData::from([[0., 1., 4.], [9., 16., 25.]]); + + // NOTE: we set higher tolerance (0.2) due to larger de/quantization errors accumulation w/ powers + output + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.2); + } + + #[test] + fn should_support_neg_values_with_odd_power() { + // Quantized [[0.0, -1.0, -2.0], [-3.0, -4.0, -4.0]] (with range [-5., 0.] to reduce quantization errors) + let data = TensorData::quantized( + vec![126i8, 75, 24, -27, -78, -78], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, 126)), + ); + let tensor = Tensor::::from_data(data, &Default::default()); + + let output = tensor.powf_scalar(3.0); + let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]); + + // NOTE: we set higher tolerance (0.2) due to larger de/quantization errors accumulation w/ powers + output + .dequantize() + .into_data() + .assert_approx_eq_diff(&expected, 0.2); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs new file mode 100644 index 0000000000..872abeb26d --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -0,0 +1,100 @@ +#[burn_tensor_testgen::testgen(quantize)] +mod tests { + use super::*; + use burn_tensor::ops::QTensorOps; + use burn_tensor::quantization::{ + AffineQuantization, QuantizationParameters, QuantizationScheme, QuantizationStrategy, + QuantizationType, SymmetricQuantization, + }; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_quantize_affine_int8() { + let device = Default::default(); + let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let qparams = QuantizationParameters { + scale: Tensor::from_floats([0.009_019_608], &device), + offset: Some(Tensor::from_ints([72], &device)), + }; + + let x_q = tensor.quantize(&scheme, qparams); + + let expected = TensorData::quantized( + vec![-128i8, -39, 72, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72)), + ); + + x_q.to_data().assert_eq(&expected, true); + } + + #[test] + fn should_support_quantize_symmetric_int8() { + let device = Default::default(); + let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let qparams = QuantizationParameters { + scale: Tensor::from_floats([0.014_173_228], &device), + offset: None, + }; + + let x_q = tensor.quantize(&scheme, qparams); + + let expected = TensorData::quantized( + vec![-127i8, -71, 0, 35], + [4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( + 0.014_173_228, + )), + ); + + x_q.to_data().assert_eq(&expected, true); + } + + #[test] + fn should_support_dequantize() { + let device = Default::default(); + let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let qparams = QuantizationParameters { + scale: Tensor::from_floats([0.014_173_228], &device), + offset: None, + }; + + let x_q = tensor.quantize(&scheme, qparams); + + // Quantized [-1.8, -1.0, 0.0, 0.5] + let data = TensorData::quantized( + vec![-127i8, -71, 0, 35], + [4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( + 0.014_173_228, + )), + ); + let x_q = Tensor::::from_data(data, &device); + + let x = x_q.dequantize(); + + // Precision 2 for dequantization errors + x.to_data() + .assert_approx_eq(&TensorData::from([-1.8, -1.0, 0.0, 0.5]), 2); + } + + #[test] + fn should_support_quantize_dynamic_int8() { + let device = Default::default(); + let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + + let x_q = tensor.quantize_dynamic(&scheme); + + let expected = TensorData::quantized( + vec![-128i8, -39, 72, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72)), + ); + + x_q.to_data().assert_eq(&expected, true); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/recip.rs b/crates/burn-tensor/src/tests/quantization/ops/recip.rs new file mode 100644 index 0000000000..c121023fb8 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/recip.rs @@ -0,0 +1,26 @@ +#[burn_tensor_testgen::testgen(q_recip)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_recip_ops() { + // Quantized [[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]] + let data = TensorData::quantized( + vec![47i8, 63, 95, 127, -96, -128], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.03137255, 31)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.recip(); + let expected = TensorData::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/remainder.rs b/crates/burn-tensor/src/tests/quantization/ops/remainder.rs new file mode 100644 index 0000000000..c53906ed17 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/remainder.rs @@ -0,0 +1,162 @@ +#[burn_tensor_testgen::testgen(q_remainder)] +mod tests { + use super::*; + use burn_tensor::quantization::{ + AffineQuantization, QuantizationStrategy, SymmetricQuantization, + }; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_remainder_basic() { + // Quantized [-3.0, -2.0, -1.0, 1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![-128i8, -85, -43, 43, 85, 127], + [6], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.023529412, 0)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.remainder_scalar(2.0); + let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_remainder_float() { + // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76, 127], + [5], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.remainder_scalar(-1.5); + let expected = TensorData::from([-0.5, -1.0, 0.0, -0.5, -1.0]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_be_zero() { + // Quantized [0.0, 0.0, 0.0] + let data = TensorData::quantized( + vec![0i8, 0, 0], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.remainder_scalar(3.5); + let expected = TensorData::from([0.0, 0.0, 0.0]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + } + + #[test] + fn should_have_no_remainder() { + // Quantized [-4.0, 4.0] + let data = TensorData::quantized( + vec![-127i8, 127], + [2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.031496063)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.remainder_scalar(4.0); + let expected = TensorData::from([-0.0, 0.0]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + } + + #[test] + fn should_be_negative() { + // Quantized [-7.0, -3.0, 2.0, 6.0] + let data = TensorData::quantized( + vec![-128i8, -50, 48, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.050980393, 9)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.remainder_scalar(-2.5); + let expected = TensorData::from([-2.0, -0.50, -0.50, -1.5]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_fp_dividends() { + // Quantized [-7.5, -2.5, 2.5, 7.5] + let data = TensorData::quantized( + vec![-127i8, -42, 42, 127], + [4], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05905512)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.remainder_scalar(3.0); + let expected = TensorData::from([1.5, 0.5, 2.5, 1.5]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_large_divisor() { + // Quantized [-1.0, 1.0] + let data = TensorData::quantized( + vec![-127i8, 127], + [2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.007874016)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.remainder_scalar(10.0); + let expected = TensorData::from([9.0, 1.0]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_remainder_op() { + // Quantized [-3.0, -2.0, -1.0, 1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![-128i8, -85, -43, 43, 85, 127], + [6], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.023529412, 0)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor % 2.0; + let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/repeat_dim.rs b/crates/burn-tensor/src/tests/quantization/ops/repeat_dim.rs new file mode 100644 index 0000000000..91ee496066 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/repeat_dim.rs @@ -0,0 +1,50 @@ +#[burn_tensor_testgen::testgen(q_repeat_dim)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_repeat_ops() { + // Quantized [[0.0, 1.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![-128i8, -43, 42, 127], + [1, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.repeat_dim(0, 4); + let expected = TensorData::from([ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_repeat_on_dims_larger_than_1() { + // Quantized [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [4, 2, 2], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let output = tensor.repeat_dim(2, 2); + let expected = TensorData::from([ + [[0., 1., 0., 1.], [2., 3., 2., 3.]], + [[4., 5., 4., 5.], [6., 7., 6., 7.]], + [[8., 9., 8., 9.], [10., 11., 10., 11.]], + [[12., 13., 12., 13.], [14., 15., 14., 15.]], + ]); + + output.dequantize().into_data().assert_eq(&expected, false); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/reshape.rs b/crates/burn-tensor/src/tests/quantization/ops/reshape.rs new file mode 100644 index 0000000000..6d42ee56aa --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/reshape.rs @@ -0,0 +1,111 @@ +#[burn_tensor_testgen::testgen(q_reshape)] +mod tests { + use super::*; + use burn_tensor::quantization::{ + AffineQuantization, QuantizationStrategy, SymmetricQuantization, + }; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_reshape_1d() { + // Quantized [0.0, 1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![-128i8, -43, 42, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.clone().reshape([1, 4]); + let expected = TensorData::from([[0.0, 1.0, 2.0, 3.0]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_reshape_2d() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.clone().reshape([6]); + let expected = TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_dim_infererence() { + // Quantized [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0] + let data = TensorData::quantized( + vec![-128i8, -105, -82, -58, -35, -12, 11, 34, 57, 81, 104, 127], + [4, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.043137256, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + // Infer the dimension via -1 + let reshaped = tensor.clone().reshape([2, -1]); + assert_eq!(reshaped.shape(), [2, 6].into()); + + // Infer the dimension via 0 (keep from the source) and -1 (infer) + let reshaped = reshaped.reshape([0, 2, -1]); + assert_eq!(reshaped.shape(), [2, 2, 3].into()); + + // This is effectively as if we did a flatten + let reshaped = tensor.clone().reshape([-1]); + assert_eq!(reshaped.shape(), [12].into()); + + // Keeping the first dimension the same (using 0) + let reshaped = tensor.clone().reshape([0, 3]); + assert_eq!(reshaped.shape(), [4, 3].into()); + } + + #[test] + fn should_not_corrupt_after_slice() { + // Quantized [0.0, 0.0] + let data = TensorData::quantized( + vec![0i8, 0], + [2], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)), + ); + let zeros = TestTensor::<1>::from_data(data, &Default::default()); + zeros.clone().slice([1..2]).reshape([1]).exp(); + + // May lead to zeroes being equal to [0.0, 1.0] + zeros.dequantize().into_data().assert_eq( + &Tensor::::zeros([2], &Default::default()).to_data(), + true, + ); + } + + #[test] + #[should_panic] + fn multiple_neg_ones() { + // Quantized [0.0, 1.0, 2.0] + let data = TensorData::quantized( + vec![0i8, 64, 127], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.015748031)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let data_actual = tensor.reshape([-1, -1]).into_data(); + } + + #[test] + #[should_panic] + fn neg_value() { + // Quantized [0.0, 1.0, 2.0] + let data = TensorData::quantized( + vec![0i8, 64, 127], + [3], + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.015748031)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let data_actual = tensor.reshape([-2, -1]).into_data(); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/select.rs b/crates/burn-tensor/src/tests/quantization/ops/select.rs new file mode 100644 index 0000000000..2d13ab7a9a --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/select.rs @@ -0,0 +1,174 @@ +#[burn_tensor_testgen::testgen(q_select)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_select_1d() { + let device = Default::default(); + // Quantized [0.0, 1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![-128i8, -43, 42, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &device); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device); + + let output = tensor.select(0, indices); + let expected = TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_select_2d_dim0_same_num_dim() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_data(([1, 0]), &device); + + let output = tensor.select(0, indices); + let expected = TensorData::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_select_2d_dim0_more_num_dim() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_data([1, 0, 1, 1], &device); + + let output = tensor.select(0, indices); + let expected = TensorData::from([ + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [3.0, 4.0, 5.0], + ]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_select_2d_dim1() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device); + + let output = tensor.select(1, indices); + let expected = TensorData::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_select_assign_1d() { + let device = Default::default(); + // Quantized [0.0, 1.0, 2.0] + let data = TensorData::quantized( + vec![-128i8, -1, 127], + [3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.007843138, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &device); + // Quantized [5.0, 4.0, 3.0, 2.0, 1.0] + let data = TensorData::quantized( + vec![127i8, 76, 25, -26, -77], + [5], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let values = TestTensor::<1>::from_data(data, &device); + let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &device); + + let output = tensor.select_assign(0, indices, values); + let expected = TensorData::from([3.0, 12.0, 3.0]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_select_assign_2d_dim0() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let values = tensor.clone(); + let indices = TestTensorInt::from_data(TensorData::from([1, 0]), &device); + + let output = tensor.select_assign(0, indices, values); + let expected = TensorData::from([[3.0, 5.0, 7.0], [3.0, 5.0, 7.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_select_assign_2d_dim1() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let values = tensor.clone(); + let indices = TestTensorInt::from_data(TensorData::from([1, 0, 2]), &device); + + let output = tensor.select_assign(1, indices, values); + let expected = TensorData::from([[1.0, 1.0, 4.0], [7.0, 7.0, 10.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + #[should_panic] + fn should_select_panic_invalid_dimension() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device); + + tensor.select(10, indices); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/sin.rs b/crates/burn-tensor/src/tests/quantization/ops/sin.rs new file mode 100644 index 0000000000..725a806b29 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/sin.rs @@ -0,0 +1,25 @@ +#[burn_tensor_testgen::testgen(q_sin)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_sin_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.sin(); + let expected = TensorData::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/slice.rs b/crates/burn-tensor/src/tests/quantization/ops/slice.rs new file mode 100644 index 0000000000..5beedb6765 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/slice.rs @@ -0,0 +1,323 @@ +#[burn_tensor_testgen::testgen(q_slice)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Int, Tensor, TensorData}; + + #[test] + fn should_support_full_sliceing_1d() { + // Quantized [0.0, 1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![-128i8, -43, 42, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); + + let output = tensor.slice([0..4]); + + output.into_data().assert_eq(&data, false); + } + + #[test] + fn should_support_partial_sliceing_1d() { + // Quantized [0.0, 1.0, 2.0, 3.0] + let data = TensorData::quantized( + vec![-128i8, -43, 42, 127], + [4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let output = tensor.slice([1..3]); + let expected = TensorData::from([1.0, 2.0]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_full_sliceing_2d() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + let output = tensor.clone().slice([0..2]); + output.into_data().assert_eq(&data, true); + + let output = tensor.slice([0..2, 0..3]); + output.into_data().assert_eq(&data, true); + } + + #[test] + fn should_support_partial_sliceing_2d() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.slice([0..2, 0..2]); + let expected = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_partial_sliceing_3d() { + // Quantized [[[0., 1., 2., 3.], [4., 5., 6., 7.]], [[8., 9., 10., 11.], [12., 13., 14., 15.]]] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [2, 2, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let output = tensor.slice([1..2, 1..2, 0..2]); + let expected = TensorData::from([[[12.0, 13.0]]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_partial_sliceing_3d_non_contiguous() { + // Quantized [[[0., 1., 2., 3.], [4., 5., 6., 7.]], [[8., 9., 10., 11.], [12., 13., 14., 15.]]] + let data = TensorData::quantized( + vec![ + -128i8, -111, -94, -77, -60, -43, -26, -9, 8, 25, 42, 59, 76, 93, 110, 127, + ], + [2, 2, 4], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, -128)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let output = tensor.transpose().slice([1..2, 1..2, 0..2]); + let expected = TensorData::from([[[9.0, 13.0]]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_slice_assign_1d() { + let device = Default::default(); + // Quantized [0.0, 1.0, 2.0] + let data = TensorData::quantized( + vec![-128i8, -1, 127], + [3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.007843138, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &device); + // Quantized [10.0, 5.0] + let data = TensorData::quantized( + vec![127i8, -1], + [2], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.039215688, -128)), + ); + let tensor_assigned = Tensor::::from_data(data, &device); + + let output = tensor.slice_assign([0..2], tensor_assigned); + let expected = TensorData::from([10.0, 5.0, 2.0]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_slice_assign_2d() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &device); + // Quantized [[10.0, 5.0]] + let data = TensorData::quantized( + vec![127i8, -1], + [1, 2], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.039215688, -128)), + ); + let tensor_assigned = Tensor::::from_data(data, &device); + + let output = tensor.slice_assign([1..2, 0..2], tensor_assigned); + let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn slice_should_not_corrupt_potentially_inplace_operations() { + // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76, 127], + [5], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + let tensor = tensor.clone().slice([0..3]) + tensor.clone().slice([2..5]); + + let expected = TensorData::from([4., 6., 8.]); + + // Precision 1 to approximate de/quantization errors + tensor + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn slice_assign_should_not_corrupt_potentially_inplace_operations() { + let device = Default::default(); + // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76, 127], + [5], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &device); + // Quantized [10., 20., 30.] + let data = TensorData::quantized( + vec![-43i8, 42, 127], + [3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.11764706, -128)), + ); + let values = TestTensor::<1>::from_data(data, &device); + let tensor_1 = tensor.clone().slice_assign([0..3], values); + let tensor_2 = tensor + 2; + + let expected = TensorData::from([10., 20., 30., 4., 5.]); + + // Precision 1 to approximate de/quantization errors + tensor_1 + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + + let expected = TensorData::from([3., 4., 5., 6., 7.]); + + // Precision 1 to approximate de/quantization errors + tensor_2 + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn clamp_when_slice_exceeds_dimension() { + // Quantized [0.0, 1.0, 2.0] + let data = TensorData::quantized( + vec![-128i8, -1, 127], + [3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.007843138, -128)), + ); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); + + let output = tensor.slice([0..4]); + output.into_data().assert_eq(&data, true); + } + + #[test] + fn negative_dimensions() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + // Clamping to the tensor dimensions + let output = tensor.clone().slice([(0, 4), (0, 4)]); + output.into_data().assert_eq(&data, true); + + // Negative dimensions + let output = tensor.clone().slice([(0, 1), (0, 1)]); + let data = TensorData::from([[0.0f32]]); + output.dequantize().into_data().assert_eq(&data, false); + + let output = tensor.slice([(0, -1), (0, -2)]); + output.dequantize().into_data().assert_eq(&data, false); + } + + #[test] + fn missing_dimensions() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); + + // Clamping to the tensor dimensions + let output = tensor.clone().slice([Some((0, 4)), Some((0, 4))]); + output.into_data().assert_eq(&data, true); + + // Negative dimensions + let data = TensorData::from([[0.0f32]]); + let output = tensor.clone().slice([Some((0, -1)), Some((0, -2))]); + output.dequantize().into_data().assert_eq(&data, false); + + // Missing dimensions + let output = tensor.clone().slice([Some((0, 1)), None]); + let data = TensorData::from([[0.0f32, 1.0, 2.0]]); + output.dequantize().into_data().assert_eq(&data, false); + + let output = tensor.clone().slice([None, Some((0, 2))]); + let data = TensorData::from([[0.0f32, 1.0], [3.0, 4.0]]); + output.dequantize().into_data().assert_eq(&data, false); + + let output = tensor.clone().slice([None, None]); + let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); + output.dequantize().into_data().assert_eq(&data, false); + } + + #[test] + #[should_panic] + fn should_panic_when_slice_with_too_many_dimensions() { + let data = TensorData::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone(), &Default::default()); + + let output = tensor.slice([0..1, 0..1]); + + output.into_data().assert_eq(&data, false); + } + + #[test] + #[should_panic] + fn should_panic_when_slice_is_desc() { + let data = TensorData::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone(), &Default::default()); + + #[allow(clippy::reversed_empty_ranges)] + let output = tensor.slice([2..1]); + + output.into_data().assert_eq(&data, false); + } + + #[test] + #[should_panic] + fn should_panic_when_slice_is_equal() { + let data = TensorData::from([0.0, 1.0, 2.0]); + let tensor = Tensor::::from_data(data.clone(), &Default::default()); + + let output = tensor.slice([1..1]); + + output.into_data().assert_eq(&data, false); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/sort_argsort.rs b/crates/burn-tensor/src/tests/quantization/ops/sort_argsort.rs new file mode 100644 index 0000000000..e69d4be36e --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/sort_argsort.rs @@ -0,0 +1,267 @@ +#[burn_tensor_testgen::testgen(q_sort_argsort)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Shape, Tensor, TensorData}; + + #[test] + fn test_sort_1d_float() { + // Quantized [0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1] + let data = TensorData::quantized( + vec![37i8, 50, 23, 27, 67, 45, 21, 71, 127, 104, 46, 85, -128], + [13], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.052156862, 27)), + ); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); + + // Sort along dim=0 + let values = tensor.sort(0); + + let values_expected = TensorData::from([ + -8.1, -0.3, -0.21, 0., 0.5, 0.94, 0.99, 1.2, 2.1, 2.3, 3., 4., 5.2, + ]); + + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + } + + #[test] + fn test_argsort_1d_float() { + // Quantized [0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1] + let data = TensorData::quantized( + vec![37i8, 50, 23, 27, 67, 45, 21, 71, 127, 104, 46, 85, -128], + [13], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.052156862, 27)), + ); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); + + // Sort along dim=0 + let indices = tensor.argsort(0); + + let indices_expected = TensorData::from([12, 6, 2, 3, 0, 5, 10, 1, 4, 7, 11, 9, 8]); + indices.into_data().assert_eq(&indices_expected, false); + } + + #[test] + fn test_sort_with_indices_descending_float() { + // 1D + // Quantized [0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1] + let data = TensorData::quantized( + vec![37i8, 50, 23, 27, 67, 45, 21, 71, 127, 104, 46, 85, -128], + [13], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.052156862, 27)), + ); + let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); + + // Sort along dim=0 + let (values, indices) = tensor.sort_descending_with_indices(0); + + let values_expected = TensorData::from([ + 5.2, 4., 3., 2.3, 2.1, 1.2, 0.99, 0.94, 0.5, 0., -0.21, -0.3, -8.1, + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([8, 9, 11, 7, 4, 1, 10, 5, 0, 3, 2, 6, 12]); + indices.into_data().assert_eq(&indices_expected, false); + + // 3D + // Quantized [-0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1] + let data = TensorData::quantized( + vec![31i8, 67, 38, 42, 86, 62, 36, 90, 126, 63, 105, -128], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.047450982, 42)), + ); + let tensor = TestTensor::<3>::from_data(data.clone(), &Default::default()); + + // Sort along dim=1 + let (values, indices) = tensor.sort_descending_with_indices(1); + + let values_expected = TensorData::from([ + [[0., 2.1, 0.94], [-0.5, 1.2, -0.21]], + [[0.99, 3., 4.], [-0.3, 2.3, -8.1]], + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[1, 1, 1], [0, 0, 0]], [[1, 1, 0], [0, 0, 1]]]); + indices.into_data().assert_eq(&indices_expected, false); + } + + #[test] + fn test_sort_float() { + // Quantized [-0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1] + let data = TensorData::quantized( + vec![31i8, 67, 38, 42, 86, 62, 36, 90, 126, 63, 105, -128], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.047450982, 42)), + ); + let tensor = TestTensor::<3>::from_data(data.clone(), &Default::default()); + + // Sort along dim=0 + let values = tensor.clone().sort(0); + + let values_expected = TensorData::from([ + [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], + [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + // Sort along dim=1 + let values = tensor.clone().sort(1); + + let values_expected = TensorData::from([ + [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], + [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + // Sort along dim=2 + let values = tensor.sort(2); + + let values_expected = TensorData::from([ + [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], + [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + } + + #[test] + fn test_sort_with_indices_float() { + // Quantized [-0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1] + let data = TensorData::quantized( + vec![31i8, 67, 38, 42, 86, 62, 36, 90, 126, 63, 105, -128], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.047450982, 42)), + ); + let tensor = TestTensor::<3>::from_data(data.clone(), &Default::default()); + + // Sort along dim=0 + let (values, indices) = tensor.clone().sort_with_indices(0); + let values_expected = TensorData::from([ + [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], + [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]); + indices.into_data().assert_eq(&indices_expected, false); + + // Sort along dim=1 + let (values, indices) = tensor.clone().sort_with_indices(1); + + let values_expected = TensorData::from([ + [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], + [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]); + indices.into_data().assert_eq(&indices_expected, false); + + // Sort along dim=2 + let (values, indices) = tensor.sort_with_indices(2); + + let values_expected = TensorData::from([ + [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], + [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], + ]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]); + indices.into_data().assert_eq(&indices_expected, false); + } + + #[test] + fn test_argsort_float() { + // Quantized [-0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1] + let data = TensorData::quantized( + vec![31i8, 67, 38, 42, 86, 62, 36, 90, 126, 63, 105, -128], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.047450982, 42)), + ); + let tensor = TestTensor::<3>::from_data(data.clone(), &Default::default()); + + // Sort along dim=0 + let indices = tensor.clone().argsort(0); + + let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]); + indices.into_data().assert_eq(&indices_expected, false); + + // Sort along dim=1 + let indices = tensor.clone().argsort(1); + + let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]); + indices.into_data().assert_eq(&indices_expected, false); + + // Sort along dim=2 + let indices = tensor.argsort(2); + + let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]); + indices.into_data().assert_eq(&indices_expected, false); + } + + #[test] + fn test_sort_float_nan() { + let tensor = TestTensor::<2>::from([[-0.5, f32::NAN], [0., 0.94], [-0.3, f32::NAN]]); + + // Sort along dim=0 + let values = tensor.sort(0); + + let values_expected = TensorData::from([[-0.5, 0.94], [-0.3, f32::NAN], [0., f32::NAN]]); + values.into_data().assert_approx_eq(&values_expected, 5); + } + + #[test] + fn test_sort_descending_1d() { + // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76, 127], + [5], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + // Sort along dim=0 + let values = tensor.sort_descending(0); + + let values_expected = TensorData::from([5., 4., 3., 2., 1.]); + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 5); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/sqrt.rs b/crates/burn-tensor/src/tests/quantization/ops/sqrt.rs new file mode 100644 index 0000000000..b75f5d95e9 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/sqrt.rs @@ -0,0 +1,26 @@ +#[burn_tensor_testgen::testgen(q_sqrt)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + use core::f32::consts::SQRT_2; + + #[test] + fn should_support_sqrt_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.sqrt(); + let expected = TensorData::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/stack.rs b/crates/burn-tensor/src/tests/quantization/ops/stack.rs new file mode 100644 index 0000000000..b760841da3 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/stack.rs @@ -0,0 +1,138 @@ +#[burn_tensor_testgen::testgen(q_stack)] +mod tests { + use super::*; + use alloc::vec::Vec; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_stack_ops_2d_dim0() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![-43i8, 42, 127], + [1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.023529412, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); + let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_stack_ops_2d_dim1() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0]] + let data = TensorData::quantized( + vec![-43i8, 42, 127], + [1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0, 6.0]] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.023529412, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 1); + let expected = TensorData::from([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_stack_ops_3d() { + let device = Default::default(); + // Quantized [[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]] + let data = TensorData::quantized( + vec![-43i8, 42, 127, 127, 42, -43], + [2, 1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor_1 = TestTensor::<3>::from_data(data, &device); + // Quantized [[[4.0, 5.0, 6.0]], [[6.0, 5.0, 4.0]]] + let data = TensorData::quantized( + vec![42i8, 85, 127, 127, 85, 42], + [2, 1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.023529412, -128)), + ); + let tensor_2 = TestTensor::<3>::from_data(data, &device); + + let output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 0); + let expected = TensorData::from([ + [[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]], + [[[4.0, 5.0, 6.0]], [[6.0, 5.0, 4.0]]], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + #[should_panic] + fn should_panic_when_dimensions_are_not_the_same() { + let device = Default::default(); + // Quantized [[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]] + let data = TensorData::quantized( + vec![-43i8, 42, 127, 127, 42, -43], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0]] + let data = TensorData::quantized( + vec![76i8, 127], + [1, 2], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output: Tensor = Tensor::stack(vec![tensor_1, tensor_2], 0); + } + + #[test] + #[should_panic] + fn should_panic_when_stack_exceeds_dimension() { + let device = Default::default(); + // Quantized [[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]] + let data = TensorData::quantized( + vec![-43i8, 42, 127, 127, 42, -43], + [1, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.011764706, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[4.0, 5.0]] + let data = TensorData::quantized( + vec![42i8, 85, 127], + [1, 1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.023529412, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output: Tensor = TestTensor::stack(vec![tensor_1, tensor_2], 3); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/sub.rs b/crates/burn-tensor/src/tests/quantization/ops/sub.rs new file mode 100644 index 0000000000..819a2fdb40 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/sub.rs @@ -0,0 +1,81 @@ +#[burn_tensor_testgen::testgen(q_sub)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{backend::Backend, Tensor, TensorData}; + + #[test] + fn should_support_sub_ops() { + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]] + let data = TensorData::quantized( + vec![11i8, 34, 57, 81, 104, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.043137256, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = tensor_1 - tensor_2; + let expected = TensorData::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn test_sub_broadcast() { + let data_1 = TensorData::from([[0.0, 1.0, 2.0]]); + let data_2 = TensorData::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let device = Default::default(); + // Quantized [[0.0, 1.0, 2.0]] + let data = TensorData::quantized( + vec![-128i8, -1, 127], + [1, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.007843138, -128)), + ); + let tensor_1 = TestTensor::<2>::from_data(data, &device); + // Quantized [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]] + let data = TensorData::quantized( + vec![-32i8, -1, 31, 63, 95, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.03137255, -128)), + ); + let tensor_2 = TestTensor::<2>::from_data(data, &device); + + let output = tensor_1 - tensor_2; + let expected = TensorData::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_sub_scalar_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + let scalar = 2.0; + + let output = tensor - scalar; + let expected = TensorData::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); + + output.dequantize().into_data().assert_eq(&expected, false); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/tanh.rs b/crates/burn-tensor/src/tests/quantization/ops/tanh.rs new file mode 100644 index 0000000000..1ace09502b --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/tanh.rs @@ -0,0 +1,25 @@ +#[burn_tensor_testgen::testgen(q_tanh)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_tanh_ops() { + // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + let data = TensorData::quantized( + vec![-128i8, -77, -26, 25, 76, 127], + [2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.tanh(); + let expected = TensorData::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]); + + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/topk.rs b/crates/burn-tensor/src/tests/quantization/ops/topk.rs new file mode 100644 index 0000000000..7913fb79e1 --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/topk.rs @@ -0,0 +1,91 @@ +#[burn_tensor_testgen::testgen(q_topk)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Shape, Tensor, TensorData}; + + #[test] + fn test_topk_1d() { + // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76, 127], + [5], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let values = tensor.topk(3, /*dim*/ 0); + let expected = TensorData::from([5., 4., 3.]); + + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + } + + #[test] + fn test_topk() { + // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] + let data = TensorData::quantized( + vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.03529412, -128)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let values = tensor.topk(2, /*dim*/ 2); + let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); + + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn test_topk_with_indices() { + // 1D + // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] + let data = TensorData::quantized( + vec![-77i8, -26, 25, 76, 127], + [5], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), + ); + let tensor = TestTensor::<1>::from_data(data, &Default::default()); + + let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); + + let values_expected = TensorData::from([5., 4., 3.]); + values + .dequantize() + .into_data() + .assert_eq(&values_expected, false); + + let indices_expected = TensorData::from([4, 3, 2]); + indices.into_data().assert_eq(&indices_expected, false); + + // 3D + // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] + let data = TensorData::quantized( + vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.03529412, -128)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); + + let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); + + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); + + indices.into_data().assert_eq(&indices_expected, false); + } +} diff --git a/crates/burn-tensor/src/tests/quantization/ops/transpose.rs b/crates/burn-tensor/src/tests/quantization/ops/transpose.rs new file mode 100644 index 0000000000..259d313deb --- /dev/null +++ b/crates/burn-tensor/src/tests/quantization/ops/transpose.rs @@ -0,0 +1,53 @@ +#[burn_tensor_testgen::testgen(q_transpose)] +mod tests { + use super::*; + use burn_tensor::quantization::{AffineQuantization, QuantizationStrategy}; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_transpose_ops() { + // Quantized [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0] + let data = TensorData::quantized( + vec![-128i8, -105, -82, -58, -35, -12, 11, 34, 57, 81, 104, 127], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.043137256, -128)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let output = tensor.transpose(); + let expected = TensorData::from([ + [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], + [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } + + #[test] + fn should_support_swap_dims() { + // Quantized [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0] + let data = TensorData::quantized( + vec![-128i8, -105, -82, -58, -35, -12, 11, 34, 57, 81, 104, 127], + [2, 2, 3], + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.043137256, -128)), + ); + let tensor = TestTensor::<3>::from_data(data, &Default::default()); + + let output = tensor.swap_dims(0, 2); + let expected = TensorData::from([ + [[0.0, 6.0], [3.0, 9.0]], + [[1.0, 7.0], [4.0, 10.0]], + [[2.0, 8.0], [5.0, 11.0]], + ]); + + // Precision 1 to approximate de/quantization errors + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + } +} From 7226467890528cd197068cce7807dc8affa1ee3e Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 6 Aug 2024 15:20:54 -0400 Subject: [PATCH 12/16] Clippy fix --- crates/burn-tensor/src/tensor/api/base.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index d9d13c7b03..b2c58d9d5b 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -1886,7 +1886,7 @@ impl BasicOps for Float { } fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { - match vectors.get(0).unwrap() { + match vectors.first().unwrap() { TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat( vectors.into_iter().map(|tensor| tensor.tensor()).collect(), dim, From 59652f3e5afeb488674c25035a2fe2f71486f558 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 6 Aug 2024 15:26:50 -0400 Subject: [PATCH 13/16] Remove dead code/comments --- crates/burn-core/src/module/quantize.rs | 13 ----------- .../src/tensor/quantization/calibration.rs | 22 ------------------- 2 files changed, 35 deletions(-) diff --git a/crates/burn-core/src/module/quantize.rs b/crates/burn-core/src/module/quantize.rs index b6c34ee029..fbbad115a2 100644 --- a/crates/burn-core/src/module/quantize.rs +++ b/crates/burn-core/src/module/quantize.rs @@ -12,10 +12,6 @@ pub struct Quantizer { pub calibration: C, /// The quantization scheme. pub scheme: QuantizationScheme, - // TODO: dynamic quant? I think we won't support fully static (with observers to record the values on data samples) - // just yet so this is not required. - // /// Dynamic quantization computes the quantized parameters at runtime. - // pub dynamic: bool, } impl ModuleMapper for Quantizer { @@ -25,12 +21,3 @@ impl ModuleMapper for Quantizer { tensor.quantize(&self.scheme, qparams) } } - -// /// Describes how to quantize a module by providing quantizer settings for activations and weights respectively. -// pub struct QuantizationConfig { -// // TODO: quantization config -// /// The quantizer used to quantize the activations (i.e., a layer's output). -// // pub activations: Quantizer, -// /// The quantizer used to quantize the weights. -// pub weights: Quantizer, -// } diff --git a/crates/burn-tensor/src/tensor/quantization/calibration.rs b/crates/burn-tensor/src/tensor/quantization/calibration.rs index b7aac15d34..c8060f6547 100644 --- a/crates/burn-tensor/src/tensor/quantization/calibration.rs +++ b/crates/burn-tensor/src/tensor/quantization/calibration.rs @@ -32,25 +32,3 @@ impl Calibration for MinMaxCalibration { CalibrationRange { min, max } } } - -// Observers keep a running min/max, so for static quantization this can be computed multiple times w/ representative data to get the "global" min/max - -// pub struct PerChannelCalibrationSettings { -// pub dtype: QuantizationType, -// pub symmetric: bool, -// } - -// For now, we only support static quantization. Since the tensor is dequantized to a float at the first operation, the remaining operations will all be performed on floats anyways. -// But to test dynamic quantization, just make the first layer use dynamic quantization. - -/* -let q_activation = Quantizer { - calibration: MinMaxCalibration {scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)}, - dynamic: true, -}; -let q_weights = Quantizer { - calibration: MinMaxCalibration {scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)}, - dynamic: false, -} - -*/ From 36cc7e5f65b11b5a907a9ae2c8ddfb3ac7b45e6d Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 7 Aug 2024 09:10:47 -0400 Subject: [PATCH 14/16] Fix quantization tests precision --- .../src/tests/quantization/ops/aggregation.rs | 2 +- .../burn-tensor/src/tests/quantization/ops/neg.rs | 13 ++++--------- .../src/tests/quantization/ops/quantize.rs | 8 +++++--- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs index 448d0a78a4..4bf7a7e76c 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs @@ -180,7 +180,7 @@ mod tests { output .dequantize() .into_data() - .assert_eq(&TensorData::from([240.0]), false); + .assert_approx_eq(&TensorData::from([240.0]), 3); // Quantized [[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]] let data = TensorData::quantized( diff --git a/crates/burn-tensor/src/tests/quantization/ops/neg.rs b/crates/burn-tensor/src/tests/quantization/ops/neg.rs index 8a27c3834f..79afba5ce6 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/neg.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/neg.rs @@ -18,14 +18,9 @@ mod tests { let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::(); // -0.0 is represented differently than 0.0 so we make sure the values are the same in f32 - assert_eq!( - output - .dequantize() - .into_data() - .convert::() - .as_slice::() - .unwrap(), - expected.as_slice::().unwrap() - ); + output + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 872abeb26d..6991001416 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -84,15 +84,17 @@ mod tests { #[test] fn should_support_quantize_dynamic_int8() { let device = Default::default(); - let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); + // NOTE: we use fully representable values since different backend implementations could differ slightly + // due to rounding discrepancies + let tensor = Tensor::::from_floats([5., 0., 4., -10.], &device); let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); let x_q = tensor.quantize_dynamic(&scheme); let expected = TensorData::quantized( - vec![-128i8, -39, 72, 127], + vec![127i8, 42, 110, -128], [4], - QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72)), + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, 42)), ); x_q.to_data().assert_eq(&expected, true); From bbb5ece25da2ec1c65a9861cf62142b83558da0d Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 7 Aug 2024 09:34:38 -0400 Subject: [PATCH 15/16] Set higher tolerance for ndarray backend --- .../src/tests/quantization/ops/aggregation.rs | 9 +++++---- crates/burn-tensor/src/tests/quantization/ops/neg.rs | 3 ++- crates/burn-tensor/src/tests/quantization/ops/powf.rs | 5 +++-- .../src/tests/quantization/ops/powf_scalar.rs | 5 +++-- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs index 4bf7a7e76c..85a1db1b0d 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs @@ -200,10 +200,11 @@ mod tests { #[test] fn test_prod_dim_float() { // Quantized [[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + // NOTE: we use affine quantization to reduce quantization errors since `prod()` amplifies the error let data = TensorData::quantized( - vec![51i8, 25, 51, 76, 102, 127], + vec![-26i8, -77, -26, 25, 76, 127], [2, 3], - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), ); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.prod_dim(1); @@ -216,9 +217,9 @@ mod tests { // Quantized [[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]] let data = TensorData::quantized( - vec![51i8, 0, 51, 76, 102, 127], + vec![-26i8, -128, -26, 25, 76, 127], [2, 3], - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)), + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)), ); let tensor_with_zero = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor_with_zero.prod_dim(1); diff --git a/crates/burn-tensor/src/tests/quantization/ops/neg.rs b/crates/burn-tensor/src/tests/quantization/ops/neg.rs index 79afba5ce6..6098b5d35c 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/neg.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/neg.rs @@ -18,9 +18,10 @@ mod tests { let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::(); // -0.0 is represented differently than 0.0 so we make sure the values are the same in f32 + // Precision 1 to approximate de/quantization errors output .dequantize() .into_data() - .assert_approx_eq(&expected, 3); + .assert_approx_eq(&expected, 1); } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/powf.rs b/crates/burn-tensor/src/tests/quantization/ops/powf.rs index d944d39517..79252ec6fe 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/powf.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/powf.rs @@ -111,10 +111,11 @@ mod tests { let output = tensor.powf(tensor_pow); let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]); - // NOTE: we set higher tolerance (0.2) due to larger de/quantization errors accumulation w/ powers + // NOTE: we set higher tolerance (0.3) due to larger de/quantization errors accumulation w/ powers + // and large output range output .dequantize() .into_data() - .assert_approx_eq_diff(&expected, 0.2); + .assert_approx_eq_diff(&expected, 0.3); } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs b/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs index bdf65b5b92..fe2b9ed67d 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/powf_scalar.rs @@ -80,10 +80,11 @@ mod tests { let output = tensor.powf_scalar(3.0); let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]); - // NOTE: we set higher tolerance (0.2) due to larger de/quantization errors accumulation w/ powers + // NOTE: we set higher tolerance (0.3) due to larger de/quantization errors accumulation w/ powers + // and large output range output .dequantize() .into_data() - .assert_approx_eq_diff(&expected, 0.2); + .assert_approx_eq_diff(&expected, 0.3); } } From d5e6976fb35a5f2ce28c59a0f77d3a63396420cd Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Sep 2024 11:59:51 -0400 Subject: [PATCH 16/16] Remove comment --- crates/burn-tensor/src/tests/quantization/ops/add.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/burn-tensor/src/tests/quantization/ops/add.rs b/crates/burn-tensor/src/tests/quantization/ops/add.rs index 188ad8b005..7ac50f7f8e 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/add.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/add.rs @@ -56,7 +56,6 @@ mod tests { .assert_approx_eq(&TensorData::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]), 1); } - // TODO: tests #[test] fn test_add_different_strides_rhs() { // Quantized [[0.0, 1.0], [2.0, 3.0]]