diff --git a/hls4ml/backends/fpga/fpga_types.py b/hls4ml/backends/fpga/fpga_types.py index ceac0b5e4d..c5327dab8c 100644 --- a/hls4ml/backends/fpga/fpga_types.py +++ b/hls4ml/backends/fpga/fpga_types.py @@ -42,6 +42,10 @@ def definition_cpp(self): self._saturation_mode_cpp(self.saturation_mode), self.saturation_bits, ] + if args[2] == 'AP_TRN' and args[3] == 'AP_WRAP' and args[4] == 0: + # This is the default, so we won't write the full definition for brevity + args[2] = args[3] = args[4] = None + args = ','.join([str(arg) for arg in args if arg is not None]) typestring = 'ap_{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args) return typestring @@ -71,7 +75,17 @@ def definition_cpp(self): self._saturation_mode_cpp(self.saturation_mode), self.saturation_bits, ] - args = ','.join([str(arg) for arg in args if arg is not None]) + if args[3] == 'AC_TRN' and args[4] == 'AC_WRAP': + # This is the default, so we won't write the full definition for brevity + args[3] = args[4] = None + if args[5] > 0: + print( + f'WARNING: Invalid setting of saturation bits ({args[5]}) for ac_fixed type, only 0 is allowed.' + 'Ignoring set value.' + ) + args[5] = None + + args = ','.join([str(arg) for arg in args[:5] if arg is not None]) typestring = f'ac_fixed<{args}>' return typestring diff --git a/hls4ml/backends/quartus/passes/convolution_winograd.py b/hls4ml/backends/quartus/passes/convolution_winograd.py index 9a66864129..8b25ab41b8 100644 --- a/hls4ml/backends/quartus/passes/convolution_winograd.py +++ b/hls4ml/backends/quartus/passes/convolution_winograd.py @@ -118,7 +118,6 @@ def transform(self, model, node): # Fractional precision is increased by 2 bits (division by 4), # for low-precision (less than 8) fractional weights if node.weights['weight'].type.precision.fractional < 8: - node.weights['weight'].type.precision.fractional += 2 node.weights['weight'].type.precision.width += 2 # Modified kernel size @@ -163,7 +162,6 @@ def transform(self, model, node): # Fractional precision is increased by 2 bits (division by 4), # for low-precision (less than 8) fractional weights if node.weights['weight'].type.precision.fractional < 8: - node.weights['weight'].type.precision.fractional += 2 node.weights['weight'].type.precision.width += 2 # Modified kernel size diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 05caca6737..0c056a0c5c 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -376,8 +376,9 @@ def init_depconv2d(self, layer): def _set_pooling_accum_t(self, layer, pool_size): extra_bits = ceil_log2(pool_size) accum_t = layer.get_attr('accum_t') - accum_t.precision.fractional += extra_bits - accum_t.precision.integer += extra_bits + accum_t.precision.width += extra_bits * 2 + if isinstance(accum_t.precision, FixedPrecisionType): + accum_t.precision.integer += extra_bits @layer_optimizer(Pooling1D) def init_pooling1d(self, layer): diff --git a/hls4ml/converters/keras/core.py b/hls4ml/converters/keras/core.py index f6119c016d..ca7d0b3541 100644 --- a/hls4ml/converters/keras/core.py +++ b/hls4ml/converters/keras/core.py @@ -1,5 +1,6 @@ from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer -from hls4ml.model.types import BinaryQuantizer, IntegerPrecisionType, TernaryQuantizer +from hls4ml.model.quantizers import BinaryQuantizer, TernaryQuantizer +from hls4ml.model.types import IntegerPrecisionType @keras_handler('InputLayer') diff --git a/hls4ml/converters/keras/graph.py b/hls4ml/converters/keras/graph.py index 5c5c2247c0..954bf20b8f 100644 --- a/hls4ml/converters/keras/graph.py +++ b/hls4ml/converters/keras/graph.py @@ -1,5 +1,5 @@ -from hls4ml.converters.keras.core import TernaryQuantizer from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer +from hls4ml.model.quantizers import TernaryQuantizer @keras_handler('GarNet', 'GarNetStack') diff --git a/hls4ml/converters/keras/qkeras.py b/hls4ml/converters/keras/qkeras.py index cae0b2caf1..a8038da46d 100644 --- a/hls4ml/converters/keras/qkeras.py +++ b/hls4ml/converters/keras/qkeras.py @@ -4,7 +4,8 @@ from hls4ml.converters.keras.core import parse_batchnorm_layer, parse_dense_layer from hls4ml.converters.keras.recurrent import parse_rnn_layer from hls4ml.converters.keras_to_hls import keras_handler, parse_default_keras_layer -from hls4ml.model.types import FixedPrecisionType, QKerasBinaryQuantizer, QKerasPO2Quantizer, QKerasQuantizer +from hls4ml.model.quantizers import QKerasBinaryQuantizer, QKerasPO2Quantizer, QKerasQuantizer +from hls4ml.model.types import FixedPrecisionType def get_quantizer_from_config(keras_layer, quantizer_var): diff --git a/hls4ml/model/optimizer/passes/precision_merge.py b/hls4ml/model/optimizer/passes/precision_merge.py index 019bfd7236..9e79b11000 100644 --- a/hls4ml/model/optimizer/passes/precision_merge.py +++ b/hls4ml/model/optimizer/passes/precision_merge.py @@ -1,5 +1,5 @@ from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.types import FixedPrecisionType +from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode def get_concat_type(itype1, itype2): @@ -8,9 +8,9 @@ def get_concat_type(itype1, itype2): if itype1.signed ^ itype2.signed: # XOR newint += 1 newwidth += 1 - newrmode = itype1.rounding_mode if itype1.rounding_mode is not None else itype2.rounding_mode - newsmode = itype1.saturation_mode if itype1.saturation_mode is not None else itype2.saturation_mode - newsbits = itype1.saturation_bits if itype1.saturation_bits is not None else itype2.saturation_bits + newrmode = itype1.rounding_mode if itype1.rounding_mode != RoundingMode.TRN else itype2.rounding_mode + newsmode = itype1.saturation_mode if itype1.saturation_mode != SaturationMode.WRAP else itype2.saturation_mode + newsbits = itype1.saturation_bits if itype1.saturation_bits != 0 else itype2.saturation_bits newtype = FixedPrecisionType(newwidth, newint, itype1.signed or itype2.signed, newrmode, newsmode, newsbits) return newtype diff --git a/hls4ml/model/optimizer/passes/qkeras.py b/hls4ml/model/optimizer/passes/qkeras.py index cdbb56ec46..ebc66fe59e 100644 --- a/hls4ml/model/optimizer/passes/qkeras.py +++ b/hls4ml/model/optimizer/passes/qkeras.py @@ -3,7 +3,8 @@ from hls4ml.model.layers import BatchNormalization, register_layer from hls4ml.model.optimizer import ConfigurableOptimizerPass, OptimizerPass, register_pass -from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, QKerasPO2Quantizer +from hls4ml.model.quantizers import QKerasPO2Quantizer +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType class OutputRoundingSaturationMode(ConfigurableOptimizerPass): diff --git a/hls4ml/model/quantizers.py b/hls4ml/model/quantizers.py new file mode 100644 index 0000000000..c857ef51ac --- /dev/null +++ b/hls4ml/model/quantizers.py @@ -0,0 +1,160 @@ +""" +This module contains definitions of hls4ml quantizer classes. These classes apply a quantization function on the +provided data. The quantization function may be defined locally or taken from a library in which case the classes +behave like simple wrappers. +""" + +import numpy as np +import tensorflow as tf +from qkeras.quantizers import get_quantizer + +from hls4ml.model.types import ExponentPrecisionType, FixedPrecisionType, IntegerPrecisionType, XnorPrecisionType + + +class Quantizer: + """ + Base class for representing quantizers in hls4ml. + + Subclasses of ``Quantizer`` are expected to wrap the quantizers of upstream tools (e.g., QKeras). + + Args: + bits (int): Total number of bits used by the quantizer. + hls_type (NamedType): The hls4ml type used by the quantizer. + """ + + def __init__(self, bits, hls_type): + self.bits = bits + self.hls_type = hls_type + + def __call__(self, data): + raise NotImplementedError + + +class BinaryQuantizer(Quantizer): + """Quantizer that quantizes to 0 and 1 (``bits=1``) or -1 and 1 (``bits==2``). + + Args: + bits (int, optional): Number of bits used by the quantizer. Defaults to 2. + + Raises: + Exception: Raised if ``bits>2`` + """ + + def __init__(self, bits=2): + if bits == 1: + hls_type = XnorPrecisionType() + elif bits == 2: + hls_type = IntegerPrecisionType(width=2) + else: + raise Exception(f'BinaryQuantizer suppots 1 or 2 bits, but called with bits={bits}') + super().__init__(bits, hls_type) + + def __call__(self, data): + zeros = np.zeros_like(data) + ones = np.ones_like(data) + quant_data = data + if self.bits == 1: + quant_data = np.where(data > 0, ones, zeros).astype('int') + if self.bits == 2: + quant_data = np.where(data > 0, ones, -ones) + return quant_data + + +class TernaryQuantizer(Quantizer): + """Quantizer that quantizes to -1, 0 and 1.""" + + def __init__(self): + super().__init__(2, IntegerPrecisionType(width=2)) + + def __call__(self, data): + zeros = np.zeros_like(data) + ones = np.ones_like(data) + return np.where(data > 0.5, ones, np.where(data <= -0.5, -ones, zeros)) + + +class QKerasQuantizer(Quantizer): + """Wrapper around QKeras quantizers. + + Args: + config (dict): Config of the QKeras quantizer to wrap. + """ + + def __init__(self, config): + self.quantizer_fn = get_quantizer(config) + self.alpha = config['config'].get('alpha', None) + if config['class_name'] == 'quantized_bits': + self.bits = config['config']['bits'] + self.hls_type = self._get_type(config) + # ! includes stochastic_ternary + elif 'ternary' in config['class_name']: + self.bits = 2 + self.hls_type = IntegerPrecisionType(width=2, signed=True) + # ! includes stochastic_binary + elif 'binary' in config['class_name']: + self.bits = 1 + self.hls_type = XnorPrecisionType() + else: + print('Unsupported quantizer: ' + config['class_name']) + self.bits = 16 + self.hls_type = FixedPrecisionType(width=16, integer=6, signed=True) + + def __call__(self, data): + tf_data = tf.convert_to_tensor(data) + return self.quantizer_fn(tf_data).numpy() + # return self.quantizer_fn(data) + + def _get_type(self, quantizer_config): + width = quantizer_config['config']['bits'] + integer = quantizer_config['config'].get('integer', 0) + if quantizer_config['class_name'] == 'quantized_po2': + return ExponentPrecisionType(width=width, signed=True) + if width == integer: + if width == 1: + return XnorPrecisionType() + else: + return IntegerPrecisionType(width=width, signed=True) + else: + return FixedPrecisionType(width=width, integer=integer + 1, signed=True) + + +class QKerasBinaryQuantizer(Quantizer): + """Wrapper around QKeras binary quantizer. + + Args: + config (dict): Config of the QKeras quantizer to wrap. + """ + + def __init__(self, config, xnor=False): + self.bits = 1 if xnor else 2 + self.hls_type = XnorPrecisionType() if xnor else IntegerPrecisionType(width=2, signed=True) + self.alpha = config['config']['alpha'] + # Use the QKeras quantizer to handle any stochastic / alpha stuff + self.quantizer_fn = get_quantizer(config) + # Then we use our BinaryQuantizer to convert to '0,1' format + self.binary_quantizer = BinaryQuantizer(1) if xnor else BinaryQuantizer(2) + + def __call__(self, data): + x = tf.convert_to_tensor(data) + y = self.quantizer_fn(x).numpy() + return self.binary_quantizer(y) + + +class QKerasPO2Quantizer(Quantizer): + """Wrapper around QKeras power-of-2 quantizers. + + Args: + config (dict): Config of the QKeras quantizer to wrap. + """ + + def __init__(self, config): + self.bits = config['config']['bits'] + self.quantizer_fn = get_quantizer(config) + self.hls_type = ExponentPrecisionType(width=self.bits, signed=True) + + def __call__(self, data): + # Weights are quantized to nearest power of two + x = tf.convert_to_tensor(data) + y = self.quantizer_fn(x) + if hasattr(y, 'numpy'): + y = y.numpy() + return y diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index f83707f6cb..ba926b11dc 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -8,162 +8,6 @@ from enum import Enum import numpy as np -import tensorflow as tf -from qkeras.quantizers import get_quantizer - -# region Quantizer definition - - -class Quantizer: - """ - Base class for representing quantizers in hls4ml. - - Subclasses of ``Quantizer`` are expected to wrap the quantizers of upstream tools (e.g., QKeras). - - Args: - bits (int): Total number of bits used by the quantizer. - hls_type (NamedType): The hls4ml type used by the quantizer. - """ - - def __init__(self, bits, hls_type): - self.bits = bits - self.hls_type = hls_type - - def __call__(self, data): - raise NotImplementedError - - -class BinaryQuantizer(Quantizer): - """Quantizer that quantizes to 0 and 1 (``bits=1``) or -1 and 1 (``bits==2``). - - Args: - bits (int, optional): Number of bits used by the quantizer. Defaults to 2. - - Raises: - Exception: Raised if ``bits>2`` - """ - - def __init__(self, bits=2): - if bits == 1: - hls_type = XnorPrecisionType() - elif bits == 2: - hls_type = IntegerPrecisionType(width=2) - else: - raise Exception(f'BinaryQuantizer suppots 1 or 2 bits, but called with bits={bits}') - super().__init__(bits, hls_type) - - def __call__(self, data): - zeros = np.zeros_like(data) - ones = np.ones_like(data) - quant_data = data - if self.bits == 1: - quant_data = np.where(data > 0, ones, zeros).astype('int') - if self.bits == 2: - quant_data = np.where(data > 0, ones, -ones) - return quant_data - - -class TernaryQuantizer(Quantizer): - """Quantizer that quantizes to -1, 0 and 1.""" - - def __init__(self): - super().__init__(2, IntegerPrecisionType(width=2)) - - def __call__(self, data): - zeros = np.zeros_like(data) - ones = np.ones_like(data) - return np.where(data > 0.5, ones, np.where(data <= -0.5, -ones, zeros)) - - -class QKerasQuantizer(Quantizer): - """Wrapper around QKeras quantizers. - - Args: - config (dict): Config of the QKeras quantizer to wrap. - """ - - def __init__(self, config): - self.quantizer_fn = get_quantizer(config) - self.alpha = config['config'].get('alpha', None) - if config['class_name'] == 'quantized_bits': - self.bits = config['config']['bits'] - self.hls_type = self._get_type(config) - # ! includes stochastic_ternary - elif 'ternary' in config['class_name']: - self.bits = 2 - self.hls_type = IntegerPrecisionType(width=2, signed=True) - # ! includes stochastic_binary - elif 'binary' in config['class_name']: - self.bits = 1 - self.hls_type = XnorPrecisionType() - else: - print("Unsupported quantizer: " + config['class_name']) - self.bits = 16 - self.hls_type = FixedPrecisionType(width=16, integer=6, signed=True) - - def __call__(self, data): - tf_data = tf.convert_to_tensor(data) - return self.quantizer_fn(tf_data).numpy() - # return self.quantizer_fn(data) - - def _get_type(self, quantizer_config): - width = quantizer_config['config']['bits'] - integer = quantizer_config['config'].get('integer', 0) - if quantizer_config['class_name'] == 'quantized_po2': - return ExponentPrecisionType(width=width, signed=True) - if width == integer: - if width == 1: - return XnorPrecisionType() - else: - return IntegerPrecisionType(width=width, signed=True) - else: - return FixedPrecisionType(width=width, integer=integer + 1, signed=True) - - -class QKerasBinaryQuantizer(Quantizer): - """Wrapper around QKeras binary quantizer. - - Args: - config (dict): Config of the QKeras quantizer to wrap. - """ - - def __init__(self, config, xnor=False): - self.bits = 1 if xnor else 2 - self.hls_type = XnorPrecisionType() if xnor else IntegerPrecisionType(width=2, signed=True) - self.alpha = config['config']['alpha'] - # Use the QKeras quantizer to handle any stochastic / alpha stuff - self.quantizer_fn = get_quantizer(config) - # Then we use our BinaryQuantizer to convert to '0,1' format - self.binary_quantizer = BinaryQuantizer(1) if xnor else BinaryQuantizer(2) - - def __call__(self, data): - x = tf.convert_to_tensor(data) - y = self.quantizer_fn(x).numpy() - return self.binary_quantizer(y) - - -class QKerasPO2Quantizer(Quantizer): - """Wrapper around QKeras power-of-2 quantizers. - - Args: - config (dict): Config of the QKeras quantizer to wrap. - """ - - def __init__(self, config): - self.bits = config['config']['bits'] - self.quantizer_fn = get_quantizer(config) - self.hls_type = ExponentPrecisionType(width=self.bits, signed=True) - - def __call__(self, data): - # Weights are quantized to nearest power of two - x = tf.convert_to_tensor(data) - y = self.quantizer_fn(x) - if hasattr(y, 'numpy'): - y = y.numpy() - return y - - -# endregion # region Precision types @@ -224,6 +68,8 @@ def __eq__(self, other): eq = self.width == other.width eq = eq and self.signed == other.signed + return eq + class IntegerPrecisionType(PrecisionType): """Arbitrary precision integer data type. @@ -237,20 +83,36 @@ class IntegerPrecisionType(PrecisionType): def __init__(self, width=16, signed=True): super().__init__(width=width, signed=signed) - self.integer = width - self.fractional = 0 def __str__(self): typestring = '{signed}int<{width}>'.format(signed='u' if not self.signed else '', width=self.width) return typestring def __eq__(self, other): - eq = self.width == other.width - eq = eq and self.signed == other.signed - # These are probably unnecessary - eq = eq and self.integer == other.integer - eq = eq and self.fractional == other.fractional - return eq + if isinstance(other, IntegerPrecisionType): + return super().__eq__(other) + + return False + + @property + def integer(self): + return self.width + + @property + def fractional(self): + return 0 + + @property + def rounding_mode(self): + return RoundingMode.TRN + + @property + def saturation_mode(self): + return SaturationMode.WRAP + + @property + def saturation_bits(self): + return 0 class FixedPrecisionType(PrecisionType): @@ -270,18 +132,23 @@ class FixedPrecisionType(PrecisionType): def __init__(self, width=16, integer=6, signed=True, rounding_mode=None, saturation_mode=None, saturation_bits=None): super().__init__(width=width, signed=signed) self.integer = integer - self.fractional = width - integer self.rounding_mode = rounding_mode self.saturation_mode = saturation_mode self.saturation_bits = saturation_bits + @property + def fractional(self): + return self.width - self.integer + @property def rounding_mode(self): return self._rounding_mode @rounding_mode.setter def rounding_mode(self, mode): - if isinstance(mode, str): + if mode is None: + self._rounding_mode = RoundingMode.TRN + elif isinstance(mode, str): self._rounding_mode = RoundingMode.from_string(mode) else: self._rounding_mode = mode @@ -292,26 +159,40 @@ def saturation_mode(self): @saturation_mode.setter def saturation_mode(self, mode): - if isinstance(mode, str): + if mode is None: + self._saturation_mode = SaturationMode.WRAP + elif isinstance(mode, str): self._saturation_mode = SaturationMode.from_string(mode) else: self._saturation_mode = mode + @property + def saturation_bits(self): + return self._saturation_bits + + @saturation_bits.setter + def saturation_bits(self, bits): + if bits is None: + self._saturation_bits = 0 + else: + self._saturation_bits = bits + def __str__(self): args = [self.width, self.integer, self.rounding_mode, self.saturation_mode, self.saturation_bits] - args = ','.join([str(arg) for arg in args if arg is not None]) + args = ','.join([str(arg) for arg in args]) typestring = '{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args) return typestring def __eq__(self, other): - eq = self.width == other.width - eq = eq and self.integer == other.integer - eq = eq and self.fractional == other.fractional - eq = eq and self.signed == other.signed - eq = eq and self.rounding_mode == other.rounding_mode - eq = eq and self.saturation_mode == other.saturation_mode - eq = eq and self.saturation_bits == other.saturation_bits - return eq + if isinstance(other, FixedPrecisionType): + eq = super().__eq__(other) + eq = eq and self.integer == other.integer + eq = eq and self.rounding_mode == other.rounding_mode + eq = eq and self.saturation_mode == other.saturation_mode + eq = eq and self.saturation_bits == other.saturation_bits + return eq + + return False class XnorPrecisionType(PrecisionType): diff --git a/test/pytest/test_precision_parsing.py b/test/pytest/test_precision_parsing.py deleted file mode 100644 index 5569a3a6ad..0000000000 --- a/test/pytest/test_precision_parsing.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - -import hls4ml - - -@pytest.mark.parametrize( - 'prec_pair', - [ - ('ap_fixed<3, 2>', True), - ('ap_ufixed<3, 2>', False), - ('ac_fixed<3, 2, true>', True), - ('ac_fixed<3, 2, false>', False), - ('ac_fixed<3, 2, 1>', True), - ('ac_fixed<3, 2, 0>', False), - ('ap_int<3, 2>', True), - ('ap_uint<3>', False), - ('ac_int<3, TRue>', True), - ('ac_int<3, FALse>', False), - ('ac_int<3, 1>', True), - ('ac_int<3, 0>', False), - ], -) -def test_sign_parsing(prec_pair): - '''Test that convert_precions_string determines the signedness correctly''' - strprec = prec_pair[0] - signed = prec_pair[1] - - evalprec = hls4ml.backends.fpga.fpga_backend.FPGABackend.convert_precision_string(strprec) - assert evalprec.signed == signed diff --git a/test/pytest/test_types.py b/test/pytest/test_types.py new file mode 100644 index 0000000000..8f4857fec9 --- /dev/null +++ b/test/pytest/test_types.py @@ -0,0 +1,87 @@ +import pytest + +from hls4ml.backends.fpga.fpga_backend import FPGABackend +from hls4ml.backends.fpga.fpga_types import ACFixedPrecisionDefinition, APFixedPrecisionDefinition +from hls4ml.model.types import ( + ExponentPrecisionType, + FixedPrecisionType, + IntegerPrecisionType, + RoundingMode, + SaturationMode, + XnorPrecisionType, +) + + +def test_precision_type_creation(capsys): + int_type = IntegerPrecisionType(width=1, signed=False) + xnr_type = XnorPrecisionType() + + assert int_type != xnr_type # Must ensure that similar types are not matched + + int_type = IntegerPrecisionType(width=8, signed=True) + exp_type = ExponentPrecisionType(width=8, signed=True) + + assert int_type != exp_type # Must ensure that similar types are not matched + + fp_type = FixedPrecisionType(12, 6) + fp_type.integer += 2 + fp_type.rounding_mode = None + fp_type.saturation_mode = 'SAT' + + assert fp_type.integer == 8 + assert fp_type.fractional == 4 # Should be automatically updated + assert fp_type.rounding_mode == RoundingMode.TRN # None should be changed to default + assert fp_type.saturation_mode == SaturationMode.SAT # Strings should parse correctly + + # Setting saturation mode but not rounding mode should still result in correct type being written out + fp_type = FixedPrecisionType(12, 6, rounding_mode=None, saturation_mode=SaturationMode.SAT_SYM, saturation_bits=1) + # Circumvent the type wrapping that happens in the backend + fp_type.__class__ = type('APFixedPrecisionType', (type(fp_type), APFixedPrecisionDefinition), {}) + fp_cpp = fp_type.definition_cpp() + assert fp_cpp == 'ap_fixed<12,6,AP_TRN,AP_SAT_SYM,1>' # Should include the whole type definition, including rounding + # Reset to default + fp_type.saturation_mode = 'WRAP' + fp_type.saturation_bits = 0 + fp_cpp = fp_type.definition_cpp() + assert fp_cpp == 'ap_fixed<12,6>' # Should not include defaults + + # Same test for AC types + fp_type = FixedPrecisionType(12, 6, rounding_mode=None, saturation_mode=SaturationMode.SAT_SYM, saturation_bits=1) + # Circumvent the type wrapping that happens in the backend + fp_type.__class__ = type('ACFixedPrecisionType', (type(fp_type), ACFixedPrecisionDefinition), {}) + fp_cpp = fp_type.definition_cpp() + assert fp_cpp == 'ac_fixed<12,6,true,AC_TRN,AC_SAT_SYM>' # Should include the whole type definition, including rounding + # The invalid saturation bit setting should produce a warning + captured = capsys.readouterr() + assert 'WARNING: Invalid setting of saturation bits' in captured.out + # Reset to default + fp_type.saturation_mode = 'WRAP' + fp_type.saturation_bits = 0 + fp_cpp = fp_type.definition_cpp() + assert fp_cpp == 'ac_fixed<12,6,true>' # Should not include defaults + + +@pytest.mark.parametrize( + 'prec_pair', + [ + ('ap_fixed<3, 2>', True), + ('ap_ufixed<3, 2>', False), + ('ac_fixed<3, 2, true>', True), + ('ac_fixed<3, 2, false>', False), + ('ac_fixed<3, 2, 1>', True), + ('ac_fixed<3, 2, 0>', False), + ('ap_int<3, 2>', True), + ('ap_uint<3>', False), + ('ac_int<3, TRue>', True), + ('ac_int<3, FALse>', False), + ('ac_int<3, 1>', True), + ('ac_int<3, 0>', False), + ], +) +def test_sign_parsing(prec_pair): + '''Test that convert_precisions_string determines the signedness correctly''' + strprec = prec_pair[0] + signed = prec_pair[1] + + evalprec = FPGABackend.convert_precision_string(strprec) + assert evalprec.signed == signed