Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GeneratorExp aren't supported by torch.jit.script when I try to export a previously trained model 'google/vit-base-patch16-224-in21k'. #15354

Open
1 task done
ssriram1978 opened this issue Jan 26, 2022 · 4 comments

Comments

@ssriram1978
Copy link

Environment info

  • transformers version: 4.15.0
  • Platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.12
  • PyTorch version (GPU?): 1.10.0+cu111 (True)
  • Tensorflow version (GPU?): 2.7.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help

Models:

ViTModel

If the model isn't in the list, ping @LysandreJik who will redirect you to the correct contributor.

Library:

Documentation: @sgugger

Model hub:

Information

GeneratorExp aren't supported by torch.jit.script when I try to export a previously trained model 'google/vit-base-patch16-224-in21k'.

Model I am using (ViTModel):

The problem arises when using:

  • my own modified scripts: (give details below)

model_x = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=len(label2id),
label2id=label2id,
id2label=id2label
)
model_scripted = torch.jit.script(model_x) # Export to TorchScript


UnsupportedNodeError Traceback (most recent call last)
in ()
6 id2label=id2label
7 )
----> 8 model_scripted = torch.jit.script(model_x) # Export to TorchScript
9 model_scripted.save('model_scripted.pt') # Save

14 frames
/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py in call(self, ctx, node)
284 method = getattr(self, 'build_' + node.class.name, None)
285 if method is None:
--> 286 raise UnsupportedNodeError(ctx, node)
287 return method(ctx, node)
288

UnsupportedNodeError: GeneratorExp aren't supported:
File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py", line 987
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
~ <--- HERE

To reproduce

Steps to reproduce the behavior:

  1. from transformers import ViTForImageClassification
  2. Instantiate a previously created mode 'google/vit-base-patch16-224-in21k' using ViTForImageClassification.from_pretrained() API.
  3. Try invoking torch.jit.script(model_x) and you will see the error.

Expected behavior

@NielsRogge
Copy link
Contributor

NielsRogge commented Jan 28, 2022

Hi,

The Vision Transformer currently isn't supported to work with TorchScript (test is disabled here).

cc'ing @lewtun who can perhaps assist here to add support for it.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@lewtun
Copy link
Member

lewtun commented Mar 7, 2022

Hi @ssriram1978 it seems the problem is that the Python any() function does not belong to the list of TorchScript supported functions. Here's a simple example that reproduces the error:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        a = any(i for i in range(10))
        return x


model = MyModel()
x = torch.randn(1, 1)
out = model(x)

scripted = torch.jit.script(model)  # Throws UnsupportedNodeError

Is tracing an option for your application? If yes, here's a snippet the does the job:

import numpy as np
import requests
import torch
from PIL import Image

from transformers import AutoFeatureExtractor, AutoModelForImageClassification

# Load feature extractor and checkpoint
model_ckpt = "google/vit-base-patch16-224-in21k"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModelForImageClassification.from_pretrained(model_ckpt, torchscript=True)
# Download sample image and feed to original model
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(images=image, return_tensors="pt")
original_outputs = model(**inputs)
# Trace model and save
traced_model = torch.jit.trace(model, [inputs["pixel_values"]])
torch.jit.save(traced_model, "traced_vit.pt")
# Reload traced model and compare outputs to original one
loaded_model = torch.jit.load("traced_vit.pt")
loaded_model.eval()
traced_outputs = loaded_model(**inputs)
assert np.allclose(
    original_outputs[0].detach().numpy(), traced_outputs[0].detach().numpy()
)

You can find more details on tracing in our guide.

An alternative to tracing would be to export the model to ONNX - once torch v1.11 is released you'll be able to do this via #15658

@evanwong1020
Copy link

I think this issue may have been fixed. Is anybody else running into this issue still, or can it be closed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants