/**
* Copyright (C) Mellanox Technologies Ltd. 2001-2016. ALL RIGHTS RESERVED.
* Copyright (C) UT-Battelle, LLC. 2015. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
*/
#include "ucp_test.h"
#include <common/test_helpers.h>
#if _OPENMP
#include "omp.h"
#endif
using namespace ucs; /* For vector<char> serialization */
class test_ucp_rma_mt : public ucp_test {
public:
static ucp_params_t get_ctx_params() {
ucp_params_t params = ucp_test::get_ctx_params();
params.features = UCP_FEATURE_RMA;
return params;
}
void init()
{
ucp_test::init();
sender().connect(&receiver(), get_ep_params());
for (int i = 0; i < sender().get_num_workers(); i++) {
/* avoid deadlock for blocking rma */
flush_worker(sender(), i);
}
}
static void send_cb(void *req, ucs_status_t status)
{
}
static std::vector<ucp_test_param> enum_test_params(const ucp_params_t& ctx_params,
const std::string& name,
const std::string& test_case_name,
const std::string& tls)
{
std::vector<ucp_test_param> result;
generate_test_params_variant(ctx_params, name, test_case_name, tls, 0,
result, MULTI_THREAD_CONTEXT);
generate_test_params_variant(ctx_params, name, test_case_name, tls, 0,
result, MULTI_THREAD_WORKER);
return result;
}
};
UCS_TEST_P(test_ucp_rma_mt, put_get) {
ucs_status_t st;
uint64_t orig_data[MT_TEST_NUM_THREADS] GTEST_ATTRIBUTE_UNUSED_;
uint64_t target_data[MT_TEST_NUM_THREADS] GTEST_ATTRIBUTE_UNUSED_;
ucp_mem_map_params_t params;
ucp_mem_h memh;
void *memheap = target_data;
params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
params.address = memheap;
params.length = sizeof(uint64_t) * MT_TEST_NUM_THREADS;
params.flags = GetParam().variant;
st = ucp_mem_map(receiver().ucph(), ¶ms, &memh);
ASSERT_UCS_OK(st);
void *rkey_buffer;
size_t rkey_buffer_size;
st = ucp_rkey_pack(receiver().ucph(), memh, &rkey_buffer, &rkey_buffer_size);
ASSERT_UCS_OK(st);
std::vector<ucp_rkey_h> rkey;
rkey.resize(MT_TEST_NUM_THREADS);
/* test parallel rkey unpack */
#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
int worker_index = 0;
if (GetParam().thread_type == MULTI_THREAD_CONTEXT) {
worker_index = i;
}
ucs_status_t status = ucp_ep_rkey_unpack(sender().ep(worker_index),
rkey_buffer, &rkey[i]);
ASSERT_UCS_OK(status);
}
#endif
ucp_rkey_buffer_release(rkey_buffer);
/* test blocking PUT */
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
orig_data[i] = 0xdeadbeefdeadbeef + 10 * i;
target_data[i] = 0;
}
#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
int worker_index = 0;
if (GetParam().thread_type == MULTI_THREAD_CONTEXT) {
worker_index = i;
}
void* req = ucp_put_nb(sender().ep(worker_index), &orig_data[i],
sizeof(uint64_t), (uintptr_t)((uint64_t*)memheap + i),
rkey[i], send_cb);
wait(req, worker_index);
flush_worker(sender(), worker_index);
EXPECT_EQ(orig_data[i], target_data[i]);
}
#endif
/* test nonblocking PUT */
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
orig_data[i] = 0xdeadbeefdeadbeef + 10 * i;
target_data[i] = 0;
}
#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
ucs_status_t status;
int worker_index = 0;
if (GetParam().thread_type == MULTI_THREAD_CONTEXT)
worker_index = i;
status = ucp_put_nbi(sender().ep(worker_index), &orig_data[i], sizeof(uint64_t),
(uintptr_t)((uint64_t*)memheap + i), rkey[i]);
ASSERT_UCS_OK_OR_INPROGRESS(status);
flush_worker(sender(), worker_index);
EXPECT_EQ(orig_data[i], target_data[i]);
}
#endif
/* test blocking GET */
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
orig_data[i] = 0;
target_data[i] = 0xdeadbeefdeadbeef + 10 * i;
}
#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
int worker_index = 0;
if (GetParam().thread_type == MULTI_THREAD_CONTEXT) {
worker_index = i;
}
void *req = ucp_get_nb(sender().ep(worker_index), &orig_data[i],
sizeof(uint64_t), (uintptr_t)((uint64_t*)memheap + i),
rkey[i], send_cb);
wait(req, worker_index);
flush_worker(sender(), worker_index);
EXPECT_EQ(orig_data[i], target_data[i]);
}
#endif
/* test nonblocking GET */
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
orig_data[i] = 0;
target_data[i] = 0xdeadbeefdeadbeef + 10 * i;
}
#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
ucs_status_t status;
int worker_index = 0;
if (GetParam().thread_type == MULTI_THREAD_CONTEXT)
worker_index = i;
status = ucp_get_nbi(sender().ep(worker_index), &orig_data[i], sizeof(uint64_t),
(uintptr_t)((uint64_t *)memheap + i), rkey[i]);
ASSERT_UCS_OK_OR_INPROGRESS(status);
flush_worker(sender(), worker_index);
EXPECT_EQ(orig_data[i], target_data[i]);
}
#endif
#if _OPENMP && ENABLE_MT
#pragma omp parallel for
for (int i = 0; i < MT_TEST_NUM_THREADS; i++) {
ucp_rkey_destroy(rkey[i]);
}
#endif
st = ucp_mem_unmap(receiver().ucph(), memh);
ASSERT_UCS_OK(st);
}
UCP_INSTANTIATE_TEST_CASE(test_ucp_rma_mt)