/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */ /* * * (C) 2001 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */ #include "mpiimpl.h" #include "bcast.h" /* Algorithm: Broadcast based on a scatter followed by an allgather. * * We first scatter the buffer using a binomial tree algorithm. This costs * lgp.alpha + n.((p-1)/p).beta * If the datatype is contiguous, we treat the data as bytes and * divide (scatter) it among processes by using ceiling division. * For the noncontiguous, we first pack the data into a temporary * buffer by using MPI_Pack, scatter it as bytes, and unpack it * after the allgather. * * For the allgather, we use a recursive doubling algorithm for * medium-size messages and power-of-two number of processes. This * takes lgp steps. In each step pairs of processes exchange all the * data they have (we take care of non-power-of-two situations). This * costs approximately lgp.alpha + n.((p-1)/p).beta. (Approximately * because it may be slightly more in the non-power-of-two case, but * it's still a logarithmic algorithm.) Therefore, for long messages * Total Cost = 2.lgp.alpha + 2.n.((p-1)/p).beta */ #undef FUNCNAME #define FUNCNAME MPIR_Bcast_intra_scatter_recursive_doubling_allgather #undef FCNAME #define FCNAME MPL_QUOTE(FUNCNAME) int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t * errflag) { MPI_Status status; int rank, comm_size, dst; int relative_rank, mask; int mpi_errno = MPI_SUCCESS; int mpi_errno_ret = MPI_SUCCESS; int scatter_size; MPI_Aint curr_size, recv_size = 0; int j, k, i, tmp_mask, is_contig; MPI_Aint type_size, nbytes = 0; int relative_dst, dst_tree_root, my_tree_root, send_offset; int recv_offset, tree_root, nprocs_completed, offset; MPI_Aint position; MPIR_CHKLMEM_DECL(1); MPI_Aint true_extent, true_lb; void *tmp_buf; comm_size = comm_ptr->local_size; rank = comm_ptr->rank; relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; /* If there is only one process, return */ if (comm_size == 1) goto fn_exit; #ifdef HAVE_ERROR_CHECKING /* This algorithm can currently handle only power of 2 cases, * non-power of 2 is still experimental */ MPIR_Assert(MPL_is_pof2(comm_size, NULL)); #endif /* HAVE_ERROR_CHECKING */ if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN) is_contig = 1; else { MPIR_Datatype_is_contig(datatype, &is_contig); } MPIR_Datatype_get_size_macro(datatype, type_size); nbytes = type_size * count; if (nbytes == 0) goto fn_exit; /* nothing to do */ if (is_contig) { /* contiguous. no need to pack. */ MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); tmp_buf = (char *) buffer + true_lb; } else { MPIR_CHKLMEM_MALLOC(tmp_buf, void *, nbytes, mpi_errno, "tmp_buf", MPL_MEM_BUFFER); position = 0; if (rank == root) { mpi_errno = MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position); if (mpi_errno) MPIR_ERR_POP(mpi_errno); } } scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, nbytes, tmp_buf, is_contig, errflag); if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } /* curr_size is the amount of data that this process now has stored in * buffer at byte offset (relative_rank*scatter_size) */ curr_size = MPL_MIN(scatter_size, (nbytes - (relative_rank * scatter_size))); if (curr_size < 0) curr_size = 0; /* medium size allgather and pof2 comm_size. use recurive doubling. */ mask = 0x1; i = 0; while (mask < comm_size) { relative_dst = relative_rank ^ mask; dst = (relative_dst + root) % comm_size; /* find offset into send and recv buffers. * zero out the least significant "i" bits of relative_rank and * relative_dst to find root of src and dst * subtrees. Use ranks of roots as index to send from * and recv into buffer */ dst_tree_root = relative_dst >> i; dst_tree_root <<= i; my_tree_root = relative_rank >> i; my_tree_root <<= i; send_offset = my_tree_root * scatter_size; recv_offset = dst_tree_root * scatter_size; if (relative_dst < comm_size) { mpi_errno = MPIC_Sendrecv(((char *) tmp_buf + send_offset), curr_size, MPI_BYTE, dst, MPIR_BCAST_TAG, ((char *) tmp_buf + recv_offset), (nbytes - recv_offset < 0 ? 0 : nbytes - recv_offset), MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, &status, errflag); if (mpi_errno) { /* --BEGIN ERROR HANDLING-- */ /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); recv_size = 0; /* --END ERROR HANDLING-- */ } else MPIR_Get_count_impl(&status, MPI_BYTE, &recv_size); curr_size += recv_size; } /* if some processes in this process's subtree in this step * did not have any destination process to communicate with * because of non-power-of-two, we need to send them the * data that they would normally have received from those * processes. That is, the haves in this subtree must send to * the havenots. We use a logarithmic recursive-halfing algorithm * for this. */ /* This part of the code will not currently be * executed because we are not using recursive * doubling for non power of two. Mark it as experimental * so that it doesn't show up as red in the coverage tests. */ /* --BEGIN EXPERIMENTAL-- */ if (dst_tree_root + mask > comm_size) { nprocs_completed = comm_size - my_tree_root - mask; /* nprocs_completed is the number of processes in this * subtree that have all the data. Send data to others * in a tree fashion. First find root of current tree * that is being divided into two. k is the number of * least-significant bits in this process's rank that * must be zeroed out to find the rank of the root */ j = mask; k = 0; while (j) { j >>= 1; k++; } k--; offset = scatter_size * (my_tree_root + mask); tmp_mask = mask >> 1; while (tmp_mask) { relative_dst = relative_rank ^ tmp_mask; dst = (relative_dst + root) % comm_size; tree_root = relative_rank >> k; tree_root <<= k; /* send only if this proc has data and destination * doesn't have data. */ /* if (rank == 3) { * printf("rank %d, dst %d, root %d, nprocs_completed %d\n", relative_rank, relative_dst, tree_root, nprocs_completed); * fflush(stdout); * } */ if ((relative_dst > relative_rank) && (relative_rank < tree_root + nprocs_completed) && (relative_dst >= tree_root + nprocs_completed)) { /* printf("Rank %d, send to %d, offset %d, size %d\n", rank, dst, offset, recv_size); * fflush(stdout); */ mpi_errno = MPIC_Send(((char *) tmp_buf + offset), recv_size, MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, errflag); /* recv_size was set in the previous * receive. that's the amount of data to be * sent now. */ if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } } /* recv only if this proc. doesn't have data and sender * has data */ else if ((relative_dst < relative_rank) && (relative_dst < tree_root + nprocs_completed) && (relative_rank >= tree_root + nprocs_completed)) { /* printf("Rank %d waiting to recv from rank %d\n", * relative_rank, dst); */ mpi_errno = MPIC_Recv(((char *) tmp_buf + offset), nbytes - offset, MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, &status, errflag); /* nprocs_completed is also equal to the no. of processes * whose data we don't have */ if (mpi_errno) { /* --BEGIN ERROR HANDLING-- */ /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); recv_size = 0; /* --END ERROR HANDLING-- */ } else MPIR_Get_count_impl(&status, MPI_BYTE, &recv_size); curr_size += recv_size; /* printf("Rank %d, recv from %d, offset %d, size %d\n", rank, dst, offset, recv_size); * fflush(stdout); */ } tmp_mask >>= 1; k--; } } /* --END EXPERIMENTAL-- */ mask <<= 1; i++; } /* check that we received as much as we expected */ if (curr_size != nbytes) { if (*errflag == MPIR_ERR_NONE) *errflag = MPIR_ERR_OTHER; MPIR_ERR_SET2(mpi_errno, MPI_ERR_OTHER, "**collective_size_mismatch", "**collective_size_mismatch %d %d", curr_size, nbytes); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } if (!is_contig) { if (rank != root) { position = 0; mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer, count, datatype); if (mpi_errno) MPIR_ERR_POP(mpi_errno); } } fn_exit: MPIR_CHKLMEM_FREEALL(); /* --BEGIN ERROR HANDLING-- */ if (mpi_errno_ret) mpi_errno = mpi_errno_ret; else if (*errflag != MPIR_ERR_NONE) MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail"); /* --END ERROR HANDLING-- */ return mpi_errno; fn_fail: goto fn_exit; }