[mpich-devel] Load imbalance in MPIR_Reduce_redscat_gather

Kenneth Raffenetti raffenet at mcs.anl.gov
Thu May 25 10:05:52 CDT 2017


Hi Mikhail,

Sorry for the delayed response. I have created an issue on our Github 
for your submission. Because this is a rather substantial change, we 
should get a signed contributors agreement from you. Please take a look 
at the individual form here: 
http://www.mpich.org/documentation/contributor-docs/

Even better, if you could submit your patch as a Github PR, that can 
speed the acceptance process.

Thanks,
Ken

On 03/29/2017 11:46 PM, Mikhail Kurnosov wrote:
> /*
>   * An implementation of Rabenseifner's reduce algorithm [1, 2].
>   *
>   * This algorithm is a combination of a reduce-scatter implemented with
>   * recursive vector halving and recursive distance doubling, followed either
>   * by a binomial tree gather [1].
>   *
>   * Step 1. If the number of processes is not a power of two, reduce it to
>   * the nearest lower power of two (p' = 2^{\lfloor\log_2 p\rfloor})
>   * by removing r = p - p' extra processes as follows. In the first 2r processes
>   * (ranks 0 to 2r - 1), all the even ranks send the second half of the input
>   * vector to their right neighbor (rank + 1), and all the odd ranks send
>   * the first half of the input vector to their left neighbor (rank тИТ 1).
>   * The even ranks compute the reduction on the first half of the vector and
>   * the odd ranks compute the reduction on the second half. The odd ranks then
>   * send the result to their left neighbors (the even ranks). As a result,
>   * the even ranks among the first 2r processes now contain the reduction with
>   * the input vector on their right neighbors (the odd ranks). These odd ranks
>   * do not participate in the rest of the algorithm, which leaves behind
>   * a power-of-two number of processes. The first r even-ranked processes and
>   * the last p - 2r processes are now renumbered from 0 to p' - 1.
>   *
>   * Step 2. The remaining processes now perform a reduce-scatter by using
>   * recursive vector halving and recursive distance doubling. The even-ranked
>   * processes send the second half of their buffer to rank + 1 and the odd-ranked
>   * processes send the first half of their buffer to rank тИТ 1. All processes
>   * then compute the reduction between the local buffer and the received buffer.
>   * In the next log_2(p') - 1 steps, the buffers are recursively halved, and the
>   * distance is doubled. At the end, each of the p' processes has 1/p' of the
>   * total reduction result.
>   *
>   * Step 3. A binomial tree gather is performed by using recursive vector
>   * doubling and distance halving. In the non-power-of-two case, if the root
>   * happens to be one of those odd-ranked processes that would normally
>   * be removed in the first step, then the role of this process and process 0
>   * are interchanged.
>   *
>   * Limitations: commutative operations only, count >= 2^{\lfloor\log_2 p\rfloor}
>   * Recommendations: root = 0, otherwise it is required additional steps
>   *                  in the root process.
>   *
>   * Memory consumption (per process):
>   * 1) rank != root: 2 * count * typesize + 4 * log2(p) * sizeof(int) = O(count)
>   * 2) rank == root: count * typesize + 4 * log2(p) * sizeof(int) = O(count)
>   *
>   * [1] Rajeev Thakur, Rolf Rabenseifner and William Gropp.
>   *     Optimization of Collective Communication Operations in MPICH //
>   *     The Int. Journal of High Performance Computing Applications. Vol 19,
>   *     Issue 1, pp. 49--66.
>   * [2]http://www.hlrs.de/mpi/myreduce.html.
>   */
> #undef FUNCNAME
> #define FUNCNAME MPIR_Reduce_redscat_gather
> #undef FCNAME
> #define FCNAME MPL_QUOTE(FUNCNAME)
> static int MPIR_Reduce_redscat_gather(
>      const void *sendbuf,
>      void *recvbuf,
>      int count,
>      MPI_Datatype datatype,
>      MPI_Op op,
>      int root,
>      MPID_Comm *comm_ptr,
>      MPIR_Errflag_t *errflag )
> {
>      int mpi_errno = MPI_SUCCESS;
>      int mpi_errno_ret = MPI_SUCCESS;
>      int comm_size, rank, type_size ATTRIBUTE((unused)), pof2, rem, newrank;
>      int mask, i, j, newdst, dst, nsteps, step, wsize;
>      int newroot, newdst_tree_root, newroot_tree_root;
>      MPI_Aint true_lb, true_extent, extent;
>      void *tmp_buf;
>      int *rindex, *rcount, *sindex, *scount, count_lhalf, count_rhalf;
> 
>      MPIU_CHKLMEM_DECL(6);
>      MPID_THREADPRIV_DECL;
> 
>      comm_size = comm_ptr->local_size;
>      rank = comm_ptr->rank;
> 
>      /* Set op_errno to 0. Stored in perthread structure */
>      MPID_THREADPRIV_GET;
>      MPID_THREADPRIV_FIELD(op_errno) = 0;
> 
>      /* Create a temporary buffer */
>      MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
>      MPID_Datatype_get_extent_macro(datatype, extent);
> 
>      /* I think this is the worse case, so we can avoid an assert()
>       * inside the for loop should be buf+{this}?
>       */
>      MPIU_Ensure_Aint_fits_in_pointer(count * MPIR_MAX(extent, true_extent));
> 
>      MPIU_CHKLMEM_MALLOC(tmp_buf, void *, count*(MPIR_MAX(extent, true_extent)),
>                          mpi_errno, "temporary buffer");
>      /* Adjust for potential negative lower bound in datatype */
>      tmp_buf = (void *)((char*)tmp_buf - true_lb);
> 
>      /* If I'm not the root, then my recvbuf may not be valid, therefore
>       * I have to allocate a temporary one */
>      if (rank != root) {
>          MPIU_CHKLMEM_MALLOC(recvbuf, void *,
>                              count * (MPIR_MAX(extent, true_extent)),
>                              mpi_errno, "receive buffer");
>          recvbuf = (void *)((char*)recvbuf - true_lb);
>      }
> 
>      if ((rank != root) || (sendbuf != MPI_IN_PLACE)) {
>          mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf,
>                                     count, datatype);
>          if (mpi_errno) { MPIR_ERR_POP(mpi_errno); }
>      }
> 
>      MPID_Datatype_get_size_macro(datatype, type_size);
> 
>      /*
>       * Step 1. Reduce the number of processes to the nearest lower power of two
>       * (p' = 2^{\lfloor\log_2 p\rfloor}) by removing r = p - p' processes.
>       * 1. In the first 2r processes (ranks 0 to 2r - 1), all the even ranks send
>       *    the second half of the input vector to their right neighbor (rank + 1)
>       *    and all the odd ranks send the first half of the input vector to their
>       *    left neighbor (rank тИТ 1).
>       * 2. All 2r processes compute the reduction on their half.
>       * 3. The odd ranks then send the result to their left neighbors
>       *    (the even ranks).
>       *
>       * The even ranks (0 to 2r - 1) now contain the reduction with the input
>       * vector on their right neighbors (the odd ranks). The first r even
>       * processes and the p - 2r last processes are renumbered from
>       * 0 to 2^{\floor(log_2 p)} - 1. These odd ranks do not participate in the
>       * rest of the algorithm.
>       */
> 
>      /* Find nearest power-of-two less than or equal to comm_size */
>      pof2 = 1;
>      nsteps = -1;
>      while (pof2 <= comm_size) {  /* O(log(p)), FIXME: use flp2 and ilog2 */
>          pof2 <<= 1;
>          nsteps++;
>      }
>      pof2 >>= 1;
> 
>      rem = comm_size - pof2;
>      if (rank < 2 * rem) {
>          count_lhalf = count / 2;
>          count_rhalf = count - count_lhalf;
> 
>          if (rank % 2 != 0) { /* odd process -- exchange with rank - 1 */
>              /*
>               * Send the left half of the input vector to the left neighbor,
>               * Recv the right half of the input vector from the left neighbor
>               */
>              mpi_errno = MPIC_Sendrecv(recvbuf, count_lhalf, datatype,
>                                        rank - 1, MPIR_REDUCE_TAG,
>                                        (char *)tmp_buf + count_lhalf * extent,
>                                        count_rhalf, datatype, rank - 1,
>                                        MPIR_REDUCE_TAG, comm_ptr,
>                                        MPI_STATUS_IGNORE, errflag);
>              if (mpi_errno) {
>                  /* for communication errors, just record the error but continue */
>                  *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                  MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                  MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>              }
> 
>              /* Reduce on the right half of the buffers (result in recvbuf) */
>              mpi_errno = MPIR_Reduce_local_impl((char *)tmp_buf +
>                                                 count_lhalf * extent,
>                                                 (char *)recvbuf +
>                                                 count_lhalf * extent,
>                                                 count_rhalf, datatype, op);
>              
>              /* Send the right half to the left neighbor */
>              mpi_errno = MPIC_Send((char *)recvbuf + count_lhalf * extent,
>                                    count_rhalf, datatype, rank - 1,
>                                    MPIR_REDUCE_TAG, comm_ptr, errflag);
>              if (mpi_errno) {
>                  /* for communication errors, just record the error but continue */
>                  *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                  MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                  MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>              }
>              /* Temporarily set the rank to -1 so that this process does not
>                 pariticipate in recursive doubling */
>              newrank = -1;
>              
>          } else { /* even process -- exchange with rank + 1 */
>              /*
>               * Send the right half of the input vector to the right neighbor,
>               * Recv the left half of the input vector from the right neighbor
>               */
>              mpi_errno = MPIC_Sendrecv((char *)recvbuf + count_lhalf * extent,
>                                        count_rhalf, datatype, rank + 1,
>                                        MPIR_REDUCE_TAG, tmp_buf, count_lhalf,
>                                        datatype, rank + 1, MPIR_REDUCE_TAG,
>                                        comm_ptr, MPI_STATUS_IGNORE, errflag);
>              if (mpi_errno) {
>                  /* for communication errors, just record the error but continue */
>                  *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                  MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                  MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>              }
> 
>              /* Reduce on the left half of the buffers (result in recvbuf) */
>              mpi_errno = MPIR_Reduce_local_impl(tmp_buf, recvbuf, count_lhalf,
>                                                 datatype, op);
> 
>              /* Recv the right half from the right neighbor */
>              mpi_errno = MPIC_Recv((char *)recvbuf + count_lhalf * extent,
>                                    count_rhalf, datatype, rank + 1,
>                                    MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE,
>                                    errflag);
>              if (mpi_errno) {
>                  /* for communication errors, just record the error but continue */
>                  *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                  MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                  MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>              }
>              newrank = rank / 2;
>          }
>      } else { /* rank >= 2 * rem */
>          newrank = rank - rem;
>      }
> 
>      /*
>       * Step 2. Reduce-scatter implemented with recursive vector halving and
>       * recursive distance doubling. We have p' = 2^{\lfloor\log_2 p\rfloor}
>       * power-of-two number of processes with new ranks and result in recvbuf.
>       *
>       * The even-ranked processes send the right half of their buffer to rank + 1
>       * and the odd-ranked processes send the left half of their buffer to
>       * rank - 1. All processes then compute the reduction between the local
>       * buffer and the received buffer. In the next \log_2(p') - 1 steps, the
>       * buffers are recursively halved, and the distance is doubled. At the end,
>       * each of the p' processes has 1 / p' of the total reduction result.
>       */
>      MPIU_CHKLMEM_MALLOC(rindex, int *, nsteps * sizeof(*rindex), mpi_errno,
>                          "rindex buffer");
>      MPIU_CHKLMEM_MALLOC(rcount, int *, nsteps * sizeof(*rcount), mpi_errno,
>                          "rcount buffer");
>      MPIU_CHKLMEM_MALLOC(sindex, int *, nsteps * sizeof(*sindex), mpi_errno,
>                          "sindex buffer");
>      MPIU_CHKLMEM_MALLOC(scount, int *, nsteps * sizeof(*scount), mpi_errno,
>                          "scount buffer");
> 
>      if (newrank != -1) {
>          step = 0;
>          wsize = count;
>          sindex[0] = rindex[0] = 0;
> 
>          for (mask = 1; mask < pof2; mask <<= 1) {
>              /*
>               * On each iteration: rindex[step] = sindex[step] -- begining of the
>               * current window. Length of the current window is storded in wsize.
>               */
>              newdst = newrank ^ mask;
>              /* Find real rank of dest */
>              dst = (newdst < rem) ? newdst * 2 : newdst + rem;
> 
>              if (rank < dst) {
>                  /* Recv into the left half of the current window, send the right
>                   * half of the window to the peer (perform reduce on the left
>                   * half of the current window)
>                   */
>                  rcount[step] = wsize / 2;
>                  scount[step] = wsize - rcount[step];
>                  sindex[step] = rindex[step] + rcount[step];
>              } else {
>                  /* Recv into the right half of the current window, send the left
>                   * half of the window to the peer (perform reduce on the right
>                   * half of the current window)
>                   */
>                  scount[step] = wsize / 2;
>                  rcount[step] = wsize - scount[step];
>                  rindex[step] = sindex[step] + scount[step];
>              }
> 
>              /* Send part of data from the recvbuf, recv into the tmp_buf */
>              mpi_errno = MPIC_Sendrecv((char *)recvbuf + sindex[step] * extent,
>                                        scount[step], datatype, dst,
>                                        MPIR_REDUCE_TAG,
>                                        (char *)tmp_buf + rindex[step] * extent,
>                                        rcount[step], datatype, dst,
>                                        MPIR_REDUCE_TAG, comm_ptr,
>                                        MPI_STATUS_IGNORE, errflag);
>              if (mpi_errno) {
>                  /* for communication errors, just record the error but continue */
>                  *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                  MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                  MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>              }
> 
>              /* Local reduce: recvbuf[] = tmp_buf[] <op> recvbuf[] */
>              mpi_errno = MPIR_Reduce_local_impl((char *)tmp_buf +
>                                                 rindex[step] * extent,
>                                                 (char *)recvbuf +
>                                                 rindex[step] * extent,
>                                                 rcount[step], datatype, op);
> 
>              /* Move the current window to the received message */
>              rindex[step + 1] = rindex[step];
>              sindex[step + 1] = rindex[step];
>              wsize = rcount[step];
>              step++;
>          }
>      }
>      /*
>       * Assertion: each process has 1 / p' of the total reduction result:
>       *   rcount[nsteps - 1] elements in the recvbuf[rindex[nsteps - 1]...].
>       */
> 
>      /*
>       * Setup the root process for gather operation.
>       * Case 1: root < 2r and root is odd -- root process was excluded on step 1
>       *         Recv data from process 0, newroot = 0, newrank = 0
>       * Case 2: root < 2r and root is even: newroot = root / 2
>       * Case 3: root >= 2r: newroot = root - r
>       */
>      newroot = 0;
>      if (root < 2 * rem) {
>          if (root % 2 != 0) {
>              newroot = 0;
>              if (rank == root) {
>                  /* Case 1: root < 2r and root is odd -- root process was
>                   * excluded on step 1 (newrank == -1).
>                   * Recv a data from the process 0.
>                   */
>                  rindex[0] = 0;
>                  step = 0, wsize = count;
>                  for (mask = 1; mask < pof2; mask *= 2) {
>                      rcount[step] = wsize / 2;
>                      scount[step] = wsize - rcount[step];
>                      rindex[step] = 0;
>                      sindex[step] = rcount[step];
>                      step++;
>                      wsize /= 2;
>                  }
>                  mpi_errno = MPIC_Recv(recvbuf, rcount[nsteps - 1], datatype, 0,
>                                        MPIR_REDUCE_TAG, comm_ptr,
>                                        MPI_STATUS_IGNORE, errflag);
>                  if (mpi_errno) {
>                      /* for communication errors, just record the error but continue */
>                      *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                      MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                      MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>                  }
>                  newrank = 0;
>              } else if (newrank == 0) {
>                  /* Send a data to the root */
>                  mpi_errno = MPIC_Send(recvbuf, rcount[nsteps - 1], datatype,
>                                        root, MPIR_REDUCE_TAG, comm_ptr, errflag);
>                  if (mpi_errno) {
>                      /* for communication errors, just record the error but continue */
>                      *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                      MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                      MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>                  }
>                  newrank = -1;
>              }
>          } else {
>              /* Case 2: root < 2r and a root is even: newroot = root / 2 */
>              newroot = root / 2;
>          }
>      } else {
>          /* Case 3: root >= 2r: newroot = root - r */
>          newroot = root - rem;
>      }
> 
>      /*
>       * Step 3. Gather result at the newroot by the binomial tree algorithm.
>       * Each process has 1 / p' of the total reduction result:
>       *   rcount[nsteps - 1] elements in the recvbuf[rindex[nsteps - 1]...].
>       * All exchanges are executed in reverse order relative
>       * to recursive doubling (previous step).
>       */
>      if (newrank != -1) {
>          mask = pof2 >> 1;
>          step = nsteps - 1; /* step = ilog2(p') - 1 */
> 
>          while (mask > 0) {
>              newdst = newrank ^ mask;
>              /* Find real rank of dest */
>              dst = (newdst < rem) ? newdst * 2 : newdst + rem;
>              /* If root is playing the role of newdst=0, adjust for it */
>              if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
>                  dst = root;
> 
>              /* If the root of newdst's half of the tree is the
>                 same as the root of newroot's half of the tree,
>                 send to newdst and exit, else receive from newdst. */
>              newdst_tree_root = newdst >> step;
>              newdst_tree_root <<= step;
>              newroot_tree_root = newroot >> step;
>              newroot_tree_root <<= step;
>              
>              if (newdst_tree_root == newroot_tree_root) {
>                  /* Send data from recvbuf and exit */
>                  mpi_errno = MPIC_Send((char *)recvbuf + rindex[step] * extent,
>                                        rcount[step], datatype, dst,
>                                        MPIR_REDUCE_TAG, comm_ptr, errflag);
>                  if (mpi_errno) {
>                      /* for communication errors, just record the error but continue */
>                      *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                      MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                      MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>                  }
>                  break;
>              } else {
>                  /* Recv and continue */
>                  mpi_errno = MPIC_Recv((char *)recvbuf + sindex[step] * extent,
>                                        scount[step], datatype, dst,
>                                        MPIR_REDUCE_TAG, comm_ptr,
>                                        MPI_STATUS_IGNORE, errflag);
>                  if (mpi_errno) {
>                      /* for communication errors, just record the error but continue */
>                      *errflag = MPIR_ERR_GET_CLASS(mpi_errno);
>                      MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
>                      MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
>                  }
>              }
>              step--;
>              mask >>= 1;
>          }
>      }
> 
>      /* FIXME does this need to be checked after each uop invocation for
>         predefined operators? */
>      /* --BEGIN ERROR HANDLING-- */
>      if (MPID_THREADPRIV_FIELD(op_errno)) {
>          mpi_errno = MPID_THREADPRIV_FIELD(op_errno);
>          goto fn_fail;
>      }
>      /* --END ERROR HANDLING-- */
> 
> fn_exit:
>      MPIU_CHKLMEM_FREEALL();
>      if (mpi_errno_ret)
>          mpi_errno = mpi_errno_ret;
>      else if (*errflag != MPIR_ERR_NONE)
>          MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
>      return mpi_errno;
> fn_fail:
>      goto fn_exit;
> }


More information about the devel mailing list