Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointwise Conv1D with code generation for "Latency" strategy (update of #811) #881

Conversation

jmduarte
Copy link
Member

@jmduarte jmduarte commented Oct 8, 2023

Description

Update of #811 with code generation. This PR adds an explicit pointwise Conv1D implementation, where the reuse factor (RF) is used to split the layer execution and reuse the existing module RF times

Original pointwise Conv1D:

  • (in_width, n_chan) -> (in_width, n_filt)

This PR splits it into RF calls of

  • (in_width/RF, n_chan) -> (in_width/RF, n_filt)
  • (in_width/RF, n_chan) -> (in_width/RF, n_filt)
  • (in_width/RF, n_chan) -> (in_width/RF, n_filt)
  • ...

The II ~ RF. It is on by default, but I think you should be able to use the standard conv1d implementation by skipping the optimizer.

Limitations:

  • Assumes in_width is divisible by RF

Type of change

  • New feature (non-breaking change which adds functionality)
  • A new research paper code implementation

Tests

See test/pytest/test_pointwiseconv.py

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@jmduarte jmduarte added the please test Trigger testing by creating local PR branch label Oct 8, 2023
@jmduarte jmduarte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Oct 8, 2023
@jmduarte jmduarte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Oct 11, 2023
hls4ml/backends/fpga/fpga_backend.py Outdated Show resolved Hide resolved
hls4ml/backends/fpga/passes/codegen.py Outdated Show resolved Hide resolved
hls4ml/backends/vivado/vivado_backend.py Outdated Show resolved Hide resolved
hls4ml/templates/vivado/build_prj.tcl Outdated Show resolved Hide resolved
hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h Outdated Show resolved Hide resolved
hls4ml/templates/vivado/nnet_utils/nnet_common.h Outdated Show resolved Hide resolved
hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h Outdated Show resolved Hide resolved
test/pytest/test_pointwiseconv.py Show resolved Hide resolved
@jmduarte jmduarte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Oct 12, 2023
@jmduarte jmduarte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Oct 15, 2023
@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 12, 2024
@jmduarte
Copy link
Member Author

pre-commit.ci autofix

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 15, 2024
@bo3z bo3z added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 17, 2024
Copy link
Contributor

@vloncar vloncar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good, there are some minor tweaks to integrate cleaner. Works very well, which is more important :-)

template<unsigned K, unsigned S, unsigned W>
using scale_index = nnet::{scale_index_type}<K, S, W>;
template<class data_T, class res_T, class CONFIG_T>
using pointwise_conv = nnet::{pointwise_fn}<data_T, res_T, CONFIG_T>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Dense layers we moved to using a function pointer like this that the main dense() function calls, eliminating the need for checks in HLS. Here it would also simplify the call hierarchy (no need for special handling of pointwise) and no need for this template.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attempted in jmduarte@d56dc73

hls4ml/backends/vivado/passes/pointwise_codegen.py Outdated Show resolved Hide resolved
'''Generates code for pointwise 1D convolution'''

def match(self, node):
return isinstance(node, Conv1D) and node.model.config.get_config_value('IOType') == 'io_parallel'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there no check for filt_width == 1 here? Otherwise we generate functions we don't use and is incorrect at that.

Copy link
Member Author

@jmduarte jmduarte Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in

def match(self, node):
return isinstance(node, Conv1D) and node.model.config.get_config_value('IOType') == 'io_parallel' and node.get_attr('filt_width') == 1

return isinstance(node, Conv1D) and node.model.config.get_config_value('IOType') == 'io_parallel'

def transform(self, model, node):
node_class = node.__class__.__name__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point, but we have node.class_name for this purpose, but wouldn't make a big difference in the check here.

The bigger question is why the check at all? In what cases can it fail?

Copy link
Member Author

@jmduarte jmduarte Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed check in

def transform(self, model, node):
self._generate_pointwise_conv1d(node)

#pragma HLS ARRAY_PARTITION variable=biases complete dim=0

// Limit multipliers to control parallelization
constexpr unsigned multiplier_limit = DIV_ROUNDUP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was problematic before, and we moved to setting multiplier_limit in python and using it here. For consistency, we should not re-introduce old approaches.

Copy link
Member Author

@jmduarte jmduarte Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to use the multiplier_limit from the config here:

#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit

but need to check the value is the same

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fixed the multiplier limit now... but it would be good to warn the user somehow if they use a reuse_factor that doesn't divide in_width

related: it would probably be beneficial to factorize the two into a reuse_factor (that limits the multipliers) and a parallelization_factor (that controls how many times the conv is split into separate calls)

for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
#pragma HLS UNROLL
res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to call cast() here to ensure compatibility with all variants of product?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added here:

res[ii * CONFIG_T::n_filt + ff] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ii][ff]);

@@ -2,6 +2,7 @@
#define NNET_COMMON_H_

#include "ap_fixed.h"
#include "nnet_helpers.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nnet_helpers.h (mostly) doesn't contain synthesizable code and it is not intended to be included by other files apart from the testbench. And I don't see you using anything from it in the code that is introduced

Copy link
Member Author

@jmduarte jmduarte Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remove that, I get this error on my EL 8 machine with g++ 8.5.0

In file included from firmware/nnet_utils/nnet_conv1d_latency.h:4,
                 from firmware/nnet_utils/nnet_code_gen.h:4,
                 from firmware/parameters.h:7,
                 from firmware/myproject.cpp:4:
firmware/nnet_utils/nnet_common.h: In function ‘T nnet::reduce(const T*, Op)’:
firmware/nnet_utils/nnet_common.h:37:39: error: there are no arguments to ‘floorlog2’ that depend on a template parameter, so a declaration of ‘floorlog2’ must be available [-fpermissive]
     static constexpr int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0;
                                       ^~~~~~~~~
firmware/nnet_utils/nnet_common.h:37:39: note: (if you use ‘-fpermissive’, G++ will accept your code, but allowing the use of an undeclared name is deprecated)
firmware/nnet_utils/nnet_common.h:37:68: error: there are no arguments to ‘floorlog2’ that depend on a template parameter, so a declaration of ‘floorlog2’ must be available [-fpermissive]
     static constexpr int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0;

@@ -84,5 +84,85 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
}
}

template <class data_T, class res_T, typename CONFIG_T>
void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The files in Vitis/Vivado are mostly identical (apart from the very recent change in the final loop), perhaps we could remove the Vitis one (first we follow up with the changes)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean both nnet_conv1d.h and nnet_conv1d_latency.h?

There are a few differences like the use of inline region, etc. Do you mean to test if just using the vivado versions works in vitis now?

#pragma HLS ARRAY_PARTITION variable=biases complete dim=0

// Limit multipliers to control parallelization
constexpr unsigned multiplier_limit = DIV_ROUNDUP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in the other file

for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
#pragma HLS UNROLL
res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in the other file

@JanFSchulte
Copy link
Contributor

pre-commit.ci autofix

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 22, 2024
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Dec 4, 2024
@JanFSchulte JanFSchulte merged commit 2fc8941 into fastmachinelearning:main Dec 4, 2024
5 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants