Blob Blame History Raw
/*
 * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
 * See file LICENSE for terms.
 */

#include "jucx_common_def.h"
#include "org_openucx_jucx_ucp_UcpWorker.h"

/**
 * Bridge method for creating ucp_worker from java
 */
JNIEXPORT jlong JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_createWorkerNative(JNIEnv *env, jclass cls,
                                                       jobject jucx_worker_params,
                                                       jlong context_ptr)
{
    ucp_worker_params_t worker_params = { 0 };
    ucp_worker_h ucp_worker;
    ucp_context_h ucp_context = (ucp_context_h)context_ptr;
    jfieldID field;

    jclass jucx_param_class = env->GetObjectClass(jucx_worker_params);
    field = env->GetFieldID(jucx_param_class, "fieldMask", "J");
    worker_params.field_mask = env->GetLongField(jucx_worker_params, field);

    if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_THREAD_MODE) {
        field = env->GetFieldID(jucx_param_class, "threadMode", "I");
        worker_params.thread_mode = static_cast<ucs_thread_mode_t>(
            env->GetIntField(jucx_worker_params, field));
    }

    if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_CPU_MASK) {
        ucs_cpu_set_t cpu_mask;
        UCS_CPU_ZERO(&cpu_mask);
        field = env->GetFieldID(jucx_param_class, "cpuMask", "Ljava/util/BitSet;");
        jobject cpu_mask_bitset = env->GetObjectField(jucx_worker_params, field);
        jclass bitset_class = env->FindClass("java/util/BitSet");
        jmethodID next_set_bit = env->GetMethodID(bitset_class, "nextSetBit", "(I)I");
        for (jint bit_index = env->CallIntMethod(cpu_mask_bitset, next_set_bit, 0); bit_index >=0;
                  bit_index = env->CallIntMethod(cpu_mask_bitset, next_set_bit, bit_index + 1)) {
            UCS_CPU_SET(bit_index, &cpu_mask);
        }
        worker_params.cpu_mask = cpu_mask;
    }


    if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_EVENTS) {
        field = env->GetFieldID(jucx_param_class, "events", "J");
        worker_params.events = env->GetLongField(jucx_worker_params, field);
    }

    if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_USER_DATA) {
        field = env->GetFieldID(jucx_param_class, "userData", "Ljava/nio/ByteBuffer;");
        jobject user_data = env->GetObjectField(jucx_worker_params, field);
        worker_params.user_data = env->GetDirectBufferAddress(user_data);
    }

    if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_EVENT_FD) {
        field = env->GetFieldID(jucx_param_class, "eventFD", "I");
        worker_params.event_fd = env->GetIntField(jucx_worker_params, field);
    }

    ucs_status_t status = ucp_worker_create(ucp_context, &worker_params, &ucp_worker);
    if (status != UCS_OK) {
        JNU_ThrowExceptionByStatus(env, status);
    }
    return (native_ptr)ucp_worker;
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_releaseWorkerNative(JNIEnv *env, jclass cls,
                                                        jlong ucp_worker_ptr)
{
    ucp_worker_destroy((ucp_worker_h)ucp_worker_ptr);
}


JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_workerGetAddressNative(JNIEnv *env, jclass cls,
                                                           jlong ucp_worker_ptr)
{
    ucp_address_t *addr;
    size_t len;
    ucs_status_t status;

    status = ucp_worker_get_address((ucp_worker_h)ucp_worker_ptr, &addr, &len);

    if (status != UCS_OK) {
        JNU_ThrowExceptionByStatus(env, status);
        return NULL;
    }

    return env->NewDirectByteBuffer(addr, len);
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_releaseAddressNative(JNIEnv *env, jclass cls,
                                                         jlong ucp_worker_ptr,
                                                         jobject ucp_address)
{

    ucp_worker_release_address((ucp_worker_h)ucp_worker_ptr,
                               (ucp_address_t *)env->GetDirectBufferAddress(ucp_address));
}

JNIEXPORT jint JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_progressWorkerNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr)
{
    return ucp_worker_progress((ucp_worker_h)ucp_worker_ptr);
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_flushNonBlockingNative(JNIEnv *env, jclass cls,
                                                           jlong ucp_worker_ptr,
                                                           jobject callback)
{
    ucs_status_ptr_t request = ucp_worker_flush_nb((ucp_worker_h)ucp_worker_ptr, 0,
                                                   jucx_request_callback);

    return process_request(request, callback);
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_waitWorkerNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr)
{
    ucs_status_t status = ucp_worker_wait((ucp_worker_h)ucp_worker_ptr);

    if (status != UCS_OK) {
        JNU_ThrowExceptionByStatus(env, status);
    }
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_signalWorkerNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr)
{
    ucs_status_t status = ucp_worker_signal((ucp_worker_h)ucp_worker_ptr);

    if (status != UCS_OK) {
        JNU_ThrowExceptionByStatus(env, status);
    }
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_recvTaggedNonBlockingNative(JNIEnv *env, jclass cls,
                                                                jlong ucp_worker_ptr,
                                                                jlong laddr, jlong size,
                                                                jlong tag, jlong tagMask,
                                                                jobject callback)
{
    ucs_status_ptr_t request = ucp_tag_recv_nb((ucp_worker_h)ucp_worker_ptr,
                                                (void *)laddr, size,
                                                ucp_dt_make_contig(1), tag, tagMask,
                                                recv_callback);

    ucs_trace_req("JUCX: recv_nb request %p, msg size: %zu, tag: %ld", request, size, tag);

    return process_request(request, callback);
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpWorker_cancelRequestNative(JNIEnv *env, jclass cls,
                                                        jlong ucp_worker_ptr,
                                                        jlong ucp_request_ptr)
{
    ucp_request_cancel((ucp_worker_h)ucp_worker_ptr, (void *)ucp_request_ptr);
}