From bd7351493e3ae2c0947b1d2fb92605360db4de08 Mon Sep 17 00:00:00 2001
From: Said Mazouz <95222894+smazouz42@users.noreply.github.com>
Date: Wed, 15 May 2024 12:58:50 +0100
Subject: [PATCH] Fix import handling (#49)

This pull request fixes https://github.com/pyccel/pyccel-cuda/issues/48, by implementing a tiny wrapper for CUDA and a wrapper for non-CUDA functionalities only with external 'C'.

**Commit Summary**

-    Implemented new header printer for CUDA.
-    Added CUDA wrapper assignment
-    Instead of wrapping all local headers, wrap only C functions with extern 'C'

---------

Co-authored-by: EmilyBourne <louise.bourne@gmail.com>
Co-authored-by: bauom <40796259+bauom@users.noreply.github.com>
---
 AUTHORS                                     |  1 +
 CHANGELOG.md                                |  3 +-
 pyccel/codegen/printing/cucode.py           | 45 ++++++++----
 pyccel/codegen/python_wrapper.py            |  4 ++
 pyccel/codegen/wrapper/cuda_to_c_wrapper.py | 78 +++++++++++++++++++++
 tests/epyccel/modules/cuda_module.py        | 13 ++++
 tests/epyccel/test_epyccel_modules.py       | 13 ++++
 7 files changed, 143 insertions(+), 14 deletions(-)
 create mode 100644 pyccel/codegen/wrapper/cuda_to_c_wrapper.py
 create mode 100644 tests/epyccel/modules/cuda_module.py

diff --git a/AUTHORS b/AUTHORS
index 6c30ce5830..3dbaa2f249 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -31,3 +31,4 @@ Contributors
 * Farouk Ech-Charef
 * Mustapha Belbiad
 * Varadarajan Rengaraj
+* Said Mazouz
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ce9212abc6..1d99c60127 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,7 +5,8 @@ All notable changes to this project will be documented in this file.
 
 ### Added
 
--   #32 : add support for `nvcc` Compiler and `cuda` language as a possible option.
+-   #32 : Add support for `nvcc` Compiler and `cuda` language as a possible option.
+-   #48 : Fix incorrect handling of imports in `cuda`.
 
 ## \[UNRELEASED\]
 
diff --git a/pyccel/codegen/printing/cucode.py b/pyccel/codegen/printing/cucode.py
index 86146b065b..277d2a3a6a 100644
--- a/pyccel/codegen/printing/cucode.py
+++ b/pyccel/codegen/printing/cucode.py
@@ -52,19 +52,7 @@ def _print_Module(self, expr):
 
         # Print imports last to be sure that all additional_imports have been collected
         imports = [Import(expr.name, Module(expr.name,(),())), *self._additional_imports.values()]
-        c_headers_imports = ''
-        local_imports = ''
-
-        for imp in imports:
-            if imp.source in c_library_headers:
-                c_headers_imports += self._print(imp)
-            else:
-                local_imports += self._print(imp)
-
-        imports = f'{c_headers_imports}\
-                    extern "C"{{\n\
-                    {local_imports}\
-                    }}'
+        imports = ''.join(self._print(i) for i in imports)
 
         code = f'{imports}\n\
                  {global_variables}\n\
@@ -72,3 +60,34 @@ def _print_Module(self, expr):
 
         self.exit_scope()
         return code
+
+    def _print_ModuleHeader(self, expr):
+        self.set_scope(expr.module.scope)
+        self._in_header = True
+        name = expr.module.name
+
+        funcs = ""
+        cuda_headers = ""
+        for f in expr.module.funcs:
+            if not f.is_inline:
+                if 'kernel' in f.decorators:  # Checking for 'kernel' decorator
+                    cuda_headers += self.function_signature(f) + ';\n'
+                else:
+                    funcs += self.function_signature(f) + ';\n'
+        global_variables = ''.join('extern '+self._print(d) for d in expr.module.declarations if not d.variable.is_private)
+        # Print imports last to be sure that all additional_imports have been collected
+        imports = [*expr.module.imports, *self._additional_imports.values()]
+        imports = ''.join(self._print(i) for i in imports)
+
+        self._in_header = False
+        self.exit_scope()
+        function_declaration = f'{cuda_headers}\n\
+                    extern "C"{{\n\
+                    {funcs}\
+                    }}\n'
+        return '\n'.join((f"#ifndef {name.upper()}_H",
+                          f"#define {name.upper()}_H",
+                          global_variables,
+                          function_declaration,
+                          "#endif // {name.upper()}_H\n"))
+
diff --git a/pyccel/codegen/python_wrapper.py b/pyccel/codegen/python_wrapper.py
index 9437727042..62c303fa64 100644
--- a/pyccel/codegen/python_wrapper.py
+++ b/pyccel/codegen/python_wrapper.py
@@ -13,6 +13,7 @@
 from pyccel.codegen.printing.fcode               import FCodePrinter
 from pyccel.codegen.wrapper.fortran_to_c_wrapper import FortranToCWrapper
 from pyccel.codegen.wrapper.c_to_python_wrapper  import CToPythonWrapper
+from pyccel.codegen.wrapper.cuda_to_c_wrapper    import CudaToCWrapper
 from pyccel.codegen.utilities                    import recompile_object
 from pyccel.codegen.utilities                    import copy_internal_library
 from pyccel.codegen.utilities                    import internal_libs
@@ -144,6 +145,9 @@ def create_shared_library(codegen,
                 verbose=verbose)
         timings['Bind C wrapping'] = time.time() - start_bind_c_compiling
         c_ast = bind_c_mod
+    elif language == 'cuda':
+        wrapper = CudaToCWrapper()
+        c_ast = wrapper.wrap(codegen.ast)
     else:
         c_ast = codegen.ast
 
diff --git a/pyccel/codegen/wrapper/cuda_to_c_wrapper.py b/pyccel/codegen/wrapper/cuda_to_c_wrapper.py
new file mode 100644
index 0000000000..c0e24c7c09
--- /dev/null
+++ b/pyccel/codegen/wrapper/cuda_to_c_wrapper.py
@@ -0,0 +1,78 @@
+# coding: utf-8
+#------------------------------------------------------------------------------------------#
+# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
+# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details.     #
+#------------------------------------------------------------------------------------------#
+"""
+Module describing the code-wrapping class : CudaToPythonWrapper
+which creates an interface exposing Cuda code to C.
+"""
+
+from pyccel.ast.bind_c      import BindCModule
+from pyccel.errors.errors   import Errors
+from pyccel.ast.bind_c      import BindCVariable
+from .wrapper               import Wrapper
+
+errors = Errors()
+
+class CudaToCWrapper(Wrapper):
+    """
+    Class for creating a wrapper exposing Cuda code to C.
+
+    While CUDA is typically compatible with C by default.
+    this wrapper becomes necessary in scenarios where specific adaptations
+    or modifications are required to ensure seamless integration with C.
+    """
+
+    def _wrap_Module(self, expr):
+        """
+        Create a Module which is compatible with C.
+
+        Create a Module which provides an interface between C and the
+        Module described by expr.
+
+        Parameters
+        ----------
+        expr : pyccel.ast.core.Module
+            The module to be wrapped.
+
+        Returns
+        -------
+        pyccel.ast.core.BindCModule
+            The C-compatible module.
+        """
+        init_func = expr.init_func
+        if expr.interfaces:
+            errors.report("Interface wrapping is not yet supported for Cuda",
+                      severity='warning', symbol=expr)
+        if expr.classes:
+            errors.report("Class wrapping is not yet supported for Cuda",
+                      severity='warning', symbol=expr)
+
+        variables = [self._wrap(v) for v in expr.variables]
+
+        return BindCModule(expr.name, variables, expr.funcs,
+                init_func=init_func,
+                scope = expr.scope,
+                original_module=expr)
+
+    def _wrap_Variable(self, expr):
+        """
+        Create all objects necessary to expose a module variable to C.
+
+        Create and return the objects which must be printed in the wrapping
+        module in order to expose the variable to C
+
+        Parameters
+        ----------
+        expr : pyccel.ast.variables.Variable
+            The module variable.
+
+        Returns
+        -------
+        pyccel.ast.core.BindCVariable
+            The C-compatible variable. which must be printed in
+            the wrapping module to expose the variable.
+        """
+        return expr.clone(expr.name, new_class = BindCVariable)
+
diff --git a/tests/epyccel/modules/cuda_module.py b/tests/epyccel/modules/cuda_module.py
new file mode 100644
index 0000000000..bb7ae6b98a
--- /dev/null
+++ b/tests/epyccel/modules/cuda_module.py
@@ -0,0 +1,13 @@
+# pylint: disable=missing-function-docstring, missing-module-docstring
+import numpy as np
+
+g = np.float64(9.81)
+r0 = np.float32(1.0)
+rmin = 0.01
+rmax = 1.0
+
+skip_centre = True
+
+method = 3
+
+tiny = np.int32(4)
diff --git a/tests/epyccel/test_epyccel_modules.py b/tests/epyccel/test_epyccel_modules.py
index ad8ae0bd75..223f741bf0 100644
--- a/tests/epyccel/test_epyccel_modules.py
+++ b/tests/epyccel/test_epyccel_modules.py
@@ -200,3 +200,16 @@ def test_awkward_names(language):
     assert mod.function() == modnew.function()
     assert mod.pure() == modnew.pure()
     assert mod.allocate(1) == modnew.allocate(1)
+
+def test_cuda_module(language_with_cuda):
+    import modules.cuda_module as mod
+
+    modnew = epyccel(mod, language=language_with_cuda)
+
+    atts = ('g', 'r0', 'rmin', 'rmax', 'skip_centre',
+            'method', 'tiny')
+    for att in atts:
+        mod_att = getattr(mod, att)
+        modnew_att = getattr(modnew, att)
+        assert mod_att == modnew_att
+        assert type(mod_att) is type(modnew_att)