From 68853501a610c1422f41de48252ace54a4c68d4f Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 18 Dec 2023 18:16:02 +0800 Subject: [PATCH] bump version to 0.0.15.post1 & refactor cutlass code & build for torch 2.1.2 --- .github/workflows/wheels.yml | 2 +- setup.py | 2 +- src/sfast/__init__.py | 2 +- .../cutlass/cutlass_dual_linear_kernel.cu | 17 +-- .../cutlass/cutlass_qlinear_dynamic_kernel.cu | 100 ++++++------------ version.txt | 2 +- 6 files changed, 46 insertions(+), 79 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 3d5ca81..a2a8bff 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -31,8 +31,8 @@ jobs: - "3.10" - "3.11" torch_version: - - "2.1.0" - "2.1.1" + - "2.1.2" cuda_short_version: - "118" - "121" diff --git a/setup.py b/setup.py index 97b05a7..ce778b4 100644 --- a/setup.py +++ b/setup.py @@ -131,7 +131,7 @@ def get_extensions(): if cuda_version >= 1102: extra_compile_args["nvcc"] += [ "--threads", - "4", + "2", "--ptxas-options=-v", ] if platform.system() == "Windows": diff --git a/src/sfast/__init__.py b/src/sfast/__init__.py index 25069ef..0bbe994 100644 --- a/src/sfast/__init__.py +++ b/src/sfast/__init__.py @@ -32,4 +32,4 @@ def new_lru_cache(*args, **kwargs): # This line will be programatically read/write by setup.py. # Leave them at the bottom of this file and don't touch them. -__version__ = "0.0.15" +__version__ = "0.0.15.post1" diff --git a/src/sfast/csrc/operators/cutlass/cutlass_dual_linear_kernel.cu b/src/sfast/csrc/operators/cutlass/cutlass_dual_linear_kernel.cu index a18ee5c..bb70a6d 100644 --- a/src/sfast/csrc/operators/cutlass/cutlass_dual_linear_kernel.cu +++ b/src/sfast/csrc/operators/cutlass/cutlass_dual_linear_kernel.cu @@ -214,14 +214,6 @@ torch::Tensor cutlass_dual_gemm( ElementComputeEpilogue( bias0.has_value() ? 1.0 : 0.0)}, // <- tuple of alpha and beta epilogue2_params}; - // Allocate workspace memory - size_t workspace_size = Gemm::get_workspace_size(arguments); - auto workspace = - torch::empty({static_cast(workspace_size)}, - torch::dtype(torch::kUInt8).device(input.device())); - - torch::DeviceGuard device_guard(input.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cutlass::Status status; Gemm gemm_op; @@ -232,11 +224,20 @@ torch::Tensor cutlass_dual_gemm( "This problem size is not supported by this Gemm implementation: ", cutlass::cutlassGetStatusString(status)); + // Allocate workspace memory + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = + torch::empty({static_cast(workspace_size)}, + torch::dtype(torch::kUInt8).device(input.device())); + status = gemm_op.initialize(arguments, workspace.data_ptr()); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize cutlass gemm: ", cutlass::cutlassGetStatusString(status)); + torch::DeviceGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + status = gemm_op(stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to execute cutlass gemm: ", diff --git a/src/sfast/csrc/operators/cutlass/cutlass_qlinear_dynamic_kernel.cu b/src/sfast/csrc/operators/cutlass/cutlass_qlinear_dynamic_kernel.cu index a21443f..9f15a84 100644 --- a/src/sfast/csrc/operators/cutlass/cutlass_qlinear_dynamic_kernel.cu +++ b/src/sfast/csrc/operators/cutlass/cutlass_qlinear_dynamic_kernel.cu @@ -28,48 +28,36 @@ namespace sm80_space { using SmArch = cutlass::arch::Sm80; constexpr int NumStages = 4; -template struct GemmWrapper { +template struct GemmConfig { using ElementA = scalar_t; using ElementB = int8_t; using ElementOutput = scalar_t; using ElementAccumulator = acc_t; using ElementComputeEpilogue = acc_t; - using Gemm = cutlass::gemm::device::GemmUniversal< - ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput, - LayoutOutput, ElementAccumulator, MMAOp, SmArch, - cutlass::gemm::GemmShape<128, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType::NoBetaScaling>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, NumStages, - 128 / cutlass::sizeof_bits::value, - 128 / cutlass::sizeof_bits::value, - cutlass::arch::OpMultiplyAddMixedInputUpcast, - cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>; + using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; +}; +} // namespace sm80_space - using GemmNoBias = cutlass::gemm::device::GemmUniversal< - ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput, - LayoutOutput, ElementAccumulator, MMAOp, SmArch, - cutlass::gemm::GemmShape<128, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, NumStages, - 128 / cutlass::sizeof_bits::value, - 128 / cutlass::sizeof_bits::value, - cutlass::arch::OpMultiplyAddMixedInputUpcast, - cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>; +using namespace sm80_space; + +template struct GemmWrapper { + using ElementA = typename config::ElementA; + using ElementB = typename config::ElementB; + using ElementOutput = typename config::ElementOutput; + using ElementAccumulator = typename config::ElementAccumulator; + using ElementComputeEpilogue = typename config::ElementComputeEpilogue; - using GemmSmall = cutlass::gemm::device::GemmUniversal< + using ThreadBlockShape = typename config::ThreadBlockShape; + using WarpShape = typename config::WarpShape; + using InstructionShape = typename config::InstructionShape; + + using Gemm = cutlass::gemm::device::GemmUniversal< ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput, LayoutOutput, ElementAccumulator, MMAOp, SmArch, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + ThreadBlockShape, WarpShape, InstructionShape, cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementComputeEpilogue, @@ -80,11 +68,10 @@ template struct GemmWrapper { cutlass::arch::OpMultiplyAddMixedInputUpcast, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>; - using GemmNoBiasSmall = cutlass::gemm::device::GemmUniversal< + using GemmNoBias = cutlass::gemm::device::GemmUniversal< ElementA, LayoutInputA, ElementB, LayoutInputB, ElementOutput, LayoutOutput, ElementAccumulator, MMAOp, SmArch, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, + ThreadBlockShape, WarpShape, InstructionShape, cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementComputeEpilogue, @@ -95,9 +82,6 @@ template struct GemmWrapper { cutlass::arch::OpMultiplyAddMixedInputUpcast, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone>; }; -} // namespace sm80_space - -using namespace sm80_space; void get_input_layout(const torch::Tensor &input, const torch::Tensor &weight, int &B, int &M, int &K, int &N, @@ -185,14 +169,6 @@ cutlass_gemm(const torch::Tensor &input, const torch::Tensor &weight, weight_ref.stride(0), bias_ref.stride(0), output_ref.stride(0)}; - // Allocate workspace memory - size_t workspace_size = Gemm::get_workspace_size(arguments); - auto workspace = - torch::empty({static_cast(workspace_size)}, - torch::dtype(torch::kUInt8).device(input.device())); - - torch::DeviceGuard device_guard(input.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cutlass::Status status; Gemm gemm_op; @@ -203,11 +179,20 @@ cutlass_gemm(const torch::Tensor &input, const torch::Tensor &weight, "This problem size is not supported by this Gemm implementation: ", cutlass::cutlassGetStatusString(status)); + // Allocate workspace memory + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = + torch::empty({static_cast(workspace_size)}, + torch::dtype(torch::kUInt8).device(input.device())); + status = gemm_op.initialize(arguments, workspace.data_ptr()); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize cutlass gemm: ", cutlass::cutlassGetStatusString(status)); + torch::DeviceGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + status = gemm_op(stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to execute cutlass gemm: ", @@ -235,11 +220,9 @@ template <> struct acc_type { using type = float; }; template struct CutlassGemmLauncher { using scalar_t = typename cutlass_type::type; using acc_t = typename acc_type::type; - using GemmWrapper_ = GemmWrapper; + using GemmWrapper_ = GemmWrapper>; using Gemm = typename GemmWrapper_::Gemm; using GemmNoBias = typename GemmWrapper_::GemmNoBias; - using GemmSmall = typename GemmWrapper_::GemmSmall; - using GemmNoBiasSmall = typename GemmWrapper_::GemmNoBiasSmall; static torch::Tensor launch(const torch::Tensor &input, const torch::Tensor &weight, @@ -247,42 +230,25 @@ template struct CutlassGemmLauncher { float dq_scale) { auto N = weight.size(0); auto K = weight.size(1); - auto M = input.numel() / K; - - bool use_small_kernel = M <= Gemm::ThreadblockShape::kM || - N <= Gemm::ThreadblockShape::kN || - K <= Gemm::ThreadblockShape::kK; + // auto M = input.numel() / K; if (K % Gemm::kAlignmentA != 0 || K % Gemm::kAlignmentB != 0 || N % Gemm::kAlignmentC != 0) { - if (K % GemmSmall::kAlignmentA != 0 || K % GemmSmall::kAlignmentB != 0 || - N % GemmSmall::kAlignmentC != 0) { auto weight_ = input.scalar_type() == at::kFloat ? weight.dequantize() : weight.int_repr() .to(input.scalar_type()) .mul_(weight.q_scale()); return cublas_lowp_linear(input, weight_, bias); - } else { - use_small_kernel = true; - } } auto input_ = input.contiguous(); auto weight_ = weight.contiguous(); if (bias.has_value()) { c10::optional bias_; bias_.emplace(bias.value().contiguous()); - if (use_small_kernel) { - return cutlass_gemm(input_, weight_, bias_, dq_scale); - } else { return cutlass_gemm(input_, weight_, bias_, dq_scale); - } } else { - if (use_small_kernel) { - return cutlass_gemm(input_, weight_, bias, dq_scale); - } else { return cutlass_gemm(input_, weight_, bias, dq_scale); - } } } }; diff --git a/version.txt b/version.txt index ceddfb2..06ee82c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.0.15 +0.0.15.post1