diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index 3d415e783..3f60698f8 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -38,6 +38,7 @@ def parse_squeeze_layer(operation, layer_name, input_names, input_shapes, node, layer = {} layer['class_name'] = 'Reshape' layer['name'] = layer_name + layer['inputs'] = input_names if len(node.args) > 1 or len(node.kwargs) > 0: # 'dim' argument is specified output_shape = [i for i in input_shapes[0]] diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 79ca1fa5c..dc9435ec4 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -147,6 +147,7 @@ def parse_pytorch_model(config, verbose=True): inputs_map = {} input_layers = [] + output_layers = [] # Output shape tracking output_shapes = {} @@ -365,11 +366,22 @@ def parse_pytorch_model(config, verbose=True): if len(input_layers) == 0: input_layers = None - return layer_list, input_layers + for layer in layer_list: + if layer['class_name'] == 'InputLayer': + continue + is_input = False + for lay in layer_list: + if 'inputs' not in lay.keys(): + continue + if layer['name'] in lay['inputs']: + is_input = True + if not is_input: + output_layers.append(layer['name']) + return layer_list, input_layers, output_layers def pytorch_to_hls(config): - layer_list, input_layers = parse_pytorch_model(config) + layer_list, input_layers, output_layers = parse_pytorch_model(config) print('Creating HLS model') - hls_model = ModelGraph(config, layer_list, inputs=input_layers) + hls_model = ModelGraph(config, layer_list, inputs=input_layers, outputs=output_layers) return hls_model diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index e45008409..1a7a0cccd 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -368,6 +368,7 @@ def config_from_pytorch_model( ( layer_list, _, + _, ) = parse_pytorch_model(config, verbose=False) def make_layer_config(layer):