Skip to content

Commit

Permalink
vladimir comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Nov 22, 2024
1 parent 9e3fc8d commit d56dc73
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 110 deletions.
20 changes: 15 additions & 5 deletions hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
typedef {config_t} mult_config;
template<unsigned K, unsigned S, unsigned W>
using scale_index = nnet::{scale_index_type}<K, S, W>;
template<class data_T, class res_T, class CONFIG_T>
using conv_kernel = nnet::{conv_fn}<data_T, res_T, CONFIG_T>;
}};
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""

Expand Down Expand Up @@ -93,16 +95,24 @@ def format(self, node):
else:
params['fill_fn'] = 'FillConv1DBuffer'

if node.get_attr('filt_width') == 1 and node.model.config.get_config_value('IOType') == 'io_parallel':
params['pointwise_fn'] = f'pointwise_conv_{node.index}'
is_pointwise_parallel_latency = node.get_attr('filt_width') == 1 and node.get_attr('strategy').lower() == 'latency' and node.model.config.get_config_value('IOType') == 'io_parallel'
if is_pointwise_parallel_latency:
params['conv_fn'] = f'pointwise_conv_{node.index}'
else:
params['pointwise_fn'] = 'PointwiseConv1D'
if node.get_attr('strategy').lower() == 'latency':
params['conv_fn'] = 'Conv1DLatency'
elif node.get_attr('strategy').lower() == 'resource':
params['conv_fn'] = 'Conv1DResource'

conv_config = self.template.format(**params)

mult_params = self._default_config_params(node)
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
if is_pointwise_parallel_latency:
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') / mult_params['reuse']
mult_params['n_out'] = node.get_attr('n_filt') / mult_params['reuse']
else:
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
mult_params['nzeros'] = node.get_weights('weight').nzeros
mult_params['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
Expand Down
36 changes: 2 additions & 34 deletions hls4ml/backends/vivado/passes/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,13 @@
Conv1DFunctionTemplate,
Conv2DConfigTemplate,
Conv2DFunctionTemplate,
conv1d_config_template,
conv2d_config_template,
conv_mult_config_template,
)
from hls4ml.model.layers import register_layer
from hls4ml.model.optimizer import OptimizerPass

pointwise_conv1d_config_template = """struct config{index} : nnet::conv1d_config {{
static const unsigned pad_left = {pad_left};
static const unsigned pad_right = {pad_right};
static const unsigned in_width = {in_width};
static const unsigned n_chan = {n_chan};
static const unsigned filt_width = {filt_width};
static const unsigned kernel_size = filt_width;
static const unsigned n_filt = {n_filt};
static const unsigned stride_width = {stride_width};
static const unsigned dilation = {dilation};
static const unsigned out_width = {out_width};
static const unsigned reuse_factor = {reuse};
static const unsigned n_zeros = {nzeros};
static const bool store_weights_in_bram = false;
static const unsigned strategy = nnet::{strategy};
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
static const unsigned min_width = {min_width};
static const ap_uint<filt_width> pixels[min_width];
static const unsigned n_partitions = {n_partitions};
static const unsigned n_pixels = out_width / n_partitions;
template<class data_T, class CONFIG_T>
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
typedef {accum_t.name} accum_t;
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef {config_t} mult_config;
template<unsigned K, unsigned S, unsigned W>
using scale_index = nnet::{scale_index_type}<K, S, W>;
template<class data_T, class res_T, class CONFIG_T>
using pointwise_conv = nnet::{pointwise_fn}<data_T, res_T, CONFIG_T>;
}};
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""

pointwise_conv1d_function_template = (
'nnet::pointwise_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
)
Expand All @@ -57,7 +25,7 @@
class PointwiseConv1DConfigTemplate(Conv1DConfigTemplate):
def __init__(self):
super(Conv1DConfigTemplate, self).__init__(PointwiseConv1D)
self.template = pointwise_conv1d_config_template
self.template = conv1d_config_template
self.mult_template = conv_mult_config_template


Expand Down
78 changes: 37 additions & 41 deletions hls4ml/backends/vivado/passes/pointwise_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,48 @@ def generate_pointwise_conv1d_fn(layer_idx, reuse_factor=1):
"""

generated_code = (
"template<class data_T, class res_T, typename CONFIG_T>\n"
"class pointwise_conv_{index} : public PointwiseConv1D<data_T, res_T, CONFIG_T> {{\n"
" public:\n"
" static void pointwise_conv(\n"
" data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n"
" res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],\n"
" typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],\n"
" typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {{\n"
" data_T data_tmp[CONFIG_T::reuse_factor][CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor];\n" # noqa: E501
" #pragma HLS ARRAY_PARTITION variable=data_tmp complete dim=0\n"
" res_T res_tmp[CONFIG_T::reuse_factor][CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor];\n" # noqa: E501
" #pragma HLS ARRAY_PARTITION variable=res_tmp complete dim=0\n\n"
" RFInputLoop:\n"
" for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {{\n"
" #pragma HLS UNROLL\n"
" InnerInputLoop:\n"
" for (int ii = 0; ii < CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor; ii++) {{\n"
" #pragma HLS UNROLL\n"
" data_tmp[jj][ii] = data[jj * CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor + ii];\n" # noqa: E501
" }}\n"
" }}\n\n"
'template<class data_T, class res_T, typename CONFIG_T>\n'
'class pointwise_conv_{index} : public Conv1DKernel<data_T, res_T, CONFIG_T> {{\n'
' public:\n'
' static void conv(\n'
' data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n'
' res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],\n'
' typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],\n'
' typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {{\n'
' data_T data_tmp[CONFIG_T::reuse_factor][CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor];\n' # noqa: E501
' #pragma HLS ARRAY_PARTITION variable=data_tmp complete dim=0\n'
' res_T res_tmp[CONFIG_T::reuse_factor][CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor];\n' # noqa: E501
' #pragma HLS ARRAY_PARTITION variable=res_tmp complete dim=0\n\n'
' RFInputLoop:\n'
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {{\n'
' #pragma HLS UNROLL\n'
' InnerInputLoop:\n'
' for (int ii = 0; ii < CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor; ii++) {{\n'
' #pragma HLS UNROLL\n'
' data_tmp[jj][ii] = data[jj * CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor + ii];\n' # noqa: E501
' }}\n'
' }}\n\n'
).format(index=layer_idx)
indent = " "
indent = ' '
for i in range(reuse_factor):
generated_code += indent
generated_code += (
f"pointwise_conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data_tmp[{i}], res_tmp[{i}], weights, biases);\n"
f'pointwise_conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data_tmp[{i}], res_tmp[{i}], weights, biases);\n'
)

generated_code += (
"\n"
" RFOutputLoop:\n"
" for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {\n"
" #pragma HLS UNROLL\n"
" InnerOutputLoop:\n"
" for (int ii = 0; ii < CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor; ii++) {\n"
" #pragma HLS UNROLL\n"
" res[jj * CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor + ii] = res_tmp[jj][ii];\n" # noqa: E501
" }\n"
" }\n"
" }\n"
"};\n"
'\n'
' RFOutputLoop:\n'
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {\n'
' #pragma HLS UNROLL\n'
' InnerOutputLoop:\n'
' for (int ii = 0; ii < CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor; ii++) {\n'
' #pragma HLS UNROLL\n'
' res[jj * CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor + ii] = res_tmp[jj][ii];\n' # noqa: E501
' }\n'
' }\n'
' }\n'
'};\n'
)

return generated_code
Expand All @@ -66,14 +66,10 @@ class GeneratePointwiseConv1D(OptimizerPass):
'''Generates code for pointwise 1D convolution'''

def match(self, node):
return isinstance(node, Conv1D) and node.model.config.get_config_value('IOType') == 'io_parallel'
return isinstance(node, Conv1D) and node.model.config.get_config_value('IOType') == 'io_parallel' and node.get_attr('filt_width') == 1

def transform(self, model, node):
node_class = node.__class__.__name__
if '1D' in node_class:
self._generate_pointwise_conv1d(node)
else:
raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})')
self._generate_pointwise_conv1d(node)

def _generate_pointwise_conv1d(self, node):
code_str = generate_pointwise_conv1d_fn(
Expand Down
32 changes: 22 additions & 10 deletions hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "nnet_common.h"
#include "nnet_conv1d_latency.h"
#include "nnet_conv1d_resource.h"
#include "nnet_function_stubs.h"
#include <cstdlib>

namespace nnet {
Expand Down Expand Up @@ -38,11 +39,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
//#pragma HLS INLINE recursive

if (CONFIG_T::strategy == nnet::latency) {
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
} else {
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T>
Expand All @@ -55,13 +52,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
//#pragma HLS INLINE recursive

if (CONFIG_T::strategy == nnet::latency) {
// Use pointwise unrolled implementation
CONFIG_T::template pointwise_conv<data_T, res_T, CONFIG_T>::pointwise_conv(data, res, weights, biases);
} else {
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
//#pragma HLS INLINE region
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
};

template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
//#pragma HLS INLINE region
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
}
};

} // namespace nnet

Expand Down
8 changes: 3 additions & 5 deletions hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
#pragma HLS ARRAY_PARTITION variable=biases complete dim=0

// Limit multipliers to control parallelization
constexpr unsigned multiplier_limit = DIV_ROUNDUP(
CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor, CONFIG_T::reuse_factor);
#pragma HLS ALLOCATION operation instances=mul limit=multiplier_limit
#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit

// Convolve, saving all multiplication results to accumulate later
ConvOut:
Expand Down Expand Up @@ -159,8 +157,8 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
// Cast to "res_t" type
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
#pragma HLS UNROLL
res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]);
#pragma HLS UNROLL
res[ii * CONFIG_T::n_filt + ff] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ii][ff]);
}
}
}
Expand Down
32 changes: 22 additions & 10 deletions hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "nnet_common.h"
#include "nnet_conv1d_latency.h"
#include "nnet_conv1d_resource.h"
#include "nnet_function_stubs.h"
#include <cstdlib>

namespace nnet {
Expand Down Expand Up @@ -37,11 +38,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
#pragma HLS INLINE region

if (CONFIG_T::strategy == nnet::latency) {
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
} else {
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T>
Expand All @@ -53,13 +50,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],

#pragma HLS INLINE region

if (CONFIG_T::strategy == nnet::latency) {
// Use pointwise unrolled implementation
CONFIG_T::template pointwise_conv<data_T, res_T, CONFIG_T>::pointwise_conv(data, res, weights, biases);
} else {
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
}

template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
#pragma HLS INLINE region
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
};

template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
#pragma HLS INLINE region
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
}
};

} // namespace nnet

Expand Down
8 changes: 3 additions & 5 deletions hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
#pragma HLS ARRAY_PARTITION variable=biases complete dim=0

// Limit multipliers to control parallelization
constexpr unsigned multiplier_limit = DIV_ROUNDUP(
CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor, CONFIG_T::reuse_factor);
#pragma HLS ALLOCATION operation instances=mul limit=multiplier_limit
#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit

// Convolve, saving all multiplication results to accumulate later
ConvOut:
Expand Down Expand Up @@ -158,8 +156,8 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
// Cast to "res_t" type
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
#pragma HLS UNROLL
res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]);
#pragma HLS UNROLL
res[ii * CONFIG_T::n_filt + ff] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ii][ff]);
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ template <class data_T, class res_T, typename CONFIG_T> class DenseKernel {
}
};

template <class data_T, class res_T, typename CONFIG_T> class Conv1DKernel {
public:
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
// To be implemented in subclasses
}
};

} // namespace nnet

#endif

0 comments on commit d56dc73

Please sign in to comment.