From e1d1033131114dc2634e664d009e061d900a9554 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 30 Nov 2023 18:32:36 +0800 Subject: [PATCH] [ORTModule] Remove Unused Arguments from Generated Triton Code (#18636) This PR: - Remove unused arguments from generated triton code, - Remove unnecessary mask for symbolic shape case from generated triton code. - Add doc for usage of ORTMODULE_TRITON_CONFIG_FILE. --- docs/ORTModule_Training_Guidelines.md | 24 ++++++++++++ .../python/training/ort_triton/_codegen.py | 4 +- .../python/training/ort_triton/_ir.py | 39 +++++++++++++------ 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 7fa89cca381d9..d3ec61e86779b 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -379,6 +379,30 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training export ORTMODULE_USE_TRITON=1 ``` +#### ORTMODULE_TRITON_CONFIG_FILE + +- **Feature Area**: *ORTMODULE/TritonOp* +- **Description**: Triton codegen currently supported some Ops such as some elementwise Ops and some reduction Ops. If Triton optimization is enabled, all these supported Ops will be optimized by default if possible. User can provide a customized JSON config file to control which Ops to optimize and how to optimize them. Below is a sample of config JSON. For each Op, Opset version list and domain is needed. Currently "conditions" field can be used to control axis/axes attribute or input, by specify the real value, or "single" means it contains only one dimension, or "constant" means it must be constant tensor. Save the JSON as a file somewhere and assign its path to below env variable to enable the customized config. + + ```json + { + "ops": { + "Add": {"versions": [13, 14]}, + "Sub": {"versions": [13, 14]}, + "Identity": {"versions": [13], "is_no_op": True}, + "ReduceSum": {"versions": [13], "conditions": {"axes": "[-1]"}}, + "Softmax": {"versions": [13]}, + "SoftmaxGrad_13": {"domain": "com.microsoft", "versions": [1]} + }, + "initializer": "scalar", + "min_nodes": 2 + } + ``` + + ```bash + export ORTMODULE_TRITON_CONFIG_FILE=triton_config.json + ``` + #### ORTMODULE_ENABLE_TUNING - **Feature Area**: *ORTMODULE/TritonOp* diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index 462491365c1fa..e0f65ed272d38 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -159,7 +159,7 @@ def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_ other_input_args = "seed_cuda, " if node.has_dropout else "" # Support symbolic shape if any. - symbolic_shape_args_str = ", ".join(node.symbolic_shape_variables) + symbolic_shape_args_str = ", ".join(sorted(node.offset_calc.symbolic_shape_variables)) if symbolic_shape_args_str: other_input_args += f"{symbolic_shape_args_str}, " @@ -490,7 +490,7 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod kernel_args_str += ", seed_cuda" # Support symbolic shape if any. - symbolic_shape_args_str = ", ".join(kernel_node.symbolic_shape_variables) + symbolic_shape_args_str = ", ".join(sorted(kernel_node.offset_calc.symbolic_shape_variables)) if symbolic_shape_args_str: kernel_args_str += f", {symbolic_shape_args_str}" diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index 50121cbf49804..a2b8407645c46 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -91,13 +91,16 @@ def __init__(self, target_shape: List[sympy.Expr], reduce_axes: List[int]): self.autotune_configs: AutotuneConfigs = AutotuneConfigs( self.x_numel, self.r_numel, not self.is_reduction or self.reduce_axes[-1] == self.rank - 1 ) - self.requires_x_mask: bool = not self.x_numel.is_number or any( - int(self.x_numel) % config[0] != 0 for config in self.autotune_configs.configs + simplified_x_numel = self.x_numel.subs({symbol: sympy.Integer(1) for symbol in self.x_numel.free_symbols}) + self.requires_x_mask: bool = any( + simplified_x_numel % sympy.Integer(config[0]) != 0 for config in self.autotune_configs.configs ) - self.requires_r_mask: bool = not self.r_numel.is_number or any( - int(self.r_numel) % config[1] != 0 for config in self.autotune_configs.configs + simplified_r_numel = self.r_numel.subs({symbol: sympy.Integer(1) for symbol in self.r_numel.free_symbols}) + self.requires_r_mask: bool = any( + simplified_r_numel % sympy.Integer(config[1]) != 0 for config in self.autotune_configs.configs ) self.reduced_args: Set[str] = set() + self.symbolic_shape_variables: Set[str] = set() def get_input_strides(self, name: str) -> List[sympy.Expr]: assert name in self.input_strides @@ -151,14 +154,32 @@ def register_tensor_arg(self, tensor_arg: TensorArg): else: strides.insert(0, sympy.Integer(0)) self.input_strides[tensor_arg.name] = strides + x_input_strides = self.get_x_input_strides(tensor_arg.name) if not self.is_same_x_shape(tensor_arg.name): - for idx, dim in enumerate(self.get_x_input_strides(tensor_arg.name)): + for idx, dim in enumerate(x_input_strides): if dim != sympy.Integer(0): self.x_compute_dims.add(idx) + if idx != self.x_rank - 1: + self.symbolic_shape_variables.update( + [symbol.name for symbol in self.x_strides[idx].free_symbols] + ) + if idx != 0: + self.symbolic_shape_variables.update([symbol.name for symbol in self.x_dims[idx].free_symbols]) + elif len(x_input_strides) > 0 and x_input_strides[-1] != sympy.Integer(1): + self.symbolic_shape_variables.update([symbol.name for symbol in x_input_strides[-1].free_symbols]) + r_input_strides = self.get_r_input_strides(tensor_arg.name) if not self.is_same_r_shape(tensor_arg.name): - for idx, dim in enumerate(self.get_r_input_strides(tensor_arg.name)): + for idx, dim in enumerate(r_input_strides): if dim != sympy.Integer(0): self.r_compute_dims.add(idx) + if idx != self.r_rank - 1: + self.symbolic_shape_variables.update( + [symbol.name for symbol in self.r_strides[idx].free_symbols] + ) + if idx != 0: + self.symbolic_shape_variables.update([symbol.name for symbol in self.r_dims[idx].free_symbols]) + elif len(r_input_strides) > 0 and r_input_strides[-1] != sympy.Integer(1): + self.symbolic_shape_variables.update([symbol.name for symbol in r_input_strides[-1].free_symbols]) def is_x_reduced(self, name: str) -> bool: strides = self.get_input_strides(name) @@ -288,7 +309,6 @@ def __init__(self, inputs: List[TensorArg], outputs: List[TensorArg], target_sha self.target_shape: List[sympy.Expr] = target_shape self.sub_nodes: List[IRNode] = [] self.var_map: Dict[str, str] = dict() - self.symbolic_shape_variables: List[str] = [] self.has_dropout: bool = False self.offset_calc: OffsetCalculator = OffsetCalculator(target_shape, reduce_axes) @@ -313,11 +333,6 @@ def gen_variable_names(self): variable_name = self.var_map[name] assert variable_name not in self.var_map self.var_map[variable_name] = str(np.array(value.item(), value.dtype)) - seen = set() - for dim in self.target_shape: - if dim.is_symbol and dim not in seen: - seen.add(dim) - self.symbolic_shape_variables.append(str(dim)) class ElementwiseKernelNode(KernelNode):