diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 953b2d895cd72..6e5ab0f69eda9 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -12,6 +12,7 @@ from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh from taichi.lang import ops as ti_ops from taichi.lang._ndrange import _Ndrange, ndrange +from taichi.lang.argpack import ArgPackType from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus, ReturnStatus from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import ( @@ -601,6 +602,44 @@ def build_FunctionDef(ctx, node): assert args.kw_defaults == [] assert args.kwarg is None + def decl_and_create_variable(annotation, name, arg_features): + if not isinstance(annotation, primitive_types.RefType): + ctx.kernel_args.append(name) + if isinstance(annotation, ArgPackType): + d = {} + for j, (_name, anno) in enumerate(annotation.members.items()): + d[_name] = decl_and_create_variable(anno, _name, arg_features[j]) + return kernel_arguments.decl_argpack_arg(annotation, d) + if isinstance(annotation, annotations.template): + return ctx.global_vars[name] + if isinstance(annotation, annotations.sparse_matrix_builder): + return kernel_arguments.decl_sparse_matrix( + to_taichi_type(arg_features), + name, + ) + if isinstance(annotation, ndarray_type.NdarrayType): + return kernel_arguments.decl_ndarray_arg( + to_taichi_type(arg_features[0]), + arg_features[1], + name, + arg_features[2], + arg_features[3], + ) + if isinstance(annotation, texture_type.TextureType): + return kernel_arguments.decl_texture_arg(arg_features[0], name) + if isinstance(annotation, texture_type.RWTextureType): + return kernel_arguments.decl_rw_texture_arg( + arg_features[0], + arg_features[1], + arg_features[2], + name, + ) + if isinstance(annotation, MatrixType): + return kernel_arguments.decl_matrix_arg(annotation, name) + if isinstance(annotation, StructType): + return kernel_arguments.decl_struct_arg(annotation, name) + return kernel_arguments.decl_scalar_arg(annotation, name) + def transform_as_kernel(): # Treat return type if node.returns is not None: @@ -608,59 +647,21 @@ def transform_as_kernel(): impl.get_runtime().compiling_callable.finalize_rets() for i, arg in enumerate(args.args): - if not isinstance(ctx.func.arguments[i].annotation, primitive_types.RefType): - ctx.kernel_args.append(arg.arg) - if isinstance(ctx.func.arguments[i].annotation, annotations.template): - ctx.create_variable(arg.arg, ctx.global_vars[arg.arg]) - elif isinstance(ctx.func.arguments[i].annotation, annotations.sparse_matrix_builder): - ctx.create_variable( - arg.arg, - kernel_arguments.decl_sparse_matrix( - to_taichi_type(ctx.arg_features[i]), - ctx.func.arguments[i].name, - ), - ) - elif isinstance(ctx.func.arguments[i].annotation, ndarray_type.NdarrayType): - ctx.create_variable( - arg.arg, - kernel_arguments.decl_ndarray_arg( - to_taichi_type(ctx.arg_features[i][0]), - ctx.arg_features[i][1], - ctx.func.arguments[i].name, - ctx.arg_features[i][2], - ctx.arg_features[i][3], - ), - ) - elif isinstance(ctx.func.arguments[i].annotation, texture_type.TextureType): - ctx.create_variable( - arg.arg, - kernel_arguments.decl_texture_arg(ctx.arg_features[i][0], ctx.func.arguments[i].name), - ) - elif isinstance(ctx.func.arguments[i].annotation, texture_type.RWTextureType): + if isinstance(ctx.func.arguments[i].annotation, ArgPackType): + d = {} + for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()): + d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j]) + ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)) + else: ctx.create_variable( arg.arg, - kernel_arguments.decl_rw_texture_arg( - ctx.arg_features[i][0], - ctx.arg_features[i][1], - ctx.arg_features[i][2], + decl_and_create_variable( + ctx.func.arguments[i].annotation, ctx.func.arguments[i].name, + ctx.arg_features[i] if ctx.arg_features is not None else None, ), ) - elif isinstance(ctx.func.arguments[i].annotation, MatrixType): - ctx.create_variable( - arg.arg, - kernel_arguments.decl_matrix_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name), - ) - elif isinstance(ctx.func.arguments[i].annotation, StructType): - ctx.create_variable( - arg.arg, - kernel_arguments.decl_struct_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name), - ) - else: - ctx.create_variable( - arg.arg, - kernel_arguments.decl_scalar_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name), - ) + impl.get_runtime().compiling_callable.finalize_params() # remove original args node.args.args = [] diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 364f5d6f57c06..94f237622a725 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -94,6 +94,10 @@ def decl_struct_arg(structtype, name): return structtype.from_taichi_object(arg_load) +def decl_argpack_arg(argpacktype, member_dict): + return argpacktype.from_taichi_object(member_dict) + + def decl_sparse_matrix(dtype, name): value_type = cook_dtype(dtype) ptr_type = cook_dtype(u64) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 8138521ee1299..c6fe985cb34ee 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -13,6 +13,7 @@ from taichi._lib import core as _ti_core from taichi.lang import impl, ops, runtime_ops from taichi.lang._wrap_inspect import getsourcefile, getsourcelines +from taichi.lang.argpack import ArgPackType, ArgPack from taichi.lang.ast import ( ASTTransformerContext, KernelSimplicityASTChecker, @@ -368,6 +369,13 @@ def extract_arg(arg, anno): # [Primitive arguments] Return the value return arg + if isinstance(anno, ArgPackType): + if not isinstance(arg, ArgPack): + raise TaichiRuntimeTypeError(f"Argument must be a argument pack, got {type(arg)}") + return tuple( + TaichiCallableTemplateMapper.extract_arg(arg[name], dtype) + for index, (name, dtype) in enumerate(anno.members.items()) + ) if isinstance(anno, texture_type.TextureType): if not isinstance(arg, taichi.lang._texture.Texture): raise TaichiRuntimeTypeError(f"Argument must be a texture, got {type(arg)}") @@ -557,6 +565,8 @@ def extract_arguments(self): pass elif isinstance(annotation, StructType): pass + elif isinstance(annotation, ArgPackType): + pass else: raise TaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}") self.arguments.append(KernelArgument(annotation, param.name, param.default)) @@ -622,182 +632,203 @@ def launch_kernel(self, t_kernel, *args): launch_ctx = t_kernel.make_launch_context() max_arg_num = 64 exceed_max_arg_num = False - for i, v in enumerate(args): - needed = self.arguments[i].annotation - if isinstance(needed, template): + for i, val in enumerate(args): + _needed = self.arguments[i].annotation + if isinstance(_needed, template): continue - if actual_argument_slot >= max_arg_num: - exceed_max_arg_num = True - break - provided = type(v) - # Note: do not use sth like "needed == f32". That would be slow. - if id(needed) in primitive_types.real_type_ids: - if not isinstance(v, (float, int, np.floating, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) - launch_ctx.set_arg_float(actual_argument_slot, float(v)) - elif id(needed) in primitive_types.integer_type_ids: - if not isinstance(v, (int, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) - if is_signed(cook_dtype(needed)): - launch_ctx.set_arg_int(actual_argument_slot, int(v)) - else: - launch_ctx.set_arg_uint(actual_argument_slot, int(v)) - elif isinstance(needed, sparse_matrix_builder): - # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument - launch_ctx.set_arg_uint(actual_argument_slot, v._get_ndarray_addr()) - elif isinstance(needed, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray): - v_primal = v.arr - v_grad = v.grad.arr if v.grad else None - if v_grad is None: - launch_ctx.set_arg_ndarray(actual_argument_slot, v_primal) - else: - launch_ctx.set_arg_ndarray_with_grad(actual_argument_slot, v_primal, v_grad) - elif isinstance(needed, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture): - launch_ctx.set_arg_texture(actual_argument_slot, v.tex) - elif isinstance(needed, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture): - launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex) - elif isinstance(needed, ndarray_type.NdarrayType): - # Element shapes are already specialized in Taichi codegen. - # The shape information for element dims are no longer needed. - # Therefore we strip the element shapes from the shape vector, - # so that it only holds "real" array shapes. - is_soa = needed.layout == Layout.SOA - array_shape = v.shape - if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max: - warnings.warn( - "Ndarray index might be out of int32 boundary but int64 indexing is not supported yet." - ) - if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids: - element_dim = 0 - else: - element_dim = needed.dtype.ndim - array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim] - if isinstance(v, np.ndarray): - if v.flags.c_contiguous: - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0 - ) - elif v.flags.f_contiguous: - # TODO: A better way that avoids copying is saving strides info. - tmp = np.ascontiguousarray(v) - # Purpose: DO NOT GC |tmp|! - tmps.append(tmp) - - def callback(original, updated): - np.copyto(original, np.asfortranarray(updated)) - - callbacks.append(functools.partial(callback, v, tmp)) - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0 - ) + needed_list, provided_list = [], [] + + def flatten_argpack(argpack, argpack_type): + for j, (name, anno) in enumerate(argpack_type.members.items()): + if isinstance(anno, ArgPackType): + flatten_argpack(argpack[name], anno) else: - raise ValueError( - "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) before passing it into taichi kernel." - ) - elif has_pytorch(): - import torch # pylint: disable=C0415 + needed_list.append(anno) + provided_list.append(argpack[name]) - if isinstance(v, torch.Tensor): - if not v.is_contiguous(): - raise ValueError( - "Non contiguous tensors are not supported, please call tensor.contiguous() before passing it into taichi kernel." - ) - taichi_arch = self.runtime.prog.config().arch - - def get_call_back(u, v): - def call_back(): - u.copy_(v) - - return call_back - - # FIXME: only allocate when launching grad kernel - if v.requires_grad and v.grad is None: - v.grad = torch.zeros_like(v) - - tmp = v - if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda: - # Getting a torch CUDA tensor on Taichi non-cuda arch: - # We just replace it with a CPU tensor and by the end of kernel execution we'll use the callback to copy the values back to the original CUDA tensor. - host_v = v.to(device="cpu", copy=True) - tmp = host_v - callbacks.append(get_call_back(v, host_v)) - - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, - int(tmp.data_ptr()), - tmp.element_size() * tmp.nelement(), - array_shape, - int(v.grad.data_ptr()) if v.grad is not None else 0, + if isinstance(_needed, ArgPackType) and isinstance(val, ArgPack): + flatten_argpack(val, _needed) + else: + needed_list, provided_list = [_needed], [val] + + for j, _v in enumerate(needed_list): + needed, provided, v = _v, type(provided_list[j]), provided_list[j] + if actual_argument_slot >= max_arg_num: + exceed_max_arg_num = True + break + # Note: do not use sth like "needed == f32". That would be slow. + if id(needed) in primitive_types.real_type_ids: + if not isinstance(v, (float, int, np.floating, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) + launch_ctx.set_arg_float(actual_argument_slot, float(v)) + elif id(needed) in primitive_types.integer_type_ids: + if not isinstance(v, (int, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) + if is_signed(cook_dtype(needed)): + launch_ctx.set_arg_int(actual_argument_slot, int(v)) + else: + launch_ctx.set_arg_uint(actual_argument_slot, int(v)) + elif isinstance(needed, sparse_matrix_builder): + # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument + launch_ctx.set_arg_uint(actual_argument_slot, v._get_ndarray_addr()) + elif isinstance(needed, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray): + v_primal = v.arr + v_grad = v.grad.arr if v.grad else None + if v_grad is None: + launch_ctx.set_arg_ndarray(actual_argument_slot, v_primal) + else: + launch_ctx.set_arg_ndarray_with_grad(actual_argument_slot, v_primal, v_grad) + elif isinstance(needed, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture): + launch_ctx.set_arg_texture(actual_argument_slot, v.tex) + elif isinstance(needed, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture): + launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex) + elif isinstance(needed, ndarray_type.NdarrayType): + # Element shapes are already specialized in Taichi codegen. + # The shape information for element dims are no longer needed. + # Therefore we strip the element shapes from the shape vector, + # so that it only holds "real" array shapes. + is_soa = needed.layout == Layout.SOA + array_shape = v.shape + if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max: + warnings.warn( + "Ndarray index might be out of int32 boundary but int64 indexing is not supported yet." ) + if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids: + element_dim = 0 else: - raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) - elif has_paddle(): - import paddle # pylint: disable=C0415 - - if isinstance(v, paddle.Tensor): - # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch - def get_call_back(u, v): - def call_back(): - u.copy_(v, False) - - return call_back - - tmp = v.value().get_tensor() - taichi_arch = self.runtime.prog.config().arch - if v.place.is_gpu_place(): - if taichi_arch != _ti_core.Arch.cuda: - # Paddle cuda tensor on Taichi non-cuda arch - host_v = v.cpu() - tmp = host_v.value().get_tensor() + element_dim = needed.dtype.ndim + array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim] + if isinstance(v, np.ndarray): + if v.flags.c_contiguous: + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0 + ) + elif v.flags.f_contiguous: + # TODO: A better way that avoids copying is saving strides info. + tmp = np.ascontiguousarray(v) + # Purpose: DO NOT GC |tmp|! + tmps.append(tmp) + + def callback(original, updated): + np.copyto(original, np.asfortranarray(updated)) + + callbacks.append(functools.partial(callback, v, tmp)) + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0 + ) + else: + raise ValueError( + "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) " + "before passing it into taichi kernel." + ) + elif has_pytorch(): + import torch # pylint: disable=C0415 + + if isinstance(v, torch.Tensor): + if not v.is_contiguous(): + raise ValueError( + "Non contiguous tensors are not supported, please call tensor.contiguous() before " + "passing it into taichi kernel." + ) + taichi_arch = self.runtime.prog.config().arch + + def get_call_back(u, v): + def call_back(): + u.copy_(v) + + return call_back + + # FIXME: only allocate when launching grad kernel + if v.requires_grad and v.grad is None: + v.grad = torch.zeros_like(v) + + tmp = v + if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda: + # Getting a torch CUDA tensor on Taichi non-cuda arch: + # We just replace it with a CPU tensor and by the end of kernel execution we'll use the + # callback to copy the values back to the original CUDA tensor. + host_v = v.to(device="cpu", copy=True) + tmp = host_v callbacks.append(get_call_back(v, host_v)) - elif v.place.is_cpu_place(): - if taichi_arch == _ti_core.Arch.cuda: - # Paddle cpu tensor on Taichi cuda arch - gpu_v = v.cuda() - tmp = gpu_v.value().get_tensor() - callbacks.append(get_call_back(v, gpu_v)) + + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, + int(tmp.data_ptr()), + tmp.element_size() * tmp.nelement(), + array_shape, + int(v.grad.data_ptr()) if v.grad is not None else 0, + ) else: - # Paddle do support many other backends like XPU, NPU, MLU, IPU - raise TaichiRuntimeTypeError(f"Taichi do not support backend {v.place} that Paddle support") - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0 - ) + raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) + elif has_paddle(): + import paddle # pylint: disable=C0415 + + if isinstance(v, paddle.Tensor): + # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch + def get_call_back(u, v): + def call_back(): + u.copy_(v, False) + + return call_back + + tmp = v.value().get_tensor() + taichi_arch = self.runtime.prog.config().arch + if v.place.is_gpu_place(): + if taichi_arch != _ti_core.Arch.cuda: + # Paddle cuda tensor on Taichi non-cuda arch + host_v = v.cpu() + tmp = host_v.value().get_tensor() + callbacks.append(get_call_back(v, host_v)) + elif v.place.is_cpu_place(): + if taichi_arch == _ti_core.Arch.cuda: + # Paddle cpu tensor on Taichi cuda arch + gpu_v = v.cuda() + tmp = gpu_v.value().get_tensor() + callbacks.append(get_call_back(v, gpu_v)) + else: + # Paddle do support many other backends like XPU, NPU, MLU, IPU + raise TaichiRuntimeTypeError( + f"Taichi do not support backend {v.place} that Paddle support" + ) + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0 + ) + else: + raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) else: raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) - else: - raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) - elif isinstance(needed, MatrixType): - if needed.dtype in primitive_types.real_types: + elif isinstance(needed, MatrixType): + if needed.dtype in primitive_types.real_types: - def cast_func(x): - if not isinstance(x, (int, float, np.integer, np.floating)): - raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) - return float(x) + def cast_func(x): + if not isinstance(x, (int, float, np.integer, np.floating)): + raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) + return float(x) - elif needed.dtype in primitive_types.integer_types: + elif needed.dtype in primitive_types.integer_types: - def cast_func(x): - if not isinstance(x, (int, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) - return int(x) + def cast_func(x): + if not isinstance(x, (int, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) + return int(x) - else: - raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.") + else: + raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.") - if needed.ndim == 2: - v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)] + if needed.ndim == 2: + v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)] + else: + v = [cast_func(v[i]) for i in range(needed.n)] + v = needed(*v) + needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) + elif isinstance(needed, StructType): + if not isinstance(v, needed): + raise TaichiRuntimeTypeError.get(i, str(needed), provided) + needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) else: - v = [cast_func(v[i]) for i in range(needed.n)] - v = needed(*v) - needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) - elif isinstance(needed, StructType): - if not isinstance(v, needed): - raise TaichiRuntimeTypeError.get(i, str(needed), provided) - needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) - else: - raise ValueError(f"Argument type mismatch. Expecting {needed}, got {type(v)}.") - actual_argument_slot += 1 + raise ValueError(f"Argument type mismatch. Expecting {needed}, got {type(v)}.") + actual_argument_slot += 1 if exceed_max_arg_num: raise TaichiRuntimeError(