Blame src/mpi/coll/igather/igather_intra_binomial.c

Packit Service c5cf8c
/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
Packit Service c5cf8c
/*
Packit Service c5cf8c
 *  (C) 2017 by Argonne National Laboratory.
Packit Service c5cf8c
 *      See COPYRIGHT in top-level directory.
Packit Service c5cf8c
 */
Packit Service c5cf8c
Packit Service c5cf8c
#include "mpiimpl.h"
Packit Service c5cf8c
Packit Service c5cf8c
/* This is the machine-independent implementation of igather. The algorithm is:
Packit Service c5cf8c
Packit Service c5cf8c
   Algorithm: MPI_Igather
Packit Service c5cf8c
Packit Service c5cf8c
   We use a binomial tree algorithm for both short and long
Packit Service c5cf8c
   messages. At nodes other than leaf nodes we need to allocate a
Packit Service c5cf8c
   temporary buffer to store the incoming message. If the root is not
Packit Service c5cf8c
   rank 0, for very small messages, we pack it into a temporary
Packit Service c5cf8c
   contiguous buffer and reorder it to be placed in the right
Packit Service c5cf8c
   order. For small (but not very small) messages, we use a derived
Packit Service c5cf8c
   datatype to unpack the incoming data into non-contiguous buffers in
Packit Service c5cf8c
   the right order.
Packit Service c5cf8c
Packit Service c5cf8c
   Cost = lgp.alpha + n.((p-1)/p).beta
Packit Service c5cf8c
   where n is the total size of the data gathered at the root.
Packit Service c5cf8c
Packit Service c5cf8c
   Possible improvements:
Packit Service c5cf8c
Packit Service c5cf8c
   End Algorithm: MPI_Gather
Packit Service c5cf8c
*/
Packit Service c5cf8c
#undef FUNCNAME
Packit Service c5cf8c
#define FUNCNAME MPIR_Igather_sched_intra_binomial
Packit Service c5cf8c
#undef FCNAME
Packit Service c5cf8c
#define FCNAME MPL_QUOTE(FUNCNAME)
Packit Service c5cf8c
int MPIR_Igather_sched_intra_binomial(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
Packit Service c5cf8c
                                      void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
Packit Service c5cf8c
                                      MPIR_Comm * comm_ptr, MPIR_Sched_t s)
Packit Service c5cf8c
{
Packit Service c5cf8c
    int mpi_errno = MPI_SUCCESS;
Packit Service c5cf8c
    int comm_size, rank;
Packit Service c5cf8c
    int relative_rank;
Packit Service c5cf8c
    int mask, src, dst, relative_src;
Packit Service c5cf8c
    MPI_Aint recvtype_size, sendtype_size, curr_cnt = 0, nbytes;
Packit Service c5cf8c
    int recvblks;
Packit Service c5cf8c
    int tmp_buf_size, missing;
Packit Service c5cf8c
    void *tmp_buf = NULL;
Packit Service c5cf8c
    int blocks[2];
Packit Service c5cf8c
    int displs[2];
Packit Service c5cf8c
    MPI_Aint struct_displs[2];
Packit Service c5cf8c
    MPI_Aint extent = 0;
Packit Service c5cf8c
    int copy_offset = 0, copy_blks = 0;
Packit Service c5cf8c
    MPI_Datatype types[2], tmp_type;
Packit Service c5cf8c
    MPIR_SCHED_CHKPMEM_DECL(1);
Packit Service c5cf8c
Packit Service c5cf8c
    comm_size = comm_ptr->local_size;
Packit Service c5cf8c
    rank = comm_ptr->rank;
Packit Service c5cf8c
Packit Service c5cf8c
    if (((rank == root) && (recvcount == 0)) || ((rank != root) && (sendcount == 0)))
Packit Service c5cf8c
        goto fn_exit;
Packit Service c5cf8c
Packit Service c5cf8c
    MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM);
Packit Service c5cf8c
Packit Service c5cf8c
    /* Use binomial tree algorithm. */
Packit Service c5cf8c
Packit Service c5cf8c
    relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
Packit Service c5cf8c
Packit Service c5cf8c
    if (rank == root) {
Packit Service c5cf8c
        MPIR_Datatype_get_extent_macro(recvtype, extent);
Packit Service c5cf8c
        MPIR_Ensure_Aint_fits_in_pointer(MPIR_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
Packit Service c5cf8c
                                         (extent * recvcount * comm_size));
Packit Service c5cf8c
    }
Packit Service c5cf8c
Packit Service c5cf8c
    if (rank == root) {
Packit Service c5cf8c
        MPIR_Datatype_get_size_macro(recvtype, recvtype_size);
Packit Service c5cf8c
        nbytes = recvtype_size * recvcount;
Packit Service c5cf8c
    } else {
Packit Service c5cf8c
        MPIR_Datatype_get_size_macro(sendtype, sendtype_size);
Packit Service c5cf8c
        nbytes = sendtype_size * sendcount;
Packit Service c5cf8c
    }
Packit Service c5cf8c
Packit Service c5cf8c
    /* Find the number of missing nodes in my sub-tree compared to
Packit Service c5cf8c
     * a balanced tree */
Packit Service c5cf8c
    for (mask = 1; mask < comm_size; mask <<= 1);
Packit Service c5cf8c
    --mask;
Packit Service c5cf8c
    while (relative_rank & mask)
Packit Service c5cf8c
        mask >>= 1;
Packit Service c5cf8c
    missing = (relative_rank | mask) - comm_size + 1;
Packit Service c5cf8c
    if (missing < 0)
Packit Service c5cf8c
        missing = 0;
Packit Service c5cf8c
    tmp_buf_size = (mask - missing);
Packit Service c5cf8c
Packit Service c5cf8c
    /* If the message is smaller than the threshold, we will copy
Packit Service c5cf8c
     * our message in there too */
Packit Service c5cf8c
    if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)
Packit Service c5cf8c
        tmp_buf_size++;
Packit Service c5cf8c
Packit Service c5cf8c
    tmp_buf_size *= nbytes;
Packit Service c5cf8c
Packit Service c5cf8c
    /* For zero-ranked root, we don't need any temporary buffer */
Packit Service c5cf8c
    if ((rank == root) && (!root || (nbytes >= MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)))
Packit Service c5cf8c
        tmp_buf_size = 0;
Packit Service c5cf8c
Packit Service c5cf8c
    if (tmp_buf_size) {
Packit Service c5cf8c
        MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf",
Packit Service c5cf8c
                                  MPL_MEM_BUFFER);
Packit Service c5cf8c
    }
Packit Service c5cf8c
Packit Service c5cf8c
    if (rank == root) {
Packit Service c5cf8c
        if (sendbuf != MPI_IN_PLACE) {
Packit Service c5cf8c
            mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype,
Packit Service c5cf8c
                                       ((char *) recvbuf + extent * recvcount * rank),
Packit Service c5cf8c
                                       recvcount, recvtype);
Packit Service c5cf8c
            if (mpi_errno)
Packit Service c5cf8c
                MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
        }
Packit Service c5cf8c
    } else if (tmp_buf_size && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)) {
Packit Service c5cf8c
        /* copy from sendbuf into tmp_buf */
Packit Service c5cf8c
        mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype, tmp_buf, nbytes, MPI_BYTE);
Packit Service c5cf8c
        if (mpi_errno)
Packit Service c5cf8c
            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
    }
Packit Service c5cf8c
    curr_cnt = nbytes;
Packit Service c5cf8c
Packit Service c5cf8c
    mask = 0x1;
Packit Service c5cf8c
    while (mask < comm_size) {
Packit Service c5cf8c
        if ((mask & relative_rank) == 0) {
Packit Service c5cf8c
            src = relative_rank | mask;
Packit Service c5cf8c
            if (src < comm_size) {
Packit Service c5cf8c
                src = (src + root) % comm_size;
Packit Service c5cf8c
Packit Service c5cf8c
                if (rank == root) {
Packit Service c5cf8c
                    recvblks = mask;
Packit Service c5cf8c
                    if ((2 * recvblks) > comm_size)
Packit Service c5cf8c
                        recvblks = comm_size - recvblks;
Packit Service c5cf8c
Packit Service c5cf8c
                    if ((rank + mask + recvblks == comm_size) ||
Packit Service c5cf8c
                        (((rank + mask) % comm_size) < ((rank + mask + recvblks) % comm_size))) {
Packit Service c5cf8c
                        /* If the data contiguously fits into the
Packit Service c5cf8c
                         * receive buffer, place it directly. This
Packit Service c5cf8c
                         * should cover the case where the root is
Packit Service c5cf8c
                         * rank 0. */
Packit Service c5cf8c
                        char *rp =
Packit Service c5cf8c
                            (char *) recvbuf + (((rank + mask) % comm_size) * recvcount * extent);
Packit Service c5cf8c
                        mpi_errno =
Packit Service c5cf8c
                            MPIR_Sched_recv(rp, (recvblks * recvcount), recvtype, src, comm_ptr, s);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                        mpi_errno = MPIR_Sched_barrier(s);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                    } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) {
Packit Service c5cf8c
                        mpi_errno =
Packit Service c5cf8c
                            MPIR_Sched_recv(tmp_buf, (recvblks * nbytes), MPI_BYTE, src,
Packit Service c5cf8c
                                            comm_ptr, s);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                        mpi_errno = MPIR_Sched_barrier(s);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                        copy_offset = rank + mask;
Packit Service c5cf8c
                        copy_blks = recvblks;
Packit Service c5cf8c
                    } else {
Packit Service c5cf8c
                        blocks[0] = recvcount * (comm_size - root - mask);
Packit Service c5cf8c
                        displs[0] = recvcount * (root + mask);
Packit Service c5cf8c
                        blocks[1] = (recvcount * recvblks) - blocks[0];
Packit Service c5cf8c
                        displs[1] = 0;
Packit Service c5cf8c
Packit Service c5cf8c
                        mpi_errno = MPIR_Type_indexed_impl(2, blocks, displs, recvtype, &tmp_type);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                        mpi_errno = MPIR_Type_commit_impl(&tmp_type);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
Packit Service c5cf8c
                        mpi_errno = MPIR_Sched_recv(recvbuf, 1, tmp_type, src, comm_ptr, s);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                        mpi_errno = MPIR_Sched_barrier(s);
Packit Service c5cf8c
                        if (mpi_errno)
Packit Service c5cf8c
                            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
Packit Service c5cf8c
                        /* this "premature" free is safe b/c the sched holds an actual ref to keep it alive */
Packit Service c5cf8c
                        MPIR_Type_free_impl(&tmp_type);
Packit Service c5cf8c
                    }
Packit Service c5cf8c
                } else {        /* Intermediate nodes store in temporary buffer */
Packit Service c5cf8c
                    MPI_Aint offset;
Packit Service c5cf8c
Packit Service c5cf8c
                    /* Estimate the amount of data that is going to come in */
Packit Service c5cf8c
                    recvblks = mask;
Packit Service c5cf8c
                    relative_src = ((src - root) < 0) ? (src - root + comm_size) : (src - root);
Packit Service c5cf8c
                    if (relative_src + mask > comm_size)
Packit Service c5cf8c
                        recvblks -= (relative_src + mask - comm_size);
Packit Service c5cf8c
Packit Service c5cf8c
                    if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)
Packit Service c5cf8c
                        offset = mask * nbytes;
Packit Service c5cf8c
                    else
Packit Service c5cf8c
                        offset = (mask - 1) * nbytes;
Packit Service c5cf8c
                    mpi_errno =
Packit Service c5cf8c
                        MPIR_Sched_recv(((char *) tmp_buf + offset), (recvblks * nbytes),
Packit Service c5cf8c
                                        MPI_BYTE, src, comm_ptr, s);
Packit Service c5cf8c
                    if (mpi_errno)
Packit Service c5cf8c
                        MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                    mpi_errno = MPIR_Sched_barrier(s);
Packit Service c5cf8c
                    if (mpi_errno)
Packit Service c5cf8c
                        MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                    curr_cnt += (recvblks * nbytes);
Packit Service c5cf8c
                }
Packit Service c5cf8c
            }
Packit Service c5cf8c
        } else {
Packit Service c5cf8c
            dst = relative_rank ^ mask;
Packit Service c5cf8c
            dst = (dst + root) % comm_size;
Packit Service c5cf8c
Packit Service c5cf8c
            if (!tmp_buf_size) {
Packit Service c5cf8c
                /* leaf nodes send directly from sendbuf */
Packit Service c5cf8c
                mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dst, comm_ptr, s);
Packit Service c5cf8c
                if (mpi_errno)
Packit Service c5cf8c
                    MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                mpi_errno = MPIR_Sched_barrier(s);
Packit Service c5cf8c
                if (mpi_errno)
Packit Service c5cf8c
                    MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
            } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) {
Packit Service c5cf8c
                mpi_errno = MPIR_Sched_send(tmp_buf, curr_cnt, MPI_BYTE, dst, comm_ptr, s);
Packit Service c5cf8c
                if (mpi_errno)
Packit Service c5cf8c
                    MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                mpi_errno = MPIR_Sched_barrier(s);
Packit Service c5cf8c
                if (mpi_errno)
Packit Service c5cf8c
                    MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
            } else {
Packit Service c5cf8c
                blocks[0] = sendcount;
Packit Service c5cf8c
                struct_displs[0] = MPIR_VOID_PTR_CAST_TO_MPI_AINT sendbuf;
Packit Service c5cf8c
                types[0] = sendtype;
Packit Service c5cf8c
                /* check for overflow.  work around int limits if needed */
Packit Service c5cf8c
                if (curr_cnt - nbytes != (int) (curr_cnt - nbytes)) {
Packit Service c5cf8c
                    blocks[1] = 1;
Packit Service c5cf8c
                    MPIR_Type_contiguous_x_impl(curr_cnt - nbytes, MPI_BYTE, &(types[1]));
Packit Service c5cf8c
                } else {
Packit Service c5cf8c
                    MPIR_Assign_trunc(blocks[1], curr_cnt - nbytes, int);
Packit Service c5cf8c
                    types[1] = MPI_BYTE;
Packit Service c5cf8c
                }
Packit Service c5cf8c
                struct_displs[1] = MPIR_VOID_PTR_CAST_TO_MPI_AINT tmp_buf;
Packit Service c5cf8c
Packit Service c5cf8c
                mpi_errno =
Packit Service c5cf8c
                    MPIR_Type_create_struct_impl(2, blocks, struct_displs, types, &tmp_type);
Packit Service c5cf8c
                if (mpi_errno)
Packit Service c5cf8c
                    MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                mpi_errno = MPIR_Type_commit_impl(&tmp_type);
Packit Service c5cf8c
                if (mpi_errno)
Packit Service c5cf8c
                    MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
Packit Service c5cf8c
                mpi_errno = MPIR_Sched_send(MPI_BOTTOM, 1, tmp_type, dst, comm_ptr, s);
Packit Service c5cf8c
                if (mpi_errno)
Packit Service c5cf8c
                    MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
                MPIR_SCHED_BARRIER(s);
Packit Service c5cf8c
Packit Service c5cf8c
                /* this "premature" free is safe b/c the sched holds an actual ref to keep it alive */
Packit Service c5cf8c
                MPIR_Type_free_impl(&tmp_type);
Packit Service c5cf8c
            }
Packit Service c5cf8c
Packit Service c5cf8c
            break;
Packit Service c5cf8c
        }
Packit Service c5cf8c
        mask <<= 1;
Packit Service c5cf8c
    }
Packit Service c5cf8c
Packit Service c5cf8c
    if ((rank == root) && root && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) && copy_blks) {
Packit Service c5cf8c
        /* reorder and copy from tmp_buf into recvbuf */
Packit Service c5cf8c
        /* FIXME why are there two copies here? */
Packit Service c5cf8c
        mpi_errno = MPIR_Sched_copy(tmp_buf, nbytes * (comm_size - copy_offset), MPI_BYTE,
Packit Service c5cf8c
                                    ((char *) recvbuf + extent * recvcount * copy_offset),
Packit Service c5cf8c
                                    recvcount * (comm_size - copy_offset), recvtype, s);
Packit Service c5cf8c
        if (mpi_errno)
Packit Service c5cf8c
            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
        mpi_errno = MPIR_Sched_copy((char *) tmp_buf + nbytes * (comm_size - copy_offset),
Packit Service c5cf8c
                                    nbytes * (copy_blks - comm_size + copy_offset), MPI_BYTE,
Packit Service c5cf8c
                                    recvbuf, recvcount * (copy_blks - comm_size + copy_offset),
Packit Service c5cf8c
                                    recvtype, s);
Packit Service c5cf8c
        if (mpi_errno)
Packit Service c5cf8c
            MPIR_ERR_POP(mpi_errno);
Packit Service c5cf8c
    }
Packit Service c5cf8c
Packit Service c5cf8c
    MPIR_SCHED_CHKPMEM_COMMIT(s);
Packit Service c5cf8c
  fn_exit:
Packit Service c5cf8c
    return mpi_errno;
Packit Service c5cf8c
  fn_fail:
Packit Service c5cf8c
    MPIR_SCHED_CHKPMEM_REAP(s);
Packit Service c5cf8c
    goto fn_exit;
Packit Service c5cf8c
}