Skip to content

Commit

Permalink
Add support for kernels (#42)
Browse files Browse the repository at this point in the history
This pull request addresses issue #28 by implementing a new feature in
Pyccel that allows users to define custom GPU kernels. The syntax for
creating these kernels is inspired by Numba. and I also need to fix
issue #45 for testing purposes

**Commit Summary**

- Introduced KernelCall class
- Added cuda printer methods _print_KernelCall and _print_FunctionDef to
generate the corresponding CUDA representation for both kernel calls and
definitions
- Added IndexedFunctionCall  represents an indexed function call
- Added CUDA module and cuda.synchronize()
- Fixing a bug that I found in the header: it does not import the
necessary header for the used function

---------

Co-authored-by: EmilyBourne <[email protected]>
Co-authored-by: bauom <[email protected]>
Co-authored-by: Emily Bourne <[email protected]>
  • Loading branch information
4 people committed Jul 26, 2024
1 parent 7ad90da commit b3de549
Show file tree
Hide file tree
Showing 19 changed files with 599 additions and 9 deletions.
1 change: 1 addition & 0 deletions .dict_custom.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,4 @@ indexable
traceback
STC
gFTL
GPUs
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ All notable changes to this project will be documented in this file.

- #32 : Add support for `nvcc` Compiler and `cuda` language as a possible option.
- #48 : Fix incorrect handling of imports in `cuda`.
- #42 : Add support for custom kernel in`cuda`.
- #42 : Add Cuda module to Pyccel. Add support for `cuda.synchronize` function.

## \[UNRELEASED\]

Expand Down
23 changes: 23 additions & 0 deletions docs/cuda.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Getting started GPU

Pyccel now supports NVIDIA CUDA, empowering users to accelerate numerical computations on GPUs seamlessly. With Pyccel's high-level syntax and automatic code generation, harnessing the power of CUDA becomes effortless. This documentation provides a quick guide to enabling CUDA in Pyccel

## Cuda Decorator

### kernel

The kernel decorator allows the user to declare a CUDA kernel. The kernel can be defined in Python, and the syntax is similar to that of Numba.

```python
from pyccel.decorators import kernel

@kernel
def my_kernel():
pass

blockspergrid = 1
threadsperblock = 1
# Call your kernel function
my_kernel[blockspergrid, threadsperblock]()

```
37 changes: 37 additions & 0 deletions pyccel/ast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
'If',
'IfSection',
'Import',
'IndexedFunctionCall',
'InProgram',
'InlineFunctionDef',
'Interface',
Expand Down Expand Up @@ -2065,6 +2066,42 @@ def _ignore(cls, c):
"""
return c is None or isinstance(c, (FunctionDef, *cls._ignored_types))

class IndexedFunctionCall(FunctionCall):
"""
Represents an indexed function call in the code.
Class representing indexed function calls, encapsulating all
relevant information for such calls within the code base.
Parameters
----------
func : FunctionDef
The function being called.
args : iterable of FunctionCallArgument
The arguments passed to the function.
indexes : iterable of TypedAstNode
The indexes of the function call.
current_function : FunctionDef, optional
The function where the call takes place.
"""
__slots__ = ('_indexes',)
_attribute_nodes = FunctionCall._attribute_nodes + ('_indexes',)
def __init__(self, func, args, indexes, current_function = None):
self._indexes = indexes
super().__init__(func, args, current_function)

@property
def indexes(self):
"""
Indexes of function call.
Represents the indexes of the function call
"""
return self._indexes

class ConstructorCall(FunctionCall):

"""
Expand Down
65 changes: 65 additions & 0 deletions pyccel/ast/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
"""
CUDA Module
This module provides a collection of classes and utilities for CUDA programming.
"""
from pyccel.ast.core import FunctionCall

__all__ = (
'KernelCall',
)

class KernelCall(FunctionCall):
"""
Represents a kernel function call in the code.
The class serves as a representation of a kernel
function call within the codebase.
Parameters
----------
func : FunctionDef
The definition of the function being called.
args : iterable of FunctionCallArgument
The arguments passed to the function.
num_blocks : TypedAstNode
The number of blocks. These objects must have a primitive type of `PrimitiveIntegerType`.
tp_block : TypedAstNode
The number of threads per block. These objects must have a primitive type of `PrimitiveIntegerType`.
current_function : FunctionDef, optional
The function where the call takes place.
"""
__slots__ = ('_num_blocks','_tp_block')
_attribute_nodes = (*FunctionCall._attribute_nodes, '_num_blocks', '_tp_block')

def __init__(self, func, args, num_blocks, tp_block, current_function = None):
self._num_blocks = num_blocks
self._tp_block = tp_block
super().__init__(func, args, current_function)

@property
def num_blocks(self):
"""
The number of blocks in the kernel being called.
The number of blocks in the kernel being called.
"""
return self._num_blocks

@property
def tp_block(self):
"""
The number of threads per block.
The number of threads per block.
"""
return self._tp_block

42 changes: 42 additions & 0 deletions pyccel/ast/cudaext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
"""
CUDA Extension Module
Provides CUDA functionality for code generation.
"""
from .internals import PyccelFunction

from .datatypes import VoidType
from .core import Module, PyccelFunctionDef

__all__ = (
'CudaSynchronize',
)

class CudaSynchronize(PyccelFunction):
"""
Represents a call to Cuda.synchronize for code generation.
This class serves as a representation of the Cuda.synchronize method.
"""
__slots__ = ()
_attribute_nodes = ()
_shape = None
_class_type = VoidType()
def __init__(self):
super().__init__()

cuda_funcs = {
'synchronize' : PyccelFunctionDef('synchronize' , CudaSynchronize),
}

cuda_mod = Module('cuda',
variables=[],
funcs=cuda_funcs.values(),
imports=[]
)

4 changes: 3 additions & 1 deletion pyccel/ast/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .literals import LiteralInteger, LiteralEllipsis, Nil
from .mathext import math_mod
from .sysext import sys_mod
from .cudaext import cuda_mod

from .numpyext import (NumpyEmpty, NumpyArray, numpy_mod,
NumpyTranspose, NumpyLinspace)
Expand All @@ -49,7 +50,8 @@
decorators_mod = Module('decorators',(),
funcs = [PyccelFunctionDef(d, PyccelFunction) for d in pyccel_decorators.__all__])
pyccel_mod = Module('pyccel',(),(),
imports = [Import('decorators', decorators_mod)])
imports = [Import('decorators', decorators_mod),
Import('cuda', cuda_mod)])

# TODO add documentation
builtin_import_registry = Module('__main__',
Expand Down
46 changes: 43 additions & 3 deletions pyccel/codegen/printing/cucode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
enabling the direct translation of high-level Pyccel expressions into CUDA code.
"""

from pyccel.codegen.printing.ccode import CCodePrinter, c_library_headers
from pyccel.codegen.printing.ccode import CCodePrinter

from pyccel.ast.core import Import, Module
from pyccel.ast.core import Import, Module
from pyccel.ast.literals import Nil

from pyccel.errors.errors import Errors
from pyccel.errors.errors import Errors


errors = Errors()
Expand Down Expand Up @@ -61,6 +62,44 @@ def _print_Module(self, expr):
self.exit_scope()
return code

def function_signature(self, expr, print_arg_names = True):
"""
Get the Cuda representation of the function signature.
Extract from the function definition `expr` all the
information (name, input, output) needed to create the
function signature and return a string describing the
function.
This is not a declaration as the signature does not end
with a semi-colon.
Parameters
----------
expr : FunctionDef
The function definition for which a signature is needed.
print_arg_names : bool, default : True
Indicates whether argument names should be printed.
Returns
-------
str
Signature of the function.
"""
cuda_decorater = '__global__' if 'kernel' in expr.decorators else ''
c_function_signature = super().function_signature(expr, print_arg_names)
return f'{cuda_decorater} {c_function_signature}'

def _print_KernelCall(self, expr):
func = expr.funcdef
args = [a.value or Nil() for a in expr.args]

args = ', '.join(self._print(a) for a in args)
return f"{func.name}<<<{expr.num_blocks}, {expr.tp_block}>>>({args});\n"

def _print_CudaSynchronize(self, expr):
return 'cudaDeviceSynchronize();\n'

def _print_ModuleHeader(self, expr):
self.set_scope(expr.module.scope)
self._in_header = True
Expand All @@ -87,6 +126,7 @@ def _print_ModuleHeader(self, expr):
}}\n'
return '\n'.join((f"#ifndef {name.upper()}_H",
f"#define {name.upper()}_H",
imports,
global_variables,
function_declaration,
"#endif // {name.upper()}_H\n"))
Expand Down
10 changes: 10 additions & 0 deletions pyccel/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
"""
This module is for exposing the CudaSubmodule functions.
"""
from .cuda_sync_primitives import synchronize

__all__ = ['synchronize']
16 changes: 16 additions & 0 deletions pyccel/cuda/cuda_sync_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
"""
This submodule contains CUDA methods for Pyccel.
"""


def synchronize():
"""
Synchronize CUDA device execution.
Synchronize CUDA device execution.
"""

32 changes: 32 additions & 0 deletions pyccel/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'sympy',
'template',
'types',
'kernel'
)


Expand Down Expand Up @@ -109,3 +110,34 @@ def allow_negative_index(f,*args):
def identity(f):
return f
return identity

def kernel(f):
"""
Decorator for marking a Python function as a kernel.
This class serves as a decorator to mark a Python function
as a kernel function, typically used for GPU computations.
This allows the function to be indexed with the number of blocks and threads.
Parameters
----------
f : function
The function to which the decorator is applied.
Returns
-------
KernelAccessor
A class representing the kernel function.
"""
class KernelAccessor:
"""
Class representing the kernel function.
Class representing the kernel function.
"""
def __init__(self, f):
self._f = f
def __getitem__(self, args):
return self._f

return KernelAccessor(f)
8 changes: 8 additions & 0 deletions pyccel/errors/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@
WRONG_LINSPACE_ENDPOINT = 'endpoint argument must be boolean'
NON_LITERAL_KEEP_DIMS = 'keep_dims argument must be a literal, otherwise rank is unknown'
NON_LITERAL_AXIS = 'axis argument must be a literal, otherwise pyccel cannot determine which dimension to operate on'
MISSING_KERNEL_CONFIGURATION = 'Kernel launch configuration not specified'
INVALID_KERNEL_LAUNCH_CONFIG = 'Expected exactly 2 parameters for kernel launch'
INVALID_KERNEL_CALL_BP_GRID = 'Invalid Block per grid parameter for Kernel call'
INVALID_KERNEL_CALL_TP_BLOCK = 'Invalid Thread per Block parameter for Kernel call'




Loading

0 comments on commit b3de549

Please sign in to comment.