Skip to content

Commit

Permalink
Add gelu kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 11, 2024
1 parent 4bad02e commit aad0be6
Show file tree
Hide file tree
Showing 10 changed files with 894 additions and 109 deletions.
338 changes: 238 additions & 100 deletions lib/nnc/cmd/gelu/mps/ccv_nnc_gelu_mps.m

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ccv_nnc_mfa_depalettize.hpp"
#include "ccv_nnc_mfa_adam.hpp"
#include "ccv_nnc_mfa_cmul.hpp"
#include "ccv_nnc_mfa_gelu.hpp"
#include "ccv_nnc_mfa_gemm.hpp"
#include "ccv_nnc_mfa_gemv.hpp"
#include "ccv_nnc_mfa_cast.hpp"
Expand Down
84 changes: 84 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_gelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include "ccv_nnc_mfa.hpp"
#include "ccv_nnc_mfa_hash.hpp"
#include "v2/GeluDescriptor.hpp"
#include "v2/GeluKernel.hpp"
#include <simd/simd.h>
using namespace ccv::nnc;

#include <string>

// MARK: - C

void ccv_nnc_mfa_prepare_gelu(mfa::context* context, ccv_nnc_mfa_gelu_params_t params)
{
// Do nothing now.
}

void ccv_nnc_mfa_encode_gelu(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gelu_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets)
{
auto encoder = command_batch->startCommand();

int num_tensors = 0;
while (tensors[num_tensors] != nullptr) {
encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors));
num_tensors += 1;
}
if (params.gradient) {
CCV_NNC_MFA_PRECONDITION(num_tensors == 3);
} else {
CCV_NNC_MFA_PRECONDITION(num_tensors == 2);
}

GeluDescriptor descriptor;
descriptor.gradient = params.gradient ? 1 : 0;
descriptor.tanh = params.tanh ? 1 : 0;
descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16;
descriptor.length = params.length;

if (params.length % (4 * 256) == 0) {
descriptor.value = 0;
} else if (params.length % 4 == 0) {
descriptor.value = 1;
} else {
descriptor.value = 2;
}

auto pool = NS::AutoreleasePool::alloc()->init();
auto &shaderCache = context->v2_cache;
DeviceProperties dprops = DeviceProperties();
auto pipelineValue = shaderCache.findKernel<GeluKernel, GeluDescriptor, GeluKernelDescriptor>(descriptor, context->device.get(), dprops);
pool->drain();
auto kernel = pipelineValue->kernel;
auto pipeline = pipelineValue->pipeline;

encoder->setComputePipelineState(pipeline.get());

if (params.gradient) {
if (tensors[0] == tensors[2]) {
encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
} else if (tensors[1] == tensors[2]) {
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
} else {
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
encoder->useResource(tensors[2], MTL::ResourceUsageWrite);
}
} else {
if (tensors[0] == tensors[1]) {
encoder->useResource(tensors[0], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
} else {
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageWrite);
}
}

const int num_blocks = (params.length + 255) / 256;
MTL::Size gridSize = MTL::Size(num_blocks, 1, 1);
CCV_NNC_MFA_PRECONDITION(gridSize.depth > 0);
encoder->dispatchThreadgroups(gridSize, kernel->threadgroupSize);

command_batch->finishCommand(encoder);
}

57 changes: 57 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_gelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef GUARD_ccv_nnc_mfa_gelu_hpp
#define GUARD_ccv_nnc_mfa_gelu_hpp

typedef struct {
uint8_t gradient;
uint8_t tanh;
uint64_t data_type;
uint32_t length;
} ccv_nnc_mfa_gelu_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include <simd/simd.h>

namespace ccv {
namespace nnc {
namespace mfa {
namespace gelu {

class hash {
public:
uint64_t data_type;
uint32_t astride[3];
uint32_t bstride[3];
uint32_t cstride[3];
uint32_t dim[4];

hash(ccv_nnc_mfa_gelu_params_t);
};

class pipeline {
public:
NS::SharedPtr<MTL::ComputePipelineState> gelu_pso;

MTL::Size grid_size;
MTL::Size group_size;

pipeline(context* context, hash hash);
};

} // namespace gelu
} // namespace mfa
} // namespace nnc
} // namespace ccv

extern "C" {
#endif // __cplusplus

void ccv_nnc_mfa_prepare_gelu(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gelu_params_t params);
void ccv_nnc_mfa_encode_gelu(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_gelu_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

#endif
2 changes: 1 addition & 1 deletion lib/nnc/mfa/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ include ../../config.mk

CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS)

SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp v2/AttentionDescriptor.cpp v2/AttentionKernelDescriptor.cpp v2/AttentionKernel.cpp v2/CMulDescriptor.cpp v2/CMulKernel.cpp
SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gelu.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp v2/AttentionDescriptor.cpp v2/AttentionKernelDescriptor.cpp v2/AttentionKernel.cpp v2/CMulDescriptor.cpp v2/CMulKernel.cpp v2/GeluDescriptor.cpp v2/GeluKernel.cpp

SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS)))

Expand Down
79 changes: 79 additions & 0 deletions lib/nnc/mfa/v2/GeluDescriptor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "GeluDescriptor.hpp"
#include "GeluKernel.hpp"
#include "../ccv_nnc_mfa_hash.hpp"
#include "../ccv_nnc_mfa_error.hpp"

bool GeluDescriptor::operator==(const GeluDescriptor& rhs) const {
return
memoryPrecision == rhs.memoryPrecision &&
tanh == rhs.tanh &&
gradient == rhs.gradient &&
value == rhs.value &&
length == rhs.length;
}

std::size_t std::hash<GeluDescriptor>::operator()(const GeluDescriptor& hash) const noexcept {
using namespace ccv::nnc::mfa::hash;
std::size_t seed = 0;
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.memoryPrecision.value, (unsigned int)hash.value }));
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, (unsigned int)hash.tanh }));
return seed;
}

std::pair<GeluKernelDescriptor, PipelineValue<GeluKernel> *> GeluDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<GeluKernelDescriptor, std::unique_ptr<GeluKernel>> *const libraryCache) const noexcept {
// The caller is not responsible for calling 'delete' on this pointer. The
// reference is saved in the 'libraryCache'. It will be deallocated whenever
// the shader cache itself is cleaned up.
auto createKernel =
[=](GeluKernelDescriptor descriptor) -> GeluKernel* {
auto iterator = libraryCache->find(descriptor);
if (iterator != libraryCache->end()) {
return iterator->second.get();
} else {
GeluKernel* kernel = new GeluKernel(descriptor, device);
(*libraryCache)[descriptor] = std::unique_ptr<GeluKernel>(kernel);
return kernel;
}
};

GeluKernelDescriptor kernelDesc;
kernelDesc.gradient = gradient;
kernelDesc.tanh = tanh;
kernelDesc.value = value;
kernelDesc.memoryPrecision = memoryPrecision;

// WARNING: The owner must explicitly retain the compute pipeline.
auto createPipeline =
[=](MTL::Library* library) -> MTL::ComputePipelineState* {
// Set the function constants.
auto constants = NS::TransferPtr
(MTL::FunctionConstantValues::alloc()->init());
if (value == 0) {
} else if (tanh && value == 1) {
uint32_t count = length / 4;
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0));
} else {
uint32_t count = length;
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0));
}

NS::String* swiftName = NS::String::string("gelu", NS::UTF8StringEncoding);
NS::Error* error = nil;

auto function = NS::TransferPtr
(library->newFunction(swiftName, constants.get(), &error));
CCV_NNC_MFA_CHECK_ERROR(error);

auto pipeline = device->newComputePipelineState(function.get(), &error);
CCV_NNC_MFA_CHECK_ERROR(error);
return pipeline;
};
GeluKernel* kernel = createKernel(kernelDesc);
auto pipeline = NS::TransferPtr(createPipeline(kernel->library.get()));

// Force the user to retrieve the return value from the cache. We ensure
// the cache takes ownership, and the pointer doesn't become a zombie
// object.
PipelineValue<GeluKernel>* output = new PipelineValue<GeluKernel> { kernel, pipeline };
return std::make_pair(kernelDesc, output);
}
49 changes: 49 additions & 0 deletions lib/nnc/mfa/v2/GeluDescriptor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef MFA_GELUDESCRIPTOR_HPP_
#define MFA_GELUDESCRIPTOR_HPP_

#include <simd/simd.h>
#include <utility>
#include "PipelineValue.hpp"
#include "DeviceProperties.hpp"
#include "GEMMOperandPrecision.hpp"

struct GeluKernelDescriptor {
uint8_t gradient;
uint8_t tanh;
uint8_t value;
GEMMOperandPrecision memoryPrecision;
constexpr bool operator==(const GeluKernelDescriptor &rhs) const { return value == rhs.value && memoryPrecision == rhs.memoryPrecision && tanh == rhs.tanh && gradient == rhs.gradient; }
};

template<>
struct std::hash<GeluKernelDescriptor>
{
std::size_t operator()(const GeluKernelDescriptor& hash) const noexcept { return (size_t)hash.value; }
};

struct GeluKernel;

struct GeluDescriptor {
uint8_t gradient;

uint8_t tanh;

uint8_t value;

GEMMOperandPrecision memoryPrecision;

uint32_t length;

bool operator==(const GeluDescriptor& rhs) const;

std::pair<GeluKernelDescriptor, PipelineValue<GeluKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map<GeluKernelDescriptor, std::unique_ptr<GeluKernel>> *const libraryCache) const noexcept;
};

template<>
struct std::hash<GeluDescriptor>
{
std::size_t operator()(const GeluDescriptor& hash) const noexcept;
};

#endif

Loading

0 comments on commit aad0be6

Please sign in to comment.