Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support multi_tower_din_trt #30

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.
graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.3.0-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
grpcio-tools<1.63.0
pandas
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.4-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.4-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.5-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.5-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
scikit-learn
tensorboard
torch==2.5.0
Expand Down
231 changes: 231 additions & 0 deletions tzrec/acc/trt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch

# cpu image has no torch_tensorrt
try:
import torch_tensorrt
except Exception:
pass
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
from torchrec.fx import symbolic_trace

from tzrec.acc.utils import is_debug_trt
from tzrec.models.model import ScriptWrapper
from tzrec.utils.logging_util import logger


def trt_convert(
tiankongdeguiji marked this conversation as resolved.
Show resolved Hide resolved
module: nn.Module,
# pyre-ignore [2]
inputs: Optional[Sequence[Sequence[Any]]],
# pyre-ignore [2]
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
) -> torch.fx.GraphModule:
"""Convert model use trt.

Args:
module (nn.Module): Source module
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): inputs
dynamic_shapes: dynamic shapes

Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
"""
logger.info("trt convert start...")
# torch_tensorrt.runtime.set_multi_device_safe_mode(True)
enabled_precisions = {torch.float32}

# Workspace size for TensorRT
workspace_size = 2 << 30

# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 2

exp_program = torch.export.export(module, (inputs,), dynamic_shapes=dynamic_shapes)
# use script model , unsupported the inputs : dict
if is_debug_trt():
with torch_tensorrt.logging.graphs():
optimized_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs,
# pyre-ignore [6]
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
hardware_compatible=True,
assume_dynamic_shape_support=True,
# truncate_long_and_double=True,
allow_shape_tensors=True,
)

else:
optimized_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs,
# pyre-ignore [6]
enabled_precisions=enabled_precisions,
workspace_size=workspace_size,
min_block_size=min_block_size,
hardware_compatible=True,
assume_dynamic_shape_support=True,
# truncate_long_and_double=True,
allow_shape_tensors=True,
)

logger.info("trt convert end")
return optimized_model


class ScriptWrapperList(ScriptWrapper):
"""Model inference wrapper for jit.script.

ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
tiankongdeguiji marked this conversation as resolved.
Show resolved Hide resolved
and return a list of Tensor instead of a dict of Tensor
"""

def __init__(self, module: nn.Module) -> None:
super().__init__(module)

# pyre-ignore [15]
def forward(
self,
data: Dict[str, torch.Tensor],
# pyre-ignore [9]
device: torch.device = "cpu",
) -> List[torch.Tensor]:
"""Predict the model.

Args:
data (dict): a dict of input data for Batch.
device (torch.device): inference device.

Return:
predictions (dict): a dict of predicted result.
"""
batch = self.get_batch(data, device)
return self.model.predict(batch)


class ScriptWrapperTRT(nn.Module):
"""Model inference wrapper for jit.script."""

def __init__(self, embedding_group: nn.Module, dense: nn.Module) -> None:
super().__init__()
self.embedding_group = embedding_group
self.dense = dense

def forward(
self,
data: Dict[str, torch.Tensor],
# pyre-ignore [9]
device: torch.device = "cuda:0",
) -> Dict[str, torch.Tensor]:
"""Predict the model.

Args:
data (dict): a dict of input data for Batch.
device (torch.device): inference device.

Return:
predictions (dict): a dict of predicted result.
"""
grouped_features = self.embedding_group(data, device)
y = self.dense(grouped_features)
return y


def export_model_trt(
model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str
) -> None:
"""Export trt model.

Args:
model (nn.Module): the model
data (Dict[str, torch.Tensor]): the test data
save_dir (str): model save dir
"""
# ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
emb_trace_gpu = ScriptWrapperList(model.model.embedding_group)
emb_res = emb_trace_gpu(data, "cuda:0")
emb_trace_gpu = symbolic_trace(emb_trace_gpu)
emb_trace_gpu = torch.jit.script(emb_trace_gpu)

# dynamic shapes
batch = torch.export.Dim("batch", min=1, max=10000)
dynamic_shapes_list = []
values_list_cuda = []
for i, value in enumerate(emb_res):
v = value.detach().to("cuda:0")
values_list_cuda.append(v)
dict_dy = {0: batch}
if v.dim() == 3:
dict_dy[1] = torch.export.Dim("seq_len" + str(i), min=1, max=10000)
dynamic_shapes_list.append(dict_dy)

# convert dense
dense = model.model.dense
logger.info("dense res: %s", dense(values_list_cuda))
dense_layer = symbolic_trace(dense)
dynamic_shapes = {"args": dynamic_shapes_list}
dense_layer_trt = trt_convert(dense_layer, values_list_cuda, dynamic_shapes)
dict_res = dense_layer_trt(values_list_cuda)
logger.info("dense trt res: %s", dict_res)

# save combined_model
combined_model = ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
result = combined_model(data, "cuda:0")
logger.info("combined model result: %s", result)
# combined_model = symbolic_trace(combined_model)
combined_model = torch.jit.trace(
combined_model, example_inputs=(data,), strict=False
)
scripted_model = torch.jit.script(combined_model)
# pyre-ignore [16]
scripted_model.save(os.path.join(save_dir, "scripted_model.pt"))

if is_debug_trt():
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference_dense"):
dict_res = dense(values_list_cuda)
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))

with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference_dense_trt"):
dict_res = dense_layer_trt(values_list_cuda)
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))

model_gpu_combined = torch.jit.load(
os.path.join(save_dir, "scripted_model.pt"), map_location="cuda:0"
)
res = model_gpu_combined(data)
logger.info("final res: %s", res)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("model_inference_combined_trt"):
dict_res = model_gpu_combined(data)
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))

logger.info("trt convert success")
60 changes: 59 additions & 1 deletion tzrec/acc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
from typing import Dict
from typing import Dict, List

import torch

Expand All @@ -34,6 +36,35 @@ def is_input_tile_emb() -> bool:
return False


def is_trt() -> bool:
"""Judge is trt or not."""
is_trt = os.environ.get("ENABLE_TRT")
if is_trt and is_trt[0] == "1":
return True
return False


def is_trt_predict(model_path: str) -> bool:
"""Judge is trt or not in predict."""
with open(model_path + "/model_acc.json", "r", encoding="utf-8") as file:
data = json.load(file)
is_trt = data.get("ENABLE_TRT")
if is_trt and is_trt[0] == "1":
return True
return False


def is_debug_trt() -> bool:
"""Judge is debug trt or not.

Embedding Split user/item
"""
is_trt = os.environ.get("DEBUG_TRT")
if is_trt and is_trt[0] == "1":
return True
return False


def is_quant() -> bool:
"""Judge is quant or not."""
is_quant = os.environ.get("QUANT_EMB")
Expand Down Expand Up @@ -116,4 +147,31 @@ def export_acc_config() -> Dict[str, str]:
acc_config["INPUT_TILE"] = os.environ["INPUT_TILE"]
if "QUANT_EMB" in os.environ:
acc_config["QUANT_EMB"] = os.environ["QUANT_EMB"]
if "ENABLE_TRT" in os.environ:
acc_config["ENABLE_TRT"] = os.environ["ENABLE_TRT"]
return acc_config


def dicts_are_equal(
dict1: Dict[str, torch.Tensor], dict2: Dict[str, torch.Tensor]
) -> bool:
"""Compare dict[str,torch.Tensor]."""
if dict1.keys() != dict2.keys():
return False

for key in dict1:
if not torch.equal(dict1[key], dict2[key]):
return False

return True


def lists_are_equal(list1: List[torch.Tensor], list2: List[torch.Tensor]) -> bool:
"""Compare List[torch.Tensor]."""
if len(list1) != len(list2):
return False

for i in range(len(list1)):
if not torch.equal(list1[i], list2[i]):
return False
return True
Loading
Loading