/**
* Copyright (C) Mellanox Technologies Ltd. 2018. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
*/
#include "sa_base.h"
#include <iostream>
#include <vector>
#include <fstream>
#include <sstream>
#include <cstring>
#include <map>
#include <sys/epoll.h>
#include <getopt.h>
#include <netdb.h>
#include <unistd.h>
class application {
public:
class usage_exception : public error {
public:
usage_exception(const std::string& message = "");
};
application(int argc, char **argv);
int run();
static void usage(const std::string& error);
private:
typedef struct {
std::string hostname;
int port;
} dest_t;
typedef std::vector<dest_t> dest_vec_t;
enum connection_type {
CONNECTION_CLIENT,
CONNECTION_SERVER
};
struct params {
params() : port(0),
total_conns(1000),
conn_ratio(1.5),
request_size(32),
response_size(1024) {
}
std::string mode;
int port;
int total_conns;
double conn_ratio;
size_t request_size;
size_t response_size;
dest_vec_t dests;
};
struct connection_state {
conn_ptr_t conn_ptr;
connection_type conn_type;
size_t bytes_sent;
size_t bytes_recvd;
std::string send_data;
std::string recv_data;
};
typedef std::shared_ptr<connection_state> conn_state_ptr_t;
typedef std::map<uint64_t, conn_state_ptr_t> conn_map_t;
void parse_hostfile(const std::string& filename);
void initiate_connections();
int max_conns_inflight() const;
void create_worker();
void add_connection(conn_ptr_t conn_ptr, connection_type conn_type);
conn_ptr_t connect(const dest_t& dst);
void advance_connection(conn_state_ptr_t s, uint32_t events);
void connection_completed(conn_state_ptr_t s);
static void pton(const dest_t& dst, struct sockaddr_storage& saddr,
socklen_t &addrlen);
template <typename O>
friend typename O::__basic_ostream& operator<<(O& os, connection_type conn_type);
params m_params;
std::shared_ptr<worker> m_worker;
evpoll_set m_evpoll;
conn_map_t m_connections;
int m_num_conns_inflight;
int m_num_conns_started;
};
application::usage_exception::usage_exception(const std::string& message) :
error(message) {
};
application::application(int argc, char **argv) : m_num_conns_inflight(0),
m_num_conns_started(0) {
int c;
while ( (c = getopt(argc, argv, "p:f:m:r:n:S:s:vh")) != -1 ) {
switch (c) {
case 'p':
m_params.port = atoi(optarg);
break;
case 'f':
parse_hostfile(optarg);
break;
case 'm':
m_params.mode = optarg;
break;
case 'r':
m_params.conn_ratio = atof(optarg);
break;
case 'n':
m_params.total_conns = atoi(optarg);
break;
case 'S':
m_params.request_size = atoi(optarg);
break;
case 's':
m_params.response_size = atoi(optarg);
break;
case 'v':
log::more_verbose();
break;
default:
throw usage_exception();
}
}
if (m_params.mode.empty()) {
throw usage_exception("missing mode argument");
}
if (m_params.dests.empty()) {
throw usage_exception("no remote destinations specified");
}
if (m_params.port == 0) {
throw usage_exception("local port not specified");
}
}
int application::run() {
LOG_INFO << "starting application with "
<< max_conns_inflight() << " simultaneous connections, "
<< m_params.total_conns << " total";
create_worker();
while ((m_num_conns_started > m_params.total_conns) || !m_connections.empty()) {
initiate_connections();
m_worker->wait(m_evpoll,
[this](conn_ptr_t conn) {
LOG_DEBUG << "accepted new connection";
add_connection(conn, CONNECTION_SERVER);
},
[this](uint64_t conn_id, uint32_t events) {
LOG_DEBUG << "new event on connection id "
<< conn_id << " events "
<< ((events & EPOLLIN ) ? "i" : "")
<< ((events & EPOLLOUT) ? "o" : "")
<< ((events & EPOLLERR) ? "e" : "")
;
advance_connection(m_connections.at(conn_id), events);
},
-1);
}
LOG_INFO << "all connections completed";
m_worker.reset();
return 0;
}
void application::create_worker() {
struct sockaddr_in inaddr_any;
memset(&inaddr_any, 0, sizeof(inaddr_any));
inaddr_any.sin_family = AF_INET;
inaddr_any.sin_port = htons(m_params.port);
inaddr_any.sin_addr.s_addr = INADDR_ANY;
m_worker = worker::make(m_params.mode, reinterpret_cast<struct sockaddr *>(&inaddr_any),
sizeof(inaddr_any));
m_worker->add_to_evpoll(m_evpoll);
}
std::shared_ptr<connection> application::connect(const dest_t& dst) {
struct sockaddr_storage saddr;
socklen_t addrlen;
pton(dst, saddr, addrlen);
return m_worker->connect(reinterpret_cast<const struct sockaddr*>(&saddr),
addrlen);
}
template <typename O>
typename O::__basic_ostream& operator<<(O& os, application::connection_type conn_type) {
switch (conn_type) {
case application::CONNECTION_CLIENT:
return os << "client";
case application::CONNECTION_SERVER:
return os << "server";
default:
return os;
}
}
void application::add_connection(conn_ptr_t conn_ptr, connection_type conn_type) {
auto s = std::make_shared<connection_state>();
s->conn_type = conn_type;
s->conn_ptr = conn_ptr;
s->bytes_sent = 0;
s->bytes_recvd = 0;
switch (s->conn_type) {
case CONNECTION_CLIENT:
s->send_data.assign(m_params.request_size, 'r');
s->recv_data.resize(m_params.response_size);
break;
case CONNECTION_SERVER:
s->send_data.resize(m_params.response_size);
s->recv_data.resize(m_params.request_size);
break;
}
LOG_DEBUG << "add " << conn_type << " connection with id " << conn_ptr->id();
conn_ptr->add_to_evpoll(m_evpoll);
m_connections[conn_ptr->id()] = s;
advance_connection(s, 0);
}
void application::initiate_connections() {
int max = max_conns_inflight();
while ((m_num_conns_started < m_params.total_conns) && (m_num_conns_inflight < max)) {
/* coverity[dont_call] */
const dest_t& dest = m_params.dests[::rand() % m_params.dests.size()];
++m_num_conns_started;
++m_num_conns_inflight;
LOG_DEBUG << "connecting to " << dest.hostname << ":" << dest.port;
add_connection(connect(dest), CONNECTION_CLIENT);
}
}
int application::max_conns_inflight() const {
return m_params.conn_ratio * m_params.dests.size() + 0.5;
}
void application::advance_connection(conn_state_ptr_t s, uint32_t events) {
LOG_DEBUG << "advance " << s->conn_type << " connection id " << s->conn_ptr->id()
<< " total sent " << s->bytes_sent << ", received " << s->bytes_recvd;
switch (s->conn_type) {
case CONNECTION_CLIENT:
if (s->bytes_sent < m_params.request_size) {
/* more data should be sent */
size_t nsent = s->conn_ptr->send(&s->send_data[s->bytes_sent],
m_params.request_size - s->bytes_sent);
LOG_DEBUG << "sent " << nsent << " bytes on connection id "
<< s->conn_ptr->id();
s->bytes_sent += nsent;
}
if (events & EPOLLIN) {
size_t nrecv = s->conn_ptr->recv(&s->recv_data[s->bytes_recvd],
m_params.response_size - s->bytes_recvd);
LOG_DEBUG << "received " << nrecv << " bytes on connection id "
<< s->conn_ptr->id();
s->bytes_recvd += nrecv;
}
if (s->bytes_recvd == m_params.response_size) {
connection_completed(s);
}
break;
case CONNECTION_SERVER:
if (events & EPOLLIN) {
size_t nrecv = s->conn_ptr->recv(&s->recv_data[s->bytes_recvd],
m_params.request_size - s->bytes_recvd);
LOG_DEBUG << "received " << nrecv << " bytes on connection id "
<< s->conn_ptr->id();
s->bytes_recvd += nrecv;
}
if ((s->bytes_recvd == m_params.request_size) &&
(s->bytes_sent < m_params.response_size)) {
/* more data should be sent */
size_t nsent = s->conn_ptr->send(&s->send_data[s->bytes_sent],
m_params.response_size - s->bytes_sent);
LOG_DEBUG << "sent " << nsent << " bytes on connection id "
<< s->conn_ptr->id();
s->bytes_sent += nsent;
}
if (s->conn_ptr->is_closed()) {
connection_completed(s);
}
break;
}
}
void application::connection_completed(conn_state_ptr_t s) {
LOG_DEBUG << "completed " << s->conn_type << " connection id " << s->conn_ptr->id();
m_connections.erase(s->conn_ptr->id());
--m_num_conns_inflight;
}
void application::pton(const dest_t& dst, struct sockaddr_storage& saddr,
socklen_t &addrlen) {
struct hostent *he = gethostbyname(dst.hostname.c_str());
if (he == NULL || he->h_addr_list == NULL) {
throw error("host " + dst.hostname + " not found: "+ hstrerror(h_errno));
}
memset(&saddr, 0, sizeof(saddr));
saddr.ss_family = he->h_addrtype;
void *addr;
int addr_datalen = 0;
switch (saddr.ss_family) {
case AF_INET:
reinterpret_cast<struct sockaddr_in*>(&saddr)->sin_port =
htons(dst.port);
/* cppcheck-suppress internalAstError */
addr = &reinterpret_cast<struct sockaddr_in*>(&saddr)->sin_addr;
addrlen = sizeof(struct sockaddr_in);
addr_datalen = sizeof(struct in_addr);
break;
case AF_INET6:
reinterpret_cast<struct sockaddr_in6*>(&saddr)->sin6_port =
htons(dst.port);
addr = &reinterpret_cast<struct sockaddr_in6*>(&saddr)->sin6_addr;
addrlen = sizeof(struct sockaddr_in6);
addr_datalen = sizeof(struct in6_addr);
break;
default:
throw error("unsupported address family");
}
if (he->h_length != addr_datalen) {
throw error("mismatching address length");
}
memcpy(addr, he->h_addr_list[0], addr_datalen);
}
void application::usage(const std::string& error) {
if (!error.empty()) {
std::cout << "Error: " << error << std::endl;
std::cout << std::endl;
}
params defaults;
std::cout << "Usage: ./sa [ options ]" << std::endl;
std::cout << "Options:" << std::endl;
std::cout << " -m <mode> Application mode (tcp)" << std::endl;
std::cout << " -p <port> Local port number to listen on" << std::endl;
std::cout << " -f <file> File with list of hosts and ports to connect to" << std::endl;
std::cout << " Each line in the file is formatter as follows:" << std::endl;
std::cout << " <address> <port>" << std::endl;
std::cout << " -r <ratio> How many in-flight connection to hold as multiple" << std::endl;
std::cout << " of number of possible destinations (" << defaults.conn_ratio << ")" << std::endl;
std::cout << " -n <count> How many total exchanges to perform (" << defaults.total_conns << ")" << std::endl;
std::cout << " -S <size> Request message size, in bytes (" << defaults.request_size << ")" << std::endl;
std::cout << " -s <size> Response message size, in bytes (" << defaults.response_size << ")" << std::endl;
std::cout << " -v Increase verbosity level (may be specified several times)" << std::endl;
}
void application::parse_hostfile(const std::string& filename) {
std::ifstream f(filename.c_str());
if (!f) {
throw error("failed to open '" + filename + "'");
}
/*
* Each line in the file contains 2 whitespace-separated tokens: host-name
* and port number.
*/
std::string line;
int lineno = 1;
while (std::getline(f, line)) {
std::stringstream ss(line);
if (line.empty()) {
continue;
}
dest_t dest;
if ((ss >> dest.hostname) && (ss >> dest.port)) {
m_params.dests.push_back(dest);
} else {
std::stringstream errss;
errss << "syntax error in file '" << filename << "' line " << lineno <<
" near `" << line << "'";
throw error(errss.str());
}
++lineno;
}
}
int main(int argc, char **argv)
{
try {
application app(argc, argv);
return app.run();
} catch (application::usage_exception& e) {
application::usage(e.what());
return -127;
} catch (error& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
}