Blob Blame History Raw
/**
* Copyright (C) Mellanox Technologies Ltd. 2001-2014.  ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/

#ifndef UCP_TEST_H_
#define UCP_TEST_H_

#include <ucp/api/ucp.h>
#include <ucs/time/time.h>
#include <common/mem_buffer.h>

/* ucp version compile time test */
#if (UCP_API_VERSION != UCP_VERSION(UCP_API_MAJOR,UCP_API_MINOR))
#error possible bug in UCP version
#endif

#include <common/test.h>

#include <queue>

#define MT_TEST_NUM_THREADS       4


namespace ucp {
extern const uint32_t MAGIC;
}


struct ucp_test_param {
    ucp_params_t              ctx_params;
    std::vector<std::string>  transports;
    int                       variant;
    int                       thread_type;
};

class ucp_test_base : public ucs::test_base {
public:
    enum {
        SINGLE_THREAD = 42,
        MULTI_THREAD_CONTEXT, /* workers are single-threaded, context is mt-shared */
        MULTI_THREAD_WORKER   /* workers are multi-threaded, cotnext is mt-single */
    };

    class entity {
        typedef std::vector<ucs::handle<ucp_ep_h, entity *> > ep_vec_t;
        typedef std::vector<std::pair<ucs::handle<ucp_worker_h>,
                                      ep_vec_t> > worker_vec_t;

    public:
        typedef enum {
            LISTEN_CB_EP,       /* User's callback accepts ucp_ep_h */
            LISTEN_CB_CONN,     /* User's callback accepts ucp_conn_request_h */
            LISTEN_CB_REJECT    /* User's callback rejects ucp_conn_request_h */
        } listen_cb_type_t;

        entity(const ucp_test_param& test_param, ucp_config_t* ucp_config,
               const ucp_worker_params_t& worker_params,
               const ucp_test_base* test_owner);

        ~entity();

        void connect(const entity* other, const ucp_ep_params_t& ep_params,
                     int ep_idx = 0, int do_set_ep = 1);

        ucp_ep_h accept(ucp_worker_h worker, ucp_conn_request_h conn_request);

        void* modify_ep(const ucp_ep_params_t& ep_params, int worker_idx = 0,
                       int ep_idx = 0);

        void* flush_ep_nb(int worker_index = 0, int ep_index = 0) const;

        void* flush_worker_nb(int worker_index = 0) const;

        void fence(int worker_index = 0) const;

        void* disconnect_nb(int worker_index = 0, int ep_index = 0,
                            enum ucp_ep_close_mode mode = UCP_EP_CLOSE_MODE_FLUSH) const;

        void destroy_worker(int worker_index = 0);

        ucs_status_t listen(listen_cb_type_t cb_type,
                            const struct sockaddr *saddr, socklen_t addrlen,
                            const ucp_ep_params_t& ep_params,
                            int worker_index = 0);

        ucp_ep_h ep(int worker_index = 0, int ep_index = 0) const;

        ucp_ep_h revoke_ep(int worker_index = 0, int ep_index = 0) const;

        ucp_worker_h worker(int worker_index = 0) const;

        ucp_context_h ucph() const;

        ucp_listener_h listenerh() const;

        unsigned progress(int worker_index = 0);

        int get_num_workers() const;

        int get_num_eps(int worker_index = 0) const;

        void add_err(ucs_status_t status);

        const size_t &get_err_num_rejected() const;

        const size_t &get_err_num() const;

        void warn_existing_eps() const;

        double set_ib_ud_timeout(double timeout_sec);

        void cleanup();

        static void ep_destructor(ucp_ep_h ep, entity *e);

    protected:
        ucs::handle<ucp_context_h>      m_ucph;
        worker_vec_t                    m_workers;
        ucs::handle<ucp_listener_h>     m_listener;
        std::queue<ucp_conn_request_h>  m_conn_reqs;
        size_t                          m_err_cntr;
        size_t                          m_rejected_cntr;
        ucs::handle<ucp_ep_params_t*>   m_server_ep_params;

    private:
        static void empty_send_completion(void *r, ucs_status_t status);
        static void accept_ep_cb(ucp_ep_h ep, void *arg);
        static void accept_conn_cb(ucp_conn_request_h conn_req, void *arg);
        static void reject_conn_cb(ucp_conn_request_h conn_req, void *arg);

        void set_ep(ucp_ep_h ep, int worker_index, int ep_index);
    };
};

/**
 * UCP test
 */
class ucp_test : public ucp_test_base,
                 public ::testing::TestWithParam<ucp_test_param>,
                 public ucs::entities_storage<ucp_test_base::entity> {

    friend class ucp_test_base::entity;

public:
    enum {
        DEFAULT_PARAM_VARIANT = 0
    };

    UCS_TEST_BASE_IMPL;

    ucp_test();
    virtual ~ucp_test();

    ucp_config_t* m_ucp_config;

    static std::vector<ucp_test_param>
    enum_test_params(const ucp_params_t& ctx_params,
                     const std::string& name,
                     const std::string& test_case_name,
                     const std::string& tls);

    static ucp_params_t get_ctx_params();
    virtual ucp_worker_params_t get_worker_params();
    virtual ucp_ep_params_t get_ep_params();

    static void
    generate_test_params_variant(const ucp_params_t& ctx_params,
                                 const std::string& name,
                                 const std::string& test_case_name,
                                 const std::string& tls,
                                 int variant,
                                 std::vector<ucp_test_param>& test_params,
                                 int thread_type = SINGLE_THREAD);

    virtual void modify_config(const std::string& name, const std::string& value,
                               bool optional = false);
    void stats_activate();
    void stats_restore();

private:
    static void set_ucp_config(ucp_config_t *config,
                               const ucp_test_param& test_param);
    static bool check_test_param(const std::string& name,
                                 const std::string& test_case_name,
                                 const ucp_test_param& test_param);

protected:
    virtual void init();
    bool is_self() const;
    virtual void cleanup();
    virtual bool has_transport(const std::string& tl_name) const;
    bool has_only_transports(const std::vector<std::string>& tl_names) const;
    entity* create_entity(bool add_in_front = false);
    entity* create_entity(bool add_in_front, const ucp_test_param& test_param);
    unsigned progress(int worker_index = 0) const;
    void short_progress_loop(int worker_index = 0) const;
    void flush_ep(const entity &e, int worker_index = 0, int ep_index = 0);
    void flush_worker(const entity &e, int worker_index = 0);
    void disconnect(const entity& entity);
    void wait(void *req, int worker_index = 0);
    void set_ucp_config(ucp_config_t *config);
    int max_connections();

    static void err_handler_cb(void *arg, ucp_ep_h ep, ucs_status_t status) {
        entity *e = reinterpret_cast<entity*>(arg);
        e->add_err(status);
    }

    template <typename T>
    void wait_for_flag(volatile T *flag, double timeout = 10.0) {
        ucs_time_t loop_end_limit = ucs_get_time() + ucs_time_from_sec(timeout);
        while ((ucs_get_time() < loop_end_limit) && (!(*flag))) {
            short_progress_loop();
        }
    }

    static const ucp_datatype_t DATATYPE;
    static const ucp_datatype_t DATATYPE_IOV;

protected:
    class mapped_buffer : public mem_buffer {
    public:
        mapped_buffer(size_t size, const entity& entity, int flags = 0,
                      ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_HOST);
        virtual ~mapped_buffer();

        ucs::handle<ucp_rkey_h> rkey(const entity& entity) const;

        ucp_mem_h memh() const;

    private:
        const entity& m_entity;
        ucp_mem_h     m_memh;
        void*         m_rkey_buffer;
    };
};


std::ostream& operator<<(std::ostream& os, const ucp_test_param& test_param);

/**
 * Instantiate the parameterized test case a combination of transports.
 *
 * @param _test_case   Test case class, derived from ucp_test.
 * @param _name        Instantiation name.
 * @param ...          Transport names.
 */
#define UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, _name, _tls) \
    INSTANTIATE_TEST_CASE_P(_name,  _test_case, \
                            testing::ValuesIn(_test_case::enum_test_params(_test_case::get_ctx_params(), \
                                                                           #_name, \
                                                                           #_test_case, \
                                                                           _tls)));


/**
 * Instantiate the parameterized test case for all transport combinations.
 *
 * @param _test_case  Test case class, derived from ucp_test.
 */
#define UCP_INSTANTIATE_TEST_CASE(_test_case) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, dcx,    "dc_x") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ud,     "ud_v") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, udx,    "ud_x") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rc,     "rc_v") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rcx,    "rc_x") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib, "shm,ib") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ugni,   "ugni") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, self,   "self") \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, tcp,    "tcp")


/**
 * The list of GPU copy TLs
 */
#define UCP_TEST_GPU_COPY_TLS "cuda_copy,rocm_copy"


/**
 * Instantiate the parameterized test case for all transport combinations
 * with GPU memory awareness
 *
 * @param _test_case  Test case class, derived from ucp_test.
 */
#define UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(_test_case) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, dcx,        "dc_x," UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ud,         "ud_v," UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, udx,        "ud_x," UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rc,         "rc_v," UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rcx,        "rc_x," UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib,     "shm,ib," UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib_ipc, "shm,ib,cuda_ipc,rocm_ipc," \
                                                          UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ugni,       "ugni," UCP_TEST_GPU_COPY_TLS) \
    UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, tcp,        "tcp," UCP_TEST_GPU_COPY_TLS)

#endif