-
Notifications
You must be signed in to change notification settings - Fork 744
/
kernel_impl.hpp
237 lines (200 loc) · 8.57 KB
/
kernel_impl.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
//==------- kernel_impl.hpp --- SYCL kernel implementation -----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once
#include <detail/context_impl.hpp>
#include <detail/device_impl.hpp>
#include <detail/kernel_arg_mask.hpp>
#include <detail/kernel_info.hpp>
#include <sycl/detail/common.hpp>
#include <sycl/detail/ur.hpp>
#include <sycl/device.hpp>
#include <sycl/ext/oneapi/experimental/root_group.hpp>
#include <sycl/info/info_desc.hpp>
#include <cassert>
#include <memory>
namespace sycl {
inline namespace _V1 {
namespace detail {
// Forward declaration
class kernel_bundle_impl;
using ContextImplPtr = std::shared_ptr<context_impl>;
using KernelBundleImplPtr = std::shared_ptr<kernel_bundle_impl>;
class kernel_impl {
public:
/// Constructs a SYCL kernel instance from a UrKernel
///
/// This constructor is used for plug-in interoperability. It always marks
/// kernel as being created from source.
///
/// \param Kernel is a valid UrKernel instance
/// \param Context is a valid SYCL context
/// \param KernelBundleImpl is a valid instance of kernel_bundle_impl
kernel_impl(ur_kernel_handle_t Kernel, ContextImplPtr Context,
KernelBundleImplPtr KernelBundleImpl,
const KernelArgMask *ArgMask = nullptr);
/// Constructs a SYCL kernel_impl instance from a SYCL device_image,
/// kernel_bundle and / UrKernel.
///
/// \param Kernel is a valid UrKernel instance
/// \param ContextImpl is a valid SYCL context
/// \param KernelBundleImpl is a valid instance of kernel_bundle_impl
kernel_impl(ur_kernel_handle_t Kernel, ContextImplPtr ContextImpl,
DeviceImageImplPtr DeviceImageImpl,
KernelBundleImplPtr KernelBundleImpl,
const KernelArgMask *ArgMask, ur_program_handle_t Program,
std::mutex *CacheMutex);
// This section means the object is non-movable and non-copyable
// There is no need of move and copy constructors in kernel_impl.
// If they need to be added, urKernelRetain method for MKernel
// should be present.
kernel_impl(const kernel_impl &) = delete;
kernel_impl(kernel_impl &&) = delete;
kernel_impl &operator=(const kernel_impl &) = delete;
kernel_impl &operator=(kernel_impl &&) = delete;
~kernel_impl();
/// Gets a valid OpenCL kernel handle
///
/// If this kernel encapsulates an instance of OpenCL kernel, a valid
/// cl_kernel will be returned. If this kernel is a host kernel,
/// an exception with errc::invalid error code will be thrown.
///
/// \return a valid cl_kernel instance
cl_kernel get() const {
getPlugin()->call(urKernelRetain, MKernel);
ur_native_handle_t nativeHandle = 0;
getPlugin()->call(urKernelGetNativeHandle, MKernel, &nativeHandle);
return ur::cast<cl_kernel>(nativeHandle);
}
const PluginPtr &getPlugin() const { return MContext->getPlugin(); }
/// Query information from the kernel object using the info::kernel_info
/// descriptor.
///
/// \return depends on information being queried.
template <typename Param> typename Param::return_type get_info() const;
/// Queries the kernel object for SYCL backend-specific information.
///
/// \return depends on information being queried.
template <typename Param>
typename Param::return_type get_backend_info() const;
/// Query device-specific information from a kernel object using the
/// info::kernel_device_specific descriptor.
///
/// \param Device is a valid SYCL device to query info for.
/// \return depends on information being queried.
template <typename Param>
typename Param::return_type get_info(const device &Device) const;
/// Query device-specific information from a kernel using the
/// info::kernel_device_specific descriptor for a specific device and value.
/// max_sub_group_size is the only valid descriptor for this function.
///
/// \param Device is a valid SYCL device.
/// \param WGSize is the work-group size the sub-group size is requested for.
/// \return depends on information being queried.
template <typename Param>
typename Param::return_type get_info(const device &Device,
const range<3> &WGSize) const;
template <typename Param>
typename Param::return_type ext_oneapi_get_info(const queue &q) const;
/// Get a constant reference to a raw kernel object.
///
/// \return a constant reference to a valid UrKernel instance with raw
/// kernel object.
const ur_kernel_handle_t &getHandleRef() const { return MKernel; }
/// Check if kernel was created from a program that had been created from
/// source.
///
/// \return true if kernel was created from source.
bool isCreatedFromSource() const;
const DeviceImageImplPtr &getDeviceImage() const { return MDeviceImageImpl; }
ur_native_handle_t getNative() const {
const PluginPtr &Plugin = MContext->getPlugin();
if (MContext->getBackend() == backend::opencl)
Plugin->call(urKernelRetain, MKernel);
ur_native_handle_t NativeKernel = 0;
Plugin->call(urKernelGetNativeHandle, MKernel, &NativeKernel);
return NativeKernel;
}
KernelBundleImplPtr get_kernel_bundle() const { return MKernelBundleImpl; }
bool isInterop() const { return MIsInterop; }
ur_program_handle_t getProgramRef() const { return MProgram; }
ContextImplPtr getContextImplPtr() const { return MContext; }
std::mutex &getNoncacheableEnqueueMutex() {
return MNoncacheableEnqueueMutex;
}
const KernelArgMask *getKernelArgMask() const { return MKernelArgMaskPtr; }
std::mutex *getCacheMutex() const { return MCacheMutex; }
private:
ur_kernel_handle_t MKernel = nullptr;
const ContextImplPtr MContext;
const ur_program_handle_t MProgram = nullptr;
bool MCreatedFromSource = true;
const DeviceImageImplPtr MDeviceImageImpl;
const KernelBundleImplPtr MKernelBundleImpl;
bool MIsInterop = false;
std::mutex MNoncacheableEnqueueMutex;
const KernelArgMask *MKernelArgMaskPtr;
std::mutex *MCacheMutex = nullptr;
bool isBuiltInKernel(const device &Device) const;
void checkIfValidForNumArgsInfoQuery() const;
};
template <typename Param>
inline typename Param::return_type kernel_impl::get_info() const {
static_assert(is_kernel_info_desc<Param>::value,
"Invalid kernel information descriptor");
if constexpr (std::is_same_v<Param, info::kernel::num_args>)
checkIfValidForNumArgsInfoQuery();
return get_kernel_info<Param>(this->getHandleRef(), getPlugin());
}
template <>
inline context kernel_impl::get_info<info::kernel::context>() const {
return createSyclObjFromImpl<context>(MContext);
}
template <typename Param>
inline typename Param::return_type
kernel_impl::get_info(const device &Device) const {
if constexpr (std::is_same_v<
Param, info::kernel_device_specific::global_work_size>) {
bool isDeviceCustom = Device.get_info<info::device::device_type>() ==
info::device_type::custom;
if (!isDeviceCustom && !isBuiltInKernel(Device))
throw exception(
sycl::make_error_code(errc::invalid),
"info::kernel_device_specific::global_work_size descriptor may only "
"be used if the device type is device_type::custom or if the kernel "
"is a built-in kernel.");
}
return get_kernel_device_specific_info<Param>(
this->getHandleRef(), getSyclObjImpl(Device)->getHandleRef(),
getPlugin());
}
template <typename Param>
inline typename Param::return_type
kernel_impl::get_info(const device &Device,
const sycl::range<3> &WGSize) const {
return get_kernel_device_specific_info_with_input<Param>(
this->getHandleRef(), getSyclObjImpl(Device)->getHandleRef(), WGSize,
getPlugin());
}
template <>
inline typename ext::oneapi::experimental::info::kernel_queue_specific::
max_num_work_group_sync::return_type
kernel_impl::ext_oneapi_get_info<
ext::oneapi::experimental::info::kernel_queue_specific::
max_num_work_group_sync>(const queue &Queue) const {
const auto &Plugin = getPlugin();
const auto &Handle = getHandleRef();
const auto MaxWorkGroupSize =
Queue.get_device().get_info<info::device::max_work_group_size>();
uint32_t GroupCount = 0;
Plugin->call(urKernelSuggestMaxCooperativeGroupCountExp, Handle,
MaxWorkGroupSize, /* DynamicSharedMemorySize */ 0, &GroupCount);
return GroupCount;
}
} // namespace detail
} // namespace _V1
} // namespace sycl