Blob Blame History Raw
/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
/*
 *  (C) 2017 by Argonne National Laboratory.
 *      See COPYRIGHT in top-level directory.
 */

#include "mpiimpl.h"

/* This is the machine-independent implementation of igather. The algorithm is:

   Algorithm: MPI_Igather

   We use a binomial tree algorithm for both short and long
   messages. At nodes other than leaf nodes we need to allocate a
   temporary buffer to store the incoming message. If the root is not
   rank 0, for very small messages, we pack it into a temporary
   contiguous buffer and reorder it to be placed in the right
   order. For small (but not very small) messages, we use a derived
   datatype to unpack the incoming data into non-contiguous buffers in
   the right order.

   Cost = lgp.alpha + n.((p-1)/p).beta
   where n is the total size of the data gathered at the root.

   Possible improvements:

   End Algorithm: MPI_Gather
*/
#undef FUNCNAME
#define FUNCNAME MPIR_Igather_sched_intra_binomial
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
int MPIR_Igather_sched_intra_binomial(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                                      void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
                                      MPIR_Comm * comm_ptr, MPIR_Sched_t s)
{
    int mpi_errno = MPI_SUCCESS;
    int comm_size, rank;
    int relative_rank;
    int mask, src, dst, relative_src;
    MPI_Aint recvtype_size, sendtype_size, curr_cnt = 0, nbytes;
    int recvblks;
    int tmp_buf_size, missing;
    void *tmp_buf = NULL;
    int blocks[2];
    int displs[2];
    MPI_Aint struct_displs[2];
    MPI_Aint extent = 0;
    int copy_offset = 0, copy_blks = 0;
    MPI_Datatype types[2], tmp_type;
    MPIR_SCHED_CHKPMEM_DECL(1);

    comm_size = comm_ptr->local_size;
    rank = comm_ptr->rank;

    if (((rank == root) && (recvcount == 0)) || ((rank != root) && (sendcount == 0)))
        goto fn_exit;

    MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM);

    /* Use binomial tree algorithm. */

    relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;

    if (rank == root) {
        MPIR_Datatype_get_extent_macro(recvtype, extent);
        MPIR_Ensure_Aint_fits_in_pointer(MPIR_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
                                         (extent * recvcount * comm_size));
    }

    if (rank == root) {
        MPIR_Datatype_get_size_macro(recvtype, recvtype_size);
        nbytes = recvtype_size * recvcount;
    } else {
        MPIR_Datatype_get_size_macro(sendtype, sendtype_size);
        nbytes = sendtype_size * sendcount;
    }

    /* Find the number of missing nodes in my sub-tree compared to
     * a balanced tree */
    for (mask = 1; mask < comm_size; mask <<= 1);
    --mask;
    while (relative_rank & mask)
        mask >>= 1;
    missing = (relative_rank | mask) - comm_size + 1;
    if (missing < 0)
        missing = 0;
    tmp_buf_size = (mask - missing);

    /* If the message is smaller than the threshold, we will copy
     * our message in there too */
    if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)
        tmp_buf_size++;

    tmp_buf_size *= nbytes;

    /* For zero-ranked root, we don't need any temporary buffer */
    if ((rank == root) && (!root || (nbytes >= MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)))
        tmp_buf_size = 0;

    if (tmp_buf_size) {
        MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf",
                                  MPL_MEM_BUFFER);
    }

    if (rank == root) {
        if (sendbuf != MPI_IN_PLACE) {
            mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype,
                                       ((char *) recvbuf + extent * recvcount * rank),
                                       recvcount, recvtype);
            if (mpi_errno)
                MPIR_ERR_POP(mpi_errno);
        }
    } else if (tmp_buf_size && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)) {
        /* copy from sendbuf into tmp_buf */
        mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype, tmp_buf, nbytes, MPI_BYTE);
        if (mpi_errno)
            MPIR_ERR_POP(mpi_errno);
    }
    curr_cnt = nbytes;

    mask = 0x1;
    while (mask < comm_size) {
        if ((mask & relative_rank) == 0) {
            src = relative_rank | mask;
            if (src < comm_size) {
                src = (src + root) % comm_size;

                if (rank == root) {
                    recvblks = mask;
                    if ((2 * recvblks) > comm_size)
                        recvblks = comm_size - recvblks;

                    if ((rank + mask + recvblks == comm_size) ||
                        (((rank + mask) % comm_size) < ((rank + mask + recvblks) % comm_size))) {
                        /* If the data contiguously fits into the
                         * receive buffer, place it directly. This
                         * should cover the case where the root is
                         * rank 0. */
                        char *rp =
                            (char *) recvbuf + (((rank + mask) % comm_size) * recvcount * extent);
                        mpi_errno =
                            MPIR_Sched_recv(rp, (recvblks * recvcount), recvtype, src, comm_ptr, s);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);
                        mpi_errno = MPIR_Sched_barrier(s);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);
                    } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) {
                        mpi_errno =
                            MPIR_Sched_recv(tmp_buf, (recvblks * nbytes), MPI_BYTE, src,
                                            comm_ptr, s);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);
                        mpi_errno = MPIR_Sched_barrier(s);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);
                        copy_offset = rank + mask;
                        copy_blks = recvblks;
                    } else {
                        blocks[0] = recvcount * (comm_size - root - mask);
                        displs[0] = recvcount * (root + mask);
                        blocks[1] = (recvcount * recvblks) - blocks[0];
                        displs[1] = 0;

                        mpi_errno = MPIR_Type_indexed_impl(2, blocks, displs, recvtype, &tmp_type);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);
                        mpi_errno = MPIR_Type_commit_impl(&tmp_type);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);

                        mpi_errno = MPIR_Sched_recv(recvbuf, 1, tmp_type, src, comm_ptr, s);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);
                        mpi_errno = MPIR_Sched_barrier(s);
                        if (mpi_errno)
                            MPIR_ERR_POP(mpi_errno);

                        /* this "premature" free is safe b/c the sched holds an actual ref to keep it alive */
                        MPIR_Type_free_impl(&tmp_type);
                    }
                } else {        /* Intermediate nodes store in temporary buffer */
                    MPI_Aint offset;

                    /* Estimate the amount of data that is going to come in */
                    recvblks = mask;
                    relative_src = ((src - root) < 0) ? (src - root + comm_size) : (src - root);
                    if (relative_src + mask > comm_size)
                        recvblks -= (relative_src + mask - comm_size);

                    if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)
                        offset = mask * nbytes;
                    else
                        offset = (mask - 1) * nbytes;
                    mpi_errno =
                        MPIR_Sched_recv(((char *) tmp_buf + offset), (recvblks * nbytes),
                                        MPI_BYTE, src, comm_ptr, s);
                    if (mpi_errno)
                        MPIR_ERR_POP(mpi_errno);
                    mpi_errno = MPIR_Sched_barrier(s);
                    if (mpi_errno)
                        MPIR_ERR_POP(mpi_errno);
                    curr_cnt += (recvblks * nbytes);
                }
            }
        } else {
            dst = relative_rank ^ mask;
            dst = (dst + root) % comm_size;

            if (!tmp_buf_size) {
                /* leaf nodes send directly from sendbuf */
                mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dst, comm_ptr, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                mpi_errno = MPIR_Sched_barrier(s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
            } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) {
                mpi_errno = MPIR_Sched_send(tmp_buf, curr_cnt, MPI_BYTE, dst, comm_ptr, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                mpi_errno = MPIR_Sched_barrier(s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
            } else {
                blocks[0] = sendcount;
                struct_displs[0] = MPIR_VOID_PTR_CAST_TO_MPI_AINT sendbuf;
                types[0] = sendtype;
                /* check for overflow.  work around int limits if needed */
                if (curr_cnt - nbytes != (int) (curr_cnt - nbytes)) {
                    blocks[1] = 1;
                    MPIR_Type_contiguous_x_impl(curr_cnt - nbytes, MPI_BYTE, &(types[1]));
                } else {
                    MPIR_Assign_trunc(blocks[1], curr_cnt - nbytes, int);
                    types[1] = MPI_BYTE;
                }
                struct_displs[1] = MPIR_VOID_PTR_CAST_TO_MPI_AINT tmp_buf;

                mpi_errno =
                    MPIR_Type_create_struct_impl(2, blocks, struct_displs, types, &tmp_type);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                mpi_errno = MPIR_Type_commit_impl(&tmp_type);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);

                mpi_errno = MPIR_Sched_send(MPI_BOTTOM, 1, tmp_type, dst, comm_ptr, s);
                if (mpi_errno)
                    MPIR_ERR_POP(mpi_errno);
                MPIR_SCHED_BARRIER(s);

                /* this "premature" free is safe b/c the sched holds an actual ref to keep it alive */
                MPIR_Type_free_impl(&tmp_type);
            }

            break;
        }
        mask <<= 1;
    }

    if ((rank == root) && root && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) && copy_blks) {
        /* reorder and copy from tmp_buf into recvbuf */
        /* FIXME why are there two copies here? */
        mpi_errno = MPIR_Sched_copy(tmp_buf, nbytes * (comm_size - copy_offset), MPI_BYTE,
                                    ((char *) recvbuf + extent * recvcount * copy_offset),
                                    recvcount * (comm_size - copy_offset), recvtype, s);
        if (mpi_errno)
            MPIR_ERR_POP(mpi_errno);
        mpi_errno = MPIR_Sched_copy((char *) tmp_buf + nbytes * (comm_size - copy_offset),
                                    nbytes * (copy_blks - comm_size + copy_offset), MPI_BYTE,
                                    recvbuf, recvcount * (copy_blks - comm_size + copy_offset),
                                    recvtype, s);
        if (mpi_errno)
            MPIR_ERR_POP(mpi_errno);
    }

    MPIR_SCHED_CHKPMEM_COMMIT(s);
  fn_exit:
    return mpi_errno;
  fn_fail:
    MPIR_SCHED_CHKPMEM_REAP(s);
    goto fn_exit;
}