Blob Blame History Raw
/*
 * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
 * See file LICENSE for terms.
 */

#include "jucx_common_def.h"
#include "org_openucx_jucx_ucp_UcpEndpoint.h"

#include <string.h>    /* memset */

#include <ucp/core/ucp_ep.inl> /* ucp_ep_peer_name */


static void error_handler(void *arg, ucp_ep_h ep, ucs_status_t status)
{
    JNIEnv* env = get_jni_env();
    JNU_ThrowExceptionByStatus(env, status);
    ucs_error("JUCX: endpoint error handler: %s", ucs_status_string(status));
}

JNIEXPORT jlong JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_createEndpointNative(JNIEnv *env, jclass cls,
                                                           jobject ucp_ep_params,
                                                           jlong worker_ptr)
{
    ucp_ep_params_t ep_params;
    jfieldID field;
    ucp_worker_h ucp_worker = (ucp_worker_h)worker_ptr;
    ucp_ep_h endpoint;

    // Get field mask
    jclass ucp_ep_params_class = env->GetObjectClass(ucp_ep_params);
    field = env->GetFieldID(ucp_ep_params_class, "fieldMask", "J");
    ep_params.field_mask = env->GetLongField(ucp_ep_params, field);

    if (ep_params.field_mask & UCP_EP_PARAM_FIELD_REMOTE_ADDRESS) {
        field = env->GetFieldID(ucp_ep_params_class, "ucpAddress", "Ljava/nio/ByteBuffer;");
        jobject buf = env->GetObjectField(ucp_ep_params, field);
        ep_params.address = static_cast<const ucp_address_t *>(env->GetDirectBufferAddress(buf));
    }

    if (ep_params.field_mask & UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE) {
        field = env->GetFieldID(ucp_ep_params_class, "errorHandlingMode", "I");
        ep_params.err_mode =  static_cast<ucp_err_handling_mode_t>(env->GetIntField(ucp_ep_params, field));
    }

    if (ep_params.field_mask & UCP_EP_PARAM_FIELD_USER_DATA) {
        field = env->GetFieldID(ucp_ep_params_class, "userData", "Ljava/nio/ByteBuffer;");
        jobject user_data = env->GetObjectField(ucp_ep_params, field);
        ep_params.user_data = env->GetDirectBufferAddress(user_data);
    }

    if (ep_params.field_mask & UCP_EP_PARAM_FIELD_FLAGS) {
        field = env->GetFieldID(ucp_ep_params_class, "flags", "J");
        ep_params.flags = env->GetLongField(ucp_ep_params, field);
    }

    if (ep_params.field_mask & UCP_EP_PARAM_FIELD_SOCK_ADDR) {
        struct sockaddr_storage worker_addr;
        socklen_t addrlen;
        memset(&worker_addr, 0, sizeof(struct sockaddr_storage));

        field = env->GetFieldID(ucp_ep_params_class,
                                "socketAddress", "Ljava/net/InetSocketAddress;");
        jobject sock_addr = env->GetObjectField(ucp_ep_params, field);

        if (j2cInetSockAddr(env, sock_addr, worker_addr, addrlen)) {
            ep_params.sockaddr.addr = (const struct sockaddr*)&worker_addr;
            ep_params.sockaddr.addrlen = addrlen;
        }
    }

    if (ep_params.field_mask & UCP_EP_PARAM_FIELD_CONN_REQUEST) {
        field = env->GetFieldID(ucp_ep_params_class, "connectionRequest", "J");
        ep_params.conn_request = reinterpret_cast<ucp_conn_request_h>(env->GetLongField(ucp_ep_params, field));
    }

    ep_params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLER;
    ep_params.err_handler.cb = error_handler;

    ucs_status_t status = ucp_ep_create(ucp_worker, &ep_params, &endpoint);
    if (status != UCS_OK) {
        JNU_ThrowExceptionByStatus(env, status);
    }

    return (native_ptr)endpoint;
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_destroyEndpointNative(JNIEnv *env, jclass cls,
                                                            jlong ep_ptr)
{
    ucp_ep_destroy((ucp_ep_h)ep_ptr);
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_closeNonBlockingNative(JNIEnv *env, jclass cls,
                                                             jlong ep_ptr, jint mode)
{
    ucs_status_ptr_t request = ucp_ep_close_nb((ucp_ep_h)ep_ptr, mode);

    return process_request(request, NULL);
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_unpackRemoteKey(JNIEnv *env, jclass cls,
                                                      jlong ep_ptr, jlong addr)
{
    ucp_rkey_h rkey;

    ucs_status_t status = ucp_ep_rkey_unpack((ucp_ep_h)ep_ptr, (void *)addr, &rkey);
    if (status != UCS_OK) {
        JNU_ThrowExceptionByStatus(env, status);
    }

    jobject result = new_rkey_instance(env, rkey);

    /* Coverity thinks that rkey is a leaked object here,
     * but it's stored in a UcpRemoteKey object */
    /* coverity[leaked_storage] */
    return result;
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_putNonBlockingNative(JNIEnv *env, jclass cls,
                                                           jlong ep_ptr, jlong laddr,
                                                           jlong size, jlong raddr,
                                                           jlong rkey_ptr, jobject callback)
{
    ucs_status_ptr_t request = ucp_put_nb((ucp_ep_h)ep_ptr, (void *)laddr, size, raddr,
                                          (ucp_rkey_h)rkey_ptr, jucx_request_callback);

    ucs_trace_req("JUCX: put_nb request %p to %s, of size: %zu, raddr: %zu",
                  request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size, raddr);
    return process_request(request, callback);
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_putNonBlockingImplicitNative(JNIEnv *env, jclass cls,
                                                                   jlong ep_ptr, jlong laddr,
                                                                   jlong size, jlong raddr,
                                                                   jlong rkey_ptr)
{
    ucs_status_t status = ucp_put_nbi((ucp_ep_h)ep_ptr, (void *)laddr, size, raddr,
                                      (ucp_rkey_h)rkey_ptr);

    if (UCS_STATUS_IS_ERR(status)) {
        JNU_ThrowExceptionByStatus(env, status);
    }
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_getNonBlockingNative(JNIEnv *env, jclass cls,
                                                           jlong ep_ptr, jlong raddr,
                                                           jlong rkey_ptr, jlong laddr,
                                                           jlong size, jobject callback)
{
    ucs_status_ptr_t request = ucp_get_nb((ucp_ep_h)ep_ptr, (void *)laddr, size,
                                          raddr, (ucp_rkey_h)rkey_ptr, jucx_request_callback);

    ucs_trace_req("JUCX: get_nb request %p to %s, raddr: %zu, size: %zu, result address: %zu",
                  request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), raddr, size, laddr);
    return process_request(request, callback);
}

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_getNonBlockingImplicitNative(JNIEnv *env, jclass cls,
                                                                   jlong ep_ptr, jlong raddr,
                                                                   jlong rkey_ptr, jlong laddr,
                                                                   jlong size)
{
    ucs_status_t status = ucp_get_nbi((ucp_ep_h)ep_ptr, (void *)laddr, size, raddr,
                                      (ucp_rkey_h)rkey_ptr);

    if (UCS_STATUS_IS_ERR(status)) {
        JNU_ThrowExceptionByStatus(env, status);
    }
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_sendTaggedNonBlockingNative(JNIEnv *env, jclass cls,
                                                                  jlong ep_ptr, jlong addr,
                                                                  jlong size, jlong tag,
                                                                  jobject callback)
{
    ucs_status_ptr_t request = ucp_tag_send_nb((ucp_ep_h)ep_ptr, (void *)addr, size,
                                               ucp_dt_make_contig(1), tag, jucx_request_callback);

    ucs_trace_req("JUCX: send_tag_nb request %p to %s, size: %zu, tag: %ld",
                  request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size, tag);
    return process_request(request, callback);
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_sendStreamNonBlockingNative(JNIEnv *env, jclass cls,
                                                                  jlong ep_ptr, jlong addr,
                                                                  jlong size, jobject callback)
{
    ucs_status_ptr_t request = ucp_stream_send_nb((ucp_ep_h)ep_ptr, (void *)addr, size,
                                                  ucp_dt_make_contig(1), jucx_request_callback, 0);

    ucs_trace_req("JUCX: send_stream_nb request %p to %s, size: %zu",
                  request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size);
    return process_request(request, callback);
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_recvStreamNonBlockingNative(JNIEnv *env, jclass cls,
                                                                  jlong ep_ptr, jlong addr,
                                                                  jlong size, jlong flags,
                                                                  jobject callback)
{
    size_t rlength;
    ucs_status_ptr_t request = ucp_stream_recv_nb((ucp_ep_h)ep_ptr, (void *)addr, size,
                                                  ucp_dt_make_contig(1), stream_recv_callback,
                                                  &rlength, flags);

    ucs_trace_req("JUCX: recv_stream_nb request %p to %s, size: %zu",
                  request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size);

    if (request == NULL) {
        // If request completed immidiately.
        return process_completed_stream_recv(rlength, callback);
    }

    return process_request(request, callback);
}

JNIEXPORT jobject JNICALL
Java_org_openucx_jucx_ucp_UcpEndpoint_flushNonBlockingNative(JNIEnv *env, jclass cls,
                                                             jlong ep_ptr,
                                                             jobject callback)
{
    ucs_status_ptr_t request = ucp_ep_flush_nb((ucp_ep_h)ep_ptr, 0, jucx_request_callback);

    return process_request(request, callback);
}