/* * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ #include "jucx_common_def.h" #include "org_openucx_jucx_ucp_UcpWorker.h" /** * Bridge method for creating ucp_worker from java */ JNIEXPORT jlong JNICALL Java_org_openucx_jucx_ucp_UcpWorker_createWorkerNative(JNIEnv *env, jclass cls, jobject jucx_worker_params, jlong context_ptr) { ucp_worker_params_t worker_params = { 0 }; ucp_worker_h ucp_worker; ucp_context_h ucp_context = (ucp_context_h)context_ptr; jfieldID field; jclass jucx_param_class = env->GetObjectClass(jucx_worker_params); field = env->GetFieldID(jucx_param_class, "fieldMask", "J"); worker_params.field_mask = env->GetLongField(jucx_worker_params, field); if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_THREAD_MODE) { field = env->GetFieldID(jucx_param_class, "threadMode", "I"); worker_params.thread_mode = static_cast( env->GetIntField(jucx_worker_params, field)); } if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_CPU_MASK) { ucs_cpu_set_t cpu_mask; UCS_CPU_ZERO(&cpu_mask); field = env->GetFieldID(jucx_param_class, "cpuMask", "Ljava/util/BitSet;"); jobject cpu_mask_bitset = env->GetObjectField(jucx_worker_params, field); jclass bitset_class = env->FindClass("java/util/BitSet"); jmethodID next_set_bit = env->GetMethodID(bitset_class, "nextSetBit", "(I)I"); for (jint bit_index = env->CallIntMethod(cpu_mask_bitset, next_set_bit, 0); bit_index >=0; bit_index = env->CallIntMethod(cpu_mask_bitset, next_set_bit, bit_index + 1)) { UCS_CPU_SET(bit_index, &cpu_mask); } worker_params.cpu_mask = cpu_mask; } if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_EVENTS) { field = env->GetFieldID(jucx_param_class, "events", "J"); worker_params.events = env->GetLongField(jucx_worker_params, field); } if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_USER_DATA) { field = env->GetFieldID(jucx_param_class, "userData", "Ljava/nio/ByteBuffer;"); jobject user_data = env->GetObjectField(jucx_worker_params, field); worker_params.user_data = env->GetDirectBufferAddress(user_data); } if (worker_params.field_mask & UCP_WORKER_PARAM_FIELD_EVENT_FD) { field = env->GetFieldID(jucx_param_class, "eventFD", "I"); worker_params.event_fd = env->GetIntField(jucx_worker_params, field); } ucs_status_t status = ucp_worker_create(ucp_context, &worker_params, &ucp_worker); if (status != UCS_OK) { JNU_ThrowExceptionByStatus(env, status); } return (native_ptr)ucp_worker; } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpWorker_releaseWorkerNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr) { ucp_worker_destroy((ucp_worker_h)ucp_worker_ptr); } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpWorker_workerGetAddressNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr) { ucp_address_t *addr; size_t len; ucs_status_t status; status = ucp_worker_get_address((ucp_worker_h)ucp_worker_ptr, &addr, &len); if (status != UCS_OK) { JNU_ThrowExceptionByStatus(env, status); return NULL; } return env->NewDirectByteBuffer(addr, len); } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpWorker_releaseAddressNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr, jobject ucp_address) { ucp_worker_release_address((ucp_worker_h)ucp_worker_ptr, (ucp_address_t *)env->GetDirectBufferAddress(ucp_address)); } JNIEXPORT jint JNICALL Java_org_openucx_jucx_ucp_UcpWorker_progressWorkerNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr) { return ucp_worker_progress((ucp_worker_h)ucp_worker_ptr); } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpWorker_flushNonBlockingNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr, jobject callback) { ucs_status_ptr_t request = ucp_worker_flush_nb((ucp_worker_h)ucp_worker_ptr, 0, jucx_request_callback); return process_request(request, callback); } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpWorker_waitWorkerNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr) { ucs_status_t status = ucp_worker_wait((ucp_worker_h)ucp_worker_ptr); if (status != UCS_OK) { JNU_ThrowExceptionByStatus(env, status); } } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpWorker_signalWorkerNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr) { ucs_status_t status = ucp_worker_signal((ucp_worker_h)ucp_worker_ptr); if (status != UCS_OK) { JNU_ThrowExceptionByStatus(env, status); } } JNIEXPORT jobject JNICALL Java_org_openucx_jucx_ucp_UcpWorker_recvTaggedNonBlockingNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr, jlong laddr, jlong size, jlong tag, jlong tagMask, jobject callback) { ucs_status_ptr_t request = ucp_tag_recv_nb((ucp_worker_h)ucp_worker_ptr, (void *)laddr, size, ucp_dt_make_contig(1), tag, tagMask, recv_callback); ucs_trace_req("JUCX: recv_nb request %p, msg size: %zu, tag: %ld", request, size, tag); return process_request(request, callback); } JNIEXPORT void JNICALL Java_org_openucx_jucx_ucp_UcpWorker_cancelRequestNative(JNIEnv *env, jclass cls, jlong ucp_worker_ptr, jlong ucp_request_ptr) { ucp_request_cancel((ucp_worker_h)ucp_worker_ptr, (void *)ucp_request_ptr); }