Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Windows compatibility #1126

Closed
ucgggg opened this issue Oct 5, 2023 · 10 comments
Closed

[QST] Windows compatibility #1126

ucgggg opened this issue Oct 5, 2023 · 10 comments

Comments

@ucgggg
Copy link

ucgggg commented Oct 5, 2023

I'm trying to compile FlashAttention-2 on Windows, which it does not support Windows yet, and got a lot of errors. One of the errors can be reproduced simply with the following code, using only CUTLASS:

int main() {
auto composedLayout = composition(
Swizzle<2, 0, 3> {},
Layout<Shape<_4, _8>, Stride<_8, _1>> {}
);
tile_to_shape(composedLayout, Shape<_64, _256> {});
}

The errors are as follows:
include/cute/algorithm/functional.hpp(104): error : no instance of overloaded function "cute::abs" matches the argument list
include/cute/layout.hpp(590): error : no instance of overloaded function "cute::as_arithmetic_tuple" matches the argument list
......

Visual Studio 2022
CUDA 11.8
Windows 10
CUTLASS 3.2.1

Is there an error in the above code, or does CUTLASS not fully support Windows yet?

@thakkarV
Copy link
Collaborator

thakkarV commented Oct 5, 2023

@mhoemmen

@mhoemmen
Copy link
Contributor

mhoemmen commented Oct 5, 2023

@ucgggg Thanks for reporting! Would you happen to have the opportunity to test with a newer CUDA Toolkit? Windows support is new to the CUTLASS 3.x series, so we've mainly been testing with CUDA 12.2.

@ucgggg
Copy link
Author

ucgggg commented Oct 6, 2023

@mhoemmen Thanks for your reply!
This problem is indeed related to the CUDA version.

I delved into the issue, and then wrote the following code:

// test.cu

#include <iostream>
#include <type_traits>

namespace abstest {
template <auto v>
struct C
{
    static constexpr auto value = v;
};

template <class T, typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
constexpr auto abs(T const& t)
{
    if constexpr(std::is_signed<T>::value) {
        return t < T(0) ? -t : t;
    } else {
        return t;
    }
}

template <auto t>
constexpr C<abs(t)> abs(C<t> a)
{
    return {};
};

struct abs_fn
{
    template <class T>
    constexpr decltype(auto) operator()(T&& arg) const
    {
        return abs(std::forward<T>(arg));
    }
};
}

int main()
{
    std::cout << abstest::abs_fn{}(abstest::C<16> {}).value << std::endl;
}

Compile:

nvcc -std=c++17 test.cu

It works well with nvcc 12.2
But it fails with nvcc 11.8:

test.cu(34): error: no instance of overloaded function "abstest::abs" matches the argument list
            argument types are: (abstest::C<16>)
          detected during instantiation of "decltype(auto) abstest::abs_fn::operator()(T &&) const [with T=abstest::C<16>]"
(41): here
1 error detected in the compilation of "test.cu".

@mhoemmen
Copy link
Contributor

mhoemmen commented Oct 7, 2023

@ucgggg Thanks for testing! Are you able to switch to CUDA 12.2?

@ucgggg
Copy link
Author

ucgggg commented Oct 7, 2023

@mhoemmen I tried to compile FlashAttention-2 with CUDA 12.2 and got a few different errors. But these new errors have nothing to do with CUTLASS. They happen in FlashAttention-2. CUTLASS works well with CUDA 12.2.

So I think CUTLASS is compatible with CUDA 12.2, but not fully compatible with CUDA 11.8. Use this to reproduce:
std::cout << cute::abs_fn{}(cute::C<16> {}).value << std::endl;

I just tested on Windows. I don't know if the same problem occurs on Linux.

@mhoemmen
Copy link
Contributor

mhoemmen commented Oct 9, 2023

@ucgggg Thanks so much for testing CUDA 11.8 on Windows!

We test several different CUDA Toolkit versions on Linux, so I'm pretty confident that 11.8 would work for you there.

@toothache
Copy link

@mhoemmen I tried to compile FlashAttention-2 with CUDA 12.2 and got a few different errors. But these new errors have nothing to do with CUTLASS. They happen in FlashAttention-2. CUTLASS works well with CUDA 12.2.

So I think CUTLASS is compatible with CUDA 12.2, but not fully compatible with CUDA 11.8. Use this to reproduce: std::cout << cute::abs_fn{}(cute::C<16> {}).value << std::endl;

I just tested on Windows. I don't know if the same problem occurs on Linux.

I came across the same issue when trying to build VLLM in windows. For CUDA 11.8, the following check is failed and causing unary operations, such as abs_fn, fail to build since all the functions are expecting the argument to be arithmetic.

// Assertion failed with CUDA 11.8
static_assert(std::is_arithmetic_v<cute::_1> == true);

template <class T,
__CUTE_REQUIRES(is_arithmetic<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
abs(T const& t) {

@mhoemmen
Copy link
Contributor

mhoemmen commented Aug 2, 2024

@toothache Could you please post exactly what version of CUTLASS you used, and how you were calling CUTLASS?

@toothache
Copy link

@toothache Could you please post exactly what version of CUTLASS you used, and how you were calling CUTLASS?

I'm trying to compile VLLM repo on Windows, but then I encountered several build errors related with CUTLASS. Later, I discovered that the build failures were caused by the following checks with CUDA 11.8. Switching to CUDA 12.x resolved those issues.

// Assertion failed with CUDA 11.8
static_assert(std::is_arithmetic_v<cute::_1> == true);

@mhoemmen
Copy link
Contributor

mhoemmen commented Aug 6, 2024

@toothache Thanks for reporting this! I've spawned a separate bug to track this: #1689 . Please put all further discussion there. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants