Skip to content

Commit

Permalink
JUCX: add request status field.
Browse files Browse the repository at this point in the history
  • Loading branch information
petro-rudenko committed Dec 12, 2020
1 parent 392443a commit ab06b46
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 14 deletions.
24 changes: 22 additions & 2 deletions bindings/java/src/main/java/org/openucx/jucx/ucp/UcpRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.openucx.jucx.UcxCallback;
import org.openucx.jucx.UcxNativeStruct;
import org.openucx.jucx.ucs.UcsConstants;

import java.io.Closeable;
import java.nio.ByteBuffer;
Expand All @@ -21,6 +22,14 @@ public class UcpRequest extends UcxNativeStruct implements Closeable {

private long senderTag;

/**
* The only request that doesn't have automatic status update (no callback)
* so need to go into ucx to check it's real status
*/
private boolean isCloseRequest = false;

private int status = UcsConstants.STATUS.UCS_INPROGRESS;

private UcpRequest(long nativeId) {
setNativeId(nativeId);
}
Expand Down Expand Up @@ -49,7 +58,18 @@ public long getSenderTag() {
* @return whether this request is completed.
*/
public boolean isCompleted() {
return (getNativeId() == null) || isCompletedNative(getNativeId());
if ((getNativeId() != null) && isCloseRequest) {
updateRequestStatus(getNativeId());
}

return (status != UcsConstants.STATUS.UCS_INPROGRESS);
}

/**
* @return status of the current request
*/
public int getStatus() {
return status;
}

/**
Expand All @@ -65,7 +85,7 @@ public void close() {
}
}

private static native boolean isCompletedNative(long ucpRequest);
private native void updateRequestStatus(long ucpRequest);

private static native void closeRequestNative(long ucpRequest);
}
8 changes: 7 additions & 1 deletion bindings/java/src/main/native/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ Java_org_openucx_jucx_ucp_UcpEndpoint_closeNonBlockingNative(JNIEnv *env, jclass
{
ucs_status_ptr_t request = ucp_ep_close_nb((ucp_ep_h)ep_ptr, mode);

return process_request(request, NULL);
jobject jucx_request = process_request(request, NULL);

jclass jucx_request_cls = env->GetObjectClass(jucx_request);
jfieldID is_close_request_field = env->GetFieldID(jucx_request_cls, "isCloseRequest", "Z");
env->SetBooleanField(jucx_request, is_close_request_field, true);

return jucx_request;
}

JNIEXPORT jobject JNICALL
Expand Down
17 changes: 13 additions & 4 deletions bindings/java/src/main/native/jucx_common_def.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ static jclass jucx_request_cls;
static jfieldID native_id_field;
static jfieldID recv_size_field;
static jfieldID sender_tag_field;
static jfieldID request_status;
static jmethodID on_success;
static jmethodID jucx_request_constructor;
static jclass ucp_rkey_cls;
Expand All @@ -40,6 +41,7 @@ extern "C" JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *jvm, void* reserved) {
jucx_request_cls = (jclass) env->NewGlobalRef(jucx_request_cls_local);
jclass jucx_callback_cls = env->FindClass("org/openucx/jucx/UcxCallback");
native_id_field = env->GetFieldID(jucx_request_cls, "nativeId", "Ljava/lang/Long;");
request_status = env->GetFieldID(jucx_request_cls, "status", "I");
recv_size_field = env->GetFieldID(jucx_request_cls, "recvSize", "J");
sender_tag_field = env->GetFieldID(jucx_request_cls, "senderTag", "J");
on_success = env->GetMethodID(jucx_callback_cls, "onSuccess",
Expand Down Expand Up @@ -172,10 +174,16 @@ JNIEnv* get_jni_env()
return (JNIEnv*)env;
}

void jucx_request_update_status(JNIEnv *env, jobject jucx_request, ucs_status_t status)
{
env->SetIntField(jucx_request, request_status, status);
}

static inline void set_jucx_request_completed(JNIEnv *env, jobject jucx_request,
struct jucx_context *ctx)
struct jucx_context *ctx, ucs_status_t status)
{
env->SetObjectField(jucx_request, native_id_field, NULL);
jucx_request_update_status(env, jucx_request, status);
if (ctx != NULL) {
/* sender_tag and length are initialized to 0,
* so try to avoid the overhead of setting them again */
Expand Down Expand Up @@ -238,7 +246,7 @@ UCS_PROFILE_FUNC_VOID(jucx_request_callback, (request, status), void *request, u
}

JNIEnv *env = get_jni_env();
set_jucx_request_completed(env, ctx->jucx_request, ctx);
set_jucx_request_completed(env, ctx->jucx_request, ctx, status);

if (ctx->callback != NULL) {
jucx_call_callback(ctx->callback, ctx->jucx_request, status);
Expand Down Expand Up @@ -285,7 +293,7 @@ UCS_PROFILE_FUNC(jobject, process_request, (request, callback), void *request, j
} else {
// request was completed whether by progress in other thread or inside
// ucp_tag_recv_nb function call.
set_jucx_request_completed(env, jucx_request, ctx);
set_jucx_request_completed(env, jucx_request, ctx, ctx->status);
if (callback != NULL) {
jucx_call_callback(callback, jucx_request, ctx->status);
}
Expand All @@ -296,7 +304,7 @@ UCS_PROFILE_FUNC(jobject, process_request, (request, callback), void *request, j
} else {
jmethodID empty_constructor = env->GetMethodID(jucx_request_cls, "<init>", "()V");
jucx_request = env->NewObject(jucx_request_cls, empty_constructor);
set_jucx_request_completed(env, jucx_request, NULL);
set_jucx_request_completed(env, jucx_request, NULL, UCS_PTR_RAW_STATUS(request));
if (UCS_PTR_IS_ERR(request)) {
JNU_ThrowExceptionByStatus(env, UCS_PTR_STATUS(request));
if (callback != NULL) {
Expand All @@ -315,6 +323,7 @@ jobject process_completed_stream_recv(size_t length, jobject callback)
jobject jucx_request = env->NewObject(jucx_request_cls, jucx_request_constructor, NULL);
env->SetObjectField(jucx_request, native_id_field, NULL);
env->SetLongField(jucx_request, recv_size_field, length);
jucx_request_update_status(env, jucx_request, UCS_OK);
if (callback != NULL) {
jucx_call_callback(callback, jucx_request, UCS_OK);
}
Expand Down
6 changes: 6 additions & 0 deletions bindings/java/src/main/native/jucx_common_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ void stream_recv_callback(void *request, ucs_status_t status, size_t length);
*/
jobject process_request(void *request, jobject callback);

/**
* @ingroup JUCX_REQ
* @brief Utility to update status of JUCX request to corresponding ucx request.
*/
void jucx_request_update_status(JNIEnv *env, jobject jucx_request, ucs_status_t status);

/**
* @brief Call java callback on completed stream recv operation, that didn't invoke callback.
*/
Expand Down
13 changes: 6 additions & 7 deletions bindings/java/src/main/native/request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/

#include "jucx_common_def.h"
#include "org_openucx_jucx_ucp_UcpRequest.h"

#include <ucp/api/ucp.h>
#include <ucs/type/status.h>

JNIEXPORT jboolean JNICALL
Java_org_openucx_jucx_ucp_UcpRequest_isCompletedNative(JNIEnv *env, jclass cls,
jlong ucp_req_ptr)
JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpRequest_updateRequestStatus(JNIEnv *env,
jobject jucx_request, jlong ucp_req_ptr)
{
return ucp_request_check_status((void *)ucp_req_ptr) != UCS_INPROGRESS;
jucx_request_update_status(env, jucx_request,
ucp_request_check_status((void *)ucp_req_ptr));
}

JNIEXPORT void JNICALL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import org.junit.Test;
import org.openucx.jucx.ucp.*;
import org.openucx.jucx.ucs.UcsConstants;

import java.nio.ByteBuffer;
import static org.junit.Assert.*;
Expand All @@ -22,6 +23,7 @@ public void testCancelRequest() throws Exception {
worker.progress();
}

assertEquals(UcsConstants.STATUS.UCS_ERR_CANCELED, recv.getStatus());
assertTrue(recv.isCompleted());
assertNull(recv.getNativeId());

Expand Down

0 comments on commit ab06b46

Please sign in to comment.