You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
Right now we are constrained to using the same input and output dtypes for all inputs and outputs in the cublas wrapper.
Wrapping the cublas*Ex api's would allow us to partially overcome the restriction and work with mixed dtypes between input and output.
As an example, we're wrapping the cublasSgemm function - which requires all inputs to be a float32, whereas if we changed to wrapping the cublasGemmEx function we could avoid this limitation and have the inputs be a float16 and the output be float32.
Describe the solution you'd like
In detail/cublas_wrapper.cuh we could change:
cublas<t>gemm to cublasGemmEx
cublas<T>gemmBatched to cublasgemmBatchedEx
cublas<t>axpy to cublasaxpyEx
cublas<t>dot to cublasDotEx
etc
setting the cudaDataType parameters in the cublas*Ex calls automatically from the template parameters.
Additional context
One issue is that the cublasEx api's don't support all combinations of type parameters. As an example, the cublasGemmEx api requires the A/B input parameters to have the same type afaict - despite taking two different datatype parameters for these values. https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmEx
Likewise the supported datatype combinations for cublasDotEx requires both inputs and output to have the same type - meaning this change won't have much benefit currently for the linalg::dot function. https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx
The text was updated successfully, but these errors were encountered:
Is your feature request related to a problem? Please describe.
Right now we are constrained to using the same input and output dtypes for all inputs and outputs in the cublas wrapper.
Wrapping the cublas*Ex api's would allow us to partially overcome the restriction and work with mixed dtypes between input and output.
As an example, we're wrapping the cublasSgemm function - which requires all inputs to be a float32, whereas if we changed to wrapping the cublasGemmEx function we could avoid this limitation and have the inputs be a float16 and the output be float32.
Describe the solution you'd like
In detail/cublas_wrapper.cuh we could change:
cublas<t>gemm
tocublasGemmEx
cublas<T>gemmBatched
tocublasgemmBatchedEx
cublas<t>axpy
tocublasaxpyEx
cublas<t>dot
tocublasDotEx
etc
setting the cudaDataType parameters in the cublas*Ex calls automatically from the template parameters.
Additional context
One issue is that the cublasEx api's don't support all combinations of type parameters. As an example, the cublasGemmEx api requires the A/B input parameters to have the same type afaict - despite taking two different datatype parameters for these values. https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmEx
Likewise the supported datatype combinations for cublasDotEx requires both inputs and output to have the same type - meaning this change won't have much benefit currently for the linalg::dot function. https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx
The text was updated successfully, but these errors were encountered: