Skip to content

Commit

Permalink
Fix quantizeDequantize function to be able to handle non-contiguous i…
Browse files Browse the repository at this point in the history
…nput tensors

Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored May 12, 2023
1 parent 7db238f commit 912c8cd
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
15 changes: 12 additions & 3 deletions TrainingExtensions/torch/src/AimetTensorQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class AimetTensorQuantizer

// Create an empty output tensor based on the dimension and options of input
at::IntArrayRef sizes = input.sizes();
at::Tensor output = at::empty(sizes, input.options());
at::Tensor output = at::empty_like(input);


size_t inputTensorSize = 1;
Expand All @@ -167,7 +167,7 @@ class AimetTensorQuantizer

// Create an empty output tensor based on the dimension and options of input
at::IntArrayRef sizes = input.sizes();
at::Tensor output = at::empty(sizes, input.options());
at::Tensor output = at::empty_like(input);

size_t inputTensorSize = 1;
for (auto size: sizes)
Expand Down Expand Up @@ -257,8 +257,17 @@ class AimetTensorQuantizer
size_t numChannel, size_t numElement, size_t numElementPerChannel,
DlQuantization::RoundingMode roundingMode, bool useCuda)
{
// Our per-channel quantizeDequantize kernel currently assumes that
// input tensor has contiguous memory format.
// `input.contiguous()` will return itself immediately if the input is already contiguous,
// and return a contiguous copy of input if the input isn't contiguous.
//
// This is a quick and dirty solution, but it's okay at the moment because
// the inputs of per-channel quantizeDequantize are almost always contiguous.
input = input.contiguous();

// Allocate an output tensor as the same shape as the input
at::Tensor output = at::empty(input.sizes(), input.options());
at::Tensor output = at::empty_like(input);
int encodingTensorSize = 2 * numChannel;

// Collect encoding min/max data
Expand Down
46 changes: 46 additions & 0 deletions TrainingExtensions/torch/test/python/test_tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# @@-COPYRIGHT-END-@@
# =============================================================================

import random
import pytest
import torch
import aimet_common.libpymo as libpymo
Expand All @@ -48,6 +49,12 @@
BUCKET_SIZE = 512


@pytest.fixture(autouse=True)
def set_seed():
random.seed(19521)
torch.random.manual_seed(19521)


class TestStaticTensorQuantizer:

def test_get_stats_histogram_per_tensor(self):
Expand Down Expand Up @@ -343,3 +350,42 @@ def test_learned_grid_update_encoding_invalid_input(self):

with pytest.raises(RuntimeError):
tensor_quantizer.encoding = enc_new

@pytest.mark.parametrize("quantizer",
[StaticGridPerTensorQuantizer(bitwidth=8,
round_mode=libpymo.RoundingMode.ROUND_NEAREST,
quant_scheme=QuantScheme.post_training_tf,
use_symmetric_encodings=False,
enabled_by_default=True),
StaticGridPerChannelQuantizer(bitwidth=8,
round_mode=libpymo.RoundingMode.ROUND_NEAREST,
quant_scheme=QuantScheme.post_training_tf,
use_symmetric_encodings=True,
enabled_by_default=True,
num_channels=5,
ch_axis=0),
StaticGridPerChannelQuantizer(bitwidth=8,
round_mode=libpymo.RoundingMode.ROUND_NEAREST,
quant_scheme=QuantScheme.post_training_tf,
use_symmetric_encodings=True,
enabled_by_default=True,
num_channels=5,
ch_axis=1)
])
def test_non_contiguous_input_tensor(self, quantizer):
tensor = torch.randn((5, 5, 100, 100))
quantizer.update_encoding_stats(tensor)
quantizer.compute_encoding()

out_contiguous = quantizer.quantize_dequantize(tensor.to(memory_format=torch.contiguous_format),
quantizer.round_mode)
out_channels_last = quantizer.quantize_dequantize(tensor.to(memory_format=torch.channels_last),
quantizer.round_mode)
assert torch.allclose(out_contiguous, out_channels_last)

tensor = tensor.view(*tensor.shape, 1)
out_contiguous = quantizer.quantize_dequantize(tensor.to(memory_format=torch.contiguous_format),
quantizer.round_mode)
out_channels_last_3d = quantizer.quantize_dequantize(tensor.to(memory_format=torch.channels_last_3d),
quantizer.round_mode)
assert torch.allclose(out_contiguous, out_channels_last_3d)

0 comments on commit 912c8cd

Please sign in to comment.