/**
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
#include <common/test.h>
#include <uct/uct_test.h>
extern "C" {
#include <uct/api/uct.h>
#include <uct/tcp/tcp.h>
}
class test_uct_tcp : public uct_test {
public:
void init() {
if (RUNNING_ON_VALGRIND) {
modify_config("TX_SEG_SIZE", "1kb");
modify_config("RX_SEG_SIZE", "1kb");
}
uct_test::init();
m_ent = uct_test::create_entity(0);
m_entities.push_back(m_ent);
m_tcp_iface = (uct_tcp_iface*)m_ent->iface();
}
size_t get_accepted_conn_num(entity& ent) {
size_t num = 0;
uct_tcp_ep_t *ep;
ucs_list_for_each(ep, &m_tcp_iface->ep_list, list) {
num += (ep->conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER);
}
return num;
}
ucs_status_t post_recv(int fd, bool nb = false) {
uint8_t msg;
size_t msg_size = sizeof(msg);
ucs_status_t status;
scoped_log_handler slh(wrap_errors_logger);
if (nb) {
status = ucs_socket_recv_nb(fd, &msg, &msg_size, NULL, NULL);
} else {
status = ucs_socket_recv(fd, &msg, msg_size, NULL, NULL);
}
return status;
}
void post_send(int fd, const std::vector<char> &buf) {
scoped_log_handler slh(wrap_errors_logger);
ucs_status_t status = ucs_socket_send(fd, &buf[0],
buf.size(), NULL, NULL);
// send can be OK or fail when a connection was closed by a peer
// before all data were sent
ASSERT_TRUE((status == UCS_OK) ||
(status == UCS_ERR_IO_ERROR));
}
void detect_conn_reset(int fd) {
// Try to receive something on this socket fd - it has to be failed
ucs_status_t status = post_recv(fd);
ASSERT_TRUE((status == UCS_ERR_IO_ERROR) ||
(status == UCS_ERR_CANCELED));
EXPECT_EQ(0, ucs_socket_is_connected(fd));
}
void test_listener_flood(entity& test_entity, size_t max_conn,
size_t msg_size) {
std::vector<int> fds;
std::vector<char> buf;
if (msg_size > 0) {
buf.resize(msg_size + sizeof(uct_tcp_am_hdr_t));
std::fill(buf.begin(), buf.end(), 0);
init_data(&buf[0], buf.size());
}
setup_conns_to_entity(test_entity, max_conn, fds);
size_t handled = 0;
for (std::vector<int>::const_iterator iter = fds.begin();
iter != fds.end(); ++iter) {
size_t sent_length = 0;
do {
if (msg_size > 0) {
post_send(*iter, buf);
sent_length += buf.size();
} else {
close(*iter);
}
// If it was sent >= the length of the magic number or sending
// is not required by the current test, wait until connection
// is destroyed. Otherwise, need to send more data
if ((msg_size == 0) || (sent_length >= sizeof(uint64_t))) {
handled++;
while (get_accepted_conn_num(test_entity) != (max_conn - handled)) {
sched_yield();
progress();
}
} else {
// Peers still have to be connected
ucs_status_t status = post_recv(*iter, true);
EXPECT_TRUE((status == UCS_OK) ||
(status == UCS_ERR_NO_PROGRESS));
EXPECT_EQ(1, ucs_socket_is_connected(*iter));
}
} while ((msg_size != 0) && (sent_length < sizeof(uint64_t)));
}
// give a chance to close all connections
while (!ucs_list_is_empty(&m_tcp_iface->ep_list)) {
sched_yield();
progress();
}
// TCP has to reject all connections and forget EPs that were
// created after accept():
// - EP list has to be empty
EXPECT_EQ(1, ucs_list_is_empty(&m_tcp_iface->ep_list));
// - all connections have to be destroyed (if wasn't closed
// yet by the clients)
if (msg_size > 0) {
// if we sent data during the test, close socket fd here
while (!fds.empty()) {
int fd = fds.back();
fds.pop_back();
detect_conn_reset(fd);
close(fd);
}
}
}
void setup_conns_to_entity(entity& to, size_t max_conn,
std::vector<int> &fds) {
for (size_t i = 0; i < max_conn; i++) {
int fd = setup_conn_to_entity(to, i + 1lu);
fds.push_back(fd);
// give a chance to finish all connections
while (get_accepted_conn_num(to) != (i + 1lu)) {
sched_yield();
progress();
}
EXPECT_EQ(1, ucs_socket_is_connected(fd));
}
}
private:
void init_data(void *buf, size_t msg_size) {
uct_tcp_am_hdr_t *tcp_am_hdr;
ASSERT_TRUE(msg_size >= sizeof(*tcp_am_hdr));
tcp_am_hdr = static_cast<uct_tcp_am_hdr_t*>(buf);
tcp_am_hdr->am_id = std::numeric_limits<uint8_t>::max();
tcp_am_hdr->length = msg_size;
}
int connect_to_entity(entity& to) {
uct_device_addr_t *dev_addr;
uct_iface_addr_t *iface_addr;
ucs_status_t status;
dev_addr = (uct_device_addr_t*)malloc(to.iface_attr().device_addr_len);
iface_addr = (uct_iface_addr_t*)malloc(to.iface_attr().iface_addr_len);
status = uct_iface_get_device_address(to.iface(), dev_addr);
ASSERT_UCS_OK(status);
status = uct_iface_get_address(to.iface(), iface_addr);
ASSERT_UCS_OK(status);
struct sockaddr_in dest_addr;
dest_addr.sin_family = AF_INET;
dest_addr.sin_port = *(in_port_t*)iface_addr;
dest_addr.sin_addr = *(struct in_addr*)dev_addr;
int fd;
status = ucs_socket_create(AF_INET, SOCK_STREAM, &fd);
ASSERT_UCS_OK(status);
status = ucs_socket_connect(fd, (const struct sockaddr*)&dest_addr);
ASSERT_UCS_OK(status);
status = ucs_sys_fcntl_modfl(fd, O_NONBLOCK, 0);
ASSERT_UCS_OK(status);
free(iface_addr);
free(dev_addr);
return fd;
}
int setup_conn_to_entity(entity &to, size_t sn = 1) {
int fd = -1;
do {
if (fd != -1) {
close(fd);
}
fd = connect_to_entity(to);
EXPECT_NE(-1, fd);
// give a chance to finish the connection
while (get_accepted_conn_num(to) != sn) {
sched_yield();
progress();
ucs_status_t status = post_recv(fd, true);
if ((status != UCS_OK) &&
(status != UCS_ERR_NO_PROGRESS)) {
break;
}
}
} while (!ucs_socket_is_connected(fd));
EXPECT_EQ(1, ucs_socket_is_connected(fd));
return fd;
}
protected:
uct_tcp_iface *m_tcp_iface;
entity *m_ent;
};
UCS_TEST_P(test_uct_tcp, listener_flood_connect_and_send_large) {
const size_t max_conn =
ucs_min(static_cast<size_t>(max_connections()), 128lu) /
ucs::test_time_multiplier();
const size_t msg_size = m_tcp_iface->config.rx_seg_size * 4;
test_listener_flood(*m_ent, max_conn, msg_size);
}
UCS_TEST_P(test_uct_tcp, listener_flood_connect_and_send_small) {
const size_t max_conn =
ucs_min(static_cast<size_t>(max_connections()), 128lu) /
ucs::test_time_multiplier();
// It should be less than length of the expected magic number by TCP
const size_t msg_size = 1;
test_listener_flood(*m_ent, max_conn, msg_size);
}
UCS_TEST_P(test_uct_tcp, listener_flood_connect_and_close) {
const size_t max_conn =
ucs_min(static_cast<size_t>(max_connections()), 128lu) /
ucs::test_time_multiplier();
test_listener_flood(*m_ent, max_conn, 0);
}
_UCT_INSTANTIATE_TEST_CASE(test_uct_tcp, tcp)