-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
894 additions
and
109 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#include "ccv_nnc_mfa.hpp" | ||
#include "ccv_nnc_mfa_hash.hpp" | ||
#include "v2/GeluDescriptor.hpp" | ||
#include "v2/GeluKernel.hpp" | ||
#include <simd/simd.h> | ||
using namespace ccv::nnc; | ||
|
||
#include <string> | ||
|
||
// MARK: - C | ||
|
||
void ccv_nnc_mfa_prepare_gelu(mfa::context* context, ccv_nnc_mfa_gelu_params_t params) | ||
{ | ||
// Do nothing now. | ||
} | ||
|
||
void ccv_nnc_mfa_encode_gelu(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gelu_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets) | ||
{ | ||
auto encoder = command_batch->startCommand(); | ||
|
||
int num_tensors = 0; | ||
while (tensors[num_tensors] != nullptr) { | ||
encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors)); | ||
num_tensors += 1; | ||
} | ||
if (params.gradient) { | ||
CCV_NNC_MFA_PRECONDITION(num_tensors == 3); | ||
} else { | ||
CCV_NNC_MFA_PRECONDITION(num_tensors == 2); | ||
} | ||
|
||
GeluDescriptor descriptor; | ||
descriptor.gradient = params.gradient ? 1 : 0; | ||
descriptor.tanh = params.tanh ? 1 : 0; | ||
descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16; | ||
descriptor.length = params.length; | ||
|
||
if (params.length % (4 * 256) == 0) { | ||
descriptor.value = 0; | ||
} else if (params.length % 4 == 0) { | ||
descriptor.value = 1; | ||
} else { | ||
descriptor.value = 2; | ||
} | ||
|
||
auto pool = NS::AutoreleasePool::alloc()->init(); | ||
auto &shaderCache = context->v2_cache; | ||
DeviceProperties dprops = DeviceProperties(); | ||
auto pipelineValue = shaderCache.findKernel<GeluKernel, GeluDescriptor, GeluKernelDescriptor>(descriptor, context->device.get(), dprops); | ||
pool->drain(); | ||
auto kernel = pipelineValue->kernel; | ||
auto pipeline = pipelineValue->pipeline; | ||
|
||
encoder->setComputePipelineState(pipeline.get()); | ||
|
||
if (params.gradient) { | ||
if (tensors[0] == tensors[2]) { | ||
encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); | ||
encoder->useResource(tensors[1], MTL::ResourceUsageRead); | ||
} else if (tensors[1] == tensors[2]) { | ||
encoder->useResource(tensors[0], MTL::ResourceUsageRead); | ||
encoder->useResource(tensors[1], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); | ||
} else { | ||
encoder->useResource(tensors[0], MTL::ResourceUsageRead); | ||
encoder->useResource(tensors[1], MTL::ResourceUsageRead); | ||
encoder->useResource(tensors[2], MTL::ResourceUsageWrite); | ||
} | ||
} else { | ||
if (tensors[0] == tensors[1]) { | ||
encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); | ||
} else { | ||
encoder->useResource(tensors[0], MTL::ResourceUsageRead); | ||
encoder->useResource(tensors[1], MTL::ResourceUsageWrite); | ||
} | ||
} | ||
|
||
const int num_blocks = (params.length + 255) / 256; | ||
MTL::Size gridSize = MTL::Size(num_blocks, 1, 1); | ||
CCV_NNC_MFA_PRECONDITION(gridSize.depth > 0); | ||
encoder->dispatchThreadgroups(gridSize, kernel->threadgroupSize); | ||
|
||
command_batch->finishCommand(encoder); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#ifndef GUARD_ccv_nnc_mfa_gelu_hpp | ||
#define GUARD_ccv_nnc_mfa_gelu_hpp | ||
|
||
typedef struct { | ||
uint8_t gradient; | ||
uint8_t tanh; | ||
uint64_t data_type; | ||
uint32_t length; | ||
} ccv_nnc_mfa_gelu_params_t; | ||
|
||
#ifdef __cplusplus | ||
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" | ||
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" | ||
#include <simd/simd.h> | ||
|
||
namespace ccv { | ||
namespace nnc { | ||
namespace mfa { | ||
namespace gelu { | ||
|
||
class hash { | ||
public: | ||
uint64_t data_type; | ||
uint32_t astride[3]; | ||
uint32_t bstride[3]; | ||
uint32_t cstride[3]; | ||
uint32_t dim[4]; | ||
|
||
hash(ccv_nnc_mfa_gelu_params_t); | ||
}; | ||
|
||
class pipeline { | ||
public: | ||
NS::SharedPtr<MTL::ComputePipelineState> gelu_pso; | ||
|
||
MTL::Size grid_size; | ||
MTL::Size group_size; | ||
|
||
pipeline(context* context, hash hash); | ||
}; | ||
|
||
} // namespace gelu | ||
} // namespace mfa | ||
} // namespace nnc | ||
} // namespace ccv | ||
|
||
extern "C" { | ||
#endif // __cplusplus | ||
|
||
void ccv_nnc_mfa_prepare_gelu(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gelu_params_t params); | ||
void ccv_nnc_mfa_encode_gelu(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gelu_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets); | ||
|
||
#ifdef __cplusplus | ||
} // extern "C" | ||
#endif // __cplusplus | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#include "GeluDescriptor.hpp" | ||
#include "GeluKernel.hpp" | ||
#include "../ccv_nnc_mfa_hash.hpp" | ||
#include "../ccv_nnc_mfa_error.hpp" | ||
|
||
bool GeluDescriptor::operator==(const GeluDescriptor& rhs) const { | ||
return | ||
memoryPrecision == rhs.memoryPrecision && | ||
tanh == rhs.tanh && | ||
gradient == rhs.gradient && | ||
value == rhs.value && | ||
length == rhs.length; | ||
} | ||
|
||
std::size_t std::hash<GeluDescriptor>::operator()(const GeluDescriptor& hash) const noexcept { | ||
using namespace ccv::nnc::mfa::hash; | ||
std::size_t seed = 0; | ||
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.memoryPrecision.value, (unsigned int)hash.value })); | ||
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, (unsigned int)hash.tanh })); | ||
return seed; | ||
} | ||
|
||
std::pair<GeluKernelDescriptor, PipelineValue<GeluKernel> *> GeluDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<GeluKernelDescriptor, std::unique_ptr<GeluKernel>> *const libraryCache) const noexcept { | ||
// The caller is not responsible for calling 'delete' on this pointer. The | ||
// reference is saved in the 'libraryCache'. It will be deallocated whenever | ||
// the shader cache itself is cleaned up. | ||
auto createKernel = | ||
[=](GeluKernelDescriptor descriptor) -> GeluKernel* { | ||
auto iterator = libraryCache->find(descriptor); | ||
if (iterator != libraryCache->end()) { | ||
return iterator->second.get(); | ||
} else { | ||
GeluKernel* kernel = new GeluKernel(descriptor, device); | ||
(*libraryCache)[descriptor] = std::unique_ptr<GeluKernel>(kernel); | ||
return kernel; | ||
} | ||
}; | ||
|
||
GeluKernelDescriptor kernelDesc; | ||
kernelDesc.gradient = gradient; | ||
kernelDesc.tanh = tanh; | ||
kernelDesc.value = value; | ||
kernelDesc.memoryPrecision = memoryPrecision; | ||
|
||
// WARNING: The owner must explicitly retain the compute pipeline. | ||
auto createPipeline = | ||
[=](MTL::Library* library) -> MTL::ComputePipelineState* { | ||
// Set the function constants. | ||
auto constants = NS::TransferPtr | ||
(MTL::FunctionConstantValues::alloc()->init()); | ||
if (value == 0) { | ||
} else if (tanh && value == 1) { | ||
uint32_t count = length / 4; | ||
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0)); | ||
} else { | ||
uint32_t count = length; | ||
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0)); | ||
} | ||
|
||
NS::String* swiftName = NS::String::string("gelu", NS::UTF8StringEncoding); | ||
NS::Error* error = nil; | ||
|
||
auto function = NS::TransferPtr | ||
(library->newFunction(swiftName, constants.get(), &error)); | ||
CCV_NNC_MFA_CHECK_ERROR(error); | ||
|
||
auto pipeline = device->newComputePipelineState(function.get(), &error); | ||
CCV_NNC_MFA_CHECK_ERROR(error); | ||
return pipeline; | ||
}; | ||
GeluKernel* kernel = createKernel(kernelDesc); | ||
auto pipeline = NS::TransferPtr(createPipeline(kernel->library.get())); | ||
|
||
// Force the user to retrieve the return value from the cache. We ensure | ||
// the cache takes ownership, and the pointer doesn't become a zombie | ||
// object. | ||
PipelineValue<GeluKernel>* output = new PipelineValue<GeluKernel> { kernel, pipeline }; | ||
return std::make_pair(kernelDesc, output); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#ifndef MFA_GELUDESCRIPTOR_HPP_ | ||
#define MFA_GELUDESCRIPTOR_HPP_ | ||
|
||
#include <simd/simd.h> | ||
#include <utility> | ||
#include "PipelineValue.hpp" | ||
#include "DeviceProperties.hpp" | ||
#include "GEMMOperandPrecision.hpp" | ||
|
||
struct GeluKernelDescriptor { | ||
uint8_t gradient; | ||
uint8_t tanh; | ||
uint8_t value; | ||
GEMMOperandPrecision memoryPrecision; | ||
constexpr bool operator==(const GeluKernelDescriptor &rhs) const { return value == rhs.value && memoryPrecision == rhs.memoryPrecision && tanh == rhs.tanh && gradient == rhs.gradient; } | ||
}; | ||
|
||
template<> | ||
struct std::hash<GeluKernelDescriptor> | ||
{ | ||
std::size_t operator()(const GeluKernelDescriptor& hash) const noexcept { return (size_t)hash.value; } | ||
}; | ||
|
||
struct GeluKernel; | ||
|
||
struct GeluDescriptor { | ||
uint8_t gradient; | ||
|
||
uint8_t tanh; | ||
|
||
uint8_t value; | ||
|
||
GEMMOperandPrecision memoryPrecision; | ||
|
||
uint32_t length; | ||
|
||
bool operator==(const GeluDescriptor& rhs) const; | ||
|
||
std::pair<GeluKernelDescriptor, PipelineValue<GeluKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map<GeluKernelDescriptor, std::unique_ptr<GeluKernel>> *const libraryCache) const noexcept; | ||
}; | ||
|
||
template<> | ||
struct std::hash<GeluDescriptor> | ||
{ | ||
std::size_t operator()(const GeluDescriptor& hash) const noexcept; | ||
}; | ||
|
||
#endif | ||
|
Oops, something went wrong.