/**
* Copyright (C) Mellanox Technologies Ltd. 2017. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
*/
#include <list>
#include <numeric>
#include <set>
#include <vector>
#include "ucp_datatype.h"
#include "ucp_test.h"
class test_ucp_stream_base : public ucp_test {
public:
static ucp_params_t get_ctx_params() {
ucp_params_t params = ucp_test::get_ctx_params();
params.field_mask |= UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_STREAM;
return params;
}
static void ucp_send_cb(void *request, ucs_status_t status) {}
static void ucp_recv_cb(void *request, ucs_status_t status, size_t length) {}
size_t wait_stream_recv(void *request);
protected:
ucs_status_ptr_t stream_send_nb(const ucp::data_type_desc_t& dt_desc);
};
size_t test_ucp_stream_base::wait_stream_recv(void *request)
{
ucs_status_t status;
size_t length;
do {
progress();
status = ucp_stream_recv_request_test(request, &length);
} while (status == UCS_INPROGRESS);
ASSERT_UCS_OK(status);
ucp_request_free(request);
return length;
}
ucs_status_ptr_t
test_ucp_stream_base::stream_send_nb(const ucp::data_type_desc_t& dt_desc)
{
return ucp_stream_send_nb(sender().ep(), dt_desc.buf(), dt_desc.count(),
dt_desc.dt(), ucp_send_cb, 0);
}
class test_ucp_stream_onesided : public test_ucp_stream_base {
public:
ucp_ep_params_t get_ep_params() {
ucp_ep_params_t params = test_ucp_stream_base::get_ep_params();
params.field_mask |= UCP_EP_PARAM_FIELD_FLAGS;
params.flags |= UCP_EP_PARAMS_FLAGS_NO_LOOPBACK;
return params;
}
};
UCS_TEST_P(test_ucp_stream_onesided, send_recv_no_ep) {
/* connect from sender side only and send */
sender().connect(&receiver(), get_ep_params());
uint64_t send_data = ucs::rand();
ucp::data_type_desc_t dt_desc(ucp_dt_make_contig(sizeof(uint64_t)),
&send_data, sizeof(send_data));
void *sreq = stream_send_nb(dt_desc);
wait(sreq);
/* must not receive data before ep is created on receiver side */
static const size_t max_eps = 10;
ucp_stream_poll_ep_t poll_eps[max_eps];
ssize_t count = ucp_stream_worker_poll(receiver().worker(), poll_eps,
max_eps, 0);
EXPECT_EQ(0l, count) << "ucp_stream_worker_poll returned ep too early";
/* create receiver side ep */
ucp_ep_params_t recv_ep_param = get_ep_params();
recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
recv_ep_param.user_data = reinterpret_cast<void*>(static_cast<uintptr_t>(ucs::rand()));
receiver().connect(&sender(), recv_ep_param);
/* expect ep to be ready */
ucs_time_t deadline = ucs_get_time() +
(ucs_time_from_sec(10.0) * ucs::test_time_multiplier());
do {
progress();
count = ucp_stream_worker_poll(receiver().worker(), poll_eps, max_eps, 0);
} while ((count == 0) && (ucs_get_time() < deadline));
EXPECT_EQ(1l, count);
EXPECT_EQ(recv_ep_param.user_data, poll_eps[0].user_data);
EXPECT_EQ(receiver().ep(0), poll_eps[0].ep);
/* expect data to be received */
uint64_t recv_data = 0;
size_t recv_length = 0;
void *rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1,
ucp_dt_make_contig(sizeof(uint64_t)),
ucp_recv_cb, &recv_length, 0);
ASSERT_UCS_PTR_OK(rreq);
if (rreq != NULL) {
recv_length = wait_stream_recv(rreq);
}
EXPECT_EQ(sizeof(uint64_t), recv_length);
EXPECT_EQ(send_data, recv_data);
}
UCP_INSTANTIATE_TEST_CASE(test_ucp_stream_onesided)
class test_ucp_stream : public test_ucp_stream_base
{
public:
virtual void init() {
ucp_test::init();
sender().connect(&receiver(), get_ep_params());
if (!is_loopback()) {
receiver().connect(&sender(), get_ep_params());
}
}
protected:
void do_send_recv_data_test(ucp_datatype_t datatype);
template <typename T, unsigned recv_flags>
void do_send_recv_test(ucp_datatype_t datatype);
template <typename T, unsigned recv_flags>
void do_send_exp_recv_test(ucp_datatype_t datatype);
void do_send_recv_data_recv_test(ucp_datatype_t datatype);
/* for self-validation of generic datatype
* NOTE: it's tested only with byte array data since it's recv completion
* granularity without UCP_RECV_FLAG_WAITALL flag */
std::vector<uint8_t> context;
};
void test_ucp_stream::do_send_recv_data_test(ucp_datatype_t datatype)
{
size_t ssize = 0; /* total send size in bytes */
std::vector<char> sbuf(16 * UCS_MBYTE, 's');
std::vector<char> check_pattern;
ucs_status_ptr_t sstatus;
/* send all msg sizes*/
for (size_t i = 3; i < sbuf.size();
i *= (2 * ucs::test_time_multiplier())) {
if (UCP_DT_IS_GENERIC(datatype)) {
for (size_t j = 0; j < i; ++j) {
check_pattern.push_back(char(j));
}
} else {
ucs::fill_random(sbuf, i);
check_pattern.insert(check_pattern.end(), sbuf.begin(),
sbuf.begin() + i);
}
ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), i);
sstatus = stream_send_nb(dt_desc);
EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
wait(sstatus);
ssize += i;
}
std::vector<char> rbuf(ssize, 'r');
size_t roffset = 0;
ucs_status_ptr_t rdata;
size_t length;
do {
progress();
rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
if (rdata == NULL) {
continue;
}
memcpy(&rbuf[roffset], rdata, length);
roffset += length;
ucp_stream_data_release(receiver().ep(), rdata);
} while (roffset < ssize);
EXPECT_EQ(roffset, ssize);
EXPECT_EQ(check_pattern, rbuf);
}
template <typename T, unsigned recv_flags>
void test_ucp_stream::do_send_recv_test(ucp_datatype_t datatype)
{
const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
ucp_contig_dt_elem_size(datatype) : 1;
size_t ssize = 0; /* total send size */
std::vector<char> sbuf(16 * UCS_MBYTE, 's');
ucs_status_ptr_t sstatus;
std::vector<char> check_pattern;
/* send all msg sizes in bytes*/
for (size_t i = 3; i < sbuf.size(); i *= 2) {
ucp_datatype_t dt;
if (UCP_DT_IS_GENERIC(datatype)) {
dt = datatype;
for (size_t j = 0; j < i; ++j) {
context.push_back(uint8_t(j));
}
} else {
dt = DATATYPE;
ucs::fill_random(sbuf, i);
check_pattern.insert(check_pattern.end(), sbuf.begin(),
sbuf.begin() + i);
}
ucp::data_type_desc_t dt_desc(dt, sbuf.data(), i);
sstatus = stream_send_nb(dt_desc);
EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
wait(sstatus);
ssize += i;
}
size_t align_tail = UCP_DT_IS_GENERIC(datatype) ? 0 :
(dt_elem_size - ssize % dt_elem_size);
if (align_tail != 0) {
ucs::fill_random(sbuf, align_tail);
check_pattern.insert(check_pattern.end(), sbuf.begin(), sbuf.begin() + align_tail);
ucp::data_type_desc_t dt_desc(ucp_dt_make_contig(align_tail),
sbuf.data(), align_tail);
sstatus = stream_send_nb(dt_desc);
EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
wait(sstatus);
ssize += align_tail;
}
EXPECT_EQ(size_t(0), (ssize % dt_elem_size));
std::vector<T> rbuf(ssize / dt_elem_size, 'r');
size_t roffset = 0;
size_t counter = 0;
do {
ucp::data_type_desc_t dt_desc(datatype, &rbuf[roffset / dt_elem_size],
ssize - roffset);
size_t length;
void *rreq = ucp_stream_recv_nb(receiver().ep(), dt_desc.buf(),
dt_desc.count(), dt_desc.dt(),
ucp_recv_cb, &length, recv_flags);
ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq));
if (UCS_PTR_IS_PTR(rreq)) {
length = wait_stream_recv(rreq);
}
EXPECT_EQ(size_t(0), length % dt_elem_size);
roffset += length;
counter++;
} while (roffset < ssize);
/* waitall flag requires completion by single request */
if (recv_flags & UCP_STREAM_RECV_FLAG_WAITALL) {
EXPECT_EQ(size_t(1), counter);
}
EXPECT_EQ(roffset, ssize);
if (!UCP_DT_IS_GENERIC(datatype)) {
const T *check_ptr = reinterpret_cast<const T *>(check_pattern.data());
const size_t check_size = check_pattern.size() / dt_elem_size;
EXPECT_EQ(std::vector<T>(check_ptr, check_ptr + check_size), rbuf);
}
}
template <typename T, unsigned recv_flags>
void test_ucp_stream::do_send_exp_recv_test(ucp_datatype_t datatype)
{
const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
ucp_contig_dt_elem_size(datatype) : 1;
const size_t msg_size = dt_elem_size * UCS_MBYTE;
const size_t n_msgs = 10;
std::vector<std::vector<T> > rbufs(n_msgs,
std::vector<T>(msg_size / dt_elem_size, 'r'));
std::vector<ucp::data_type_desc_t> dt_rdescs(n_msgs);
std::vector<void *> rreqs;
/* post recvs */
for (size_t i = 0; i < n_msgs; ++i) {
ucp::data_type_desc_t &rdesc = dt_rdescs[i].make(datatype, &rbufs[i][0],
msg_size);
size_t length;
void *rreq = ucp_stream_recv_nb(receiver().ep(), rdesc.buf(),
rdesc.count(), rdesc.dt(), ucp_recv_cb,
&length, recv_flags);
EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
rreqs.push_back(rreq);
}
std::vector<char> sbuf(msg_size, 's');
size_t scount = 0; /* total send size */
ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), sbuf.size());
/* send all msgs */
for (size_t i = 0; i < n_msgs; ++i) {
void *sreq = stream_send_nb(dt_desc);
EXPECT_FALSE(UCS_PTR_IS_ERR(sreq));
wait(sreq);
scount += sbuf.size();
}
size_t rcount = 0;
for (size_t i = 0; i < rreqs.size(); ++i) {
size_t length = wait_stream_recv(rreqs[i]);
EXPECT_EQ(size_t(0), length % dt_elem_size);
rcount += length;
}
size_t counter = 0;
while (rcount < scount) {
size_t length = std::numeric_limits<size_t>::max();
ucs_status_ptr_t rreq;
rreq = ucp_stream_recv_nb(receiver().ep(), dt_rdescs[0].buf(),
dt_rdescs[0].count(), dt_rdescs[0].dt(),
ucp_recv_cb, &length, 0);
if (UCS_PTR_IS_PTR(rreq)) {
length = wait_stream_recv(rreq);
}
ASSERT_GT(length, 0ul);
ASSERT_LE(length, msg_size);
EXPECT_EQ(size_t(0), length % dt_elem_size);
rcount += length;
counter++;
}
EXPECT_EQ(scount, rcount);
/* waitall flag requires completion by single request */
if (recv_flags & UCP_STREAM_RECV_FLAG_WAITALL) {
EXPECT_EQ(size_t(0), counter);
}
/* double check, no data should be here */
while (progress());
size_t s;
void *p;
while ((p = ucp_stream_recv_data_nb(receiver().ep(), &s)) != NULL) {
rcount += s;
ucp_stream_data_release(receiver().ep(), p);
progress();
}
EXPECT_EQ(scount, rcount);
}
void test_ucp_stream::do_send_recv_data_recv_test(ucp_datatype_t datatype)
{
const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
ucp_contig_dt_elem_size(datatype) : 1;
size_t ssize = 0; /* total send size */
size_t roffset = 0;
size_t send_i = dt_elem_size;
size_t recv_i = 0;
std::vector<char> sbuf(16 * UCS_MBYTE, 's');
ucs_status_ptr_t sstatus;
std::vector<char> check_pattern;
std::vector<char> rbuf;
ucs_status_ptr_t rdata;
size_t length;
do {
if (send_i < sbuf.size()) {
rbuf.resize(rbuf.size() + send_i, 'r');
ucs::fill_random(sbuf, send_i);
check_pattern.insert(check_pattern.end(), sbuf.begin(),
sbuf.begin() + send_i);
ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), send_i);
sstatus = stream_send_nb(dt_desc);
EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
wait(sstatus);
ssize += send_i;
send_i *= 2;
}
progress();
if ((++recv_i % 2) || ((ssize - roffset) < dt_elem_size)) {
rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
if (rdata == NULL) {
continue;
}
memcpy(&rbuf[roffset], rdata, length);
ucp_stream_data_release(receiver().ep(), rdata);
} else {
ucp::data_type_desc_t dt_desc(datatype, &rbuf[roffset], ssize - roffset);
void *rreq = ucp_stream_recv_nb(receiver().ep(), dt_desc.buf(),
dt_desc.count(), dt_desc.dt(),
ucp_recv_cb, &length, 0);
ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq));
if (UCS_PTR_IS_PTR(rreq)) {
length = wait_stream_recv(rreq);
}
}
roffset += length;
} while (roffset < ssize);
EXPECT_EQ(roffset, ssize);
EXPECT_EQ(check_pattern, rbuf);
}
UCS_TEST_P(test_ucp_stream, send_recv_data) {
do_send_recv_data_test(DATATYPE);
}
UCS_TEST_P(test_ucp_stream, send_iov_recv_data) {
do_send_recv_data_test(DATATYPE_IOV);
}
UCS_TEST_P(test_ucp_stream, send_generic_recv_data) {
ucp_datatype_t dt;
ucs_status_t status;
status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
ASSERT_UCS_OK(status);
do_send_recv_data_test(dt);
ucp_dt_destroy(dt);
}
UCS_TEST_P(test_ucp_stream, send_recv_8) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint8_t));
do_send_recv_test<uint8_t, 0>(datatype);
do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_recv_16) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint16_t));
do_send_recv_test<uint16_t, 0>(datatype);
do_send_recv_test<uint16_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_recv_32) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint32_t));
do_send_recv_test<uint32_t, 0>(datatype);
do_send_recv_test<uint32_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_recv_64) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint64_t));
do_send_recv_test<uint64_t, 0>(datatype);
do_send_recv_test<uint64_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_recv_iov) {
do_send_recv_test<uint8_t, 0>(DATATYPE_IOV);
do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(DATATYPE_IOV);
}
UCS_TEST_P(test_ucp_stream, send_recv_generic) {
ucp_datatype_t dt;
ucs_status_t status;
status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, &context, &dt);
ASSERT_UCS_OK(status);
do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(dt);
ucp_dt_destroy(dt);
}
UCS_TEST_P(test_ucp_stream, send_exp_recv_8) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint8_t));
do_send_exp_recv_test<uint8_t, 0>(datatype);
do_send_exp_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_exp_recv_16) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint16_t));
do_send_exp_recv_test<uint16_t, 0>(datatype);
do_send_exp_recv_test<uint16_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_exp_recv_32) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint32_t));
do_send_exp_recv_test<uint32_t, 0>(datatype);
do_send_exp_recv_test<uint32_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_exp_recv_64) {
ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint64_t));
do_send_exp_recv_test<uint64_t, 0>(datatype);
do_send_exp_recv_test<uint64_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
}
UCS_TEST_P(test_ucp_stream, send_exp_recv_iov) {
do_send_exp_recv_test<uint8_t, 0>(DATATYPE_IOV);
do_send_exp_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(DATATYPE_IOV);
}
UCS_TEST_P(test_ucp_stream, send_recv_data_recv_8) {
do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint8_t)));
}
UCS_TEST_P(test_ucp_stream, send_recv_data_recv_16) {
do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint16_t)));
}
UCS_TEST_P(test_ucp_stream, send_recv_data_recv_32) {
do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint32_t)));
}
UCS_TEST_P(test_ucp_stream, send_recv_data_recv_64) {
do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint64_t)));
}
UCS_TEST_P(test_ucp_stream, send_recv_data_recv_iov) {
do_send_recv_data_recv_test(DATATYPE_IOV);
}
UCS_TEST_P(test_ucp_stream, send_zero_ending_iov_recv_data) {
const size_t min_size = UCS_KBYTE;
const size_t max_size = min_size * 64;
const size_t iov_num = 8; /* must be divisible by 4 without a
* remainder, caught on mlx5 based TLs
* where max_iov = 3 for zcopy multi
* protocol, where every posting includes:
* 1 header + 2 nonempty IOVs */
const size_t iov_num_nonempty = iov_num / 2;
std::vector<uint8_t> buf(max_size * 2);
ucs::fill_random(buf, buf.size());
std::vector<ucp_dt_iov_t> v(iov_num);
for (size_t size = min_size; size < max_size; ++size) {
size_t slen = 0;
for (size_t j = 0; j < iov_num; ++j) {
if ((j % 2) == 0) {
uint8_t *ptr = buf.data();
v[j].buffer = &(ptr[j * size / iov_num_nonempty]);
v[j].length = size / iov_num_nonempty;
slen += v[j].length;
} else {
v[j].buffer = NULL;
v[j].length = 0;
}
}
void *sreq = ucp_stream_send_nb(sender().ep(), &v[0], iov_num,
DATATYPE_IOV, ucp_send_cb, 0);
size_t rlen = 0;
while (rlen < slen) {
progress();
size_t length;
void *rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
EXPECT_FALSE(UCS_PTR_IS_ERR(rdata));
if (rdata != NULL) {
rlen += length;
ucp_stream_data_release(receiver().ep(), rdata);
}
}
wait(sreq);
}
}
UCP_INSTANTIATE_TEST_CASE(test_ucp_stream)
class test_ucp_stream_many2one : public test_ucp_stream_base {
protected:
struct request_wrapper_t {
request_wrapper_t(void *request, ucp::data_type_desc_t *dt_desc)
: m_req(request), m_dt_desc(dt_desc) {}
void *m_req;
ucp::data_type_desc_t *m_dt_desc;
};
public:
test_ucp_stream_many2one() : m_receiver_idx(3), m_nsenders(3) {
m_recv_data.resize(m_nsenders);
}
static ucp_params_t get_ctx_params() {
return test_ucp_stream::get_ctx_params();
}
virtual void init();
static void ucp_send_cb(void *request, ucs_status_t status) {}
static void ucp_recv_cb(void *request, ucs_status_t status, size_t length) {}
void do_send_worker_poll_test(ucp_datatype_t dt);
void do_send_recv_test(ucp_datatype_t dt);
protected:
static void erase_completed_reqs(std::vector<request_wrapper_t> &reqs);
ucs_status_ptr_t stream_send_nb(size_t sender_idx,
const ucp::data_type_desc_t& dt_desc);
size_t send_all_nb(ucp_datatype_t datatype, size_t n_iter,
std::vector<request_wrapper_t> &sreqs);
size_t send_all(ucp_datatype_t datatype, size_t n_iter);
void check_no_data();
std::set<ucp_ep_h> check_no_data(entity &e);
void check_recv_data(size_t n_iter, ucp_datatype_t dt);
std::vector<std::string> m_msgs;
std::vector<std::vector<char> > m_recv_data;
const size_t m_receiver_idx;
const size_t m_nsenders;
};
void test_ucp_stream_many2one::init()
{
if (is_self()) {
UCS_TEST_SKIP_R("self");
}
/* Skip entities creation */
test_base::init();
for (size_t i = 0; i < m_nsenders + 1; ++i) {
create_entity();
}
for (size_t i = 0; i < m_nsenders; ++i) {
e(i).connect(&e(m_receiver_idx), get_ep_params(), i);
ucp_ep_params_t recv_ep_param = get_ep_params();
recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
recv_ep_param.user_data = (void *)uintptr_t(i);
e(m_receiver_idx).connect(&e(i), recv_ep_param, i);
}
for (size_t i = 0; i < m_nsenders; ++i) {
m_msgs.push_back(std::string("sender_") + ucs::to_string(i));
}
}
void test_ucp_stream_many2one::do_send_worker_poll_test(ucp_datatype_t dt)
{
const size_t niter = 2018;
std::vector<request_wrapper_t> sreqs;
size_t total_len;
total_len = send_all_nb(dt, niter, sreqs);
/* Recv and progress all data */
do {
ssize_t count;
do {
const size_t max_eps = 10;
ucp_stream_poll_ep_t poll_eps[max_eps];
progress();
count = ucp_stream_worker_poll(e(m_receiver_idx).worker(),
poll_eps, max_eps, 0);
EXPECT_LE(0, count);
for (ssize_t i = 0; i < count; ++i) {
char *rdata;
size_t length;
while ((rdata = (char *)ucp_stream_recv_data_nb(poll_eps[i].ep,
&length)) != NULL) {
ASSERT_FALSE(UCS_PTR_IS_ERR(rdata));
size_t senser_idx = uintptr_t(poll_eps[i].user_data);
std::vector<char> &dst = m_recv_data[senser_idx];
dst.insert(dst.end(), rdata, rdata + length);
total_len -= length;
ucp_stream_data_release(poll_eps[i].ep, rdata);
}
}
} while (count > 0);
erase_completed_reqs(sreqs);
} while (!sreqs.empty() || (total_len != 0));
check_no_data();
check_recv_data(niter, dt);
}
void test_ucp_stream_many2one::do_send_recv_test(ucp_datatype_t dt)
{
const size_t niter = 2018;
std::vector<size_t> roffsets(m_nsenders, 0);
std::vector<ucp::data_type_desc_t> dt_rdescs(m_nsenders);
std::vector<std::pair<size_t, request_wrapper_t> > rreqs;
std::vector<request_wrapper_t> sreqs;
size_t total_sdata;
ASSERT_FALSE(m_msgs.empty());
/* Do preposts */
for (size_t i = 0; i < m_nsenders; ++i) {
m_recv_data[i].resize(m_msgs[i].length() * niter + 1);
ucp::data_type_desc_t &rdesc = dt_rdescs[i].make(dt,
&m_recv_data[i][roffsets[i]],
m_recv_data[i].size());
size_t length;
void *rreq = ucp_stream_recv_nb(e(m_receiver_idx).ep(0, i),
rdesc.buf(), rdesc.count(), rdesc.dt(),
ucp_recv_cb, &length, 0);
EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
rreqs.push_back(std::make_pair(i, request_wrapper_t(rreq, &rdesc)));
}
total_sdata = send_all_nb(dt, niter, sreqs);
/* Recv and progress all the rest of data */
do {
ssize_t count;
/* wait rreqs */
for (size_t i = 0; i < rreqs.size(); ++i) {
roffsets[rreqs[i].first] += wait_stream_recv(rreqs[i].second.m_req);
}
rreqs.clear();
progress();
const size_t max_eps = 10;
ucp_stream_poll_ep_t poll_eps[max_eps];
count = ucp_stream_worker_poll(e(m_receiver_idx).worker(),
poll_eps, max_eps, 0);
EXPECT_LE(0, count);
EXPECT_LE(size_t(count), m_nsenders);
for (ssize_t i = 0; i < count; ++i) {
bool again = true;
while (again) {
size_t sender_idx = uintptr_t(poll_eps[i].user_data);
size_t &roffset = roffsets[sender_idx];
ucp::data_type_desc_t &dt_desc =
dt_rdescs[sender_idx].forward_to(roffset);
EXPECT_TRUE(dt_desc.is_valid());
size_t length;
void *rreq = ucp_stream_recv_nb(poll_eps[i].ep,
dt_desc.buf(),
dt_desc.count(),
dt_desc.dt(),
ucp_recv_cb, &length, 0);
EXPECT_FALSE(UCS_PTR_IS_ERR(rreq));
if (rreq == NULL) {
EXPECT_LT(size_t(0), length);
roffset += length;
if (ssize_t(length) < dt_desc.buf_length()) {
continue; /* Need to drain the EP */
}
} else {
rreqs.push_back(std::make_pair(sender_idx,
request_wrapper_t(rreq,
&dt_desc)));
}
again = false;
}
}
erase_completed_reqs(sreqs);
} while (!rreqs.empty() || !sreqs.empty() ||
(total_sdata > std::accumulate(roffsets.begin(),
roffsets.end(), 0ul)));
EXPECT_EQ(total_sdata, std::accumulate(roffsets.begin(),
roffsets.end(), 0ul));
check_no_data();
check_recv_data(niter, dt);
}
ucs_status_ptr_t
test_ucp_stream_many2one::stream_send_nb(size_t sender_idx,
const ucp::data_type_desc_t& dt_desc)
{
return ucp_stream_send_nb(m_entities.at(sender_idx).ep(), dt_desc.buf(),
dt_desc.count(), dt_desc.dt(), ucp_send_cb, 0);
}
size_t
test_ucp_stream_many2one::send_all_nb(ucp_datatype_t datatype, size_t n_iter,
std::vector<request_wrapper_t> &sreqs)
{
size_t total = 0;
/* Send many times in round robin */
for (size_t i = 0; i < n_iter; ++i) {
for (size_t sender_idx = 0; sender_idx < m_nsenders; ++sender_idx) {
const void *buf = m_msgs[sender_idx].c_str();
size_t len = m_msgs[sender_idx].length();
if (i == (n_iter - 1)) {
++len;
}
ucp::data_type_desc_t *dt_desc = new ucp::data_type_desc_t(datatype,
buf,
len);
void *sreq = stream_send_nb(sender_idx, *dt_desc);
total += len;
if (UCS_PTR_IS_PTR(sreq)) {
sreqs.push_back(request_wrapper_t(sreq, dt_desc));
} else {
EXPECT_FALSE(UCS_PTR_IS_ERR(sreq));
delete dt_desc;
}
}
}
return total;
}
size_t
test_ucp_stream_many2one::send_all(ucp_datatype_t datatype, size_t n_iter)
{
std::vector<request_wrapper_t> sreqs;
size_t total;
total = send_all_nb(datatype, n_iter, sreqs);
while (!sreqs.empty()) {
progress();
erase_completed_reqs(sreqs);
}
return total;
}
void test_ucp_stream_many2one::check_no_data()
{
std::set<ucp_ep_h> check;
for (size_t i = 0; i <= m_receiver_idx; ++i) {
std::set<ucp_ep_h> check_e = check_no_data(e(i));
check.insert(check_e.begin(), check_e.end());
}
EXPECT_EQ(size_t(0), check.size());
}
std::set<ucp_ep_h> test_ucp_stream_many2one::check_no_data(entity &e)
{
const size_t max_eps = 10;
ucp_stream_poll_ep_t poll_eps[max_eps];
std::set<ucp_ep_h> ret;
std::list<ucp_ep_h> check_list;
while (progress());
ssize_t count = ucp_stream_worker_poll(m_entities.at(m_receiver_idx).worker(),
poll_eps, max_eps, 0);
EXPECT_GE(count, ssize_t(0));
for (ssize_t i = 0; i < count; ++i) {
ret.insert(poll_eps[i].ep);
}
for (int i = 0; i < e.get_num_workers(); ++i) {
for (int j = 0; j < e.get_num_eps(); ++j) {
check_list.push_back(e.ep(i, j));
}
}
std::list<ucp_ep_h>::const_iterator check_it = check_list.begin();
while (check_it != check_list.end()) {
EXPECT_EQ(ret.end(), ret.find(*check_it));
++check_it;
}
return ret;
}
void test_ucp_stream_many2one::check_recv_data(size_t n_iter, ucp_datatype_t dt)
{
for (size_t i = 0; i < m_nsenders; ++i) {
std::string test = std::string("sender_") + ucs::to_string(i);
const std::string str(&m_recv_data[i].front());
if (UCP_DT_IS_GENERIC(dt)) {
std::vector<char> test_gen;
for (size_t j = 0; j < test.length(); ++j) {
test_gen.push_back(char(j));
}
test_gen.push_back('\0');
test = std::string(test_gen.data());
}
size_t next = 0;
for (size_t j = 0; j < n_iter; ++j) {
size_t match = str.find(test, next);
EXPECT_NE(std::string::npos, match) << "failed on sender " << i
<< " iteration " << j;
if (match == std::string::npos) {
break;
}
EXPECT_EQ(next, match);
next += test.length();
}
EXPECT_EQ(next, str.length()); /* nothing more */
}
}
void
test_ucp_stream_many2one::erase_completed_reqs(std::vector<request_wrapper_t> &reqs)
{
std::vector<request_wrapper_t>::iterator i = reqs.begin();
while (i != reqs.end()) {
ucs_status_t status = ucp_request_check_status(i->m_req);
if (status != UCS_INPROGRESS) {
EXPECT_EQ(UCS_OK, status);
ucp_request_free(i->m_req);
delete i->m_dt_desc;
i = reqs.erase(i);
} else {
++i;
}
}
}
UCS_TEST_P(test_ucp_stream_many2one, drop_data) {
send_all(DATATYPE, 10);
ASSERT_EQ(m_receiver_idx, m_nsenders);
for (size_t i = 0; i <= m_receiver_idx; ++i) {
flush_worker(e(i));
}
/* destroy 1 connection */
entity::ep_destructor(m_entities.at(0).ep(),
&m_entities.at(0));
entity::ep_destructor(m_entities.at(m_receiver_idx).ep(),
&m_entities.at(0));
m_entities.at(0).revoke_ep();
m_entities.at(m_receiver_idx).revoke_ep(0, 0);
/* wait for 1-st byte on the last EP to be sure the network packets have
been arrived */
uint8_t check;
size_t check_length;
ucp_ep_h last_ep = m_entities.at(m_receiver_idx).ep(0, m_nsenders - 1);
void *check_req = ucp_stream_recv_nb(last_ep, &check, 1, DATATYPE,
ucp_recv_cb, &check_length, 0);
EXPECT_FALSE(UCS_PTR_IS_ERR(check_req));
if (UCS_PTR_IS_PTR(check_req)) {
wait_stream_recv(check_req);
}
/* data from disconnected EP should be dropped */
std::set<ucp_ep_h> others = check_no_data(m_entities.at(0));
/* since ordering between EPs is not guaranteed, some data may be still in
* the network or buffered by transport */
EXPECT_LE(others.size(), m_nsenders - 1);
/* reconnect */
m_entities.at(0).connect(&m_entities.at(m_receiver_idx), get_ep_params(), 0);
ucp_ep_params_t recv_ep_param = get_ep_params();
recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
recv_ep_param.user_data = (void *)uintptr_t(0xdeadbeef);
e(m_receiver_idx).connect(&e(0), recv_ep_param, 0);
/* send again */
send_all(DATATYPE, 10);
for (size_t i = 0; i <= m_receiver_idx; ++i) {
flush_worker(e(i));
}
/* Need to poll out all incoming data from transport layer, see PR #2048 */
while (progress() > 0);
}
UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll) {
do_send_worker_poll_test(DATATYPE);
}
UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll_iov) {
do_send_worker_poll_test(DATATYPE_IOV);
}
UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll_generic) {
ucp_datatype_t dt;
ucs_status_t status;
status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
ASSERT_UCS_OK(status);
do_send_worker_poll_test(dt);
ucp_dt_destroy(dt);
}
UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb) {
do_send_recv_test(DATATYPE);
}
UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb_iov) {
do_send_recv_test(DATATYPE_IOV);
}
UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb_generic) {
ucp_datatype_t dt;
ucs_status_t status;
status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
ASSERT_UCS_OK(status);
do_send_recv_test(dt);
ucp_dt_destroy(dt);
}
UCP_INSTANTIATE_TEST_CASE(test_ucp_stream_many2one)