Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change fractional (and others) to be a property, move quantizers #964

Merged
merged 13 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion hls4ml/backends/fpga/fpga_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def definition_cpp(self):
self._saturation_mode_cpp(self.saturation_mode),
self.saturation_bits,
]
if args[2] == 'AP_TRN' and args[3] == 'AP_WRAP' and args[4] == 0:
# This is the default, so we won't write the full definition for brevity
args[2] = args[3] = args[4] = None

args = ','.join([str(arg) for arg in args if arg is not None])
typestring = 'ap_{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args)
return typestring
Expand Down Expand Up @@ -71,7 +75,17 @@ def definition_cpp(self):
self._saturation_mode_cpp(self.saturation_mode),
self.saturation_bits,
]
args = ','.join([str(arg) for arg in args if arg is not None])
if args[3] == 'AC_TRN' and args[4] == 'AC_WRAP':
# This is the default, so we won't write the full definition for brevity
args[3] = args[4] = None
if args[5] > 0:
print(
f'WARNING: Invalid setting of saturation bits ({args[5]}) for ac_fixed type, only 0 is allowed.'
'Ignoring set value.'
)
args[5] = None

args = ','.join([str(arg) for arg in args[:5] if arg is not None])
typestring = f'ac_fixed<{args}>'
return typestring

Expand Down
2 changes: 0 additions & 2 deletions hls4ml/backends/quartus/passes/convolution_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def transform(self, model, node):
# Fractional precision is increased by 2 bits (division by 4),
# for low-precision (less than 8) fractional weights
if node.weights['weight'].type.precision.fractional < 8:
node.weights['weight'].type.precision.fractional += 2
node.weights['weight'].type.precision.width += 2

# Modified kernel size
Expand Down Expand Up @@ -163,7 +162,6 @@ def transform(self, model, node):
# Fractional precision is increased by 2 bits (division by 4),
# for low-precision (less than 8) fractional weights
if node.weights['weight'].type.precision.fractional < 8:
node.weights['weight'].type.precision.fractional += 2
node.weights['weight'].type.precision.width += 2

# Modified kernel size
Expand Down
5 changes: 3 additions & 2 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,9 @@ def init_depconv2d(self, layer):
def _set_pooling_accum_t(self, layer, pool_size):
extra_bits = ceil_log2(pool_size)
accum_t = layer.get_attr('accum_t')
accum_t.precision.fractional += extra_bits
accum_t.precision.integer += extra_bits
accum_t.precision.width += extra_bits * 2
if isinstance(accum_t.precision, FixedPrecisionType):
accum_t.precision.integer += extra_bits

@layer_optimizer(Pooling1D)
def init_pooling1d(self, layer):
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer
from hls4ml.model.types import BinaryQuantizer, IntegerPrecisionType, TernaryQuantizer
from hls4ml.model.quantizers import BinaryQuantizer, TernaryQuantizer
from hls4ml.model.types import IntegerPrecisionType


@keras_handler('InputLayer')
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/converters/keras/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hls4ml.converters.keras.core import TernaryQuantizer
from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer
from hls4ml.model.quantizers import TernaryQuantizer


@keras_handler('GarNet', 'GarNetStack')
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/converters/keras/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from hls4ml.converters.keras.core import parse_batchnorm_layer, parse_dense_layer
from hls4ml.converters.keras.recurrent import parse_rnn_layer
from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer
from hls4ml.model.types import FixedPrecisionType, QKerasBinaryQuantizer, QKerasPO2Quantizer, QKerasQuantizer
from hls4ml.model.quantizers import QKerasBinaryQuantizer, QKerasPO2Quantizer, QKerasQuantizer
from hls4ml.model.types import FixedPrecisionType


def get_quantizer_from_config(keras_layer, quantizer_var):
Expand Down
8 changes: 4 additions & 4 deletions hls4ml/model/optimizer/passes/precision_merge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import FixedPrecisionType
from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode


def get_concat_type(itype1, itype2):
Expand All @@ -8,9 +8,9 @@ def get_concat_type(itype1, itype2):
if itype1.signed ^ itype2.signed: # XOR
newint += 1
newwidth += 1
newrmode = itype1.rounding_mode if itype1.rounding_mode is not None else itype2.rounding_mode
newsmode = itype1.saturation_mode if itype1.saturation_mode is not None else itype2.saturation_mode
newsbits = itype1.saturation_bits if itype1.saturation_bits is not None else itype2.saturation_bits
newrmode = itype1.rounding_mode if itype1.rounding_mode != RoundingMode.TRN else itype2.rounding_mode
newsmode = itype1.saturation_mode if itype1.saturation_mode != SaturationMode.WRAP else itype2.saturation_mode
newsbits = itype1.saturation_bits if itype1.saturation_bits != 0 else itype2.saturation_bits

newtype = FixedPrecisionType(newwidth, newint, itype1.signed or itype2.signed, newrmode, newsmode, newsbits)
return newtype
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/model/optimizer/passes/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from hls4ml.model.layers import BatchNormalization, register_layer
from hls4ml.model.optimizer import ConfigurableOptimizerPass, OptimizerPass, register_pass
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, QKerasPO2Quantizer
from hls4ml.model.quantizers import QKerasPO2Quantizer
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType


class OutputRoundingSaturationMode(ConfigurableOptimizerPass):
Expand Down
160 changes: 160 additions & 0 deletions hls4ml/model/quantizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
This module contains definitions of hls4ml quantizer classes. These classes apply a quantization function on the
provided data. The quantization function may be defined locally or taken from a library in which case the classes
behave like simple wrappers.
"""

import numpy as np
import tensorflow as tf
from qkeras.quantizers import get_quantizer

from hls4ml.model.types import ExponentPrecisionType, FixedPrecisionType, IntegerPrecisionType, XnorPrecisionType


class Quantizer:
"""
Base class for representing quantizers in hls4ml.

Subclasses of ``Quantizer`` are expected to wrap the quantizers of upstream tools (e.g., QKeras).

Args:
bits (int): Total number of bits used by the quantizer.
hls_type (NamedType): The hls4ml type used by the quantizer.
"""

def __init__(self, bits, hls_type):
self.bits = bits
self.hls_type = hls_type

def __call__(self, data):
raise NotImplementedError


class BinaryQuantizer(Quantizer):
"""Quantizer that quantizes to 0 and 1 (``bits=1``) or -1 and 1 (``bits==2``).

Args:
bits (int, optional): Number of bits used by the quantizer. Defaults to 2.

Raises:
Exception: Raised if ``bits>2``
"""

def __init__(self, bits=2):
if bits == 1:
hls_type = XnorPrecisionType()
elif bits == 2:
hls_type = IntegerPrecisionType(width=2)
else:
raise Exception(f'BinaryQuantizer suppots 1 or 2 bits, but called with bits={bits}')
super().__init__(bits, hls_type)

def __call__(self, data):
zeros = np.zeros_like(data)
ones = np.ones_like(data)
quant_data = data
if self.bits == 1:
quant_data = np.where(data > 0, ones, zeros).astype('int')
if self.bits == 2:
quant_data = np.where(data > 0, ones, -ones)
return quant_data


class TernaryQuantizer(Quantizer):
"""Quantizer that quantizes to -1, 0 and 1."""

def __init__(self):
super().__init__(2, IntegerPrecisionType(width=2))

def __call__(self, data):
zeros = np.zeros_like(data)
ones = np.ones_like(data)
return np.where(data > 0.5, ones, np.where(data <= -0.5, -ones, zeros))


class QKerasQuantizer(Quantizer):
"""Wrapper around QKeras quantizers.

Args:
config (dict): Config of the QKeras quantizer to wrap.
"""

def __init__(self, config):
self.quantizer_fn = get_quantizer(config)
self.alpha = config['config'].get('alpha', None)
if config['class_name'] == 'quantized_bits':
self.bits = config['config']['bits']
self.hls_type = self._get_type(config)
# ! includes stochastic_ternary
elif 'ternary' in config['class_name']:
self.bits = 2
self.hls_type = IntegerPrecisionType(width=2, signed=True)
# ! includes stochastic_binary
elif 'binary' in config['class_name']:
self.bits = 1
self.hls_type = XnorPrecisionType()
else:
print('Unsupported quantizer: ' + config['class_name'])
self.bits = 16
self.hls_type = FixedPrecisionType(width=16, integer=6, signed=True)

def __call__(self, data):
tf_data = tf.convert_to_tensor(data)
return self.quantizer_fn(tf_data).numpy()
# return self.quantizer_fn(data)

def _get_type(self, quantizer_config):
width = quantizer_config['config']['bits']
integer = quantizer_config['config'].get('integer', 0)
if quantizer_config['class_name'] == 'quantized_po2':
return ExponentPrecisionType(width=width, signed=True)
if width == integer:
if width == 1:
return XnorPrecisionType()
else:
return IntegerPrecisionType(width=width, signed=True)
else:
return FixedPrecisionType(width=width, integer=integer + 1, signed=True)


class QKerasBinaryQuantizer(Quantizer):
"""Wrapper around QKeras binary quantizer.

Args:
config (dict): Config of the QKeras quantizer to wrap.
"""

def __init__(self, config, xnor=False):
self.bits = 1 if xnor else 2
self.hls_type = XnorPrecisionType() if xnor else IntegerPrecisionType(width=2, signed=True)
self.alpha = config['config']['alpha']
# Use the QKeras quantizer to handle any stochastic / alpha stuff
self.quantizer_fn = get_quantizer(config)
# Then we use our BinaryQuantizer to convert to '0,1' format
self.binary_quantizer = BinaryQuantizer(1) if xnor else BinaryQuantizer(2)

def __call__(self, data):
x = tf.convert_to_tensor(data)
y = self.quantizer_fn(x).numpy()
return self.binary_quantizer(y)


class QKerasPO2Quantizer(Quantizer):
"""Wrapper around QKeras power-of-2 quantizers.

Args:
config (dict): Config of the QKeras quantizer to wrap.
"""

def __init__(self, config):
self.bits = config['config']['bits']
self.quantizer_fn = get_quantizer(config)
self.hls_type = ExponentPrecisionType(width=self.bits, signed=True)

def __call__(self, data):
# Weights are quantized to nearest power of two
x = tf.convert_to_tensor(data)
y = self.quantizer_fn(x)
if hasattr(y, 'numpy'):
y = y.numpy()
return y
Loading
Loading