Skip to content

Commit

Permalink
Fix/jni (#152)
Browse files Browse the repository at this point in the history
* Use io/shapelets/khiva/KhivaException instead of java/lang/Exception
* Refactor JNI bindings to remove C Macro functions
  • Loading branch information
raulbocanegra authored Apr 24, 2020
1 parent ff8c31c commit f61cc00
Show file tree
Hide file tree
Showing 12 changed files with 482 additions and 376 deletions.
3 changes: 2 additions & 1 deletion bindings/jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ set(KHIVALIB_JNI_HEADERS ${KHIVALIB_JNI_INC}/khiva_jni/array.h
${KHIVALIB_JNI_INC}/khiva_jni/regression.h
${KHIVALIB_JNI_INC}/khiva_jni/regularization.h
${KHIVALIB_JNI_INC}/khiva_jni/statistics.h
${KHIVALIB_JNI_INC}/khiva_jni/internal/utils.h)
${KHIVALIB_JNI_INC}/khiva_jni/internal/utils.h
${KHIVALIB_JNI_INC}/khiva_jni/internal/jni_traits.h)

# Sources to add to compilation
set(KHIVALIB_JNI_SOURCES ${KHIVALIB_JNI_SRC}/array.cpp
Expand Down
131 changes: 109 additions & 22 deletions bindings/jni/include/khiva_jni/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,72 @@ extern "C" {
*
* @return The array reference.
*/
#define CREATE_T_ARRAY(Ty, ty, dty) \
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFrom##Ty(JNIEnv *env, jobject, \
j##ty##Array elems, jlongArray dims);
CREATE_T_ARRAY(Float, float, khiva::dtype::f32)
CREATE_T_ARRAY(Double, double, khiva::dtype::f64)
CREATE_T_ARRAY(Int, int, khiva::dtype::s32)
CREATE_T_ARRAY(Boolean, boolean, khiva::dtype::b8)
CREATE_T_ARRAY(Long, long, khiva::dtype::s64)
CREATE_T_ARRAY(Short, short, khiva::dtype::s16)
CREATE_T_ARRAY(Byte, byte, khiva::dtype::u8)
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromFloat(JNIEnv *env, jobject, jfloatArray elems,
jlongArray dims);

#undef CREATE_T_ARRAY
/**
* @brief Creates an Array object of Double.
*
* @param elems Data used in order to create the array.
* @param dims Cardinality of dimensions of the data.
*
* @return The array reference.
*/
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromDouble(JNIEnv *env, jobject, jdoubleArray elems,
jlongArray dims);
/**
* @brief Creates an Array object of Int.
*
* @param elems Data used in order to create the array.
* @param dims Cardinality of dimensions of the data.
*
* @return The array reference.
*/
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromInt(JNIEnv *env, jobject, jintArray elems,
jlongArray dims);
/**
* @brief Creates an Array object of Boolean.
*
* @param elems Data used in order to create the array.
* @param dims Cardinality of dimensions of the data.
*
* @return The array reference.
*/
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromBoolean(JNIEnv *env, jobject, jbooleanArray elems,
jlongArray dims);

/**
* @brief Creates an Array object of Long.
*
* @param elems Data used in order to create the array.
* @param dims Cardinality of dimensions of the data.
*
* @return The array reference.
*/
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromLong(JNIEnv *env, jobject, jlongArray elems,
jlongArray dims);

/**
* @brief Creates an Array object of Short.
*
* @param elems Data used in order to create the array.
* @param dims Cardinality of dimensions of the data.
*
* @return The array reference.
*/
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromShort(JNIEnv *env, jobject, jshortArray elems,
jlongArray dims);

/**
* @brief Creates an Array object of Byte.
*
* @param elems Data used in order to create the array.
* @param dims Cardinality of dimensions of the data.
*
* @return The array reference.
*/
JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromByte(JNIEnv *env, jobject, jbyteArray elems,
jlongArray dims);

/**
* @brief Creates an Array object of Float Complex.
Expand Down Expand Up @@ -62,20 +116,53 @@ JNIEXPORT jlong JNICALL Java_io_shapelets_khiva_Array_createArrayFromDoubleCompl
JNIEXPORT void JNICALL Java_io_shapelets_khiva_Array_deleteArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host (Float, Double, Int, Boolean, Long, Short or Byte).
* @brief Retrieves data from the device to host.
*
* @return Array with the data.
*/
JNIEXPORT jfloatArray JNICALL Java_io_shapelets_khiva_Array_getFloatFromArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host.
*
* @return Array with the data.
*/
JNIEXPORT jdoubleArray JNICALL Java_io_shapelets_khiva_Array_getDoubleFromArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host.
*
* @return Array with the data.
*/
JNIEXPORT jintArray JNICALL Java_io_shapelets_khiva_Array_getIntFromArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host.
*
* @return Array with the data.
*/
JNIEXPORT jbooleanArray JNICALL Java_io_shapelets_khiva_Array_getBooleanFromArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host.
*
* @return Array with the data.
*/
JNIEXPORT jlongArray JNICALL Java_io_shapelets_khiva_Array_getLongFromArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host.
*
* @return Array with the data.
*/
JNIEXPORT jshortArray JNICALL Java_io_shapelets_khiva_Array_getShortFromArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host.
*
* @return Array with the data.
*/
#define GET_T_FROM_ARRAY(Ty, ty) \
JNIEXPORT j##ty##Array JNICALL Java_io_shapelets_khiva_Array_get##Ty##FromArray(JNIEnv *env, jobject thisObj);
GET_T_FROM_ARRAY(Float, float)
GET_T_FROM_ARRAY(Double, double)
GET_T_FROM_ARRAY(Int, int)
GET_T_FROM_ARRAY(Boolean, boolean)
GET_T_FROM_ARRAY(Long, long)
GET_T_FROM_ARRAY(Short, short)
GET_T_FROM_ARRAY(Byte, byte)
#undef GET_T_FROM_ARRAY
JNIEXPORT jbyteArray JNICALL Java_io_shapelets_khiva_Array_getByteFromArray(JNIEnv *env, jobject thisObj);

/**
* @brief Retrieves data from the device to host (Double Complex).
Expand Down
104 changes: 104 additions & 0 deletions bindings/jni/include/khiva_jni/internal/jni_traits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) 2019 Shapelets.io
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef KHIVA_BINDINGJAVA_INTERNAL_JNI_TRAITS_H
#define KHIVA_BINDINGJAVA_INTERNAL_JNI_TRAITS_H

#include <jni.h>
#include <khiva/array.h>

namespace khiva {
namespace jni {

template <typename T>
struct ArrayTraits {};

template <>
struct ArrayTraits<float> {
using JavaType = jfloat;
using JavaTypePtr = jfloat *;
using JavaArrayType = jfloatArray;

static constexpr auto type = khiva::dtype::f32;
static constexpr auto newArray = &JNIEnv::NewFloatArray;
static constexpr auto getArrayElements = &JNIEnv::GetFloatArrayElements;
static constexpr auto releaseArrayElements = &JNIEnv::ReleaseFloatArrayElements;
};

template <>
struct ArrayTraits<double> {
using JavaType = jdouble;
using JavaTypePtr = jdouble *;
using JavaArrayType = jdoubleArray;

static constexpr auto type = khiva::dtype::f64;
static constexpr auto newArray = &JNIEnv::NewDoubleArray;
static constexpr auto getArrayElements = &JNIEnv::GetDoubleArrayElements;
static constexpr auto releaseArrayElements = &JNIEnv::ReleaseDoubleArrayElements;
};

template <>
struct ArrayTraits<int> {
using JavaType = jint;
using JavaTypePtr = jint *;
using JavaArrayType = jintArray;

static constexpr auto type = khiva::dtype::s32;
static constexpr auto newArray = &JNIEnv::NewIntArray;
static constexpr auto getArrayElements = &JNIEnv::GetIntArrayElements;
static constexpr auto releaseArrayElements = &JNIEnv::ReleaseIntArrayElements;
};

template <>
struct ArrayTraits<bool> {
using JavaType = jboolean;
using JavaTypePtr = jboolean *;
using JavaArrayType = jbooleanArray;

static constexpr auto type = khiva::dtype::b8;
static constexpr auto newArray = &JNIEnv::NewBooleanArray;
static constexpr auto getArrayElements = &JNIEnv::GetBooleanArrayElements;
static constexpr auto releaseArrayElements = &JNIEnv::ReleaseBooleanArrayElements;
};

template <>
struct ArrayTraits<long> {
using JavaType = jlong;
using JavaTypePtr = jlong *;
using JavaArrayType = jlongArray;

static constexpr auto type = khiva::dtype::s64;
static constexpr auto newArray = &JNIEnv::NewLongArray;
static constexpr auto getArrayElements = &JNIEnv::GetLongArrayElements;
static constexpr auto releaseArrayElements = &JNIEnv::ReleaseLongArrayElements;
};

template <>
struct ArrayTraits<short> {
using JavaType = jshort;
using JavaTypePtr = jshort *;
using JavaArrayType = jshortArray;

static constexpr auto type = khiva::dtype::s16;
static constexpr auto newArray = &JNIEnv::NewShortArray;
static constexpr auto getArrayElements = &JNIEnv::GetShortArrayElements;
static constexpr auto releaseArrayElements = &JNIEnv::ReleaseShortArrayElements;
};

template <>
struct ArrayTraits<jbyte> {
using JavaType = jbyte;
using JavaTypePtr = jbyte *;
using JavaArrayType = jbyteArray;

static constexpr auto type = khiva::dtype::u8;
static constexpr auto newArray = &JNIEnv::NewByteArray;
static constexpr auto getArrayElements = &JNIEnv::GetByteArrayElements;
static constexpr auto releaseArrayElements = &JNIEnv::ReleaseByteArrayElements;
};
} // namespace jni
} // namespace khiva
#endif
8 changes: 4 additions & 4 deletions bindings/jni/include/khiva_jni/internal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jlong KhivaCall(JNIEnv *env, Func f, jlong ref, Args &&... args) {
auto result = f(arr, std::forward<Args>(args)...);
return reinterpret_cast<jlong>(new af::array(result));
} catch (const std::exception &e) {
auto exceptionClass = env->FindClass("java/lang/Exception");
auto exceptionClass = env->FindClass("io/shapelets/khiva/KhivaException");
env->ThrowNew(exceptionClass, e.what());
} catch (...) {
auto exceptionClass = env->FindClass("java/lang/Exception");
auto exceptionClass = env->FindClass("io/shapelets/khiva/KhivaException");
env->ThrowNew(exceptionClass, "Unknown error executing native function");
}
return 0;
Expand All @@ -36,10 +36,10 @@ jlong KhivaCallTwoArrays(JNIEnv *env, Func f, jlong ref_a, jlong ref_b, Args &&.
auto result = f(arr_a, arr_b, std::forward<Args>(args)...);
return reinterpret_cast<jlong>(new af::array(result));
} catch (const std::exception &e) {
auto exceptionClass = env->FindClass("java/lang/Exception");
auto exceptionClass = env->FindClass("io/shapelets/khiva/KhivaException");
env->ThrowNew(exceptionClass, e.what());
} catch (...) {
auto exceptionClass = env->FindClass("java/lang/Exception");
auto exceptionClass = env->FindClass("io/shapelets/khiva/KhivaException");
env->ThrowNew(exceptionClass, "Unknown error executing native function");
}
return 0;
Expand Down
Loading

0 comments on commit f61cc00

Please sign in to comment.