/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
/*
* (C) 2017 by Argonne National Laboratory.
* See COPYRIGHT in top-level directory.
*/
#include "mpiimpl.h"
/* This is the machine-independent implementation of igather. The algorithm is:
Algorithm: MPI_Igather
We use a binomial tree algorithm for both short and long
messages. At nodes other than leaf nodes we need to allocate a
temporary buffer to store the incoming message. If the root is not
rank 0, for very small messages, we pack it into a temporary
contiguous buffer and reorder it to be placed in the right
order. For small (but not very small) messages, we use a derived
datatype to unpack the incoming data into non-contiguous buffers in
the right order.
Cost = lgp.alpha + n.((p-1)/p).beta
where n is the total size of the data gathered at the root.
Possible improvements:
End Algorithm: MPI_Gather
*/
#undef FUNCNAME
#define FUNCNAME MPIR_Igather_sched_intra_binomial
#undef FCNAME
#define FCNAME MPL_QUOTE(FUNCNAME)
int MPIR_Igather_sched_intra_binomial(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
MPIR_Comm * comm_ptr, MPIR_Sched_t s)
{
int mpi_errno = MPI_SUCCESS;
int comm_size, rank;
int relative_rank;
int mask, src, dst, relative_src;
MPI_Aint recvtype_size, sendtype_size, curr_cnt = 0, nbytes;
int recvblks;
int tmp_buf_size, missing;
void *tmp_buf = NULL;
int blocks[2];
int displs[2];
MPI_Aint struct_displs[2];
MPI_Aint extent = 0;
int copy_offset = 0, copy_blks = 0;
MPI_Datatype types[2], tmp_type;
MPIR_SCHED_CHKPMEM_DECL(1);
comm_size = comm_ptr->local_size;
rank = comm_ptr->rank;
if (((rank == root) && (recvcount == 0)) || ((rank != root) && (sendcount == 0)))
goto fn_exit;
MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM);
/* Use binomial tree algorithm. */
relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
if (rank == root) {
MPIR_Datatype_get_extent_macro(recvtype, extent);
MPIR_Ensure_Aint_fits_in_pointer(MPIR_VOID_PTR_CAST_TO_MPI_AINT recvbuf +
(extent * recvcount * comm_size));
}
if (rank == root) {
MPIR_Datatype_get_size_macro(recvtype, recvtype_size);
nbytes = recvtype_size * recvcount;
} else {
MPIR_Datatype_get_size_macro(sendtype, sendtype_size);
nbytes = sendtype_size * sendcount;
}
/* Find the number of missing nodes in my sub-tree compared to
* a balanced tree */
for (mask = 1; mask < comm_size; mask <<= 1);
--mask;
while (relative_rank & mask)
mask >>= 1;
missing = (relative_rank | mask) - comm_size + 1;
if (missing < 0)
missing = 0;
tmp_buf_size = (mask - missing);
/* If the message is smaller than the threshold, we will copy
* our message in there too */
if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)
tmp_buf_size++;
tmp_buf_size *= nbytes;
/* For zero-ranked root, we don't need any temporary buffer */
if ((rank == root) && (!root || (nbytes >= MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)))
tmp_buf_size = 0;
if (tmp_buf_size) {
MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, tmp_buf_size, mpi_errno, "tmp_buf",
MPL_MEM_BUFFER);
}
if (rank == root) {
if (sendbuf != MPI_IN_PLACE) {
mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype,
((char *) recvbuf + extent * recvcount * rank),
recvcount, recvtype);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
}
} else if (tmp_buf_size && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)) {
/* copy from sendbuf into tmp_buf */
mpi_errno = MPIR_Localcopy(sendbuf, sendcount, sendtype, tmp_buf, nbytes, MPI_BYTE);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
}
curr_cnt = nbytes;
mask = 0x1;
while (mask < comm_size) {
if ((mask & relative_rank) == 0) {
src = relative_rank | mask;
if (src < comm_size) {
src = (src + root) % comm_size;
if (rank == root) {
recvblks = mask;
if ((2 * recvblks) > comm_size)
recvblks = comm_size - recvblks;
if ((rank + mask + recvblks == comm_size) ||
(((rank + mask) % comm_size) < ((rank + mask + recvblks) % comm_size))) {
/* If the data contiguously fits into the
* receive buffer, place it directly. This
* should cover the case where the root is
* rank 0. */
char *rp =
(char *) recvbuf + (((rank + mask) % comm_size) * recvcount * extent);
mpi_errno =
MPIR_Sched_recv(rp, (recvblks * recvcount), recvtype, src, comm_ptr, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_barrier(s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
} else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) {
mpi_errno =
MPIR_Sched_recv(tmp_buf, (recvblks * nbytes), MPI_BYTE, src,
comm_ptr, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_barrier(s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
copy_offset = rank + mask;
copy_blks = recvblks;
} else {
blocks[0] = recvcount * (comm_size - root - mask);
displs[0] = recvcount * (root + mask);
blocks[1] = (recvcount * recvblks) - blocks[0];
displs[1] = 0;
mpi_errno = MPIR_Type_indexed_impl(2, blocks, displs, recvtype, &tmp_type);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Type_commit_impl(&tmp_type);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_recv(recvbuf, 1, tmp_type, src, comm_ptr, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_barrier(s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
/* this "premature" free is safe b/c the sched holds an actual ref to keep it alive */
MPIR_Type_free_impl(&tmp_type);
}
} else { /* Intermediate nodes store in temporary buffer */
MPI_Aint offset;
/* Estimate the amount of data that is going to come in */
recvblks = mask;
relative_src = ((src - root) < 0) ? (src - root + comm_size) : (src - root);
if (relative_src + mask > comm_size)
recvblks -= (relative_src + mask - comm_size);
if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)
offset = mask * nbytes;
else
offset = (mask - 1) * nbytes;
mpi_errno =
MPIR_Sched_recv(((char *) tmp_buf + offset), (recvblks * nbytes),
MPI_BYTE, src, comm_ptr, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_barrier(s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
curr_cnt += (recvblks * nbytes);
}
}
} else {
dst = relative_rank ^ mask;
dst = (dst + root) % comm_size;
if (!tmp_buf_size) {
/* leaf nodes send directly from sendbuf */
mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dst, comm_ptr, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_barrier(s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
} else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) {
mpi_errno = MPIR_Sched_send(tmp_buf, curr_cnt, MPI_BYTE, dst, comm_ptr, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_barrier(s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
} else {
blocks[0] = sendcount;
struct_displs[0] = MPIR_VOID_PTR_CAST_TO_MPI_AINT sendbuf;
types[0] = sendtype;
/* check for overflow. work around int limits if needed */
if (curr_cnt - nbytes != (int) (curr_cnt - nbytes)) {
blocks[1] = 1;
MPIR_Type_contiguous_x_impl(curr_cnt - nbytes, MPI_BYTE, &(types[1]));
} else {
MPIR_Assign_trunc(blocks[1], curr_cnt - nbytes, int);
types[1] = MPI_BYTE;
}
struct_displs[1] = MPIR_VOID_PTR_CAST_TO_MPI_AINT tmp_buf;
mpi_errno =
MPIR_Type_create_struct_impl(2, blocks, struct_displs, types, &tmp_type);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Type_commit_impl(&tmp_type);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_send(MPI_BOTTOM, 1, tmp_type, dst, comm_ptr, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
MPIR_SCHED_BARRIER(s);
/* this "premature" free is safe b/c the sched holds an actual ref to keep it alive */
MPIR_Type_free_impl(&tmp_type);
}
break;
}
mask <<= 1;
}
if ((rank == root) && root && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) && copy_blks) {
/* reorder and copy from tmp_buf into recvbuf */
/* FIXME why are there two copies here? */
mpi_errno = MPIR_Sched_copy(tmp_buf, nbytes * (comm_size - copy_offset), MPI_BYTE,
((char *) recvbuf + extent * recvcount * copy_offset),
recvcount * (comm_size - copy_offset), recvtype, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
mpi_errno = MPIR_Sched_copy((char *) tmp_buf + nbytes * (comm_size - copy_offset),
nbytes * (copy_blks - comm_size + copy_offset), MPI_BYTE,
recvbuf, recvcount * (copy_blks - comm_size + copy_offset),
recvtype, s);
if (mpi_errno)
MPIR_ERR_POP(mpi_errno);
}
MPIR_SCHED_CHKPMEM_COMMIT(s);
fn_exit:
return mpi_errno;
fn_fail:
MPIR_SCHED_CHKPMEM_REAP(s);
goto fn_exit;
}