Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ViT to torchvision/models #4594

Merged
merged 39 commits into from
Nov 27, 2021
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fbd0024
[vit] Adding ViT to torchvision/models
yiwen-song Oct 12, 2021
7521ffe
adding pre-logits layer + resolving comments
yiwen-song Oct 20, 2021
7e63685
Merge branch 'pytorch:main' into main
yiwen-song Oct 22, 2021
2dd878a
Merge branch 'pytorch:main' into main
yiwen-song Oct 25, 2021
53b6967
Fix the model attribute bug
yiwen-song Oct 26, 2021
fe248f0
Merge branch 'main' of https://github.com/sallysyw/vision into main
yiwen-song Oct 26, 2021
a84361a
Change version to arch
yiwen-song Oct 26, 2021
f981519
Merge branch 'pytorch:main' into main
yiwen-song Oct 26, 2021
9d2ef95
Merge branch 'main' into main
datumbox Oct 27, 2021
0aaac5b
Merge branch 'pytorch:main' into main
yiwen-song Nov 1, 2021
1cf8b92
Merge branch 'pytorch:main' into main
yiwen-song Nov 5, 2021
c2f3826
fix failing unittests
yiwen-song Nov 6, 2021
35c1d22
remove useless prints
yiwen-song Nov 6, 2021
1aff5cd
Merge branch 'pytorch:main' into main
yiwen-song Nov 13, 2021
568c560
reduce input size to fix unittests
yiwen-song Nov 15, 2021
8e71e4b
Increase windows-cpu executor to 2xlarge
yiwen-song Nov 16, 2021
f9860ec
Use `batch_first=True` and remove classifier
yiwen-song Nov 17, 2021
4d7d7fe
Merge branch 'pytorch:main' into main
yiwen-song Nov 17, 2021
b795e85
Change resource_class back to xlarge
yiwen-song Nov 17, 2021
ff64591
Remove vit_h_14
yiwen-song Nov 17, 2021
bd3a747
Remove vit_h_14 from __all__
yiwen-song Nov 17, 2021
8f88592
Move vision_transformer.py into prototype
yiwen-song Nov 19, 2021
22025ac
Fix formatting issue
yiwen-song Nov 19, 2021
26bc529
remove arch in builder
yiwen-song Nov 19, 2021
cc22238
Fix type err in model builder
yiwen-song Nov 19, 2021
1d4e2aa
Merge branch 'main' into main
yiwen-song Nov 19, 2021
091bf6b
Merge branch 'pytorch:main' into main
yiwen-song Nov 23, 2021
41edd15
address comments and trigger unittests
yiwen-song Nov 24, 2021
48ce69e
remove the prototype import in torchvision.models
yiwen-song Nov 24, 2021
0caf745
Merge branch 'main' into main
yiwen-song Nov 24, 2021
3a6b445
Adding vit back to models to trigger CircleCI test
yiwen-song Nov 24, 2021
72c5af7
fix test_jit_forward_backward
yiwen-song Nov 24, 2021
aae308c
Move all to prototype.
datumbox Nov 25, 2021
7b1e59e
Merge branch 'main' into main
datumbox Nov 25, 2021
717b6af
Merge branch 'main' into main
datumbox Nov 25, 2021
f0df7f8
Adopt new helper methods and fix prototype tests.
datumbox Nov 25, 2021
3807b23
Remove unused import.
datumbox Nov 25, 2021
eabec95
Merge branch 'main' into main
yiwen-song Nov 26, 2021
40b566b
Merge branch 'main' into main
yiwen-song Nov 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/expect/ModelTester.test_vit_b_16_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_b_32_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_l_16_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_l_32_expect.pkl
Binary file not shown.
23 changes: 21 additions & 2 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from itertools import chain
from typing import Mapping, Sequence

import pytest
import torch
Expand Down Expand Up @@ -89,7 +90,16 @@ def _create_feature_extractor(self, *args, **kwargs):

def _get_return_nodes(self, model):
set_rng_seed(0)
exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"]
exclude_nodes_filter = [
"getitem",
"floordiv",
"size",
"chunk",
"_assert",
"eq",
"dim",
"getattr",
]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
)
Expand Down Expand Up @@ -144,7 +154,16 @@ def test_forward_backward(self, model_name):
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
out = model(self.inp)
sum(o.mean() for o in out.values()).backward()
out_agg = 0
for node_out in out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.mean()
out_agg.backward()

def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval()
Expand Down
3 changes: 2 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def test_classification_model(model_fn, dev):
}
model_name = model_fn.__name__
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape")

model = model_fn(**kwargs)
Expand All @@ -515,7 +516,7 @@ def test_classification_model(model_fn, dev):
x = torch.rand(input_shape).to(device=dev)
out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50
assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)

Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from . import detection
from . import quantization
from . import segmentation
Expand Down
Loading