-
Notifications
You must be signed in to change notification settings - Fork 421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic type inference for param_t
in Parametrised Activations
#1139
Open
nghielme
wants to merge
8
commits into
main
Choose a base branch
from
leaky_relu_quant_alpha
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+28
−8
Open
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
11601cd
Added automatic inference of `param_t` constant for parametrised acti…
nghielme 72026fb
pre-commit fixes
nghielme 10ec7a2
Fix the case the param is a power of 2
nghielme 29f0831
Fix for a specific case related to no bits in the mantissa
nghielme ecf5c2c
Merge branch 'main' into leaky_relu_quant_alpha
nghielme 49e5a75
Merge branch 'main' into leaky_relu_quant_alpha
nghielme baba0f3
Merge branch 'main' into leaky_relu_quant_alpha
JanFSchulte 0808580
Merge branch 'main' into leaky_relu_quant_alpha
nghielme File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import math | ||
import struct | ||
from typing import Iterable | ||
|
||
import numpy as np | ||
|
@@ -561,15 +562,34 @@ def _infer_rnn_precision(self, node, types_to_infer): | |
|
||
return inferred_types | ||
|
||
def _infer_par_act_precision(self, node, types_to_infer): | ||
def _infer_const_precision(self, node, type_to_infer, attr_name): | ||
inferred_types = [] | ||
|
||
# For threshold relu, set the parameter precision to be the input precision by default; | ||
# for other parametrized activations, just allow the default precision to be used. | ||
# Can override these values in the configuration by explicitly setting them. | ||
if 'param_t' in inferred_types and self.get_attr('activation').lower() == 'thresholdedrelu': | ||
in_type = node.get_input_variable().type.precision | ||
node.attributes['param_t'].type = in_type | ||
inferred_types.append('param_t') | ||
def get_man_exp(f): | ||
f = np.abs(f) | ||
s = struct.pack('>f', f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use calculations: based on the value, you can easily determine how many bits you need. Going to structs is hard to follow. |
||
l_float = struct.unpack('>l', s)[0] | ||
bits = f'{l_float:032b}' | ||
m = bits[-23:] | ||
e = bits[-23 - 8 : -23] | ||
return m, e | ||
|
||
param = node.get_attr(attr_name) | ||
m, e = get_man_exp(param) | ||
I_pos = int(e, 2) - 127 + 1 # -127 is the bias of the exponent | ||
try: | ||
W_bits = m.rindex('1') + 2 # + 1 for accounting the index starting from 0, +1 for the leading 1 of the exponent | ||
except Exception: | ||
W_bits = 1 # the value is a power of 2, 1 bit is needed, I_pos will offset the bit in the proper place | ||
if param < 0 and W_bits > 1: # for po2 values the increment is not needed | ||
I_pos += 1 | ||
W_bits += 1 | ||
node.attributes[type_to_infer].precision = FixedPrecisionType(W_bits, I_pos, True if param < 0 else False) | ||
inferred_types.append(type_to_infer) | ||
return inferred_types | ||
|
||
def _infer_par_act_precision(self, node, types_to_infer): | ||
inferred_types = [] | ||
if 'param_t' in types_to_infer: | ||
inferred_types.extend(self._infer_const_precision(node, 'param_t', 'activ_param')) | ||
return inferred_types |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Struct is much too low level for what we are doing here. We have a python float. We should use it, not look at the bits.