Blob Blame History Raw
/*
 * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
 * See file LICENSE for terms.
 */

#include "jucx_common_def.h"
extern "C" {
  #include <ucs/arch/cpu.h>
  #include <ucs/debug/assert.h>
  #include <ucs/debug/debug.h>
}

#include <string.h>    /* memset */
#include <arpa/inet.h> /* inet_addr */
#include <pthread.h>   /* pthread_yield */


static JavaVM *jvm_global;
static jclass jucx_request_cls;
static jfieldID native_id_field;
static jfieldID recv_size_field;
static jmethodID on_success;
static jmethodID jucx_request_constructor;
static jclass ucp_rkey_cls;
static jmethodID ucp_rkey_cls_constructor;

extern "C" JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *jvm, void* reserved) {
    ucs_debug_disable_signals();
    jvm_global = jvm;
    JNIEnv* env;
    if (jvm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_1) != JNI_OK) {
       return JNI_ERR;
    }

    jclass jucx_request_cls_local = env->FindClass("org/openucx/jucx/ucp/UcpRequest");
    jucx_request_cls = (jclass) env->NewGlobalRef(jucx_request_cls_local);
    jclass jucx_callback_cls = env->FindClass("org/openucx/jucx/UcxCallback");
    native_id_field = env->GetFieldID(jucx_request_cls, "nativeId", "Ljava/lang/Long;");
    recv_size_field = env->GetFieldID(jucx_request_cls, "recvSize", "J");
    on_success = env->GetMethodID(jucx_callback_cls, "onSuccess",
                                  "(Lorg/openucx/jucx/ucp/UcpRequest;)V");
    jucx_request_constructor = env->GetMethodID(jucx_request_cls, "<init>", "(J)V");

    jclass ucp_rkey_cls_local = env->FindClass("org/openucx/jucx/ucp/UcpRemoteKey");
    ucp_rkey_cls = (jclass) env->NewGlobalRef(ucp_rkey_cls_local);
    ucp_rkey_cls_constructor = env->GetMethodID(ucp_rkey_cls, "<init>", "(J)V");
    return JNI_VERSION_1_1;
}

extern "C" JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *jvm, void *reserved) {
    JNIEnv* env;
    if (jvm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_1) != JNI_OK) {
        return;
    }

    if (jucx_request_cls != NULL) {
        env->DeleteGlobalRef(jucx_request_cls);
    }
}

bool j2cInetSockAddr(JNIEnv *env, jobject sock_addr, sockaddr_storage& ss,  socklen_t& sa_len)
{
    jfieldID field;
    memset(&ss, 0, sizeof(ss));
    sa_len = 0;

    if (sock_addr == NULL) {
        JNU_ThrowException(env, "j2cInetSockAddr: InetSocketAddr is null");
        return false;
    }

    jclass inetsockaddr_cls = env->GetObjectClass(sock_addr);

    // Get sockAddr->port
    jmethodID getPort = env->GetMethodID(inetsockaddr_cls, "getPort", "()I");
    jint port = env->CallIntMethod(sock_addr, getPort);

    // Get sockAddr->getAddress (InetAddress)
    jmethodID getAddress = env->GetMethodID(inetsockaddr_cls, "getAddress",
                                            "()Ljava/net/InetAddress;");
    jobject inet_address = env->CallObjectMethod(sock_addr, getAddress);

    if (inet_address == NULL) {
        JNU_ThrowException(env, "j2cInetSockAddr: InetSocketAddr.getAddress is null");
        return false;
    }

    jclass inetaddr_cls = env->GetObjectClass(inet_address);

    // Get address family. In Java IPv4 has addressFamily = 1, IPv6 = 2.
    field = env->GetFieldID(inetaddr_cls, "holder",
                            "Ljava/net/InetAddress$InetAddressHolder;");
    jobject inet_addr_holder = env->GetObjectField(inet_address, field);
    jclass inet_addr_holder_cls = env->GetObjectClass(inet_addr_holder);
    field = env->GetFieldID(inet_addr_holder_cls, "family", "I");
    jint family = env->GetIntField(inet_addr_holder, field);

    field = env->GetStaticFieldID(inetaddr_cls, "IPv4", "I");
    const int JAVA_IPV4_FAMILY = env->GetStaticIntField(inetaddr_cls, field);
    field = env->GetStaticFieldID(inetaddr_cls, "IPv6", "I");
    const int JAVA_IPV6_FAMILY = env->GetStaticIntField(inetaddr_cls, field);

    // Get the byte array that stores the IP address bytes in the InetAddress.
    jmethodID get_addr_bytes = env->GetMethodID(inetaddr_cls, "getAddress", "()[B");
    jobject ip_byte_array = env->CallObjectMethod(inet_address, get_addr_bytes);

    if (ip_byte_array == NULL) {
        JNU_ThrowException(env, "j2cInetSockAddr: InetAddr.getAddress.getAddress is null");
        return false;
    }

    jbyteArray addressBytes = static_cast<jbyteArray>(ip_byte_array);

    if (family == JAVA_IPV4_FAMILY) {
        // Deal with Inet4Address instances.
        // We should represent this Inet4Address as an IPv4 sockaddr_in.
        ss.ss_family = AF_INET;
        sockaddr_in &sin = reinterpret_cast<sockaddr_in &>(ss);
        sin.sin_port = htons(port);
        jbyte *dst = reinterpret_cast<jbyte *>(&sin.sin_addr.s_addr);
        env->GetByteArrayRegion(addressBytes, 0, 4, dst);
        sa_len = sizeof(sockaddr_in);
        return true;
    } else if (family == JAVA_IPV6_FAMILY) {
        jclass inet6_addr_cls = env->FindClass("java/net/Inet6Address");
        ss.ss_family = AF_INET6;
        sockaddr_in6& sin6 = reinterpret_cast<sockaddr_in6&>(ss);
        sin6.sin6_port = htons(port);
        // IPv6 address. Copy the bytes...
        jbyte *dst = reinterpret_cast<jbyte *>(&sin6.sin6_addr.s6_addr);
        env->GetByteArrayRegion(addressBytes, 0, 16, dst);
        // ...and set the scope id...
        jmethodID getScopeId = env->GetMethodID(inet6_addr_cls, "getScopeId", "()I");
        sin6.sin6_scope_id = env->CallIntMethod(inet_address, getScopeId);
        sa_len = sizeof(sockaddr_in6);
        return true;
    }
    JNU_ThrowException(env, "Unknown InetAddress family");
    return false;
}

static inline void jucx_context_reset(struct jucx_context* ctx)
{
    ctx->callback = NULL;
    ctx->jucx_request = NULL;
    ctx->status = UCS_INPROGRESS;
    ctx->length = 0;
}

void jucx_request_init(void *request)
{
     struct jucx_context *ctx = (struct jucx_context *)request;
     jucx_context_reset(ctx);
     ucs_spinlock_init(&ctx->lock);
}

JNIEnv* get_jni_env()
{
    void *env;
    jint rs = jvm_global->AttachCurrentThread(&env, NULL);
    ucs_assert_always(rs == JNI_OK);
    return (JNIEnv*)env;
}

static inline void set_jucx_request_completed(JNIEnv *env, jobject jucx_request,
                                              struct jucx_context *ctx)
{
    env->SetObjectField(jucx_request, native_id_field, NULL);
    if ((ctx != NULL) && (ctx->length > 0)) {
        env->SetLongField(jucx_request, recv_size_field, ctx->length);
    }
}

static inline void call_on_success(jobject callback, jobject request)
{
    JNIEnv *env = get_jni_env();
    env->CallVoidMethod(callback, on_success, request);
}

static inline void call_on_error(jobject callback, ucs_status_t status)
{
    if (status == UCS_ERR_CANCELED) {
        ucs_debug("JUCX: Request canceled");
    } else {
        ucs_error("JUCX: request error: %s", ucs_status_string(status));
    }

    JNIEnv *env = get_jni_env();
    jclass callback_cls = env->GetObjectClass(callback);
    jmethodID on_error = env->GetMethodID(callback_cls, "onError", "(ILjava/lang/String;)V");
    jstring error_msg = env->NewStringUTF(ucs_status_string(status));
    env->CallVoidMethod(callback, on_error, status, error_msg);
}

static inline void jucx_call_callback(jobject callback, jobject jucx_request,
                                      ucs_status_t status)
{
    if (status == UCS_OK) {
        UCS_PROFILE_CALL_VOID(call_on_success, callback, jucx_request);
    } else {
        call_on_error(callback, status);
    }
}

UCS_PROFILE_FUNC_VOID(jucx_request_callback, (request, status), void *request, ucs_status_t status)
{
    struct jucx_context *ctx = (struct jucx_context *)request;
    ucs_spin_lock(&ctx->lock);
    if (ctx->jucx_request == NULL) {
        // here because 1 of 2 reasons:
        // 1. progress is in another thread and got here earlier then process_request happened.
        // 2. this callback is inside ucp_tag_recv_nb function.
        ctx->status = status;
        ucs_spin_unlock(&ctx->lock);
        return;
    }

    JNIEnv *env = get_jni_env();
    set_jucx_request_completed(env, ctx->jucx_request, ctx);

    if (ctx->callback != NULL) {
        jucx_call_callback(ctx->callback, ctx->jucx_request, status);
        env->DeleteGlobalRef(ctx->callback);
    }

    env->DeleteGlobalRef(ctx->jucx_request);
    jucx_context_reset(ctx);
    ucp_request_free(request);
    ucs_spin_unlock(&ctx->lock);
}

void recv_callback(void *request, ucs_status_t status, ucp_tag_recv_info_t *info)
{
    struct jucx_context *ctx = (struct jucx_context *)request;
    ctx->length = info->length;
    jucx_request_callback(request, status);
}

void stream_recv_callback(void *request, ucs_status_t status, size_t length)
{
    struct jucx_context *ctx = (struct jucx_context *)request;
    ctx->length = length;
    jucx_request_callback(request, status);
}

UCS_PROFILE_FUNC(jobject, process_request, (request, callback), void *request, jobject callback)
{
    JNIEnv *env = get_jni_env();
    jobject jucx_request = env->NewObject(jucx_request_cls, jucx_request_constructor,
                                          (native_ptr)request);

    if (UCS_PTR_IS_PTR(request)) {
        struct jucx_context *ctx = (struct jucx_context *)request;
        ucs_spin_lock(&ctx->lock);
        if (ctx->status == UCS_INPROGRESS) {
            // request not completed yet, install user callback
            if (callback != NULL) {
                ctx->callback = env->NewGlobalRef(callback);
            }
            ctx->jucx_request = env->NewGlobalRef(jucx_request);
        } else {
            // request was completed whether by progress in other thread or inside
            // ucp_tag_recv_nb function call.
            set_jucx_request_completed(env, jucx_request, ctx);
            if (callback != NULL) {
                jucx_call_callback(callback, jucx_request, ctx->status);
            }
            jucx_context_reset(ctx);
            ucp_request_free(request);
        }
        ucs_spin_unlock(&ctx->lock);
    } else {
        set_jucx_request_completed(env, jucx_request, NULL);
        if (UCS_PTR_IS_ERR(request)) {
            JNU_ThrowExceptionByStatus(env, UCS_PTR_STATUS(request));
            if (callback != NULL) {
                call_on_error(callback, UCS_PTR_STATUS(request));
            }
        } else if (callback != NULL) {
            call_on_success(callback, jucx_request);
        }
    }
    return jucx_request;
}

jobject process_completed_stream_recv(size_t length, jobject callback)
{
    JNIEnv *env = get_jni_env();
    jobject jucx_request = env->NewObject(jucx_request_cls, jucx_request_constructor, NULL);
    env->SetObjectField(jucx_request, native_id_field, NULL);
    env->SetLongField(jucx_request, recv_size_field, length);
    if (callback != NULL) {
        jucx_call_callback(callback, jucx_request, UCS_OK);
    }
    return jucx_request;
}

void jucx_connection_handler(ucp_conn_request_h conn_request, void *arg)
{
    jobject jucx_conn_handler = reinterpret_cast<jobject>(arg);

    JNIEnv *env = get_jni_env();

    // Construct connection request class instance
    jclass conn_request_cls = env->FindClass("org/openucx/jucx/ucp/UcpConnectionRequest");
    jmethodID conn_request_constructor = env->GetMethodID(conn_request_cls, "<init>", "(J)V");
    jobject jucx_conn_request = env->NewObject(conn_request_cls, conn_request_constructor,
                                               (native_ptr)conn_request);

    // Call onConnectionRequest method
    jclass jucx_conn_hndl_cls = env->FindClass("org/openucx/jucx/ucp/UcpListenerConnectionHandler");
    jmethodID on_conn_request = env->GetMethodID(jucx_conn_hndl_cls, "onConnectionRequest",
                                       "(Lorg/openucx/jucx/ucp/UcpConnectionRequest;)V");
    env->CallVoidMethod(jucx_conn_handler, on_conn_request, jucx_conn_request);
    env->DeleteGlobalRef(jucx_conn_handler);
}


jobject new_rkey_instance(JNIEnv *env, ucp_rkey_h rkey)
{
    return env->NewObject(ucp_rkey_cls, ucp_rkey_cls_constructor, (native_ptr)rkey);
}