/* * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ #include "jucx_common_def.h" #include "org_openucx_jucx_ucp_UcpEndpoint.h" #include /* memset */ #include /* ucp_ep_peer_name */ static void error_handler(void *arg, ucp_ep_h ep, ucs_status_t status) { JNIEnv* env = get_jni_env(); JNU_ThrowExceptionByStatus(env, status); ucs_error("JUCX: endpoint error handler: %s", ucs_status_string(status)); } JNIEXPORT jlong JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_createEndpointNative(JNIEnv *env, jclass cls, jobject ucp_ep_params, jlong worker_ptr) { ucp_ep_params_t ep_params; jfieldID field; ucp_worker_h ucp_worker = (ucp_worker_h)worker_ptr; ucp_ep_h endpoint; // Get field mask jclass ucp_ep_params_class = env->GetObjectClass(ucp_ep_params); field = env->GetFieldID(ucp_ep_params_class, "fieldMask", "J"); ep_params.field_mask = env->GetLongField(ucp_ep_params, field); if (ep_params.field_mask & UCP_EP_PARAM_FIELD_REMOTE_ADDRESS) { field = env->GetFieldID(ucp_ep_params_class, "ucpAddress", "Ljava/nio/ByteBuffer;"); jobject buf = env->GetObjectField(ucp_ep_params, field); ep_params.address = static_cast(env->GetDirectBufferAddress(buf)); } if (ep_params.field_mask & UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE) { field = env->GetFieldID(ucp_ep_params_class, "errorHandlingMode", "I"); ep_params.err_mode = static_cast(env->GetIntField(ucp_ep_params, field)); } if (ep_params.field_mask & UCP_EP_PARAM_FIELD_USER_DATA) { field = env->GetFieldID(ucp_ep_params_class, "userData", "Ljava/nio/ByteBuffer;"); jobject user_data = env->GetObjectField(ucp_ep_params, field); ep_params.user_data = env->GetDirectBufferAddress(user_data); } if (ep_params.field_mask & UCP_EP_PARAM_FIELD_FLAGS) { field = env->GetFieldID(ucp_ep_params_class, "flags", "J"); ep_params.flags = env->GetLongField(ucp_ep_params, field); } if (ep_params.field_mask & UCP_EP_PARAM_FIELD_SOCK_ADDR) { struct sockaddr_storage worker_addr; socklen_t addrlen; memset(&worker_addr, 0, sizeof(struct sockaddr_storage)); field = env->GetFieldID(ucp_ep_params_class, "socketAddress", "Ljava/net/InetSocketAddress;"); jobject sock_addr = env->GetObjectField(ucp_ep_params, field); if (j2cInetSockAddr(env, sock_addr, worker_addr, addrlen)) { ep_params.sockaddr.addr = (const struct sockaddr*)&worker_addr; ep_params.sockaddr.addrlen = addrlen; } } if (ep_params.field_mask & UCP_EP_PARAM_FIELD_CONN_REQUEST) { field = env->GetFieldID(ucp_ep_params_class, "connectionRequest", "J"); ep_params.conn_request = reinterpret_cast(env->GetLongField(ucp_ep_params, field)); } ep_params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLER; ep_params.err_handler.cb = error_handler; ucs_status_t status = ucp_ep_create(ucp_worker, &ep_params, &endpoint); if (status != UCS_OK) { JNU_ThrowExceptionByStatus(env, status); } return (native_ptr)endpoint; } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_destroyEndpointNative(JNIEnv *env, jclass cls, jlong ep_ptr) { ucp_ep_destroy((ucp_ep_h)ep_ptr); } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_closeNonBlockingNative(JNIEnv *env, jclass cls, jlong ep_ptr, jint mode) { ucs_status_ptr_t request = ucp_ep_close_nb((ucp_ep_h)ep_ptr, mode); return process_request(request, NULL); } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_unpackRemoteKey(JNIEnv *env, jclass cls, jlong ep_ptr, jlong addr) { ucp_rkey_h rkey; ucs_status_t status = ucp_ep_rkey_unpack((ucp_ep_h)ep_ptr, (void *)addr, &rkey); if (status != UCS_OK) { JNU_ThrowExceptionByStatus(env, status); } jobject result = new_rkey_instance(env, rkey); /* Coverity thinks that rkey is a leaked object here, * but it's stored in a UcpRemoteKey object */ /* coverity[leaked_storage] */ return result; } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_putNonBlockingNative(JNIEnv *env, jclass cls, jlong ep_ptr, jlong laddr, jlong size, jlong raddr, jlong rkey_ptr, jobject callback) { ucs_status_ptr_t request = ucp_put_nb((ucp_ep_h)ep_ptr, (void *)laddr, size, raddr, (ucp_rkey_h)rkey_ptr, jucx_request_callback); ucs_trace_req("JUCX: put_nb request %p to %s, of size: %zu, raddr: %zu", request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size, raddr); return process_request(request, callback); } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_putNonBlockingImplicitNative(JNIEnv *env, jclass cls, jlong ep_ptr, jlong laddr, jlong size, jlong raddr, jlong rkey_ptr) { ucs_status_t status = ucp_put_nbi((ucp_ep_h)ep_ptr, (void *)laddr, size, raddr, (ucp_rkey_h)rkey_ptr); if (UCS_STATUS_IS_ERR(status)) { JNU_ThrowExceptionByStatus(env, status); } } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_getNonBlockingNative(JNIEnv *env, jclass cls, jlong ep_ptr, jlong raddr, jlong rkey_ptr, jlong laddr, jlong size, jobject callback) { ucs_status_ptr_t request = ucp_get_nb((ucp_ep_h)ep_ptr, (void *)laddr, size, raddr, (ucp_rkey_h)rkey_ptr, jucx_request_callback); ucs_trace_req("JUCX: get_nb request %p to %s, raddr: %zu, size: %zu, result address: %zu", request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), raddr, size, laddr); return process_request(request, callback); } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_getNonBlockingImplicitNative(JNIEnv *env, jclass cls, jlong ep_ptr, jlong raddr, jlong rkey_ptr, jlong laddr, jlong size) { ucs_status_t status = ucp_get_nbi((ucp_ep_h)ep_ptr, (void *)laddr, size, raddr, (ucp_rkey_h)rkey_ptr); if (UCS_STATUS_IS_ERR(status)) { JNU_ThrowExceptionByStatus(env, status); } } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_sendTaggedNonBlockingNative(JNIEnv *env, jclass cls, jlong ep_ptr, jlong addr, jlong size, jlong tag, jobject callback) { ucs_status_ptr_t request = ucp_tag_send_nb((ucp_ep_h)ep_ptr, (void *)addr, size, ucp_dt_make_contig(1), tag, jucx_request_callback); ucs_trace_req("JUCX: send_tag_nb request %p to %s, size: %zu, tag: %ld", request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size, tag); return process_request(request, callback); } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_sendStreamNonBlockingNative(JNIEnv *env, jclass cls, jlong ep_ptr, jlong addr, jlong size, jobject callback) { ucs_status_ptr_t request = ucp_stream_send_nb((ucp_ep_h)ep_ptr, (void *)addr, size, ucp_dt_make_contig(1), jucx_request_callback, 0); ucs_trace_req("JUCX: send_stream_nb request %p to %s, size: %zu", request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size); return process_request(request, callback); } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_recvStreamNonBlockingNative(JNIEnv *env, jclass cls, jlong ep_ptr, jlong addr, jlong size, jlong flags, jobject callback) { size_t rlength; ucs_status_ptr_t request = ucp_stream_recv_nb((ucp_ep_h)ep_ptr, (void *)addr, size, ucp_dt_make_contig(1), stream_recv_callback, &rlength, flags); ucs_trace_req("JUCX: recv_stream_nb request %p to %s, size: %zu", request, ucp_ep_peer_name((ucp_ep_h)ep_ptr), size); if (request == NULL) { // If request completed immidiately. return process_completed_stream_recv(rlength, callback); } return process_request(request, callback); } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpEndpoint_flushNonBlockingNative(JNIEnv *env, jclass cls, jlong ep_ptr, jobject callback) { ucs_status_ptr_t request = ucp_ep_flush_nb((ucp_ep_h)ep_ptr, 0, jucx_request_callback); return process_request(request, callback); }