Blob Blame History Raw
/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
/*
 *  (C) 2010 by Argonne National Laboratory.
 *      See COPYRIGHT in top-level directory.
 */

#include "mpiimpl.h"

/* helper callbacks and associated state structures */
struct shared_state {
    int sendcount;
    MPI_Aint curr_count;
    MPI_Aint send_subtree_count;
    int nbytes;
    MPI_Status status;
};
static int get_count(MPIR_Comm * comm, int tag, void *state)
{
    struct shared_state *ss = state;
    MPIR_Get_count_impl(&ss->status, MPI_BYTE, &ss->curr_count);
    return MPI_SUCCESS;
}

static int calc_send_count_root(MPIR_Comm * comm, int tag, void *state, void *state2)
{
    struct shared_state *ss = state;
    int mask = (int) (size_t) state2;
    ss->send_subtree_count = ss->curr_count - ss->sendcount * mask;
    return MPI_SUCCESS;
}

static int calc_send_count_non_root(MPIR_Comm * comm, int tag, void *state, void *state2)
{
    struct shared_state *ss = state;
    int mask = (int) (size_t) state2;
    ss->send_subtree_count = ss->curr_count - ss->nbytes * mask;
    return MPI_SUCCESS;
}

static int calc_curr_count(MPIR_Comm * comm, int tag, void *state)
{
    struct shared_state *ss = state;
    ss->curr_count -= ss->send_subtree_count;
    return MPI_SUCCESS;
}

/* any non-MPI functions go here, especially non-static ones */

/* This is the machine-independent implementation of scatter. The algorithm is:

   Algorithm: MPI_Scatter

   We use a binomial tree algorithm for both short and
   long messages. At nodes other than leaf nodes we need to allocate
   a temporary buffer to store the incoming message. If the root is
   not rank 0, we reorder the sendbuf in order of relative ranks by
   copying it into a temporary buffer, so that all the sends from the
   root are contiguous and in the right order.

   Cost = lgp.alpha + n.((p-1)/p).beta
   where n is the total size of the data to be scattered from the root.

   Possible improvements:

   End Algorithm: MPI_Scatter
*/
#undef FUNCNAME
#define FUNCNAME MPIR_Iscatter_sched_intra_binomial
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
int MPIR_Iscatter_sched_intra_binomial(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                                       void *recvbuf, int recvcount, MPI_Datatype recvtype,
                                       int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s)
{
    int mpi_errno = MPI_SUCCESS;
    MPI_Aint extent = 0;
    int rank, comm_size, sendtype_size;
    int relative_rank;
    int mask, recvtype_size = 0, src, dst;
    int tmp_buf_size = 0;
    void *tmp_buf = NULL;
    struct shared_state *ss = NULL;
    MPIR_SCHED_CHKPMEM_DECL(4);

    comm_size = comm_ptr->local_size;
    rank = comm_ptr->rank;

    if (((rank == root) && (sendcount == 0)) || ((rank != root) && (recvcount == 0)))
        goto fn_exit;

    MPIR_SCHED_CHKPMEM_MALLOC(ss, struct shared_state *, sizeof(struct shared_state), mpi_errno,
                              "shared_state", MPL_MEM_BUFFER);
    ss->sendcount = sendcount;

    if (rank == root)
        MPIR_Datatype_get_extent_macro(sendtype, extent);

    relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;

    if (rank == root) {
        /* We separate the two cases (root and non-root) because
         * in the event of recvbuf=MPI_IN_PLACE on the root,
         * recvcount and recvtype are not valid */
        MPIR_Datatype_get_size_macro(sendtype, sendtype_size);
        MPIR_Ensure_Aint_fits_in_pointer(MPIR_VOID_PTR_CAST_TO_MPI_AINT sendbuf +
                                         extent * sendcount * comm_size);

        ss->nbytes = sendtype_size * sendcount;
    } else {
        MPIR_Datatype_get_size_macro(recvtype, recvtype_size);
        MPIR_Ensure_Aint_fits_in_pointer(extent * recvcount * comm_size);
        ss->nbytes = recvtype_size * recvcount;
    }

    ss->curr_count = 0;

    /* all even nodes other than root need a temporary buffer to
     * receive data of max size (ss->nbytes*comm_size)/2 */
    if (relative_rank && !(relative_rank % 2)) {
        tmp_buf_size = (ss->nbytes * comm_size) / 2;
        MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf",
                                  MPL_MEM_BUFFER);
    }

    /* if the root is not rank 0, we reorder the sendbuf in order of
     * relative ranks and copy it into a temporary buffer, so that
     * all the sends from the root are contiguous and in the right
     * order. */
    if (rank == root) {
        if (root != 0) {
            tmp_buf_size = ss->nbytes * comm_size;
            MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf",
                                      MPL_MEM_BUFFER);

            if (recvbuf != MPI_IN_PLACE)
                mpi_errno = MPIR_Sched_copy(((char *) sendbuf + extent * sendcount * rank),
                                            sendcount * (comm_size - rank), sendtype,
                                            tmp_buf, ss->nbytes * (comm_size - rank), MPI_BYTE, s);
            else
                mpi_errno =
                    MPIR_Sched_copy(((char *) sendbuf + extent * sendcount * (rank + 1)),
                                    sendcount * (comm_size - rank - 1), sendtype,
                                    ((char *) tmp_buf + ss->nbytes),
                                    ss->nbytes * (comm_size - rank - 1), MPI_BYTE, s);
            if (mpi_errno)
                MPIR_ERR_POP(mpi_errno);

            mpi_errno = MPIR_Sched_copy(sendbuf, sendcount * rank, sendtype,
                                        ((char *) tmp_buf + ss->nbytes * (comm_size - rank)),
                                        ss->nbytes * rank, MPI_BYTE, s);
            if (mpi_errno)
                MPIR_ERR_POP(mpi_errno);

            MPIR_SCHED_BARRIER(s);
            ss->curr_count = ss->nbytes * comm_size;
        } else
            ss->curr_count = sendcount * comm_size;
    }

    /* root has all the data; others have zero so far */

    mask = 0x1;
    while (mask < comm_size) {
        if (relative_rank & mask) {
            src = rank - mask;
            if (src < 0)
                src += comm_size;

            /* The leaf nodes receive directly into recvbuf because
             * they don't have to forward data to anyone. Others
             * receive data into a temporary buffer. */
            if (relative_rank % 2) {
                mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, src, comm_ptr, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);
            } else {

                /* the recv size is larger than what may be sent in
                 * some cases. query amount of data actually received */
                mpi_errno =
                    MPIR_Sched_recv_status(tmp_buf, tmp_buf_size, MPI_BYTE, src, comm_ptr,
                                           &ss->status, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);
                mpi_errno = MPIR_Sched_cb(&get_count, ss, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);
            }
            break;
        }
        mask <<= 1;
    }

    /* This process is responsible for all processes that have bits
     * set from the LSB upto (but not including) mask.  Because of
     * the "not including", we start by shifting mask back down
     * one. */

    mask >>= 1;
    while (mask > 0) {
        if (relative_rank + mask < comm_size) {
            dst = rank + mask;
            if (dst >= comm_size)
                dst -= comm_size;

            if ((rank == root) && (root == 0)) {
#if 0
                /* FIXME how can this be right? shouldn't (sendcount*mask)
                 * be the amount sent and curr_cnt be reduced by that?  Or
                 * is it always true the (curr_cnt/2==sendcount*mask)? */
                send_subtree_cnt = curr_cnt - sendcount * mask;
#endif
                mpi_errno = MPIR_Sched_cb2(&calc_send_count_root, ss, ((void *) (size_t) mask), s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);

                /* mask is also the size of this process's subtree */
                mpi_errno =
                    MPIR_Sched_send_defer(((char *) sendbuf + extent * sendcount * mask),
                                          &ss->send_subtree_count, sendtype, dst, comm_ptr, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);
            } else {
                /* non-zero root and others */
                mpi_errno =
                    MPIR_Sched_cb2(&calc_send_count_non_root, ss, ((void *) (size_t) mask), s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);

                /* mask is also the size of this process's subtree */
                mpi_errno = MPIR_Sched_send_defer(((char *) tmp_buf + ss->nbytes * mask),
                                                  &ss->send_subtree_count, MPI_BYTE, dst,
                                                  comm_ptr, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);
            }
            mpi_errno = MPIR_Sched_cb(&calc_curr_count, ss, s);
            if (mpi_errno)
                MPIR_ERR_POP(mpi_errno);
            MPIR_SCHED_BARRIER(s);
        }
        mask >>= 1;
    }

    if ((rank == root) && (root == 0) && (recvbuf != MPI_IN_PLACE)) {
        /* for root=0, put root's data in recvbuf if not MPI_IN_PLACE */
        mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, s);
        if (mpi_errno)
            MPIR_ERR_POP(mpi_errno);
        MPIR_SCHED_BARRIER(s);
    } else if (!(relative_rank % 2) && (recvbuf != MPI_IN_PLACE)) {
        /* for non-zero root and non-leaf nodes, copy from tmp_buf
         * into recvbuf */
        mpi_errno = MPIR_Sched_copy(tmp_buf, ss->nbytes, MPI_BYTE, recvbuf, recvcount, recvtype, s);
        if (mpi_errno)
            MPIR_ERR_POP(mpi_errno);
        MPIR_SCHED_BARRIER(s);
    }

    MPIR_SCHED_CHKPMEM_COMMIT(s);
  fn_exit:
    return mpi_errno;
  fn_fail:
    MPIR_SCHED_CHKPMEM_REAP(s);
    goto fn_exit;
}