#import mpi4py #mpi4py.profile('mpe') from mpi4py import MPI import numpy as np import time def make_intercomm(basecomm): rank = basecomm.Get_rank() size = basecomm.Get_size() assert size > 1, "requires at least two processes" if rank < size // 2 : color = 0 key = rank local_leader = 0 remote_leader = size // 2 low_group = color else: color = 1 key = rank local_leader = 0 remote_leader = 0 low_group = color intracomm = basecomm.Split(color, key) Create_intercomm = MPI.Intracomm.Create_intercomm intercomm = Create_intercomm(intracomm, local_leader, basecomm, remote_leader) return intercomm, intracomm, low_group def reduce_binomial(obj, op, root, comm, tag): size = comm.size rank = comm.rank mask = 1 assert comm.is_intra assert 0 <= root < size result = obj while mask < size: if (mask & rank) != 0: target = (rank & ~mask) % size comm.send(result, target, tag) else: target = (rank | mask) if target < size: tmp = comm.recv(None, target, tag) result = op(result, tmp) mask <<= 1 if root != 0: if rank == 0: comm.send(result, root, tag) elif rank == root: result = comm.recv(None, 0, tag) if rank != root: result = None return result def reduce_inter(obj, op, root, comm, tag, localcomm): if root == MPI.PROC_NULL: return None if root == MPI.ROOT: return comm.recv(None, 0, tag) result = reduce_binomial(obj, op, 0, localcomm, tag) if comm.rank == 0: comm.send(result, root, tag) return None def allreduce_inter_mpich(obj, op, comm, tag, localcomm, low_group): zero = 0 if comm.rank == 0: root = MPI.ROOT else: root = MPI.PROC_NULL if low_group: ignore = reduce_inter(obj, op, zero, comm, tag, localcomm) result = reduce_inter(obj, op, root, comm, tag, localcomm) else: result = reduce_inter(obj, op, root, comm, tag, localcomm) ignore = reduce_inter(obj, op, zero, comm, tag, localcomm) return localcomm.bcast(result, 0) def allreduce_inter_dalcinl(obj, op, comm, tag, localcomm): result = reduce_binomial(obj, op, 0, localcomm, tag) if comm.rank == 0: result = comm.sendrecv(result, 0, tag, None, 0, tag) return localcomm.bcast(result, 0) intercomm, localcomm, low_group = make_intercomm(MPI.COMM_WORLD) tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB) if 1: def op(a, b): time.sleep(1.0) return MPI.SUM(a, b) else: op = MPI.SUM obj = 1 tag = tag_ub - 1 tic = MPI.Wtime() result = allreduce_inter_mpich(obj, op, intercomm, tag, localcomm, low_group) toc = MPI.Wtime() assert result == intercomm.remote_size elapsed = MPI.COMM_WORLD.gather(toc-tic, root=0) if MPI.COMM_WORLD.rank == 0: print "[mpich] time: min=%e max=%e" % (min(elapsed), max(elapsed)) obj = 1 tag = tag_ub - 2 tic = MPI.Wtime() result = allreduce_inter_dalcinl(obj, op, intercomm, tag, localcomm) toc = MPI.Wtime() assert result == intercomm.remote_size elapsed = MPI.COMM_WORLD.gather(toc-tic, root=0) if MPI.COMM_WORLD.rank == 0: print "[dalcinl] time: min=%e max=%e" % (min(elapsed), max(elapsed))