From e7139b0f15d0b5e9ebd8ea36c3982eb60fc1d42f Mon Sep 17 00:00:00 2001
From: Nikita Kornev <nikita.kornev@intel.com>
Date: Mon, 16 Oct 2023 15:12:44 +0200
Subject: [PATCH] [SYCL] Implement interface of sycl_ext_oneapi_prefetch
 (#11458)

Spec:
https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_prefetch.asciidoc

Properties are not yet fully functional and being ignored, they require other changes in the SW stack to be properly passed through SPIR-V layer. Will be done in follow-up patches.
---
 .../SYCLLowerIR/CompileTimePropertiesPass.cpp |  10 +-
 .../sycl/ext/oneapi/experimental/prefetch.hpp | 269 ++++++++++++++++++
 sycl/include/sycl/sycl.hpp                    |   1 +
 sycl/source/feature_test.hpp.in               |   1 +
 sycl/test/extensions/prefetch.cpp             |  62 ++++
 5 files changed, 338 insertions(+), 5 deletions(-)
 create mode 100644 sycl/include/sycl/ext/oneapi/experimental/prefetch.hpp
 create mode 100644 sycl/test/extensions/prefetch.cpp

diff --git a/llvm/lib/SYCLLowerIR/CompileTimePropertiesPass.cpp b/llvm/lib/SYCLLowerIR/CompileTimePropertiesPass.cpp
index 8714ec70493b5..b16b6b515d04f 100644
--- a/llvm/lib/SYCLLowerIR/CompileTimePropertiesPass.cpp
+++ b/llvm/lib/SYCLLowerIR/CompileTimePropertiesPass.cpp
@@ -628,13 +628,13 @@ bool CompileTimePropertiesPass::transformSYCLPropertiesAnnotation(
   // Read the annotation values and create the new annotation string.
   std::string NewAnnotString = "";
   auto Properties = parseSYCLPropertiesString(M, IntrInst);
-  for (auto &Property : Properties) {
+  for (const auto &[PropName, PropVal] : Properties) {
     // sycl-alignment is converted to align on
     // previous parseAlignmentAndApply(), dropping here
-    if (*Property.first == "sycl-alignment")
+    if (PropName == "sycl-alignment")
       continue;
 
-    auto DecorIt = SpirvDecorMap.find(*Property.first);
+    auto DecorIt = SpirvDecorMap.find(*PropName);
     if (DecorIt == SpirvDecorMap.end())
       continue;
     uint32_t DecorCode = DecorIt->second.Code;
@@ -644,8 +644,8 @@ bool CompileTimePropertiesPass::transformSYCLPropertiesAnnotation(
     // string values are handled correctly. Note that " around values are
     // always valid, even if the decoration parameters are not strings.
     NewAnnotString += "{" + std::to_string(DecorCode);
-    if (Property.second)
-      NewAnnotString += ":\"" + Property.second->str() + "\"";
+    if (PropVal)
+      NewAnnotString += ":\"" + PropVal->str() + "\"";
     NewAnnotString += "}";
   }
 
diff --git a/sycl/include/sycl/ext/oneapi/experimental/prefetch.hpp b/sycl/include/sycl/ext/oneapi/experimental/prefetch.hpp
new file mode 100644
index 0000000000000..9271af5059402
--- /dev/null
+++ b/sycl/include/sycl/ext/oneapi/experimental/prefetch.hpp
@@ -0,0 +1,269 @@
+//==--------------- prefetch.hpp --- SYCL prefetch extension ---------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#pragma once
+
+#include <CL/__spirv/spirv_ops.hpp>
+#include <sycl/ext/oneapi/properties/properties.hpp>
+
+namespace sycl {
+inline namespace _V1 {
+namespace ext::oneapi::experimental {
+
+enum class cache_level { L1 = 0, L2 = 1, L3 = 2, L4 = 3 };
+
+struct nontemporal;
+
+struct prefetch_hint_key {
+  template <cache_level Level, typename Hint>
+  using value_t =
+      property_value<prefetch_hint_key,
+                     std::integral_constant<cache_level, Level>, Hint>;
+};
+
+template <cache_level Level, typename Hint>
+inline constexpr prefetch_hint_key::value_t<Level, Hint> prefetch_hint;
+
+inline constexpr prefetch_hint_key::value_t<cache_level::L1, void>
+    prefetch_hint_L1;
+inline constexpr prefetch_hint_key::value_t<cache_level::L2, void>
+    prefetch_hint_L2;
+inline constexpr prefetch_hint_key::value_t<cache_level::L3, void>
+    prefetch_hint_L3;
+inline constexpr prefetch_hint_key::value_t<cache_level::L4, void>
+    prefetch_hint_L4;
+
+inline constexpr prefetch_hint_key::value_t<cache_level::L1, nontemporal>
+    prefetch_hint_L1_nt;
+inline constexpr prefetch_hint_key::value_t<cache_level::L2, nontemporal>
+    prefetch_hint_L2_nt;
+inline constexpr prefetch_hint_key::value_t<cache_level::L3, nontemporal>
+    prefetch_hint_L3_nt;
+inline constexpr prefetch_hint_key::value_t<cache_level::L4, nontemporal>
+    prefetch_hint_L4_nt;
+
+namespace detail {
+template <> struct IsCompileTimeProperty<prefetch_hint_key> : std::true_type {};
+
+template <cache_level Level, typename Hint>
+struct PropertyMetaInfo<prefetch_hint_key::value_t<Level, Hint>> {
+  static constexpr const char *name = std::is_same_v<Hint, nontemporal>
+                                          ? "sycl-prefetch-hint-nt"
+                                          : "sycl-prefetch-hint";
+  static constexpr int value = static_cast<int>(Level);
+};
+
+template <access::address_space AS>
+inline constexpr bool check_prefetch_AS =
+    AS == access::address_space::global_space ||
+    AS == access::address_space::generic_space;
+
+template <access_mode mode>
+inline constexpr bool check_prefetch_acc_mode =
+    mode == access_mode::read || mode == access_mode::read_write;
+
+template <typename T, typename Properties>
+void prefetch_impl(T *ptr, size_t bytes, Properties properties) {
+#ifdef __SYCL_DEVICE_ONLY__
+  auto *ptrGlobalAS = __SYCL_GenericCastToPtrExplicit_ToGlobal<const char>(ptr);
+  const __attribute__((opencl_global)) char *ptrAnnotated = nullptr;
+  if constexpr (!properties.template has_property<prefetch_hint_key>()) {
+    ptrAnnotated = __builtin_intel_sycl_ptr_annotation(
+        ptrGlobalAS, "sycl-prefetch-hint", static_cast<int>(cache_level::L1));
+  } else {
+    auto prop = properties.template get_property<prefetch_hint_key>();
+    ptrAnnotated = __builtin_intel_sycl_ptr_annotation(
+        ptrGlobalAS, PropertyMetaInfo<decltype(prop)>::name,
+        PropertyMetaInfo<decltype(prop)>::value);
+  }
+  __spirv_ocl_prefetch(ptrAnnotated, bytes);
+#else
+  std::ignore = ptr;
+  std::ignore = bytes;
+  std::ignore = properties;
+#endif
+}
+
+template <typename Group, typename T, typename Properties>
+void joint_prefetch_impl(Group g, T *ptr, size_t bytes, Properties properties) {
+  // Although calling joint_prefetch is functionally equivalent to calling
+  // prefetch from every work-item in a group, native suppurt may be added to to
+  // issue cooperative prefetches more efficiently on some hardware.
+  std::ignore = g;
+  prefetch_impl(ptr, bytes, properties);
+}
+} // namespace detail
+
+template <typename Properties = empty_properties_t>
+void prefetch(void *ptr, Properties properties = {}) {
+  detail::prefetch_impl(ptr, 1, properties);
+}
+
+template <typename Properties = empty_properties_t>
+void prefetch(void *ptr, size_t bytes, Properties properties = {}) {
+  detail::prefetch_impl(ptr, bytes, properties);
+}
+
+template <typename T, typename Properties = empty_properties_t>
+void prefetch(T *ptr, Properties properties = {}) {
+  detail::prefetch_impl(ptr, sizeof(T), properties);
+}
+
+template <typename T, typename Properties = empty_properties_t>
+void prefetch(T *ptr, size_t count, Properties properties = {}) {
+  detail::prefetch_impl(ptr, count * sizeof(T), properties);
+}
+
+template <access::address_space AddressSpace, access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
+prefetch(multi_ptr<void, AddressSpace, IsDecorated> ptr,
+         Properties properties = {}) {
+  detail::prefetch_impl(ptr.get(), 1, properties);
+}
+
+template <access::address_space AddressSpace, access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
+prefetch(multi_ptr<void, AddressSpace, IsDecorated> ptr, size_t bytes,
+         Properties properties = {}) {
+  detail::prefetch_impl(ptr.get(), bytes, properties);
+}
+
+template <typename T, access::address_space AddressSpace,
+          access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
+prefetch(multi_ptr<T, AddressSpace, IsDecorated> ptr,
+         Properties properties = {}) {
+  detail::prefetch_impl(ptr.get(), sizeof(T), properties);
+}
+
+template <typename T, access::address_space AddressSpace,
+          access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
+prefetch(multi_ptr<T, AddressSpace, IsDecorated> ptr, size_t count,
+         Properties properties = {}) {
+  detail::prefetch_impl(ptr.get(), count * sizeof(T), properties);
+}
+
+template <typename DataT, int Dimensions, access_mode AccessMode,
+          access::placeholder IsPlaceholder,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
+                 (Dimensions > 0)>
+prefetch(
+    accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
+    id<Dimensions> offset, Properties properties = {}) {
+  detail::prefetch_impl(&acc[offset], sizeof(DataT), properties);
+}
+
+template <typename DataT, int Dimensions, access_mode AccessMode,
+          access::placeholder IsPlaceholder,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
+                 (Dimensions > 0)>
+prefetch(
+    accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
+    size_t offset, size_t count, Properties properties = {}) {
+  detail::prefetch_impl(&acc[offset], count * sizeof(DataT), properties);
+}
+
+template <typename Group, typename Properties = empty_properties_t>
+std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, void *ptr, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr, 1, properties);
+}
+
+template <typename Group, typename Properties = empty_properties_t>
+std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, void *ptr, size_t bytes, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr, bytes, properties);
+}
+
+template <typename Group, typename T, typename Properties = empty_properties_t>
+std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, T *ptr, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr, sizeof(T), properties);
+}
+
+template <typename Group, typename T, typename Properties = empty_properties_t>
+std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, T *ptr, size_t count, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr, count * sizeof(T), properties);
+}
+
+template <typename Group, access::address_space AddressSpace,
+          access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
+                 sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, multi_ptr<void, AddressSpace, IsDecorated> ptr,
+               Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr.get(), 1, properties);
+}
+
+template <typename Group, access::address_space AddressSpace,
+          access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
+                 sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, multi_ptr<void, AddressSpace, IsDecorated> ptr,
+               size_t bytes, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr.get(), bytes, properties);
+}
+
+template <typename Group, typename T, access::address_space AddressSpace,
+          access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
+                 sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, multi_ptr<T, AddressSpace, IsDecorated> ptr,
+               Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr.get(), sizeof(T), properties);
+}
+
+template <typename Group, typename T, access::address_space AddressSpace,
+          access::decorated IsDecorated,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
+                 sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(Group g, multi_ptr<T, AddressSpace, IsDecorated> ptr,
+               size_t count, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, ptr.get(), count * sizeof(T), properties);
+}
+
+template <typename Group, typename DataT, int Dimensions,
+          access_mode AccessMode, access::placeholder IsPlaceholder,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
+                 (Dimensions > 0) && sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(
+    Group g,
+    accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
+    size_t offset, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, &acc[offset], sizeof(DataT), properties);
+}
+
+template <typename Group, typename DataT, int Dimensions,
+          access_mode AccessMode, access::placeholder IsPlaceholder,
+          typename Properties = empty_properties_t>
+std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
+                 (Dimensions > 0) && sycl::is_group_v<std::decay_t<Group>>>
+joint_prefetch(
+    Group g,
+    accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
+    size_t offset, size_t count, Properties properties = {}) {
+  detail::joint_prefetch_impl(g, &acc[offset], count * sizeof(DataT),
+                              properties);
+}
+
+} // namespace ext::oneapi::experimental
+} // namespace _V1
+} // namespace sycl
diff --git a/sycl/include/sycl/sycl.hpp b/sycl/include/sycl/sycl.hpp
index 5b55801796e01..ca663c981c5e4 100644
--- a/sycl/include/sycl/sycl.hpp
+++ b/sycl/include/sycl/sycl.hpp
@@ -82,6 +82,7 @@
 #include <sycl/ext/oneapi/experimental/cuda/barrier.hpp>
 #include <sycl/ext/oneapi/experimental/fixed_size_group.hpp>
 #include <sycl/ext/oneapi/experimental/opportunistic_group.hpp>
+#include <sycl/ext/oneapi/experimental/prefetch.hpp>
 #include <sycl/ext/oneapi/experimental/tangle_group.hpp>
 #include <sycl/ext/oneapi/filter_selector.hpp>
 #include <sycl/ext/oneapi/functional.hpp>
diff --git a/sycl/source/feature_test.hpp.in b/sycl/source/feature_test.hpp.in
index 9ef771010320b..f0d686a889fe9 100644
--- a/sycl/source/feature_test.hpp.in
+++ b/sycl/source/feature_test.hpp.in
@@ -90,6 +90,7 @@ inline namespace _V1 {
 #define SYCL_EXT_CODEPLAY_MAX_REGISTERS_PER_WORK_GROUP_QUERY 1
 #define SYCL_EXT_ONEAPI_DEVICE_GLOBAL 1
 #define SYCL_EXT_INTEL_QUEUE_IMMEDIATE_COMMAND_LIST 1
+#define SYCL_EXT_ONEAPI_PREFETCH 1
 
 #ifndef __has_include
 #define __has_include(x) 0
diff --git a/sycl/test/extensions/prefetch.cpp b/sycl/test/extensions/prefetch.cpp
new file mode 100644
index 0000000000000..56fd678f83336
--- /dev/null
+++ b/sycl/test/extensions/prefetch.cpp
@@ -0,0 +1,62 @@
+// RUN: %clangxx -fsycl -fsyntax-only %s
+
+#include <sycl/sycl.hpp>
+
+int data[] = {0, 1, 2, 3};
+
+int main() {
+  namespace syclex = sycl::ext::oneapi::experimental;
+  void *dataPtrVoid = data;
+  int *dataPtrInt = data;
+  auto prop = syclex::properties{syclex::prefetch_hint_L1};
+
+  {
+    sycl::buffer<int, 1> buf(data, 4);
+    sycl::queue q;
+    q.submit([&](sycl::handler &h) {
+      auto acc = buf.get_access<sycl::access_mode::read>(h);
+      h.parallel_for<class Kernel>(
+          sycl::nd_range<1>(1, 1), ([=](sycl::nd_item<1> index) {
+            syclex::prefetch(dataPtrVoid, prop);
+            syclex::prefetch(dataPtrVoid, 16, prop);
+
+            syclex::prefetch(dataPtrInt, prop);
+            syclex::prefetch(dataPtrInt, 4, prop);
+
+            auto mPtrVoid = sycl::address_space_cast<
+                sycl::access::address_space::global_space,
+                sycl::access::decorated::yes>(dataPtrVoid);
+            syclex::prefetch(mPtrVoid, prop);
+            syclex::prefetch(mPtrVoid, 16, prop);
+
+            auto mPtrInt = sycl::address_space_cast<
+                sycl::access::address_space::global_space,
+                sycl::access::decorated::yes>(dataPtrInt);
+            syclex::prefetch(mPtrInt, prop);
+            syclex::prefetch(mPtrInt, 8, prop);
+
+            syclex::prefetch(acc, sycl::id(0), prop);
+            syclex::prefetch(acc, sycl::id(0), 4, prop);
+
+            auto g = index.get_group();
+            syclex::joint_prefetch(g, dataPtrVoid, prop);
+            syclex::joint_prefetch(g, dataPtrVoid, 16, prop);
+
+            syclex::joint_prefetch(g, dataPtrInt, prop);
+            syclex::joint_prefetch(g, dataPtrInt, 4, prop);
+
+            syclex::joint_prefetch(g, mPtrVoid, prop);
+            syclex::joint_prefetch(g, mPtrVoid, 16, prop);
+
+            syclex::joint_prefetch(g, mPtrInt, prop);
+            syclex::joint_prefetch(g, mPtrInt, 8, prop);
+
+            syclex::joint_prefetch(g, acc, sycl::id(0), prop);
+            syclex::joint_prefetch(g, acc, sycl::id(0), 4, prop);
+          }));
+    });
+    q.wait();
+  }
+
+  return 0;
+}