Blob Blame History Raw
/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
/*
 *  (C) 2014 by Argonne National Laboratory.
 *      See COPYRIGHT in top-level directory.
 */

#include "mpiimpl.h"
#include "mpidimpl.h"
#include "hcoll.h"
#include "hcoll/api/hcoll_dte.h"
#include "hcoll_dtypes.h"

static int recv_nb(dte_data_representation_t data,
                   uint32_t count,
                   void *buffer,
                   rte_ec_handle_t, rte_grp_handle_t, uint32_t tag, rte_request_handle_t * req);

static int send_nb(dte_data_representation_t data,
                   uint32_t count,
                   void *buffer,
                   rte_ec_handle_t ec_h,
                   rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req);

static int test(rte_request_handle_t * request, int *completed);

static int ec_handle_compare(rte_ec_handle_t handle_1,
                             rte_grp_handle_t
                             group_handle_1,
                             rte_ec_handle_t handle_2, rte_grp_handle_t group_handle_2);

static int get_ec_handles(int num_ec,
                          int *ec_indexes, rte_grp_handle_t, rte_ec_handle_t * ec_handles);

static int group_size(rte_grp_handle_t group);
static int my_rank(rte_grp_handle_t grp_h);
static int ec_on_local_node(rte_ec_handle_t ec, rte_grp_handle_t group);
static rte_grp_handle_t get_world_group_handle(void);
static uint32_t jobid(void);

static void *get_coll_handle(void);
static int coll_handle_test(void *handle);
static void coll_handle_free(void *handle);
static void coll_handle_complete(void *handle);
static int group_id(rte_grp_handle_t group);

static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec);

#undef FUNCNAME
#define FUNCNAME progress
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static void progress(void)
{
    int ret;
    int made_progress;

    if (0 == world_comm_destroying) {
        MPID_Progress_test();
    } else {
        /* FIXME: The hcoll library needs to be updated to return
         * error codes.  The progress function pointer right now
         * expects that the function returns void. */
        ret = hcoll_do_progress(&made_progress);
        MPIR_Assert(ret == MPI_SUCCESS);
    }
}

#if HCOLL_API >= HCOLL_VERSION(3,6)
static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
                                 int *num_addresses, int *num_datatypes,
                                 hcoll_mpi_type_combiner_t * combiner);
static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
                                 int max_datatypes, int *array_of_integers,
                                 void *array_of_addresses, void *array_of_datatypes);
static int get_hcoll_type(void *mpi_type, dte_data_representation_t * hcoll_type);
static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type);
static int get_mpi_constants(size_t * mpi_datatype_size,
                             int *mpi_order_c, int *mpi_order_fortran,
                             int *mpi_distribute_block,
                             int *mpi_distribute_cyclic,
                             int *mpi_distribute_none, int *mpi_distribute_dflt_darg);
#endif

#undef FUNCNAME
#define FUNCNAME init_module_fns
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static void init_module_fns(void)
{
    hcoll_rte_functions.send_fn = send_nb;
    hcoll_rte_functions.recv_fn = recv_nb;
    hcoll_rte_functions.ec_cmp_fn = ec_handle_compare;
    hcoll_rte_functions.get_ec_handles_fn = get_ec_handles;
    hcoll_rte_functions.rte_group_size_fn = group_size;
    hcoll_rte_functions.test_fn = test;
    hcoll_rte_functions.rte_my_rank_fn = my_rank;
    hcoll_rte_functions.rte_ec_on_local_node_fn = ec_on_local_node;
    hcoll_rte_functions.rte_world_group_fn = get_world_group_handle;
    hcoll_rte_functions.rte_jobid_fn = jobid;
    hcoll_rte_functions.rte_progress_fn = progress;
    hcoll_rte_functions.rte_get_coll_handle_fn = get_coll_handle;
    hcoll_rte_functions.rte_coll_handle_test_fn = coll_handle_test;
    hcoll_rte_functions.rte_coll_handle_free_fn = coll_handle_free;
    hcoll_rte_functions.rte_coll_handle_complete_fn = coll_handle_complete;
    hcoll_rte_functions.rte_group_id_fn = group_id;
    hcoll_rte_functions.rte_world_rank_fn = world_rank;
#if HCOLL_API >= HCOLL_VERSION(3,6)
    hcoll_rte_functions.rte_get_mpi_type_envelope_fn = get_mpi_type_envelope;
    hcoll_rte_functions.rte_get_mpi_type_contents_fn = get_mpi_type_contents;
    hcoll_rte_functions.rte_get_hcoll_type_fn = get_hcoll_type;
    hcoll_rte_functions.rte_set_hcoll_type_fn = set_hcoll_type;
    hcoll_rte_functions.rte_get_mpi_constants_fn = get_mpi_constants;
#endif
}

#undef FUNCNAME
#define FUNCNAME hcoll_rte_fns_setup
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
void hcoll_rte_fns_setup(void)
{
    init_module_fns();
}

#undef FUNCNAME
#define FUNCNAME recv_nb
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int recv_nb(struct dte_data_representation_t data,
                   uint32_t count,
                   void *buffer,
                   rte_ec_handle_t ec_h,
                   rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req)
{
    int mpi_errno;
    MPI_Datatype dtype;
    MPIR_Request *request;
    MPIR_Comm *comm;
    size_t size;
    mpi_errno = MPI_SUCCESS;
    comm = (MPIR_Comm *) grp_h;
    if (!ec_h.handle) {
        MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**hcoll_wrong_arg",
                             "**hcoll_wrong_arg %p %d", ec_h.handle, ec_h.rank);
    }

    MPIR_Assert(HCOL_DTE_IS_INLINE(data));
    if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
        MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**null_buff_ptr");
    }
    size = (size_t) data.rep.in_line_rep.data_handle.in_line.packed_size * count / 8;
    dtype = MPI_CHAR;
    request = NULL;
    mpi_errno = MPIC_Irecv(buffer, size, dtype, ec_h.rank, tag, comm, &request);
    MPIR_Assert(request);
    req->data = (void *) request;
    req->status = HCOLRTE_REQUEST_ACTIVE;
  fn_exit:
    return mpi_errno;
  fn_fail:
    return HCOLL_ERROR;
}

#undef FUNCNAME
#define FUNCNAME send_nb
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int send_nb(dte_data_representation_t data,
                   uint32_t count,
                   void *buffer,
                   rte_ec_handle_t ec_h,
                   rte_grp_handle_t grp_h, uint32_t tag, rte_request_handle_t * req)
{
    int mpi_errno;
    MPI_Datatype dtype;
    MPIR_Request *request;
    MPIR_Comm *comm;
    size_t size;
    mpi_errno = MPI_SUCCESS;
    comm = (MPIR_Comm *) grp_h;
    if (!ec_h.handle) {
        MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**hcoll_wrong_arg",
                             "**hcoll_wrong_arg %p %d", ec_h.handle, ec_h.rank);
    }

    MPIR_Assert(HCOL_DTE_IS_INLINE(data));
    if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
        MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**null_buff_ptr");
    }
    size = (size_t) data.rep.in_line_rep.data_handle.in_line.packed_size * count / 8;
    dtype = MPI_CHAR;
    request = NULL;
    MPIR_Errflag_t err = MPIR_ERR_NONE;
    mpi_errno = MPIC_Isend(buffer, size, dtype, ec_h.rank, tag, comm, &request, &err);
    MPIR_Assert(request);
    req->data = (void *) request;
    req->status = HCOLRTE_REQUEST_ACTIVE;
  fn_exit:
    return mpi_errno;
  fn_fail:
    return HCOLL_ERROR;
}

#undef FUNCNAME
#define FUNCNAME test
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int test(rte_request_handle_t * request, int *completed)
{
    MPIR_Request *req;
    req = (MPIR_Request *) request->data;
    if (HCOLRTE_REQUEST_ACTIVE != request->status) {
        *completed = true;
        return HCOLL_SUCCESS;
    }

    *completed = (int) MPIR_Request_is_complete(req);
    if (*completed) {
        MPIR_Request_free(req);
        request->status = HCOLRTE_REQUEST_DONE;
    }

    return HCOLL_SUCCESS;
}

#undef FUNCNAME
#define FUNCNAME ec_handle_compare
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int ec_handle_compare(rte_ec_handle_t handle_1,
                             rte_grp_handle_t
                             group_handle_1,
                             rte_ec_handle_t handle_2, rte_grp_handle_t group_handle_2)
{
    return handle_1.handle == handle_2.handle;
}

#undef FUNCNAME
#define FUNCNAME get_ec_handles
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int get_ec_handles(int num_ec,
                          int *ec_indexes, rte_grp_handle_t grp_h, rte_ec_handle_t * ec_handles)
{
    int i;
    MPIR_Comm *comm;
    comm = (MPIR_Comm *) grp_h;
    for (i = 0; i < num_ec; i++) {
        ec_handles[i].rank = ec_indexes[i];
#ifdef MPIDCH4_H_INCLUDED
        ec_handles[i].handle = (void *) (MPIDIU_comm_rank_to_av(comm, ec_indexes[i]));
#else
        ec_handles[i].handle = (void *) (comm->dev.vcrt->vcr_table[ec_indexes[i]]);
#endif
    }
    return HCOLL_SUCCESS;
}

#undef FUNCNAME
#define FUNCNAME group_size
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int group_size(rte_grp_handle_t grp_h)
{
    return MPIR_Comm_size((MPIR_Comm *) grp_h);
}

#undef FUNCNAME
#define FUNCNAME my_rank
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int my_rank(rte_grp_handle_t grp_h)
{
    return MPIR_Comm_rank((MPIR_Comm *) grp_h);
}

#undef FUNCNAME
#define FUNCNAME ec_on_local_node
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int ec_on_local_node(rte_ec_handle_t ec, rte_grp_handle_t group)
{
    MPIR_Comm *comm;
    int nodeid, my_nodeid;
    int my_rank;
    comm = (MPIR_Comm *) group;
    MPID_Get_node_id(comm, ec.rank, &nodeid);
    my_rank = MPIR_Comm_rank(comm);
    MPID_Get_node_id(comm, my_rank, &my_nodeid);
    return (nodeid == my_nodeid);
}


#undef FUNCNAME
#define FUNCNAME get_world_group_handle
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static rte_grp_handle_t get_world_group_handle(void)
{
    return (rte_grp_handle_t) (MPIR_Process.comm_world);
}

#undef FUNCNAME
#define FUNCNAME jobid
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static uint32_t jobid(void)
{
    /* not used currently */
    return 0;
}

#undef FUNCNAME
#define FUNCNAME group_id
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int group_id(rte_grp_handle_t group)
{
    MPIR_Comm *comm;
    comm = (MPIR_Comm *) group;
    return comm->context_id;
}

#undef FUNCNAME
#define FUNCNAME get_coll_handle
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static void *get_coll_handle(void)
{
    MPIR_Request *req;
    req = MPIR_Request_create(MPIR_REQUEST_KIND__COLL);
    MPIR_Request_add_ref(req);
    return (void *) req;
}

#undef FUNCNAME
#define FUNCNAME coll_handle_test
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int coll_handle_test(void *handle)
{
    int completed;
    MPIR_Request *req;
    req = (MPIR_Request *) handle;
    completed = (int) MPIR_Request_is_complete(req);
    return completed;
}

#undef FUNCNAME
#define FUNCNAME coll_handle_free
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static void coll_handle_free(void *handle)
{
    MPIR_Request *req;
    if (NULL != handle) {
        req = (MPIR_Request *) handle;
        MPIR_Request_free(req);
    }
}

#undef FUNCNAME
#define FUNCNAME coll_handle_complete
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static void coll_handle_complete(void *handle)
{
    MPIR_Request *req;
    if (NULL != handle) {
        req = (MPIR_Request *) handle;
        MPID_Request_complete(req);
    }
}

#undef FUNCNAME
#define FUNCNAME world_rank
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec)
{
#ifdef MPIDCH4_H_INCLUDED
    return MPIDI_CH4U_rank_to_lpid(ec.rank, (MPIR_Comm *) grp_h);
#else
    return ((struct MPIDI_VC *) ec.handle)->pg_rank;
#endif
}

#if HCOLL_API >= HCOLL_VERSION(3,6)
hcoll_mpi_type_combiner_t mpi_combiner_2_hcoll_combiner(int combiner)
{
    switch (combiner) {
        case MPI_COMBINER_CONTIGUOUS:
            return HCOLL_MPI_COMBINER_CONTIGUOUS;
        case MPI_COMBINER_VECTOR:
            return HCOLL_MPI_COMBINER_VECTOR;
        case MPI_COMBINER_HVECTOR:
            return HCOLL_MPI_COMBINER_HVECTOR;
        case MPI_COMBINER_INDEXED:
            return HCOLL_MPI_COMBINER_INDEXED;
        case MPI_COMBINER_HINDEXED_INTEGER:
        case MPI_COMBINER_HINDEXED:
            return HCOLL_MPI_COMBINER_HINDEXED;
        case MPI_COMBINER_DUP:
            return HCOLL_MPI_COMBINER_DUP;
        case MPI_COMBINER_INDEXED_BLOCK:
            return HCOLL_MPI_COMBINER_INDEXED_BLOCK;
        case MPI_COMBINER_HINDEXED_BLOCK:
            return HCOLL_MPI_COMBINER_HINDEXED_BLOCK;
        case MPI_COMBINER_SUBARRAY:
            return HCOLL_MPI_COMBINER_SUBARRAY;
        case MPI_COMBINER_DARRAY:
            return HCOLL_MPI_COMBINER_DARRAY;
        case MPI_COMBINER_F90_REAL:
            return HCOLL_MPI_COMBINER_F90_REAL;
        case MPI_COMBINER_F90_COMPLEX:
            return HCOLL_MPI_COMBINER_F90_COMPLEX;
        case MPI_COMBINER_F90_INTEGER:
            return HCOLL_MPI_COMBINER_F90_INTEGER;
        case MPI_COMBINER_RESIZED:
            return HCOLL_MPI_COMBINER_RESIZED;
        case MPI_COMBINER_STRUCT:
        case MPI_COMBINER_STRUCT_INTEGER:
            return HCOLL_MPI_COMBINER_STRUCT;
        default:
            break;
    }
    return HCOLL_MPI_COMBINER_LAST;
}

static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
                                 int *num_addresses, int *num_datatypes,
                                 hcoll_mpi_type_combiner_t * combiner)
{
    int mpi_combiner;
    MPI_Datatype dt_handle = (MPI_Datatype) mpi_type;

    MPIR_Type_get_envelope(dt_handle, num_integers, num_addresses, num_datatypes, &mpi_combiner);

    *combiner = mpi_combiner_2_hcoll_combiner(mpi_combiner);

    return HCOLL_SUCCESS;
}

static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
                                 int max_datatypes, int *array_of_integers,
                                 void *array_of_addresses, void *array_of_datatypes)
{
    int ret;
    MPI_Datatype dt_handle = (MPI_Datatype) mpi_type;

    ret = MPIR_Type_get_contents(dt_handle,
                                 max_integers, max_addresses, max_datatypes,
                                 array_of_integers,
                                 (MPI_Aint *) array_of_addresses,
                                 (MPI_Datatype *) array_of_datatypes);

    return ret == MPI_SUCCESS ? HCOLL_SUCCESS : HCOLL_ERROR;
}

static int get_hcoll_type(void *mpi_type, dte_data_representation_t * hcoll_type)
{
    MPI_Datatype dt_handle = (MPI_Datatype) mpi_type;
    MPIR_Datatype *dt_ptr;

    *hcoll_type = mpi_dtype_2_hcoll_dtype(dt_handle, -1, TRY_FIND_DERIVED);

    return HCOL_DTE_IS_ZERO((*hcoll_type)) ? HCOLL_ERR_NOT_FOUND : HCOLL_SUCCESS;
}

static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type)
{
    return HCOLL_SUCCESS;
}

static int get_mpi_constants(size_t * mpi_datatype_size,
                             int *mpi_order_c, int *mpi_order_fortran,
                             int *mpi_distribute_block,
                             int *mpi_distribute_cyclic,
                             int *mpi_distribute_none, int *mpi_distribute_dflt_darg)
{
    *mpi_datatype_size = sizeof(MPI_Datatype);
    *mpi_order_c = MPI_ORDER_C;
    *mpi_order_fortran = MPI_ORDER_FORTRAN;
    *mpi_distribute_block = MPI_DISTRIBUTE_BLOCK;
    *mpi_distribute_cyclic = MPI_DISTRIBUTE_CYCLIC;
    *mpi_distribute_none = MPI_DISTRIBUTE_NONE;
    *mpi_distribute_dflt_darg = MPI_DISTRIBUTE_DFLT_DARG;

    return HCOLL_SUCCESS;
}

#endif