Skip to content

Commit

Permalink
Change class instantiation
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Mar 15, 2021
1 parent 71adb44 commit 322a415
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 52 deletions.
11 changes: 4 additions & 7 deletions docs/_sources/tutorials/ptq.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ TRTorch Python API provides an easy and convenient way to use pytorch dataloader
a TensorRT calibrator by providing desired configuration. The following code demonstrates an example on how to use it

.. code-block:: python
self.testing_dataset = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=True,
Expand All @@ -170,7 +171,7 @@ a TensorRT calibrator by providing desired configuration. The following code dem
compile_spec = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": self.calibrator(),
"calibrator": self.calibrator,
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
Expand All @@ -189,13 +190,9 @@ to use ``CacheCalibrator`` to use in INT8 mode.
calibrator = trtorch.ptq.CacheCalibrator("./calibration.cache")
compile_settings = {
"input_shapes": [{
"min": [1, 3, 32, 32],
"opt": [1, 3, 32, 32],
"max": [1, 3, 32, 32]
},],
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": calibrator(),
"calibrator": calibrator,
"max_batch_size": 32,
}
Expand Down
106 changes: 63 additions & 43 deletions py/trtorch/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,51 +51,64 @@ def write_calibration_cache(self, cache):


class DataLoaderCalibrator(object):

def __init__(self, dataloader, **kwargs):
self.algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2)
self.cache_file = kwargs.get("cache_file", None)
self.use_cache = kwargs.get("use_cache", False)
self.device = kwargs.get("device", torch.device("cuda:0"))
"""
Constructs a calibrator class in TensorRT and uses pytorch dataloader to load/preproces
data which is passed during calibration.
Args:
dataloader: an instance of pytorch dataloader which iterates through a given dataset.
algo_type: choice of calibration algorithm.
cache_file: path to cache file.
use_cache: flag which enables usage of pre-existing cache.
device: device on which calibration data is copied to.
"""

def __init__(self, **kwargs):
pass

def __new__(cls, *args, **kwargs):
dataloader = args[0]
algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2)
cache_file = kwargs.get("cache_file", None)
use_cache = kwargs.get("use_cache", False)
device = kwargs.get("device", torch.device("cuda:0"))

if not isinstance(dataloader, torch.utils.data.DataLoader):
log(Level.Error,
"Dataloader : {} is not a valid instance of torch.utils.data.DataLoader".format(dataloader))

if not self.cache_file:
if self.use_cache:
log(Level.Debug, "Using existing cache_file {} for calibration".format(self.cache_file))
if not cache_file:
if use_cache:
log(Level.Debug, "Using existing cache_file {} for calibration".format(cache_file))
else:
log(Level.Debug, "Overwriting existing calibration cache file.")
else:
if self.use_cache:
if use_cache:
log(Level.Error, "Input cache file is None but use_cache is set to True in INT8 mode.")

# Define attributes and member functions for the calibrator class
self.attribute_mapping = {
attribute_mapping = {
'data_loader': dataloader,
'current_batch_idx': 0,
'batch_size': dataloader.batch_size,
'dataset_iterator': iter(dataloader),
'cache_file': self.cache_file,
'device': self.device,
'use_cache': self.use_cache,
'cache_file': cache_file,
'device': device,
'use_cache': use_cache,
'get_batch_size': get_batch_size,
'get_batch': get_cache_mode_batch if self.use_cache else get_batch,
'get_batch': get_cache_mode_batch if use_cache else get_batch,
'read_calibration_cache': read_calibration_cache,
'write_calibration_cache': write_calibration_cache
}

def __call__(self):
# Using type metaclass to construct calibrator class based on algorithm type
if self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), attribute_mapping)()
elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)()
elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), attribute_mapping)()
elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)()
else:
log(
Level.Error,
Expand All @@ -104,36 +117,43 @@ def __call__(self):


class CacheCalibrator(object):

def __init__(self, cache_file, **kwargs):
self.cache_file = cache_file
self.algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2)

if os.path.isfile(self.cache_file):
log(Level.Debug, "Using existing cache_file {} for calibration".format(self.cache_file))
"""
Constructs a calibrator class in TensorRT which directly uses pre-existing cache file for calibration.
Args:
cache_file: path to cache file.
algo_type: choice of calibration algorithm.
"""

def __init__(self, **kwargs):
pass

def __new__(cls, *args, **kwargs):
cache_file = args[0]
algo_type = kwargs.get("algo_type", trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2)

if os.path.isfile(cache_file):
log(Level.Debug, "Using existing cache_file {} for calibration".format(cache_file))
else:
log(Level.Error, "Invalid calibration cache file.")

# Define attributes and member functions for the calibrator class
self.attribute_mapping = {
attribute_mapping = {
'use_cache': True,
'cache_file': self.cache_file,
'cache_file': cache_file,
'get_batch_size': get_batch_size,
'get_batch': get_cache_mode_batch,
'read_calibration_cache': read_calibration_cache,
'write_calibration_cache': write_calibration_cache
}

def __call__(self):
# Using type metaclass to construct calibrator class based on algorithm type
if self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), self.attribute_mapping)()
elif self.algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), self.attribute_mapping)()
if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8EntropyCalibrator,), attribute_mapping)()
elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)()
elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8LegacyCalibrator,), attribute_mapping)()
elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
return type('DataLoaderCalibrator', (trtorch._C.IInt8MinMaxCalibrator,), attribute_mapping)()
else:
log(
Level.Error,
Expand Down
4 changes: 2 additions & 2 deletions tests/py/test_ptq_dataloader_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
shuffle=False,
num_workers=1)
self.calibrator = trtorch.ptq.DataLoaderCalibrator(self.testing_dataloader,
cache_file=None,
cache_file='./calibration.cache',
use_cache=False,
algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device('cuda:0'))
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_compile_script(self):
compile_spec = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": self.calibrator(),
"calibrator": self.calibrator,
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
Expand Down

0 comments on commit 322a415

Please sign in to comment.