forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Import & Cache Mechanism (apache#26)
* Import \& Cache Mechanism * unused * use None as default cache_dir * invalid name + more layout * fix bert * remove
- Loading branch information
Showing
2 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
import multiprocessing | ||
import os | ||
import pickle | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import tvm | ||
import tvm.relay.testing | ||
from tvm import relay | ||
from tvm.ir import IRModule | ||
from tvm.runtime import NDArray, load_param_dict, save_param_dict | ||
|
||
SUPPORTED = [ | ||
# TorchVision | ||
"resnet_18", | ||
"resnet_50", | ||
"mobilenet_v2", | ||
"mobilenet_v3", | ||
"wide_resnet_50", | ||
"resnext_50", | ||
"resnet3d_18", | ||
"inception_v3", | ||
"densenet_121", | ||
"vgg_16", | ||
# Transformer | ||
"bert_tiny", | ||
"bert_base", | ||
"bert_medium", | ||
"bert_large", | ||
# Relay testing | ||
"dcgan", | ||
] | ||
|
||
|
||
def _get_network( | ||
args: Tuple[str, List[int]] | ||
) -> Tuple[IRModule, bytearray, Tuple[str, List[int], str]]: | ||
name: str | ||
input_shape: List[int] | ||
name, input_shape = args | ||
|
||
mod: IRModule | ||
|
||
if name in [ | ||
"resnet_18", | ||
"resnet_50", | ||
"wide_resnet_50", | ||
"resnext_50", | ||
"mobilenet_v2", | ||
"mobilenet_v3", | ||
"inception_v3", | ||
"densenet_121", | ||
"resnet3d_18", | ||
"vgg_16", | ||
]: | ||
# torchvision>=0.9.0 | ||
import torch # type: ignore | ||
import torchvision.models as models # type: ignore | ||
|
||
if name in ["resnet_18", "resnet_50"]: | ||
model = getattr(models, name.replace("_", ""))(pretrained=False) | ||
elif name == "wide_resnet_50": | ||
model = getattr(models, "wide_resnet50_2")(pretrained=False) | ||
elif name == "resnext_50": | ||
model = getattr(models, "resnext50_32x4d")(pretrained=False) | ||
elif name == "mobilenet_v2": | ||
model = getattr(models, name)(pretrained=False) | ||
elif name == "mobilenet_v3": | ||
model = getattr(models, name + "_large")(pretrained=False) | ||
elif name == "inception_v3": | ||
model = getattr(models, name)(pretrained=False, aux_logits=False) | ||
elif name == "densenet_121": | ||
model = getattr(models, name.replace("_", ""))(pretrained=False) | ||
elif name == "resnet3d_18": | ||
model = models.video.r3d_18(pretrained=False) | ||
elif name == "vgg_16": | ||
model = getattr(models, name.replace("_", ""))(pretrained=False) | ||
|
||
dtype = "float32" | ||
input_data = torch.randn(input_shape).type( | ||
{ | ||
"float32": torch.float32, | ||
}[dtype] | ||
) | ||
scripted_model = torch.jit.trace(model, input_data).eval() | ||
input_name = "input0" | ||
shape_list = [(input_name, input_shape)] | ||
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) | ||
with tvm.transform.PassContext(opt_level=3): | ||
mod = tvm.transform.Sequential( | ||
[ | ||
relay.transform.RemoveUnusedFunctions(), | ||
relay.transform.ConvertLayout( | ||
{ | ||
"nn.conv2d": ["NHWC", "default"], | ||
"nn.conv3d": ["NDHWC", "default"], | ||
"nn.max_pool2d": ["NHWC", "default"], | ||
"nn.avg_pool2d": ["NHWC", "default"], | ||
} | ||
), | ||
] | ||
)(mod) | ||
inputs = (input_name, input_shape, dtype) | ||
elif name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]: | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
# pip3 install transformers==3.5 torch==1.7 | ||
import torch # type: ignore | ||
import transformers # type: ignore | ||
|
||
config_dict = { | ||
"bert_tiny": transformers.BertConfig( | ||
num_hidden_layers=6, | ||
hidden_size=512, | ||
intermediate_size=2048, | ||
num_attention_heads=8, | ||
return_dict=False, | ||
), | ||
"bert_base": transformers.BertConfig( | ||
num_hidden_layers=12, | ||
hidden_size=768, | ||
intermediate_size=3072, | ||
num_attention_heads=12, | ||
return_dict=False, | ||
), | ||
"bert_medium": transformers.BertConfig( | ||
num_hidden_layers=12, | ||
hidden_size=1024, | ||
intermediate_size=4096, | ||
num_attention_heads=16, | ||
return_dict=False, | ||
), | ||
"bert_large": transformers.BertConfig( | ||
num_hidden_layers=24, | ||
hidden_size=1024, | ||
intermediate_size=4096, | ||
num_attention_heads=16, | ||
return_dict=False, | ||
), | ||
} | ||
configuration = config_dict[name] | ||
model = transformers.BertModel(configuration) | ||
input_name = "input_ids" | ||
input_dtype = "int64" | ||
A = torch.randint(10000, input_shape) | ||
model.eval() | ||
scripted_model = torch.jit.trace(model, [A], strict=False) | ||
input_name = "input_ids" | ||
shape_list = [(input_name, input_shape)] | ||
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) | ||
mod = relay.transform.FastMath()(mod) | ||
mod = relay.transform.CombineParallelBatchMatmul()(mod) | ||
inputs = (input_name, input_shape, input_dtype) | ||
elif name == "dcgan": | ||
output_shape = input_shape | ||
batch_size = output_shape[0] | ||
oshape = output_shape[1:] | ||
mod, params = relay.testing.dcgan.get_workload( | ||
batch_size=batch_size, | ||
oshape=oshape, | ||
layout="NHWC", | ||
) | ||
inputs = ("data", [100], "float32") | ||
else: | ||
raise ValueError("Invalid name: " + name) | ||
|
||
params_bytearray: bytearray = save_param_dict(params) | ||
return mod, params_bytearray, inputs | ||
|
||
|
||
def get_network( | ||
name: str, | ||
input_shape: List[int], | ||
cache_dir: Optional[str] = None, | ||
) -> Tuple[IRModule, Dict[str, NDArray], Tuple[str, List[int], str]]: | ||
mod: IRModule | ||
params_bytearray: bytearray | ||
params: Dict[str, NDArray] | ||
inputs: Tuple[str, List[int], str] | ||
keyword = f'{name}-{",".join(str(i) for i in input_shape)}.json' | ||
if cache_dir is not None: | ||
path = os.path.join(cache_dir, keyword) | ||
if os.path.exists(path): | ||
print(f"Load cached network file: {path}") | ||
with open(path, "rb") as i_f: | ||
mod, params_bytearray, inputs = pickle.load(i_f) | ||
params = load_param_dict(params_bytearray) | ||
return mod, params, inputs | ||
with multiprocessing.Pool(processes=1) as pool: | ||
result = pool.map(_get_network, [(name, input_shape)]) | ||
((mod, params_bytearray, inputs),) = result | ||
params = load_param_dict(params_bytearray) | ||
if cache_dir is not None: | ||
path = os.path.join(cache_dir, keyword) | ||
with open(path, "wb") as o_f: | ||
pickle.dump((mod, params_bytearray, inputs), o_f) | ||
return mod, params, inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from tvm.meta_schedule.testing.e2e import get_network | ||
|
||
|
||
def test_import(): | ||
network_keys = [] | ||
for name in [ | ||
"resnet_18", | ||
"resnet_50", | ||
"mobilenet_v2", | ||
"mobilenet_v3", | ||
"wide_resnet_50", | ||
"resnext_50", | ||
"densenet_121", | ||
]: | ||
for batch_size in [1, 4, 8]: | ||
for image_size in [224, 240, 256]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size])) | ||
# inception-v3 | ||
for name in ["inception_v3"]: | ||
for batch_size in [1, 2, 4]: | ||
for image_size in [299]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size])) | ||
# resnet3d | ||
for name in ["resnet3d_18"]: | ||
for batch_size in [1, 2, 4]: | ||
for image_size in [112, 128, 144]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size, 16])) | ||
# bert | ||
for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]: | ||
for batch_size in [1, 2, 4]: | ||
for seq_length in [64, 128, 256]: | ||
network_keys.append((name, [batch_size, seq_length])) | ||
# dcgan | ||
for name in ["dcgan"]: | ||
for batch_size in [1, 4, 8]: | ||
for image_size in [64]: | ||
network_keys.append((name, [batch_size, 3, image_size, image_size])) | ||
|
||
for i, (name, input_shape) in enumerate(network_keys, 1): | ||
print(f"[{i} / {len(network_keys)}] {name}, input_shape = {input_shape}") | ||
get_network(name, input_shape, cache_dir="/tmp/relay/") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_import() |