...33_FP32_ALLREDUCE = None34def is_unitialized():35 """Useful for code segments that may be accessed with or without mpu initialization"""36 return _DATA_PARALLEL_GROUP is None37def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False):38 """39 Initialize model data parallel groups.40 Arguments:41 model_parallel_size: number of GPUs used to parallelize model.42 Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we43 use 2 GPUs to parallelize the model. The present function will44 create 4 model parallel groups and 2 data parallel groups as:45 4 model parallel groups:46 [g0, g1], [g2, g3], [g4, g5], [g6, g7]47 2 data parallel groups:48 [g0, g2, g4, g6], [g1, g3, g5, g7]49 Note that for efficiency, the caller should make sure adjacent ranks50 are on the same DGX box. For example if we are using 2 DGX-1 boxes51 with a total of 16 GPUs, rank 0 to 7 belong to the first box and52 ranks 8 to 15 belong to the second box.53 """54 if torch.distributed.get_rank() == 0:55 print("> initializing model parallel with size {}".format(model_parallel_size))56 # Get world size and rank. Ensure some consistencies.57 assert torch.distributed.is_initialized()58 world_size = torch.distributed.get_world_size()59 if world_size < model_parallel_size:60 raise ValueError("world size cannot be smaller than model parallel size")61 ensure_divisibility(world_size, model_parallel_size)62 rank = torch.distributed.get_rank()63 global _MPU_TOPOLOGY64 if topology:65 _MPU_TOPOLOGY = topology66 # Build the data parallel groups.67 global _DATA_PARALLEL_GROUP68 assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"69 if topology:70 for dp_group in topology.get_axis_comm_lists("data"):71 group = torch.distributed.new_group(ranks=dp_group)72 if rank == 0:73 print(f"MPU DP:", dp_group)74 if rank in dp_group:75 _DATA_PARALLEL_GROUP = group76 else:77 for i in range(model_parallel_size):78 ranks = range(i, world_size, model_parallel_size)79 group = torch.distributed.new_group(ranks)80 if i == (rank % model_parallel_size):81 _DATA_PARALLEL_GROUP = group82 # Build pipeline parallel group83 if topology is not None:84 global _PIPE_PARALLEL_GROUP85 for pp_group in topology.get_axis_comm_lists("pipe"):86 group = torch.distributed.new_group(ranks=pp_group)87 if rank == 0:88 print(f"MPU PP:", pp_group)89 if rank in pp_group:90 _PIPE_PARALLEL_GROUP = group91 # Build IO group92 global _IO_PARALLEL_GROUP93 if topology and topology.get_dim("pipe") > 1:94 io_stages = [0, topology.get_dim("pipe") - 1]95 io_group = []96 for stage in io_stages:97 io_group.extend(topology.filter_match(pipe=stage, model=0))98 if rank == 0:99 print(f"MPU IO:", io_group)100 group = torch.distributed.new_group(ranks=io_group)101 if rank in io_group:102 _IO_PARALLEL_GROUP = group103 else:104 _IO_PARALLEL_GROUP = get_data_parallel_group()105 # Build the model parallel groups.106 global _MODEL_PARALLEL_GROUP107 assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"108 if topology:109 # Short circuit case without model parallelism.110 # TODO: it would be nice to avoid this branching case?111 if model_parallel_size == 1:112 for group_rank in range(world_size):113 group = torch.distributed.new_group(ranks=[group_rank])114 if rank == 0:115 print(f"MPU MP:", [group_rank])116 if rank == group_rank:117 _MODEL_PARALLEL_GROUP = group118 return119 for mp_group in topology.get_axis_comm_lists("model"):120 group = torch.distributed.new_group(ranks=mp_group)121 if rank == 0:122 print(f"MPU MP:", mp_group)123 if rank in mp_group:124 _MODEL_PARALLEL_GROUP = group125 else:126 for i in range(world_size // model_parallel_size):127 ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)128 group = torch.distributed.new_group(ranks)129 if i == (rank // model_parallel_size):130 _MODEL_PARALLEL_GROUP = group131 global _FP32_ALLREDUCE132 assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized"133 _FP32_ALLREDUCE = fp32_allreduce134def model_parallel_is_initialized():135 """Check if model and data parallel groups are initialized."""136 if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:137 return False138 return True139def get_model_parallel_group():140 """Get the model parallel group the caller rank belongs to."""141 assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"142 return _MODEL_PARALLEL_GROUP143def get_data_parallel_group():144 """Get the data parallel group the caller rank belongs to."""145 assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"146 return _DATA_PARALLEL_GROUP147def get_io_parallel_group():148 """Get the IO parallel group the caller rank belongs to."""149 assert _IO_PARALLEL_GROUP is not None, "IO parallel group is not initialized"150 return _IO_PARALLEL_GROUP151def set_model_parallel_world_size(world_size):152 """Set the model parallel size"""153 global _MPU_WORLD_SIZE154 _MPU_WORLD_SIZE = world_size155def get_model_parallel_world_size():156 """Return world size for the model parallel group."""157 global _MPU_WORLD_SIZE158 if _MPU_WORLD_SIZE is not None:159 return _MPU_WORLD_SIZE160 return torch.distributed.get_world_size(group=get_model_parallel_group())161def set_model_parallel_rank(rank):162 """Set model parallel rank."""163 global _MPU_RANK164 _MPU_RANK = rank165def get_model_parallel_rank():166 """Return my rank for the model parallel group."""167 global _MPU_RANK168 if _MPU_RANK is not None:169 return _MPU_RANK170 return torch.distributed.get_rank(group=get_model_parallel_group())171def get_model_parallel_src_rank():172 """Calculate the global rank corresponding to a local rank zero173 in the model parallel group."""174 global_rank = torch.distributed.get_rank()175 local_world_size = get_model_parallel_world_size()176 return (global_rank // local_world_size) * local_world_size177def get_data_parallel_src_rank():178 """Calculate the global rank corresponding to a local rank zero179 in the data parallel group."""180 global_rank = torch.distributed.get_rank()181 topo = get_topology()182 if topo is None:183 # we are just using model parallel184 return global_rank % get_model_parallel_world_size()185 else:186 # We are using pipeline parallel187 d = topo.get_axis_comm_lists("data")188 for l in d:189 if global_rank in l:190 return l[0]191def get_data_parallel_world_size():192 """Return world size for the data parallel group."""193 return torch.distributed.get_world_size(group=get_data_parallel_group())194def get_data_parallel_rank():195 """Return my rank for the data parallel group."""196 return torch.distributed.get_rank(group=get_data_parallel_group())197def get_topology():198 return _MPU_TOPOLOGY199def get_pipe_parallel_group():200 """Get the pipe parallel group the caller rank belongs to."""201 assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized"202 return _PIPE_PARALLEL_GROUP203def get_pipe_parallel_rank():204 """Return my rank for the pipe parallel group."""205 return torch.distributed.get_rank(group=get_pipe_parallel_group())206def get_pipe_parallel_world_size():207 """Return world size for the pipe parallel group."""208 return torch.distributed.get_world_size(group=get_pipe_parallel_group())209def destroy_model_parallel():210 """Set the groups to none."""211 global _MODEL_PARALLEL_GROUP212 _MODEL_PARALLEL_GROUP = None213 global _DATA_PARALLEL_GROUP214 _DATA_PARALLEL_GROUP = None215 global _PIPE_PARALLEL_GROUP216 _PIPE_PARALLEL_GROUP = None217 global _IO_PARALLEL_GROUP218 _IO_PARALLEL_GROUP = None219 global _MPU_WORLD_SIZE220 global _MPU_RANK221 _MPU_WORLD_SIZE = None222 _MPU_RANK = None223 global _MPU_TOPOLOGY...

1from mpi4py.MPI import Request2import numpy as np3# Time evolution for the inner part of the grid4def exchange_init(u, parallel):5 # Send to the up, receive from down6 parallel.requests[0] = parallel.comm.Isend((u[1,:], 1, parallel.rowtype),7 dest=parallel.nup)8 parallel.requests[1] = parallel.comm.Irecv((u[-1,:], 1, parallel.rowtype),9 source=parallel.ndown)10 # Send to the down, receive from up11 parallel.requests[2] = parallel.comm.Isend((u[-2,:], 1, parallel.rowtype),12 dest=parallel.ndown)13 parallel.requests[3] = parallel.comm.Irecv((u[0,:], 1, parallel.rowtype),14 source=parallel.nup)15 # Send to the left, receive from right16 parallel.requests[4] = parallel.comm.Isend((u.ravel()[1:], 1, 17 parallel.columntype),18 dest=parallel.nleft)19 idx = u.shape[1] - 1 # ny + 120 parallel.requests[5] = parallel.comm.Irecv((u.ravel()[idx:], 1, 21 parallel.columntype),22 source=parallel.nright)23 # Send to the right, receive from left24 idx = u.shape[1] - 2 # ny25 parallel.requests[6] = parallel.comm.Isend((u.ravel()[idx:], 1, 26 parallel.columntype),27 dest=parallel.nright)28 parallel.requests[7] = parallel.comm.Irecv((u, 1, parallel.columntype),29 source=parallel.nleft)30def exchange_finalize(parallel):31 # MPI.Request.Waitall(parallel.requests) ...

