Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converter for pytorch models that have multiple inputs #29

Closed
bayleef1 opened this issue Jan 19, 2022 · 4 comments
Closed

Converter for pytorch models that have multiple inputs #29

bayleef1 opened this issue Jan 19, 2022 · 4 comments
Labels
question Further information is requested

Comments

@bayleef1
Copy link

Hi, thanks for your excellent work. I've converted my pytorch model that has one input to tflite model successfully.
However, the converter seems not support pytorch models that have multiple inputs yet. Any plan for it?

@peterjc123 peterjc123 added the question Further information is requested label Jan 19, 2022
@peterjc123
Copy link
Collaborator

peterjc123 commented Jan 19, 2022

@bayleef1 We do support models with multiple inputs. You need to pass the inputs as a tuple. e.g.

input_1 = torch.randn(1, 3, 224, 224)
input_2 = torch.randn(1, 3, 224, 224)
converter = TFLiteConverter(model, (input_1, input_2), ...)
converter.convert()

@bayleef1
Copy link
Author

OK, I successfully converted pytorch model with multiple inputs and outputs that exclude lstm state parameters.
When I tried to update lstm state by adding parameters in function 'forward(self, lstm_h, lstm_c, ...)' and 'output, (lstm_h, lstm_c) = self.lstm(input, (lstm_h, lstm_c))', I got the error message:
assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}'
AssertionError: Some output nodes are missing: ['hx1.1', 'hx2.1']
I wonder if the converter could solve this properly when both model inputs and outputs include lstm state.

@peterjc123
Copy link
Collaborator

peterjc123 commented Jan 21, 2022

@bayleef1 Yes, it's due to a limitation that you cannot read or write variables in the TFLite graph. So you need to export the model without adding them to inputs or outputs. But you can still use them through the variable mechanism.
For example, consider the following subgraph.
image
You are free to access them through set_tensor and get_tensor, given you've found out the locations of the state tensors.

import tensorflow as tf
import numpy as np

interpreter = tf.lite.Interpreter(model_path='xxx.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
tensor_details = interpreter.get_tensor_details()
tensor_shapes = {d['index']: d['shape'] for d in tensor_details}

# set inputs
for i in range(len(input_details)):
    interpreter.set_tensor(input_details[i]['index'], np.random.random(input_details[i]['shape']))

# set states
state_tensors = [3, 21]
for i in state_tensors:
    interpreter.set_tensor(i, np.zeros(tensor_shapes[i]))

# invoke
interpreter.invoke()

# get outputs
outputs = []
for i in range(len(output_details)):
    outputs.append(interpreter.get_tensor(output_details[i]['index']))

# get states
states = []
for i in state_tensors:
    states.append(interpreter.get_tensor(i))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants