Skip to content

Commit

Permalink
Update oneAPI.jl for the release 2024.2.0 (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Jul 3, 2024
1 parent 2c0299f commit 457e020
Show file tree
Hide file tree
Showing 7 changed files with 4,055 additions and 3,429 deletions.
2 changes: 1 addition & 1 deletion deps/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
oneAPI_Support_Headers_jll = "24f86df5-245d-5634-a4cc-32433d9800b3"

[compat]
oneAPI_Support_Headers_jll = "=2024.1.0"
oneAPI_Support_Headers_jll = "=2024.2.0"
2 changes: 1 addition & 1 deletion deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if !isfile(joinpath(conda_dir, "condarc-julia.yml"))
mkpath(joinpath(conda_dir, "conda-meta"))
touch(joinpath(conda_dir, "conda-meta", "history"))
end
Conda.add(["dpcpp_linux-64=2024.1.0", "mkl-devel-dpcpp=2024.1.0"], conda_dir;
Conda.add(["dpcpp_linux-64=2024.2.0", "mkl-devel-dpcpp=2024.2.0"], conda_dir;
channel="intel")

Conda.list(conda_dir)
Expand Down
29 changes: 24 additions & 5 deletions deps/generate_interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ using oneAPI_Support_Headers_jll
include("generate_helpers.jl")

blas = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "blas", "buffer_decls.hpp")]
lapack = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "lapack", "lapack.hpp")]
lapack = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "lapack", "lapack.hpp"),
joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "lapack", "scratchpad.hpp")]
sparse = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "spblas", "sparse_structures.hpp"),
joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "spblas", "sparse_auxiliary.hpp"),
joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "spblas", "sparse_operations.hpp")]
Expand Down Expand Up @@ -64,13 +65,15 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
occursin("get_matmat_data", header) && continue # SPARSE routine
occursin("matmat(", header) && continue # SPARSE routine
occursin("gemm_bias", header) && continue # BLAS routine
occursin("heevx", header) && continue # LAPACK routine (compiler bug)
occursin("hegvx", header) && continue # LAPACK routine (compiler bug)
occursin("getri_batch", header) && occursin("ldainv", header) && continue # LAPACK routine

# Check if the routine is a template
template = occursin("template", header)
if template
header = replace(header, "template <typename fp, oneapi::mkl::lapack::internal::is_floating_point<fp> = nullptr> " => "")
header = replace(header, "template <typename fp, oneapi::mkl::lapack::internal::is_real_floating_point<fp> = nullptr> " => "")
header = replace(header, "template <typename fp, oneapi::mkl::lapack::internal::is_complex_floating_point<fp> = nullptr> " => "")

header = replace(header, "template <typename data_t, oneapi::mkl::lapack::internal::is_floating_point<data_t> = nullptr>" => "")
header = replace(header, "template <typename data_t, oneapi::mkl::lapack::internal::is_real_floating_point<data_t> = nullptr>" => "")
header = replace(header, "template <typename data_t, oneapi::mkl::lapack::internal::is_complex_floating_point<data_t> = nullptr>" => "")
Expand Down Expand Up @@ -99,6 +102,7 @@ function generate_headers(library::String, filename::Vector{String}, output::Str

# Replace the types
header = replace(header, "sycl::queue &queue" => "syclQueue_t device_queue")
header = replace(header, "sycl::queue& queue" => "syclQueue_t device_queue")

if library ("blas", "sparse")
header = replace(header, "compute_mode mode = MKL_BLAS_COMPUTE_MODE" => "")
Expand Down Expand Up @@ -214,16 +218,21 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
copy_header = header
copy_header = replace(copy_header, "typename fp_type::value_type" => version_types_header[blas_version])
copy_header = replace(copy_header, "fp_type" => version_types_header[blas_version])
copy_header = replace(copy_header, "fp" => version_types_header[blas_version])
copy_header = replace(copy_header, name_routine => "onemkl$(blas_version)$(name_routine)")
if name_routine ("heevx_scratchpad_size", "hegvx_scratchpad_size")
copy_header = replace(copy_header, "typename float _Complex::value_type" => "float")
copy_header = replace(copy_header, "typename double _Complex::value_type" => "double")
end
if occursin("batch", name_routine) && !occursin("*", header)
copy_header = replace(copy_header, "_batch" => "_batch_strided")
end
push!(signatures, (copy_header, name_routine, blas_version, type_routine, template))
end
else
if isempty(list_versions)
suffix = ""
# The routine "optimize_trsm" has two versions.
suffix = ""
(name_routine == "optimize_trsm") && occursin("columns", header) && (suffix = "_advanced")
name_routine ("set_csr_data", "set_coo_data") && occursin("int64_t", header) && (suffix = "_64")
occursin("batch", name_routine) && !occursin("**", header) && (suffix = "_strided")
Expand Down Expand Up @@ -281,6 +290,13 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
copy_header = replace(copy_header, "_batch" => "_batch_strided")
end
if library == "blas"
# Out-of-place variants of trsm and trmm
if occursin("trsm", header) && occursin("ldc", header)
copy_header = replace(copy_header, "trsm" => "trsm_variant")
end
if occursin("trmm", header) && occursin("ldc", header)
copy_header = replace(copy_header, "trmm" => "trmm_variant")
end
copy_header = replace(copy_header, "compute_mode mode," => "")
copy_header = replace(copy_header, ", compute_mode mode)" => ")")
copy_header = replace(copy_header, "value_or_pointer<float _Complex>" => "float _Complex")
Expand Down Expand Up @@ -380,11 +396,14 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
parameters = replace(parameters, ", double " => ", ")
parameters = replace(parameters, ", **" => ", ")
parameters = replace(parameters, ", *" => ", ")

parameters = replace(parameters, "onemklTranspose *trans," => "convert(trans, group_count),")
parameters = replace(parameters, "onemklTranspose* trans," => "convert(trans, group_count),")
parameters = replace(parameters, "onemklUplo *uplo," => "convert(uplo, group_count),")
parameters = replace(parameters, "onemklUplo* uplo," => "convert(uplo, group_count),")
parameters = replace(parameters, "onemklDiag *diag," => "convert(diag, group_count),")
parameters = replace(parameters, "onemklDiag* diag," => "convert(diag, group_count),")
parameters = replace(parameters, "onemklSide *side," => "convert(side, group_count),")
parameters = replace(parameters, "onemklSide* side," => "convert(side, group_count),")

for type in ("onemklTranspose", "onemklSide", "onemklUplo", "onemklDiag", "onemklGenerate",
"onemklLayout", "onemklJob", "onemklJobsvd", "onemklCompz", "onemklRangev",
Expand Down
Loading

0 comments on commit 457e020

Please sign in to comment.