diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/im2col_codegen.py similarity index 100% rename from hls4ml/backends/fpga/passes/codegen.py rename to hls4ml/backends/fpga/passes/im2col_codegen.py diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index dd77bee85e..e098107eae 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -60,6 +60,8 @@ typedef {config_t} mult_config; template using scale_index = nnet::{scale_index_type}; + template + using conv_kernel = nnet::{conv_fn}; }}; const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" @@ -93,11 +95,30 @@ def format(self, node): else: params['fill_fn'] = 'FillConv1DBuffer' + is_pointwise_parallel_latency = ( + node.get_attr('filt_width') == 1 + and node.get_attr('strategy').lower() == 'latency' + and node.model.config.get_config_value('IOType') == 'io_parallel' + ) + if is_pointwise_parallel_latency: + params['conv_fn'] = f'pointwise_conv_{node.index}' + else: + if node.get_attr('strategy').lower() == 'latency': + params['conv_fn'] = 'Conv1DLatency' + else: + params['conv_fn'] = 'Conv1DResource' + conv_config = self.template.format(**params) mult_params = self._default_config_params(node) - mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') - mult_params['n_out'] = node.get_attr('n_filt') + if is_pointwise_parallel_latency: + mult_params['n_in'] = int( + node.get_attr('in_width') * node.get_attr('n_chan') * node.get_attr('filt_width') / mult_params['reuse'] + ) + mult_params['n_out'] = int(node.get_attr('in_width') * node.get_attr('n_filt') / mult_params['reuse']) + else: + mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') + mult_params['n_out'] = node.get_attr('n_filt') mult_params['nzeros'] = node.get_weights('weight').nzeros mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision diff --git a/hls4ml/backends/vivado/passes/pointwise_codegen.py b/hls4ml/backends/vivado/passes/pointwise_codegen.py new file mode 100644 index 0000000000..d41d51f82f --- /dev/null +++ b/hls4ml/backends/vivado/passes/pointwise_codegen.py @@ -0,0 +1,84 @@ +from hls4ml.model.layers import Conv1D +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.types import Source + + +def generate_pointwise_conv1d_fn(layer_idx, reuse_factor=1): + """Generate a C++ function for a pointwise convolution layer. + + Args: + layer_idx (int): Index of layer ('index' attribute). + reuse_factor (int): Number of partitions to divide the input into. + + Returns: + str: Generated C++ function + """ + + generated_code = ( + 'template\n' + 'class pointwise_conv_{index} : public Conv1DKernel {{\n' + ' public:\n' + ' static void conv(\n' + ' data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n' + ' res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],\n' + ' typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],\n' + ' typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {{\n' + ' data_T data_tmp[CONFIG_T::reuse_factor][CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor];\n' # noqa: E501 + ' #pragma HLS ARRAY_PARTITION variable=data_tmp complete dim=0\n' + ' res_T res_tmp[CONFIG_T::reuse_factor][CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor];\n' # noqa: E501 + ' #pragma HLS ARRAY_PARTITION variable=res_tmp complete dim=0\n\n' + ' RFInputLoop:\n' + ' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {{\n' + ' #pragma HLS UNROLL\n' + ' InnerInputLoop:\n' + ' for (int ii = 0; ii < CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor; ii++) {{\n' + ' #pragma HLS UNROLL\n' + ' data_tmp[jj][ii] = data[jj * CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor + ii];\n' # noqa: E501 + ' }}\n' + ' }}\n\n' + ).format(index=layer_idx) + indent = ' ' + for i in range(reuse_factor): + generated_code += indent + generated_code += ( + f'pointwise_conv_1d_latency_cl(data_tmp[{i}], res_tmp[{i}], weights, biases);\n' + ) + + generated_code += ( + '\n' + ' RFOutputLoop:\n' + ' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {\n' + ' #pragma HLS UNROLL\n' + ' InnerOutputLoop:\n' + ' for (int ii = 0; ii < CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor; ii++) {\n' + ' #pragma HLS UNROLL\n' + ' res[jj * CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor + ii] = res_tmp[jj][ii];\n' # noqa: E501 + ' }\n' + ' }\n' + ' }\n' + '};\n' + ) + + return generated_code + + +class GeneratePointwiseConv1D(OptimizerPass): + '''Generates code for pointwise 1D convolution''' + + def match(self, node): + return ( + isinstance(node, Conv1D) + and node.model.config.get_config_value('IOType') == 'io_parallel' + and node.get_attr('filt_width') == 1 + ) + + def transform(self, model, node): + self._generate_pointwise_conv1d(node) + + def _generate_pointwise_conv1d(self, node): + code_str = generate_pointwise_conv1d_fn( + node.get_attr('index'), + node.get_attr('reuse_factor'), + ) + + node.set_attr('pointwise_conv1d_codegen', Source(code_str)) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index e88af278f0..3656908816 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -70,7 +70,6 @@ def _register_layer_attributes(self): cnn_layers = [Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Pooling1D, Pooling2D] for layer in cnn_layers: attrs = self.attribute_map.get(layer, []) - # attrs.append(ConfigurableAttribute('conv_implementation', value_type=str, default='LineBuffer')) attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer')) self.attribute_map[layer] = attrs @@ -114,6 +113,7 @@ def _register_flows(self): 'vivado:generate_conv_streaming_instructions', 'vivado:apply_resource_strategy', 'vivado:generate_conv_im2col', + 'vivado:generate_pointwise_conv1_d', 'vivado:generate_unrolled_dense_resource', 'vivado:set_pipeline_style', ] diff --git a/hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h b/hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h index 52a404672c..46beeacb03 100644 --- a/hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h +++ b/hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h @@ -4,6 +4,7 @@ #include "nnet_common.h" #include "nnet_conv1d_latency.h" #include "nnet_conv1d_resource.h" +#include "nnet_function_stubs.h" #include namespace nnet { @@ -38,11 +39,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO // Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully. //#pragma HLS INLINE recursive - if (CONFIG_T::strategy == nnet::latency) { - conv_1d_latency_cl(data, res, weights, biases); - } else { - conv_1d_resource_cl(data, res, weights, biases); - } + CONFIG_T::template conv_kernel::conv(data, res, weights, biases); } template @@ -55,13 +52,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], // Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully. //#pragma HLS INLINE recursive - // Nothing special to be done for io_parallel implementation - if (CONFIG_T::strategy == nnet::latency) { + CONFIG_T::template conv_kernel::conv(data, res, weights, biases); +} + +template class Conv1DLatency : public Conv1DKernel { + public: + static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS INLINE region conv_1d_latency_cl(data, res, weights, biases); - } else { + } +}; + +template class Conv1DResource : public Conv1DKernel { + public: + static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS INLINE region conv_1d_resource_cl(data, res, weights, biases); } -} +}; } // namespace nnet diff --git a/hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h b/hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h index 1bf25cc89c..e166cdd470 100644 --- a/hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h +++ b/hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h @@ -85,5 +85,83 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], } } +template +void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::filt_width == 1); + + typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor]; + typename CONFIG_T::accum_t acc[CONFIG_T::out_width / CONFIG_T::reuse_factor][CONFIG_T::n_filt]; + + #pragma HLS ARRAY_PARTITION variable=mult complete dim=0 + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + #pragma HLS function_instantiate variable=weights,biases + + // Parallel mode + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + #pragma HLS ARRAY_PARTITION variable=weights complete dim=0 + #pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + + // Limit multipliers to control parallelization + #pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit + +// Convolve, saving all multiplication results to accumulate later +ConvOut: + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + ConvFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + ConvChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + #pragma HLS UNROLL + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + int index_weight = cc * CONFIG_T::n_filt + ff; + int index_data = (ii * CONFIG_T::stride_width - CONFIG_T::pad_left) * CONFIG_T::n_chan + cc; + + if ((ii * CONFIG_T::stride_width) < CONFIG_T::pad_left || + (ii * CONFIG_T::stride_width) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + mult[index_mult] = 0; + } else { + mult[index_mult] = CONFIG_T::mult_config::template product::product( + data[index_data], weights[index_weight]); + } + } // end channel loop + } // end filter loop + } // end output loop + + // Initialize accumulator with input biases + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + #pragma HLS UNROLL + acc[ii][ff] = biases[ff]; + } + } + +// Accumulate multiplication result +AccumOut: + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + AccumFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Do "dot product" sum within filter and sum over channels + AccumChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + acc[ii][ff] += mult[index_mult]; + } // end channel loop + } // end filter loop + } // end output loop + + // Cast to "res_t" type + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + #pragma HLS UNROLL + res[ii * CONFIG_T::n_filt + ff] = cast(acc[ii][ff]); + } + } +} + } // namespace nnet #endif diff --git a/hls4ml/templates/vivado/build_prj.tcl b/hls4ml/templates/vivado/build_prj.tcl index 7d0420611a..05d4b8a4d5 100644 --- a/hls4ml/templates/vivado/build_prj.tcl +++ b/hls4ml/templates/vivado/build_prj.tcl @@ -161,7 +161,7 @@ if {$opt(reset)} { } else { open_solution "solution1" } -catch {config_array_partition -maximum_size 4096} +catch {config_array_partition -maximum_size $maximum_size} config_compile -name_max_length 80 set_part $part config_schedule -enable_dsp_full_reg=false diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h index 4a8a40cd10..6011e20cca 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h @@ -1,6 +1,7 @@ #ifndef NNET_INSTR_GEN_H_ #define NNET_INSTR_GEN_H_ +#include "nnet_conv1d_latency.h" #include "nnet_helpers.h" #include "hls_stream.h" @@ -10,6 +11,16 @@ namespace nnet { +template class PointwiseConv1D { + public: + static void pointwise_conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + // To be implemented in subclasses + } +}; + // hls4ml insert code } // namespace nnet diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_common.h b/hls4ml/templates/vivado/nnet_utils/nnet_common.h index a14517df5b..6db3f62f6e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_common.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_common.h @@ -2,6 +2,7 @@ #define NNET_COMMON_H_ #include "ap_fixed.h" +#include "nnet_helpers.h" // This is a substitute for "ceil(n/(float)d)". #define DIV_ROUNDUP(n, d) ((n + d - 1) / d) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h index e2e0211b49..72bce78067 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h @@ -4,6 +4,7 @@ #include "nnet_common.h" #include "nnet_conv1d_latency.h" #include "nnet_conv1d_resource.h" +#include "nnet_function_stubs.h" #include namespace nnet { @@ -37,11 +38,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { #pragma HLS INLINE region - if (CONFIG_T::strategy == nnet::latency) { - conv_1d_latency_cl(data, res, weights, biases); - } else { - conv_1d_resource_cl(data, res, weights, biases); - } + CONFIG_T::template conv_kernel::conv(data, res, weights, biases); } template @@ -53,13 +50,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], #pragma HLS INLINE region - // Nothing special to be done for io_parallel implementation - if (CONFIG_T::strategy == nnet::latency) { + CONFIG_T::template conv_kernel::conv(data, res, weights, biases); +} + +template class Conv1DLatency : public Conv1DKernel { + public: + static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE region conv_1d_latency_cl(data, res, weights, biases); - } else { + } +}; + +template class Conv1DResource : public Conv1DKernel { + public: + static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE region conv_1d_resource_cl(data, res, weights, biases); } -} +}; } // namespace nnet diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h index 0d9afb10cb..ef2f94dcaf 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h @@ -84,5 +84,83 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], } } +template +void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::filt_width == 1); + + typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor]; + typename CONFIG_T::accum_t acc[CONFIG_T::out_width / CONFIG_T::reuse_factor][CONFIG_T::n_filt]; + + #pragma HLS ARRAY_PARTITION variable=mult complete dim=0 + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + #pragma HLS function_instantiate variable=weights,biases + + // Parallel mode + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + #pragma HLS ARRAY_PARTITION variable=weights complete dim=0 + #pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + + // Limit multipliers to control parallelization + #pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit + +// Convolve, saving all multiplication results to accumulate later +ConvOut: + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + ConvFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + ConvChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + #pragma HLS UNROLL + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + int index_weight = cc * CONFIG_T::n_filt + ff; + int index_data = (ii * CONFIG_T::stride_width - CONFIG_T::pad_left) * CONFIG_T::n_chan + cc; + + if ((ii * CONFIG_T::stride_width) < CONFIG_T::pad_left || + (ii * CONFIG_T::stride_width) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + mult[index_mult] = 0; + } else { + mult[index_mult] = CONFIG_T::mult_config::template product::product( + data[index_data], weights[index_weight]); + } + } // end channel loop + } // end filter loop + } // end output loop + + // Initialize accumulator with input biases + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + #pragma HLS UNROLL + acc[ii][ff] = biases[ff]; + } + } + +// Accumulate multiplication result +AccumOut: + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + AccumFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Do "dot product" sum within filter and sum over channels + AccumChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + acc[ii][ff] += mult[index_mult]; + } // end channel loop + } // end filter loop + } // end output loop + + // Cast to "res_t" type + for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + #pragma HLS UNROLL + res[ii * CONFIG_T::n_filt + ff] = cast(acc[ii][ff]); + } + } +} + } // namespace nnet #endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h b/hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h index 1316bbe776..97774bc95b 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h @@ -37,6 +37,15 @@ template class DenseKernel { } }; +template class Conv1DKernel { + public: + static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + // To be implemented in subclasses + } +}; + } // namespace nnet #endif diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 5ab13736ec..0341959045 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -717,6 +717,8 @@ def write_build_script(self, model): f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '12.5%'))) f.write('variable version\n') f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0'))) + f.write('variable maximum_size\n') + f.write('set maximum_size {}\n'.format(model.config.get_config_value('MaximumSize', '4096'))) # build_prj.tcl srcpath = (filedir / '../templates/vivado/build_prj.tcl').resolve() diff --git a/test/pytest/test_pointwiseconv.py b/test/pytest/test_pointwiseconv.py index 678b22bfeb..1cfb43e4cd 100644 --- a/test/pytest/test_pointwiseconv.py +++ b/test/pytest/test_pointwiseconv.py @@ -19,25 +19,27 @@ @pytest.mark.parametrize('padds', padds_options) @pytest.mark.parametrize('strides', strides1d_options) @pytest.mark.parametrize( - 'backend, io_type, strategy', + 'backend, io_type, strategy, rf', [ - ('Quartus', 'io_parallel', 'resource'), - ('Quartus', 'io_stream', 'resource'), - ('oneAPI', 'io_parallel', 'resource'), - ('oneAPI', 'io_stream', 'resource'), - ('Vivado', 'io_parallel', 'resource'), - ('Vitis', 'io_parallel', 'resource'), - ('Vivado', 'io_parallel', 'latency'), - ('Vitis', 'io_parallel', 'latency'), - ('Vivado', 'io_stream', 'latency'), - ('Vivado', 'io_stream', 'resource'), - ('Vitis', 'io_stream', 'latency'), - ('Vitis', 'io_stream', 'resource'), - ('Catapult', 'io_stream', 'latency'), - ('Catapult', 'io_stream', 'resource'), + ('Quartus', 'io_parallel', 'resource', 1), + ('Quartus', 'io_stream', 'resource', 1), + ('oneAPI', 'io_parallel', 'resource', 1), + ('oneAPI', 'io_stream', 'resource', 1), + ('Vivado', 'io_parallel', 'resource', 1), + ('Vitis', 'io_parallel', 'resource', 1), + ('Vivado', 'io_parallel', 'latency', 1), + ('Vitis', 'io_parallel', 'latency', 1), + ('Vivado', 'io_parallel', 'latency', 14), + ('Vitis', 'io_parallel', 'latency', 14), + ('Vivado', 'io_stream', 'latency', 1), + ('Vivado', 'io_stream', 'resource', 1), + ('Vitis', 'io_stream', 'latency', 1), + ('Vitis', 'io_stream', 'resource', 1), + ('Catapult', 'io_stream', 'latency', 1), + ('Catapult', 'io_stream', 'resource', 1), ], ) -def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy): +def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy, rf): model = tf.keras.models.Sequential() input_shape = (28, 3) model.add( @@ -50,6 +52,7 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy): kernel_initializer='normal', use_bias=False, data_format=chans, + name='pointwise1d', ) ) model.compile(optimizer='adam', loss='mse') @@ -58,14 +61,12 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy): keras_prediction = model.predict(X_input) default_precision = 'fixed<32,16>' - config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision) + config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision, granularity='name') config['Model']['Strategy'] = strategy + config['LayerName']['pointwise1d']['ReuseFactor'] = rf output_dir = str( - test_root_path - / 'hls4mlprj_pointwise1d_{}_strides_{}_{}_padding_{}_{}_{}'.format( - chans, strides[0], padds, backend, io_type, strategy - ) + test_root_path / f'hls4mlprj_pointwise1d_{chans}_{strides[0]}_{padds}_{backend}_{io_type}_{strategy}_rf{rf}' ) hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, io_type=io_type, backend=backend @@ -110,6 +111,7 @@ def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy): kernel_initializer='normal', use_bias=False, data_format=chans, + name='pointwise2d', ) ) @@ -123,10 +125,7 @@ def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy): config['Model']['Strategy'] = strategy stride_cfg = str(strides).replace(', ', '_').replace('(', '').replace(')', '') output_dir = str( - test_root_path - / 'hls4mlprj_pointwise2d_{}_strides_{}_{}_padding_{}_{}_{}'.format( - chans, stride_cfg, padds, backend, io_type, strategy - ) + test_root_path / f'hls4mlprj_pointwise2d_{chans}_strides_{stride_cfg}_{padds}_padding_{backend}_{io_type}_{strategy}' ) hls_model = hls4ml.converters.convert_from_keras_model(