Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No longer make ExponentPrecisionType and XnorPrecisionType inherit from IntegerPrecisionType #845

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions hls4ml/model/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
higher-dimensional tensors, which are defined as arrays or FIFO streams in the generated code.
"""

import re
from enum import Enum

import numpy as np
Expand Down Expand Up @@ -46,7 +45,7 @@ class BinaryQuantizer(Quantizer):

def __init__(self, bits=2):
if bits == 1:
hls_type = IntegerPrecisionType(width=1, signed=False)
hls_type = XnorPrecisionType()
elif bits == 2:
hls_type = IntegerPrecisionType(width=2)
else:
Expand Down Expand Up @@ -221,6 +220,10 @@ def __init__(self, width, signed):
self.width = width
self.signed = signed

def __eq__(self, other):
eq = self.width == other.width
eq = eq and self.signed == other.signed


class IntegerPrecisionType(PrecisionType):
"""Arbitrary precision integer data type.
Expand Down Expand Up @@ -311,16 +314,21 @@ def __eq__(self, other):
return eq


class XnorPrecisionType(IntegerPrecisionType):
class XnorPrecisionType(PrecisionType):
"""
Convenience class to differentiate 'regular' integers from BNN Xnor ones
"""

def __init__(self):
super().__init__(width=1, signed=False)
Copy link
Contributor Author

@jmitrevs jmitrevs Aug 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically xnor is signed, even if we represent it as unit<1>. How it gets implemented should be a backend issue, not a definition issue. Leaving it unsigned for now, though, since that's a bigger change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the XNOR idea was that it is a single bit implementation, the more general definition of BNN is signed (-1 and 1), but that's a different case.

self.integer = 1

def __str__(self):
typestring = 'uint<1>'
Copy link
Contributor Author

@jmitrevs jmitrevs Aug 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am truthfully not much of a fan of the typestring. What is it's function? Is it to have a string representation in the configs? There is not a 1:1 between a typestring and and the type. Part of me thinks this should be 'xnor' and not 'uint<1>', though this requires a few more changes downstream. (There would be a corresponding change for exponent precision types.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure, but this may be the last leftover from the previous way of doing things. I think that it is only used in WeightVariable.update_precision() to set a new string format that is then not used anywhere (it was used only in print_array_to_cpp in the writer, but this was changed). If it doesn't hurt your eyes too much I would leave it as-is here, and have a follow up PR that removes it, size_cpp() and perhaps the whole dim_names concept as one of our usual "cleanup" PRs that we do prior to a release.

return typestring


class ExponentPrecisionType(IntegerPrecisionType):
class ExponentPrecisionType(PrecisionType):
"""
Convenience class to differentiate 'regular' integers from those which represent exponents,
for QKeras po2 quantizers, for example.
Expand All @@ -329,6 +337,10 @@ class ExponentPrecisionType(IntegerPrecisionType):
def __init__(self, width=16, signed=True):
super().__init__(width=width, signed=signed)

def __str__(self):
typestring = '{signed}int<{width}>'.format(signed='u' if not self.signed else '', width=self.width)
return typestring


def find_minimum_width(data, signed=True):
"""
Expand Down Expand Up @@ -536,34 +548,27 @@ def __next__(self):
if not self._iterator.finished:
value = self._iterator[0]
self._iterator.iternext()
return self.precision_fmt % value
return self.precision_fmt.format(value)
else:
raise StopIteration

next = __next__

def update_precision(self, new_precision):
self.type.precision = new_precision
precision_str = str(self.type.precision)
if 'int' in precision_str:
self.precision_fmt = '%d'
else:
match = re.search('.+<(.+?)>', precision_str)
if match is not None:
precision_bits = match.group(1).split(',')
width_bits = int(precision_bits[0])
integer_bits = int(precision_bits[1])
fractional_bits = integer_bits - width_bits
lsb = 2**fractional_bits
if lsb < 1:
# Use str to represent the float with digits, get the length
# to right of decimal point
decimal_spaces = len(str(lsb).split('.')[1])
else:
decimal_spaces = len(str(2**integer_bits))
self.precision_fmt = f'%.{decimal_spaces}f'
if isinstance(new_precision, (IntegerPrecisionType, XnorPrecisionType, ExponentPrecisionType)):
self.precision_fmt = '{:.0f}'
elif isinstance(new_precision, FixedPrecisionType):
if new_precision.fractional > 0:
# Use str to represent the float with digits, get the length
# to right of decimal point
lsb = 2**-new_precision.fractional
decimal_spaces = len(str(lsb).split('.')[1])
self.precision_fmt = f'{{:{decimal_spaces}f}}'
else:
self.precision_fmt = '%f'
self.precision_fmt = '{:.0f}'
else:
raise RuntimeError(f"Unexpected new precision type: {new_precision}")


class CompressedWeightVariable(WeightVariable):
Expand Down Expand Up @@ -618,8 +623,8 @@ def __iter__(self):

def __next__(self):
value = next(self._iterator)
value_fmt = self.precision_fmt % value[2]
return '{ %u, %u, %s }' % (value[1], value[0], value_fmt)
value_fmt = self.precision_fmt.format(value[2])
return f'{{{value[1]}, {value[0]}, {value_fmt}}}'

next = __next__

Expand Down Expand Up @@ -656,8 +661,8 @@ def __iter__(self):

def __next__(self):
value = next(self._iterator)
value_fmt = self.precision_fmt % value[1]
return '{%d, %s}' % (value[0], value_fmt)
value_fmt = self.precision_fmt.format(value[1])
return f'{{{value[0]}, {value_fmt}}}'

next = __next__

Expand Down