Skip to content

Commit

Permalink
[fbsync] Adding ViT to torchvision/models (#4594)
Browse files Browse the repository at this point in the history
Summary:
* [vit] Adding ViT to torchvision/models

* adding pre-logits layer + resolving comments

* Fix the model attribute bug

* Change version to arch

* fix failing unittests

* remove useless prints

* reduce input size to fix unittests

* Increase windows-cpu executor to 2xlarge

* Use `batch_first=True` and remove classifier

* Change resource_class back to xlarge

* Remove vit_h_14

* Remove vit_h_14 from __all__

* Move vision_transformer.py into prototype

* Fix formatting issue

* remove arch in builder

* Fix type err in model builder

* address comments and trigger unittests

* remove the prototype import in torchvision.models

* Adding vit back to models to trigger CircleCI test

* fix test_jit_forward_backward

* Move all to prototype.

* Adopt new helper methods and fix prototype tests.

* Remove unused import.

Reviewed By: NicolasHug

Differential Revision: D32694316

fbshipit-source-id: fa2867555fb7ae65f8dab537517386f6694585a2

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
3 people authored and facebook-github-bot committed Nov 30, 2021
1 parent 9e0f868 commit 0b66abe
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 6 deletions.
Binary file added test/expect/ModelTester.test_vit_b_16_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_b_32_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_l_16_expect.pkl
Binary file not shown.
Binary file added test/expect/ModelTester.test_vit_l_32_expect.pkl
Binary file not shown.
34 changes: 31 additions & 3 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from itertools import chain
from typing import Mapping, Sequence

import pytest
import torch
Expand Down Expand Up @@ -89,7 +90,16 @@ def _create_feature_extractor(self, *args, **kwargs):

def _get_return_nodes(self, model):
set_rng_seed(0)
exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"]
exclude_nodes_filter = [
"getitem",
"floordiv",
"size",
"chunk",
"_assert",
"eq",
"dim",
"getattr",
]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
)
Expand Down Expand Up @@ -144,7 +154,16 @@ def test_forward_backward(self, model_name):
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
out = model(self.inp)
sum(o.mean() for o in out.values()).backward()
out_agg = 0
for node_out in out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.mean()
out_agg.backward()

def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval()
Expand Down Expand Up @@ -176,7 +195,16 @@ def test_jit_forward_backward(self, model_name):
)
model = torch.jit.script(model)
fgn_out = model(self.inp)
sum(o.mean() for o in fgn_out.values()).backward()
out_agg = 0
for node_out in fgn_out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.mean()
out_agg.backward()

def test_train_eval(self):
class TestModel(torch.nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def test_classification_model(model_fn, dev):
}
model_name = model_fn.__name__
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape")

model = model_fn(**kwargs)
Expand All @@ -515,7 +516,7 @@ def test_classification_model(model_fn, dev):
x = torch.rand(input_shape).to(device=dev)
out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50
assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)

Expand Down
7 changes: 5 additions & 2 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
x = [x]

# compare with new model builder parameterized in the old fashion way
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from . import detection
from . import quantization
from . import segmentation
Expand Down
Loading

0 comments on commit 0b66abe

Please sign in to comment.