-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GeneratorExp aren't supported by torch.jit.script when I try to export a previously trained model 'google/vit-base-patch16-224-in21k'. #15354
Comments
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi @ssriram1978 it seems the problem is that the Python import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
a = any(i for i in range(10))
return x
model = MyModel()
x = torch.randn(1, 1)
out = model(x)
scripted = torch.jit.script(model) # Throws UnsupportedNodeError Is tracing an option for your application? If yes, here's a snippet the does the job: import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
# Load feature extractor and checkpoint
model_ckpt = "google/vit-base-patch16-224-in21k"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModelForImageClassification.from_pretrained(model_ckpt, torchscript=True)
# Download sample image and feed to original model
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(images=image, return_tensors="pt")
original_outputs = model(**inputs)
# Trace model and save
traced_model = torch.jit.trace(model, [inputs["pixel_values"]])
torch.jit.save(traced_model, "traced_vit.pt")
# Reload traced model and compare outputs to original one
loaded_model = torch.jit.load("traced_vit.pt")
loaded_model.eval()
traced_outputs = loaded_model(**inputs)
assert np.allclose(
original_outputs[0].detach().numpy(), traced_outputs[0].detach().numpy()
) You can find more details on tracing in our guide. An alternative to tracing would be to export the model to ONNX - once |
I think this issue may have been fixed. Is anybody else running into this issue still, or can it be closed? |
Environment info
transformers
version: 4.15.0Who can help
Models:
ViTModel
If the model isn't in the list, ping @LysandreJik who will redirect you to the correct contributor.
Library:
Documentation: @sgugger
Model hub:
Information
GeneratorExp aren't supported by torch.jit.script when I try to export a previously trained model 'google/vit-base-patch16-224-in21k'.
Model I am using (ViTModel):
The problem arises when using:
model_x = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=len(label2id),
label2id=label2id,
id2label=id2label
)
model_scripted = torch.jit.script(model_x) # Export to TorchScript
UnsupportedNodeError Traceback (most recent call last)
in ()
6 id2label=id2label
7 )
----> 8 model_scripted = torch.jit.script(model_x) # Export to TorchScript
9 model_scripted.save('model_scripted.pt') # Save
14 frames
/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py in call(self, ctx, node)
284 method = getattr(self, 'build_' + node.class.name, None)
285 if method is None:
--> 286 raise UnsupportedNodeError(ctx, node)
287 return method(ctx, node)
288
UnsupportedNodeError: GeneratorExp aren't supported:
File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py", line 987
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
~ <--- HERE
To reproduce
Steps to reproduce the behavior:
Expected behavior
The text was updated successfully, but these errors were encountered: