/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */ /* * (C) 2012 by Argonne National Laboratory. * See COPYRIGHT in top-level directory. */ #include "mpiimpl.h" /* Local utility macro: takes an two args and sets lvalue cr_ equal to the rank * in comm_ptr corresponding to rvalue gr_ */ #define to_comm_rank(cr_, gr_) \ do { \ int gr_tmp_ = (gr_); \ mpi_errno = MPIR_Group_translate_ranks_impl(group_ptr, 1, &(gr_tmp_), comm_ptr->local_group, &(cr_)); \ if (mpi_errno) MPIR_ERR_POP(mpi_errno); \ MPIR_Assert((cr_) != MPI_UNDEFINED); \ } while (0) #undef FUNCNAME #define FUNCNAME MPII_Allreduce_group_intra #undef FCNAME #define FCNAME MPL_QUOTE(FUNCNAME) int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, int tag, MPIR_Errflag_t * errflag) { MPI_Aint type_size; int mpi_errno = MPI_SUCCESS; int mpi_errno_ret = MPI_SUCCESS; /* newrank is a rank in group_ptr */ int mask, dst, is_commutative, pof2, newrank, rem, newdst, i, send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps; MPI_Aint true_extent, true_lb, extent; void *tmp_buf; int group_rank, group_size; int cdst, csrc; MPIR_CHKLMEM_DECL(3); group_rank = group_ptr->rank; group_size = group_ptr->size; MPIR_ERR_CHKANDJUMP(group_rank == MPI_UNDEFINED, mpi_errno, MPI_ERR_OTHER, "**rank"); is_commutative = MPIR_Op_is_commutative(op); /* need to allocate temporary buffer to store incoming data */ MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Ensure_Aint_fits_in_pointer(count * MPL_MAX(extent, true_extent)); MPIR_CHKLMEM_MALLOC(tmp_buf, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno, "temporary buffer", MPL_MEM_BUFFER); /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *) ((char *) tmp_buf - true_lb); /* copy local data into recvbuf */ if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf, count, datatype); if (mpi_errno) MPIR_ERR_POP(mpi_errno); } MPIR_Datatype_get_size_macro(datatype, type_size); /* get nearest power-of-two less than or equal to comm_size */ pof2 = MPL_pof2(group_size); rem = group_size - pof2; /* In the non-power-of-two case, all even-numbered * processes of rank < 2*rem send their data to * (rank+1). These even-numbered processes no longer * participate in the algorithm until the very end. The * remaining processes form a nice power-of-two. */ if (group_rank < 2 * rem) { if (group_rank % 2 == 0) { /* even */ to_comm_rank(cdst, group_rank + 1); mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag); if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } /* temporarily set the rank to -1 so that this * process does not pariticipate in recursive * doubling */ newrank = -1; } else { /* odd */ to_comm_rank(csrc, group_rank - 1); mpi_errno = MPIC_Recv(tmp_buf, count, datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE, errflag); if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } /* do the reduction on received data. since the * ordering is right, it doesn't matter whether * the operation is commutative or not. */ mpi_errno = MPIR_Reduce_local(tmp_buf, recvbuf, count, datatype, op); if (mpi_errno) MPIR_ERR_POP(mpi_errno); /* change the rank */ newrank = group_rank / 2; } } else /* rank >= 2*rem */ newrank = group_rank - rem; /* If op is user-defined or count is less than pof2, use * recursive doubling algorithm. Otherwise do a reduce-scatter * followed by allgather. (If op is user-defined, * derived datatypes are allowed and the user could pass basic * datatypes on one process and derived on another as long as * the type maps are the same. Breaking up derived * datatypes to do the reduce-scatter is tricky, therefore * using recursive doubling in that case.) */ if (newrank != -1) { if ((count * type_size <= MPIR_CVAR_ALLREDUCE_SHORT_MSG_SIZE) || (HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) || (count < pof2)) { /* use recursive doubling */ mask = 0x1; while (mask < pof2) { newdst = newrank ^ mask; /* find real rank of dest */ dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem; to_comm_rank(cdst, dst); /* Send the most current data, which is in recvbuf. Recv * into tmp_buf */ mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype, cdst, tag, tmp_buf, count, datatype, cdst, tag, comm_ptr, MPI_STATUS_IGNORE, errflag); if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } else { /* tmp_buf contains data received in this step. * recvbuf contains data accumulated so far */ if (is_commutative || (dst < group_rank)) { /* op is commutative OR the order is already right */ mpi_errno = MPIR_Reduce_local(tmp_buf, recvbuf, count, datatype, op); if (mpi_errno) MPIR_ERR_POP(mpi_errno); } else { /* op is noncommutative and the order is not right */ mpi_errno = MPIR_Reduce_local(recvbuf, tmp_buf, count, datatype, op); if (mpi_errno) MPIR_ERR_POP(mpi_errno); /* copy result back into recvbuf */ mpi_errno = MPIR_Localcopy(tmp_buf, count, datatype, recvbuf, count, datatype); if (mpi_errno) MPIR_ERR_POP(mpi_errno); } } mask <<= 1; } } else { /* do a reduce-scatter followed by allgather */ /* for the reduce-scatter, calculate the count that * each process receives and the displacement within * the buffer */ MPIR_CHKLMEM_MALLOC(cnts, int *, pof2 * sizeof(int), mpi_errno, "counts", MPL_MEM_BUFFER); MPIR_CHKLMEM_MALLOC(disps, int *, pof2 * sizeof(int), mpi_errno, "displacements", MPL_MEM_BUFFER); for (i = 0; i < (pof2 - 1); i++) cnts[i] = count / pof2; cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1); if (pof2) disps[0] = 0; for (i = 1; i < pof2; i++) disps[i] = disps[i - 1] + cnts[i - 1]; mask = 0x1; send_idx = recv_idx = 0; last_idx = pof2; while (mask < pof2) { newdst = newrank ^ mask; /* find real rank of dest */ dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem; to_comm_rank(cdst, dst); send_cnt = recv_cnt = 0; if (newrank < newdst) { send_idx = recv_idx + pof2 / (mask * 2); for (i = send_idx; i < last_idx; i++) send_cnt += cnts[i]; for (i = recv_idx; i < send_idx; i++) recv_cnt += cnts[i]; } else { recv_idx = send_idx + pof2 / (mask * 2); for (i = send_idx; i < recv_idx; i++) send_cnt += cnts[i]; for (i = recv_idx; i < last_idx; i++) recv_cnt += cnts[i]; } /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIC_Sendrecv((char *) recvbuf + disps[send_idx] * extent, send_cnt, datatype, cdst, tag, (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, cdst, tag, comm_ptr, MPI_STATUS_IGNORE, errflag); if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } /* tmp_buf contains data received in this step. * recvbuf contains data accumulated so far */ /* This algorithm is used only for predefined ops * and predefined ops are always commutative. */ mpi_errno = MPIR_Reduce_local(((char *) tmp_buf + disps[recv_idx] * extent), ((char *) recvbuf + disps[recv_idx] * extent), recv_cnt, datatype, op); if (mpi_errno) MPIR_ERR_POP(mpi_errno); /* update send_idx for next iteration */ send_idx = recv_idx; mask <<= 1; /* update last_idx, but not in last iteration * because the value is needed in the allgather * step below. */ if (mask < pof2) last_idx = recv_idx + pof2 / mask; } /* now do the allgather */ mask >>= 1; while (mask > 0) { newdst = newrank ^ mask; /* find real rank of dest */ dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem; to_comm_rank(cdst, dst); send_cnt = recv_cnt = 0; if (newrank < newdst) { /* update last_idx except on first iteration */ if (mask != pof2 / 2) last_idx = last_idx + pof2 / (mask * 2); recv_idx = send_idx + pof2 / (mask * 2); for (i = send_idx; i < recv_idx; i++) send_cnt += cnts[i]; for (i = recv_idx; i < last_idx; i++) recv_cnt += cnts[i]; } else { recv_idx = send_idx - pof2 / (mask * 2); for (i = send_idx; i < last_idx; i++) send_cnt += cnts[i]; for (i = recv_idx; i < send_idx; i++) recv_cnt += cnts[i]; } mpi_errno = MPIC_Sendrecv((char *) recvbuf + disps[send_idx] * extent, send_cnt, datatype, cdst, tag, (char *) recvbuf + disps[recv_idx] * extent, recv_cnt, datatype, cdst, tag, comm_ptr, MPI_STATUS_IGNORE, errflag); if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } if (newrank > newdst) send_idx = recv_idx; mask >>= 1; } } } /* In the non-power-of-two case, all odd-numbered * processes of rank < 2*rem send the result to * (rank-1), the ranks who didn't participate above. */ if (group_rank < 2 * rem) { if (group_rank % 2) { /* odd */ to_comm_rank(cdst, group_rank - 1); mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag); } else { /* even */ to_comm_rank(csrc, group_rank + 1); mpi_errno = MPIC_Recv(recvbuf, count, datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE, errflag); } if (mpi_errno) { /* for communication errors, just record the error but continue */ *errflag = MPIX_ERR_PROC_FAILED == MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER; MPIR_ERR_SET(mpi_errno, *errflag, "**fail"); MPIR_ERR_ADD(mpi_errno_ret, mpi_errno); } } fn_exit: MPIR_CHKLMEM_FREEALL(); if (mpi_errno_ret) mpi_errno = mpi_errno_ret; else if (*errflag != MPIR_ERR_NONE) MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail"); return (mpi_errno); fn_fail: goto fn_exit; } #undef FUNCNAME #define FUNCNAME MPII_Allreduce_group #undef FCNAME #define FCNAME MPL_QUOTE(FUNCNAME) int MPII_Allreduce_group(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, int tag, MPIR_Errflag_t * errflag) { int mpi_errno = MPI_SUCCESS; MPIR_ERR_CHKANDJUMP(comm_ptr->comm_kind != MPIR_COMM_KIND__INTRACOMM, mpi_errno, MPI_ERR_OTHER, "**commnotintra"); mpi_errno = MPII_Allreduce_group_intra(sendbuf, recvbuf, count, datatype, op, comm_ptr, group_ptr, tag, errflag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); fn_exit: return mpi_errno; fn_fail: goto fn_exit; }