Best Python code snippet using tappy_python
test_pass_enforce_sync.py
Source:test_pass_enforce_sync.py  
...33        mod = EnforceSync()(mod)34        def expected():35            """36            fn (%x: Tensor[(64, 128), float32]) {37                let %set_stream_comp = raf.op.set_stream(int64(0), int64(1));38                let %a1 = raf.op.atan(%x);39                let %a2 = (%a1,);40                let %add_event_comp = raf.op.add_event(int64(1), int64(1));41                let %set_stream_comp1 = raf.op.set_stream(int64(0), int64(1));42                let %wait_for_comp = raf.op.wait_event(int64(1), int64(4));43                let %a3 = raf.op._allreduce(%a2, str"sum", nullptr);44                let %add_event_comm = raf.op.add_event(int64(2), int64(4));45                let %set_stream_comp2 = raf.op.set_stream(int64(0), int64(1));46                let %wait_for_comm = raf.op.wait_event(int64(2), int64(1));47                let %a4 = raf.op.atan(%a3);48                %a449            }50            """51            builder = ANFBuilder()52            x = extended_var("x", shape=shape)53            builder.set_stream(0, comp_stream)54            x_1 = builder.call("atan", [x])55            x_2 = builder.make_tuple([x_1])56            builder.add_event(1, comp_stream)57            builder.set_stream(0, comm_stream)58            builder.wait_event(1, comm_stream)59            x_2 = builder.call("_allreduce", [x_2, raf.ir.const("sum"), raf.ir.const(None)])60            builder.add_event(2, comm_stream)61            builder.set_stream(0, comp_stream)62            builder.wait_event(2, comp_stream)63            x_3 = builder.call("atan", [x_2])64            return tvm.relay.Function([x], builder.ret(x_3))65    assert tvm.ir.structural_equal(mod["main"], expected())66    dcfg.enable_data_parallel = False67@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")68@pytest.mark.parametrize("shape,comp_stream,comm_stream", [[(64, 128), 1, 5]])69def test_parallel_allreduce(shape, comp_stream, comm_stream):70    dcfg = dist.get_config()71    dcfg.enable_data_parallel = True72    with Device("cuda(0)"):73        class Model(raf.Model):74            #    /-> atan -> allreduce -> atan -\75            # atan                               concat76            #    \-> relu -> allreduce -> atan -/77            def build(self):78                pass79            @raf.model.trace80            def forward(self, x):81                a0 = raf.atan(x)82                a1_a = raf.atan(a0)83                a1_b = raf.relu(a0)84                a2_a = raf.allreduce(a1_a)85                a2_b = raf.allreduce(a1_b)86                a3_a = raf.atan(a2_a)87                a3_b = raf.atan(a2_b)88                a4 = raf.concatenate([a3_a, a3_b])89                return a490        model = Model()91        x, _ = randn(shape)92        mod = model._internal(x).mod93        mod = RAFSequential([EnforceSync()])(mod)94        def expected():95            """96            fn (%x: Tensor[(64, 128), float32]) {97                let %set_stream_comp = raf.op.set_stream(int64(0), int64(1));98                let %a1 = raf.op.atan(%x);99                let %a2 = raf.op.atan(%a1);100                let %a3 = (%a2,);101                let %add_event_comp = raf.op.add_event(int64(1), int64(1));102                let %set_stream_comm = raf.op.set_stream(int64(0), int64(4));103                let %wait_for_comp = raf.op.wait_event(int64(1), int64(4));104                let %a4 = raf.op._allreduce(%a3, str"sum", nullptr);105                let %add_event_comm = raf.op.add_event(int64(3), int64(4));106                let %set_stream_comp1 = raf.op.set_stream(int64(0), int64(1));107                let %wait_for_comm = raf.op.wait_event(int64(3), int64(1));108                let %a5 = raf.op.atan(%a4);109                let %a6 = raf.op.relu(%a1);110                let %a7 = (%a6,);111                let %add_event_comp1 = raf.op.add_event(int64(2), int64(1));112                let %set_stream_comm1 = raf.op.set_stream(int64(0), int64(4));113                let %wait_for_comp1 = raf.op.wait_event(int64(2), int64(4));114                let %a8 = raf.op._allreduce(%a7, str"sum", nullptr);115                let %add_event_comm1 = raf.op.add_event(int64(4), int64(4));116                let %set_stream_comp2 = raf.op.set_stream(int64(0), int64(1));117                let %wait_for_comm1 = raf.op.wait_event(int64(4), int64(1));118                let %a9 = raf.op.atan(%a8);119                let %a10 = (%a5, %a9);120                let %a11 = raf.op.concatenate(%a10, int64(0));121                %a11122            }123            """124            builder = ANFBuilder()125            x = extended_var("x", shape=shape)126            builder.set_stream(0, comp_stream)127            x_1 = builder.call("atan", [x])128            x_2 = builder.call("atan", [x_1])129            x_3i = builder.make_tuple([x_2])130            builder.add_event(1, comp_stream)131            builder.set_stream(0, comm_stream)132            builder.wait_event(1, comm_stream)133            x_3 = builder.call("_allreduce", [x_3i, raf.ir.const("sum"), raf.ir.const(None)])134            builder.add_event(2, comm_stream)135            builder.set_stream(0, comp_stream)136            builder.wait_event(2, comp_stream)137            x_4 = builder.call("atan", [x_3])138            x_5 = builder.call("relu", [x_1])139            x_6i = builder.make_tuple([x_5])140            builder.add_event(3, comp_stream)141            builder.set_stream(0, comm_stream)142            builder.wait_event(3, comm_stream)143            x_6 = builder.call("_allreduce", [x_6i, raf.ir.const("sum"), raf.ir.const(None)])144            builder.add_event(4, comm_stream)145            builder.set_stream(0, comp_stream)146            builder.wait_event(4, comp_stream)147            x_7 = builder.call("atan", [x_6])148            x_8i = builder.make_tuple([x_4, x_7])149            x_8 = builder.call("concatenate", [x_8i, raf.ir.const(0)])150            return tvm.relay.Function([x], builder.ret(x_8))151    assert tvm.ir.structural_equal(mod["main"], expected())152    dcfg.enable_data_parallel = False153@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")154@pytest.mark.parametrize("shape,comp_stream,comm_stream", [[(64, 128), 1, 5]])155def test_redundant_comm_to_comp_sync(shape, comp_stream, comm_stream):156    dcfg = dist.get_config()157    dcfg.enable_data_parallel = True158    with Device("cuda(0)"):159        def construct_model_func():160            # order in ANF is shown in parenthesis, synchronization (2)->(6) is not needed161            # atan (1) -> atan (3) -> allreduce (4) -> mul (5) -> concat (6)162            #    \                                           /163            #     -------->  allreduce (2) ---------------->164            builder = ANFBuilder()165            x = extended_var("x", shape=shape)166            x_1 = builder.call("atan", [x])167            x_2i = builder.make_tuple([x_1])168            x_2 = builder.call("_allreduce", [x_2i, raf.ir.const("sum"), raf.ir.const(None)])169            x_3 = builder.call("atan", [x_1])170            x_4i = builder.make_tuple([x_3])171            x_4 = builder.call("_allreduce", [x_4i, raf.ir.const("sum"), raf.ir.const(None)])172            x_5 = builder.call("atan", [x_4])173            x_6i = builder.make_tuple([x_5, x_2])174            x_6 = builder.call("concatenate", [x_6i, raf.ir.const(0)])175            return tvm.relay.Function([x], builder.ret(x_6))176        mod = tvm.IRModule()177        mod["main"] = construct_model_func()178        mod = RAFSequential([EnforceSync()])(mod)179        def expected():180            """181            fn (%x: Tensor[(64, 128), float32]) {182                let %set_stream_comp = raf.op.set_stream(int64(0), int64(1));183                let %v = raf.op.atan(%x);184                let %v1 = (%v,);185                let %add_event_comp = raf.op.add_event(int64(1), int64(1));186                let %set_stream_comm = raf.op.set_stream(int64(0), int64(4));187                let %wait_for_comp = raf.op.wait_event(int64(1), int64(4));188                let %v2 = raf.op._allreduce(%v1, str"sum", nullptr);189                let %set_stream_comp1 = raf.op.set_stream(int64(0), int64(1));190                let %v3 = raf.op.atan(%v);191                let %v4 = (%v3,);192                let %add_event_comp1 = raf.op.add_event(int64(2), int64(1));193                let %set_stream_comm1 = raf.op.set_stream(int64(0), int64(4));194                let %wait_for_comp1 = raf.op.wait_event(int64(2), int64(4));195                let %v5 = raf.op._allreduce(%v4, str"sum", nullptr);196                let %add_event_comm = raf.op.add_event(int64(3), int64(4));197                let %set_stream_comp2 = raf.op.set_stream(int64(0), int64(1));198                let %wait_for_comm = raf.op.wait_event(int64(3), int64(1));199                let %v6 = raf.op.atan(%v5);200                let %v7 = (%v6, %v2);201                let %v8 = raf.op.concatenate(%v7, int64(0));202                %v8203            }204            """205            builder = ANFBuilder()206            x = extended_var("x", shape=shape)207            builder.set_stream(0, comp_stream)208            x_1 = builder.call("atan", [x])209            x_2i = builder.make_tuple([x_1])210            builder.add_event(1, comp_stream)211            builder.set_stream(0, comm_stream)212            builder.wait_event(1, comm_stream)213            x_2 = builder.call("_allreduce", [x_2i, raf.ir.const("sum"), raf.ir.const(None)])214            builder.set_stream(0, comp_stream)215            x_3 = builder.call("atan", [x_1])216            x_4i = builder.make_tuple([x_3])217            builder.add_event(2, comp_stream)218            builder.set_stream(0, comm_stream)219            builder.wait_event(2, comm_stream)220            x_4 = builder.call("_allreduce", [x_4i, raf.ir.const("sum"), raf.ir.const(None)])221            builder.add_event(3, comm_stream)222            builder.set_stream(0, comp_stream)223            builder.wait_event(3, comp_stream)224            x_5 = builder.call("atan", [x_4])225            x_6i = builder.make_tuple([x_5, x_2])226            x_6 = builder.call("concatenate", [x_6i, raf.ir.const(0)])227            return tvm.relay.Function([x], builder.ret(x_6))228    assert tvm.ir.structural_equal(mod["main"], expected())229    dcfg.enable_data_parallel = False230@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")231@pytest.mark.parametrize("shape,comp_stream,comm_stream", [[(64, 128), 1, 5]])232def test_redundant_comp_to_comm_sync(shape, comp_stream, comm_stream):233    dcfg = dist.get_config()234    dcfg.enable_data_parallel = True235    with Device("cuda(0)"):236        def construct_model_func():237            # order in ANF is shown in parenthesis, synchronization (1)->(4) is not needed238            # atan (1) -> atan (2) -> allreduce (3) -> atan (5) -> concat (6)239            #    \                                           /240            #     -------->  allreduce (4) ---------------->241            builder = ANFBuilder()242            x = extended_var("x", shape=shape)243            x_1 = builder.call("atan", [x])244            x_4i = builder.make_tuple([x_1])245            x_2 = builder.call("atan", [x_1])246            x_3i = builder.make_tuple([x_2])247            x_3 = builder.call("_allreduce", [x_3i, raf.ir.const("sum"), raf.ir.const(None)])248            x_4 = builder.call("_allreduce", [x_4i, raf.ir.const("sum"), raf.ir.const(None)])249            x_5 = builder.call("atan", [x_3])250            x_6i = builder.make_tuple([x_5, x_4])251            x_6 = builder.call("concatenate", [x_6i, raf.ir.const(0)])252            return tvm.relay.Function([x], builder.ret(x_6))253        mod = tvm.IRModule()254        mod["main"] = construct_model_func()255        mod = RAFSequential([EnforceSync()])(mod)256        def expected():257            """258            fn (%x: Tensor[(64, 128), float32]) {259                let %set_stream_comp = raf.op.set_stream(int64(0), int64(1));260                let %v = raf.op.atan(%x);261                let %v1 = (%v,);262                let %v2 = raf.op.atan(%v);263                let %v3 = (%v2,);264                let %add_event_comp = raf.op.add_event(int64(1), int64(1));265                let %set_stream_comm = raf.op.set_stream(int64(0), int64(4));266                let %wait_for_comp = raf.op.wait_event(int64(1), int64(4));267                let %v4 = raf.op._allreduce(%v3, str"sum", nullptr);268                let %add_event_comm = raf.op.add_event(int64(2), int64(4));269                let %v5 = raf.op._allreduce(%v1, str"sum", nullptr);270                let %add_event_comm1 = raf.op.add_event(int64(3), int64(4));271                let %set_stream_comp1 = raf.op.set_stream(int64(0), int64(1));272                let %wait_for_comm = raf.op.wait_event(int64(2), int64(1));273                let %v6 = raf.op.atan(%v4);274                let %wait_for_comm1 = raf.op.wait_event(int64(3), int64(1));275                let %v7 = (%v6, %v5);276                let %v8 = raf.op.concatenate(%v7, int64(0));277                %v8278            }279            """280            builder = ANFBuilder()281            x = extended_var("x", shape=shape)282            builder.set_stream(0, comp_stream)283            x_1 = builder.call("atan", [x])284            x_4i = builder.make_tuple([x_1])285            x_2 = builder.call("atan", [x_1])286            x_3i = builder.make_tuple([x_2])287            builder.add_event(1, comp_stream)288            builder.set_stream(0, comm_stream)289            builder.wait_event(1, comm_stream)290            x_3 = builder.call("_allreduce", [x_3i, raf.ir.const("sum"), raf.ir.const(None)])291            builder.add_event(2, comm_stream)292            x_4 = builder.call("_allreduce", [x_4i, raf.ir.const("sum"), raf.ir.const(None)])293            builder.add_event(3, comm_stream)294            builder.set_stream(0, comp_stream)295            builder.wait_event(3, comp_stream)296            x_5 = builder.call("atan", [x_3])297            builder.wait_event(2, comp_stream)298            x_6i = builder.make_tuple([x_5, x_4])299            x_6 = builder.call("concatenate", [x_6i, raf.ir.const(0)])300            return tvm.relay.Function([x], builder.ret(x_6))301    assert tvm.ir.structural_equal(mod["main"], expected())302    dcfg.enable_data_parallel = False303@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")304@pytest.mark.parametrize("shape,comp_stream,comm_stream", [[(64, 128), 1, 5]])305def test_multi_input_allreduce(shape, comp_stream, comm_stream):306    dcfg = dist.get_config()307    dcfg.enable_data_parallel = True308    with Device("cuda(0)"):309        class Model(raf.Model):310            # x -> atan -> allreduce -> mul311            #    \          /312            #     -> atan ->313            def build(self):314                pass315            @raf.model.trace316            def forward(self, x):317                a1_a = raf.atan(x)318                a1_b = raf.atan(x)319                a2 = raf.allreduce([a1_a, a1_b])320                a3 = raf.multiply(a2[0], a2[1])321                return a3322        model = Model()323        x, _ = randn(shape)324        mod = model._internal(x).mod325        mod = RAFSequential([EnforceSync()])(mod)326        def expected():327            builder = ANFBuilder()328            x = extended_var("x", shape=shape)329            builder.set_stream(0, comp_stream)330            x_1_a = builder.call("atan", [x])331            x_1_b = builder.call("atan", [x])332            x_2 = builder.make_tuple([x_1_a, x_1_b])333            builder.add_event(1, comp_stream)334            builder.set_stream(0, comm_stream)335            builder.wait_event(1, comm_stream)336            x_3 = builder.call("_allreduce", [x_2, raf.ir.const("sum"), raf.ir.const(None)])337            builder.add_event(2, comm_stream)338            builder.set_stream(0, comp_stream)339            builder.wait_event(2, comp_stream)340            x_4_a = builder.get_tuple_item(x_3, 0)341            x_4_b = builder.get_tuple_item(x_3, 1)342            x_5 = builder.call("multiply", [x_4_a, x_4_b])343            return tvm.relay.Function([x], builder.ret(x_5))344    assert tvm.ir.structural_equal(mod["main"], expected())345    dcfg.enable_data_parallel = False346@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")347@pytest.mark.parametrize("shape,comp_stream,comm_stream", [[(64, 128), 1, 5]])348def test_multi_user_allreduce(shape, comp_stream, comm_stream):349    dcfg = dist.get_config()350    dcfg.enable_data_parallel = True351    with Device("cuda(0)"):352        class Model(raf.Model):353            # atan --> allreduce -> atan ---> mul354            #                   \-> atan --/355            def build(self):356                pass357            @raf.model.trace358            def forward(self, x):359                a = raf.atan(x)360                b = raf.allreduce([a])361                c = raf.atan(b)362                d = raf.atan(b)363                e = raf.multiply(c, d)364                return e365        model = Model()366        x, _ = randn(shape)367        mod = model._internal(x).mod368        mod = RAFSequential([EnforceSync()])(mod)369        def expected():370            """371            fn (%x: Tensor[(64, 128), float32]) {372                let %set_stream_comp = raf.op.set_stream(int64(0), int64(1));373                let %a1 = raf.op.atan(%x);374                let %a2 = (%a1,);375                let %add_event_comp = raf.op.add_event(int64(1), int64(1));376                let %set_stream_comm = raf.op.set_stream(int64(0), int64(4));377                let %wait_for_comp = raf.op.wait_event(int64(1), int64(4));378                let %a3 = raf.op._allreduce(%a2, str"sum", nullptr);379                let %add_event_comm = raf.op.add_event(int64(2), int64(4));380                let %set_stream_comp1 = raf.op.set_stream(int64(0), int64(1));381                let %wait_for_comm = raf.op.wait_event(int64(2), int64(1));382                let %a4 = raf.op.atan(%a3);383                let %a5 = raf.op.atan(%a3);384                let %a6 = raf.op.multiply(%a4, %a5);385                %a6386            }387            """388            builder = ANFBuilder()389            x = extended_var("x", shape=shape)390            builder.set_stream(0, comp_stream)391            x_1 = builder.call("atan", [x])392            x_2 = builder.make_tuple([x_1])393            builder.add_event(1, comp_stream)394            builder.set_stream(0, comm_stream)395            builder.wait_event(1, comm_stream)396            x_3 = builder.call("_allreduce", [x_2, raf.ir.const("sum"), raf.ir.const(None)])397            builder.add_event(4, comm_stream)398            builder.set_stream(0, comp_stream)399            builder.wait_event(4, comp_stream)400            x_4_a = builder.call("atan", [x_3])401            x_4_b = builder.call("atan", [x_3])402            x_5 = builder.call("multiply", [x_4_a, x_4_b])403            return tvm.relay.Function([x], builder.ret(x_5))404    assert tvm.ir.structural_equal(mod["main"], expected())405    dcfg.enable_data_parallel = False406@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")407@pytest.mark.parametrize(408    "shape,comp_stream,fuse_tensor_stream,defuse_tensor_stream", [[(64, 128), 1, 5, 6]]409)410def test_memory_copy_ops(shape, comp_stream, fuse_tensor_stream, defuse_tensor_stream):411    dcfg = dist.get_config()412    dcfg.enable_data_parallel = True413    size = 1414    for axis in shape:415        size *= axis416    sizes = [size, size]417    tuple_shape = shape + shape418    indices = [len(shape), 2 * len(shape)]419    with Device("cuda(0)"):420        def construct_model_func():421            # x -> atan -> fuse_tensor -> defuse_tensor -> mul422            #    \          /423            #     -> atan ->424            builder = ANFBuilder()425            x = extended_var("x", shape=shape)426            x_0 = builder.call("atan", [x])427            x_1 = builder.call("atan", [x])428            x_2i = builder.make_tuple([x_0, x_1])429            x_2 = builder.call("fuse_tensor", [x_2i])430            x_3 = builder.call(431                "defuse_tensor",432                [x_2, raf.ir.const(sizes), raf.ir.const(tuple_shape), raf.ir.const(indices)],433            )434            x_3_a = builder.get_tuple_item(x_3, 0)435            x_3_b = builder.get_tuple_item(x_3, 1)436            x_4 = builder.call("multiply", [x_3_a, x_3_b])437            return tvm.relay.Function([x], builder.ret(x_4))438        def expected():439            builder = ANFBuilder()440            x = extended_var("x", shape=shape)441            builder.set_stream(0, comp_stream)442            x_0 = builder.call("atan", [x])443            x_1 = builder.call("atan", [x])444            x_2i = builder.make_tuple([x_0, x_1])445            builder.add_event(0, comp_stream)446            builder.set_stream(0, fuse_tensor_stream)447            builder.wait_event(0, fuse_tensor_stream)448            x_2 = builder.call("fuse_tensor", [x_2i])449            builder.add_event(1, fuse_tensor_stream)450            builder.set_stream(0, defuse_tensor_stream)451            builder.wait_event(1, defuse_tensor_stream)452            x_3 = builder.call(453                "defuse_tensor",454                [x_2, raf.ir.const(sizes), raf.ir.const(tuple_shape), raf.ir.const(indices)],455            )456            builder.add_event(2, defuse_tensor_stream)457            builder.set_stream(0, comp_stream)458            builder.wait_event(2, comp_stream)459            x_3_a = builder.get_tuple_item(x_3, 0)460            x_3_b = builder.get_tuple_item(x_3, 1)461            x_4 = builder.call("multiply", [x_3_a, x_3_b])462            return tvm.relay.Function([x], builder.ret(x_4))463        mod = tvm.IRModule()464        mod["main"] = construct_model_func()465        mod = RAFSequential([EnforceSync()])(mod)466    assert tvm.ir.structural_equal(mod["main"], expected())467    dcfg.enable_data_parallel = False468@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")469@pytest.mark.parametrize(470    "shape,comp_stream,comm_stream,fuse_tensor_stream,defuse_tensor_stream",471    [[(64, 128), 1, 4, 5, 6]],472)473def test_dependency_analysis(474    shape, comp_stream, comm_stream, fuse_tensor_stream, defuse_tensor_stream475):476    dcfg = dist.get_config()477    dcfg.enable_data_parallel = True478    size = 1479    for axis in shape:480        size *= axis481    sizes = [size, size]482    tuple_shape = shape + shape483    indices = [len(shape), 2 * len(shape)]484    with Device("cuda(0)"):485        def construct_model_func():486            # order in ANF is shown in parenthesis, synchronization (2)->(6) is not needed487            # without memcpy pipeline, but is required if memcpy pipeline is enabled488            # atan(0) -> atan (1) -> atan (3) -> allreduce (4) -> atan (5) -> concat (7)489            #        \            \                                           /490            #         \-----------> allreduce (2) ------> concat(6) -------> /491            builder = ANFBuilder()492            x = extended_var("x", shape=shape)493            x_0 = builder.call("atan", [x])494            x_1 = builder.call("atan", [x])495            x_2i = builder.make_tuple([x_0, x_1])496            x_2 = builder.call("_allreduce", [x_2i, raf.ir.const("sum")])497            x_3 = builder.call("atan", [x_1])498            x_4i = builder.make_tuple([x_3])499            x_4 = builder.call("_allreduce", [x_4i, raf.ir.const("sum")])500            x_5 = builder.call("atan", [x_4])501            x_6 = builder.call("concatenate", [x_2, raf.ir.const(0)])502            x_6i = builder.make_tuple([x_5, x_6])503            x_7 = builder.call("concatenate", [x_6i, raf.ir.const(0)])504            return tvm.relay.Function([x], builder.ret(x_7))505        def expected():506            builder = ANFBuilder()507            x = extended_var("x", shape=shape)508            builder.set_stream(0, comp_stream)509            x_0 = builder.call("atan", [x])510            x_1 = builder.call("atan", [x])511            x_2i = builder.make_tuple([x_0, x_1])512            builder.add_event(0, comp_stream)513            builder.set_stream(0, fuse_tensor_stream)514            builder.wait_event(0, fuse_tensor_stream)515            x2_fused = builder.call("fuse_tensor", [x_2i])516            builder.add_event(1, fuse_tensor_stream)517            builder.set_stream(0, comm_stream)518            builder.wait_event(1, comm_stream)519            x2_fused_tuple = builder.make_tuple([x2_fused])520            builder.add_event(2, comm_stream)521            builder.set_stream(0, comm_stream)522            builder.wait_event(2, comm_stream)523            x_2_to_defuse = builder.call("_allreduce", [x2_fused_tuple, raf.ir.const("sum")])524            builder.add_event(3, comm_stream)525            builder.set_stream(0, defuse_tensor_stream)526            builder.wait_event(3, defuse_tensor_stream)527            x_2 = builder.call(528                "defuse_tensor",529                [530                    x_2_to_defuse,531                    raf.ir.const(sizes),532                    raf.ir.const(tuple_shape),533                    raf.ir.const(indices),534                ],535            )536            builder.add_event(4, defuse_tensor_stream)537            builder.set_stream(0, comp_stream)538            x_3 = builder.call("atan", [x_1])539            x_4i = builder.make_tuple([x_3])540            builder.add_event(5, comp_stream)541            builder.set_stream(0, comm_stream)542            builder.wait_event(5, comm_stream)543            x_4 = builder.call("_allreduce", [x_4i, raf.ir.const("sum")])544            builder.add_event(6, comm_stream)545            builder.set_stream(0, comp_stream)546            builder.wait_event(6, comp_stream)547            x_5 = builder.call("atan", [x_4])548            builder.wait_event(4, comp_stream)549            x_6 = builder.call("concatenate", [x_2, raf.ir.const(0)])550            x_6i = builder.make_tuple([x_5, x_6])551            x_7 = builder.call("concatenate", [x_6i, raf.ir.const(0)])552            return tvm.relay.Function([x], builder.ret(x_7))553        with PassContext(config={"raf.annotate_collective_ops.use_memory_copy_ops": True}):554            mod = tvm.IRModule()555            mod["main"] = construct_model_func()556            mod = RAFSequential([AnnotateCollectiveOps(), EnforceSync()])(mod)557    assert tvm.ir.structural_equal(mod["main"], expected())558    dcfg.enable_data_parallel = False559if __name__ == "__main__":...test_pass_stream_schedule_wavefront.py
Source:test_pass_stream_schedule_wavefront.py  
...22    def make_tuple(self, fields: List[tvm.relay.Expr]) -> tvm.relay.Var:23        return self.scope_builder.let("", tvm.relay.Tuple(fields))24    def call(self, op_name: str, args: List[tvm.relay.Expr]) -> tvm.relay.Var:25        return self.scope_builder.let("", tvm.relay.Call(self.get_operator(op_name), args))26    def set_stream(self, device_id: int, stream_id: int) -> tvm.relay.Var:27        return self.call("set_stream", [raf.ir.const(device_id), raf.ir.const(stream_id)])28    def add_event(self, event_id: int) -> tvm.relay.Var:29        return self.call("add_event", [raf.ir.const(event_id)])30    def wait_event(self, event_id: int) -> tvm.relay.Var:31        return self.call("wait_event", [raf.ir.const(event_id)])32    def stream_barrier(self) -> tvm.relay.Var:33        return self.call("stream_barrier", [])34    def atan(self, x: tvm.ir.RelayExpr) -> tvm.relay.Var:35        return self.call("atan", [x])36    def concatenate(self, x: tvm.ir.RelayExpr, axis: int) -> tvm.relay.Var:37        return self.call("concatenate", [x, raf.ir.const(axis)])38    def ret(self, body: tvm.relay.Expr) -> tvm.relay.Expr:39        self.scope_builder.ret(body)40        return self.scope_builder.get()41@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")42def test_wavefront_schedule_three_simple_branches():43    class Model(raf.Model):44        # wavefront schedule:45        # wave 146        #   chain 1: op 147        #   chain 2: op 2, op 348        #   chain 3: op 4, op 5, op 649        # wave 250        #   chain 1: op 751        def build(self):52            pass53        @raf.model.trace54        def forward(self, x):55            p_0 = raf.atan(x)  # op 156            p_1 = raf.atan(x)  # op 257            p_1 = raf.atan(p_1)  # op 358            p_2 = raf.atan(x)  # op 459            p_2 = raf.atan(p_2)  # op 560            p_2 = raf.atan(p_2)  # op 661            return raf.concatenate([p_0, p_1, p_2])  # op 762    model = Model()63    input_shape = [2, 2]64    x, _ = randn(input_shape)65    mod = model._internal(x).mod66    with raf.ir.PassContext(opt_level=2, config={"raf.stream_schedule.policy": "wavefront"}):67        mod = RAFSequential([ToGraphNormalForm(), WavefrontStreamSchedule()])(mod)68    def expected():69        """70        def @main(%x: Tensor[(2, 2), float32]) {71          let %x_0 = raf.op.set_stream(int64(0), int64(0));72          let %x_1 = raf.op.atan(%x);73          let %x_2 = raf.op.set_stream(int64(0), int64(1));74          let %x_3 = raf.op.atan(%x);75          let %x_4 = raf.op.atan(%x_3);76          let %x_5 = raf.op.set_stream(int64(0), int64(2));77          let %x_6 = raf.op.atan(%x);78          let %x_7 = raf.op.atan(%x_6);79          let %x_8 = raf.op.atan(%x_7);80          let %x_9 = raf.op.stream_barrier();81          let %x_10 = raf.op.set_stream(int64(0), int64(0));82          let %x_11 = (%x_1, %x_4, %x_8);83          let %x_12 = raf.op.concatenate(%x_11, int64(0));84          %x_1285        }86        """87        sb = ANFBuilder()88        x = extended_var("x", shape=input_shape)89        x_0 = sb.set_stream(0, 0)90        x_1 = sb.atan(x)91        x_2 = sb.set_stream(0, 1)92        x_3 = sb.atan(x)93        x_4 = sb.atan(x_3)94        x_5 = sb.set_stream(0, 2)95        x_6 = sb.atan(x)96        x_7 = sb.atan(x_6)97        x_8 = sb.atan(x_7)98        x_9 = sb.stream_barrier()99        x_10 = sb.set_stream(0, 0)100        x_11 = sb.make_tuple([x_1, x_4, x_8])101        x_12 = sb.concatenate(x_11, 0)102        return tvm.relay.Function([x], sb.ret(x_12))103    # We verify the correctness of the pass by structural_equal here, but it does not check the104    # equivalence of raf's extended constant. See issue #700.105    print(raf.ir.AsText(mod))106    assert tvm.ir.structural_equal(mod["main"], expected())107@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")108def test_wavefront_schedule_branch_in_branch():109    class Model(raf.Model):110        # wavefront schedule111        # wave 1112        #   chain 1: op 1113        #   chain 2: op 2, op 3114        #   chain 3: op 7, op 8, op 9115        # wave 2:116        #   chain 1: op 4117        #   chain 2: op 5118        # wave 3:119        #   chain 1: op 6, op 10120        def build(self):121            pass122        @raf.model.trace123        def forward(self, x):124            p_0 = raf.atan(x)  # op 1125            p_1 = raf.atan(x)  # op 2126            p_1 = raf.atan(p_1)  # op 3127            p_1a = raf.atan(p_1)  # op 4128            p_1b = raf.atan(p_1)  # op 5129            p_1 = raf.concatenate([p_1a, p_1b])  # op 6130            p_2 = raf.atan(x)  # op 7131            p_2 = raf.atan(p_2)  # op 8132            p_2 = raf.atan(p_2)  # op 9133            return raf.concatenate([p_0, p_1, p_2])  # op 10134    model = Model()135    input_shape = [2, 2]136    x, _ = randn(input_shape)137    mod = model._internal(x).mod138    with raf.ir.PassContext(opt_level=2, config={"raf.stream_schedule.policy": "wavefront"}):139        mod = RAFSequential([ToGraphNormalForm(), WavefrontStreamSchedule()])(mod)140    def expected():141        """142        def @main(%x: Tensor[(2, 2), float32]) {143          let %x_0 = raf.op.set_stream(int64(0), int64(0));144          let %x_1 = raf.op.atan(%x);145          let %x_2 = raf.op.set_stream(int64(0), int64(1));146          let %x_3 = raf.op.atan(%x);147          let %x_4 = raf.op.atan(%x_3);148          let %x_5 = raf.op.set_stream(int64(0), int64(2));149          let %x_6 = raf.op.atan(%x);150          let %x_7 = raf.op.atan(%x_6);151          let %x_8 = raf.op.atan(%x_7);152          let %x_9 = raf.op.stream_barrier();153          let %x_10 = raf.op.set_stream(int64(0), int64(0));154          let %x_11 = raf.op.atan(%x_4);155          let %x_12 = raf.op.set_stream(int64(0), int64(1));156          let %x_13 = raf.op.atan(%x_4);157          let %x_14 = raf.op.stream_barrier();158          let %x_15 = raf.op.set_stream(int64(0), int64(0));159          let %x_16 = (%x_11, %x_13);160          let %x_17 = raf.op.concatenate(%x_16, int64(0));161          let %x_18 = (%x_1, %x_17, %x_8);162          let %x_19 = raf.op.concatenate(%x_18, int64(0));163          %x_19164        }165        """166        sb = ANFBuilder()167        x = extended_var("x", shape=input_shape)168        x_0 = sb.set_stream(0, 0)169        x_1 = sb.atan(x)170        x_2 = sb.set_stream(0, 1)171        x_3 = sb.atan(x)172        x_4 = sb.atan(x_3)173        x_5 = sb.set_stream(0, 2)174        x_6 = sb.atan(x)175        x_7 = sb.atan(x_6)176        x_8 = sb.atan(x_7)177        x_9 = sb.stream_barrier()178        x_10 = sb.set_stream(0, 0)179        x_11 = sb.atan(x_4)180        x_12 = sb.set_stream(0, 1)181        x_13 = sb.atan(x_4)182        x_14 = sb.stream_barrier()183        x_15 = sb.set_stream(0, 0)184        x_16 = sb.make_tuple([x_11, x_13])185        x_17 = sb.concatenate(x_16, 0)186        x_18 = sb.make_tuple([x_1, x_17, x_8])187        x_19 = sb.concatenate(x_18, 0)188        return tvm.relay.Function([x], sb.ret(x_19))189    assert tvm.ir.structural_equal(mod["main"], expected())190@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled")191def test_wavefront_schedule_stacked_blocks():192    class Model(raf.Model):193        # wavefront schedule194        # wave 1195        #   chain 1: op 1196        #   chain 2: op 2,197        #   chain 3: op 3, op 4198        # wave 2:199        #   chain 1: op 5200        # wave 3:201        #   chain 1: op 6202        #   chain 2: op 7203        #   chain 3: op 8, op 9204        # wave 4:205        #   chain 1: op 10206        def build(self):207            pass208        @raf.model.trace209        def forward(self, x):210            p_0 = raf.atan(x)  # op 1211            p_1 = raf.atan(x)  # op 2212            p_2 = raf.atan(x)  # op 3213            p_2 = raf.atan(p_2)  # op 4214            x = raf.concatenate([p_0, p_1, p_2])  # op 5215            p_0 = raf.atan(x)  # op 6216            p_1 = raf.atan(x)  # op 7217            p_2 = raf.atan(x)  # op 8218            p_2 = raf.atan(p_2)  # op 9219            return raf.concatenate([p_0, p_1, p_2])  # op 10220    model = Model()221    input_shape = [2, 2]222    x, _ = randn(input_shape)223    mod = model._internal(x).mod224    with raf.ir.PassContext(opt_level=2, config={"raf.stream_schedule.policy": "wavefront"}):225        mod = RAFSequential([ToGraphNormalForm(), WavefrontStreamSchedule()])(mod)226    def expected():227        """228        def @main(%x: Tensor[(2, 2), float32]) {229          let %x_0 = raf.op.set_stream(int64(0), int64(0));230          let %x_1 = raf.op.atan(%x);231          let %x_2 = raf.op.set_stream(int64(0), int64(1));232          let %x_3 = raf.op.atan(%x);233          let %x_4 = raf.op.set_stream(int64(0), int64(2));234          let %x_5 = raf.op.atan(%x);235          let %x_6 = raf.op.atan(%x_5);236          let %x_7 = raf.op.stream_barrier();237          let %x_8 = raf.op.set_stream(int64(0), int64(0));238          let %x_9 = (%x_1, %x_3, %x_6);239          let %x_10 = raf.op.concatenate(%x_9, int64(0));240          let %x_11 = raf.op.stream_barrier();241          let %x_12 = raf.op.set_stream(int64(0), int64(0));242          let %x_13 = raf.op.atan(%x_10);243          let %x_14 = raf.op.set_stream(int64(0), int64(1));244          let %x_15 = raf.op.atan(%x_10);245          let %x_16 = raf.op.set_stream(int64(0), int64(2));246          let %x_17 = raf.op.atan(%x_10);247          let %x_18 = raf.op.atan(%x_17);248          let %x_19 = raf.op.stream_barrier();249          let %x_20 = raf.op.set_stream(int64(0), int64(0));250          let %x_21 = (%x_13, %x_15, %x_18);251          let %x_22 = raf.op.concatenate(%x_21, int64(0));252          %x_22253        }254        """255        sb = ANFBuilder()256        x = extended_var("x", shape=input_shape)257        x_0 = sb.set_stream(0, 0)258        x_1 = sb.atan(x)259        x_2 = sb.set_stream(0, 1)260        x_3 = sb.atan(x)261        x_4 = sb.set_stream(0, 2)262        x_5 = sb.atan(x)263        x_6 = sb.atan(x_5)264        x_7 = sb.stream_barrier()265        x_8 = sb.set_stream(0, 0)266        x_9 = sb.make_tuple([x_1, x_3, x_6])267        x_10 = sb.concatenate(x_9, 0)268        x_11 = sb.stream_barrier()269        x_12 = sb.set_stream(0, 0)270        x_13 = sb.atan(x_10)271        x_14 = sb.set_stream(0, 1)272        x_15 = sb.atan(x_10)273        x_16 = sb.set_stream(0, 2)274        x_17 = sb.atan(x_10)275        x_18 = sb.atan(x_17)276        x_19 = sb.stream_barrier()277        x_20 = sb.set_stream(0, 0)278        x_21 = sb.make_tuple([x_13, x_15, x_18])279        x_22 = sb.concatenate(x_21, 0)280        return tvm.relay.Function([x], sb.ret(x_22))281    assert tvm.ir.structural_equal(mod["main"], expected())282if __name__ == "__main__":...test_schedule_verifier.py
Source:test_schedule_verifier.py  
...22    def make_tuple(self, fields: List[tvm.relay.Expr]) -> tvm.relay.Var:23        return self.scope_builder.let("", tvm.relay.Tuple(fields))24    def call(self, op_name: str, args: List[tvm.relay.Expr]) -> tvm.relay.Var:25        return self.scope_builder.let("", tvm.relay.Call(self.get_operator(op_name), args))26    def set_stream(self, device_id: int, stream_id: int) -> tvm.relay.Var:27        return self.call("set_stream", [raf.ir.const(device_id), raf.ir.const(stream_id)])28    def add_event(self, event_id: int) -> tvm.relay.Var:29        return self.call("add_event", [raf.ir.const(event_id)])30    def wait_event(self, event_id: int) -> tvm.relay.Var:31        return self.call("wait_event", [raf.ir.const(event_id)])32    def stream_barrier(self) -> tvm.relay.Var:33        return self.call("stream_barrier", [])34    def atan(self, x: tvm.ir.RelayExpr) -> tvm.relay.Var:35        return self.call("atan", [x])36    def concatenate(self, x: tvm.ir.RelayExpr, axis: int) -> tvm.relay.Var:37        return self.call("concatenate", [x, raf.ir.const(axis)])38    def ret(self, body: tvm.relay.Expr) -> tvm.relay.Expr:39        self.scope_builder.ret(body)40        return self.scope_builder.get()41@pytest.mark.parametrize("removed_events", [[], [0], [1], [0, 1]])42def test_simple_branches_event(removed_events):43    def scheduled_func():44        sb = ANFBuilder()45        x = extended_var("x", shape=[2, 2])46        x_0 = sb.set_stream(0, 0)47        x_1 = sb.atan(x)48        x_2 = sb.atan(x_1)49        x_3 = sb.atan(x_2)50        if 0 not in removed_events:51            x_4 = sb.add_event(0)52        x_5 = sb.set_stream(0, 1)53        x_6 = sb.atan(x)54        x_7 = sb.atan(x_6)55        if 1 not in removed_events:56            x_8 = sb.add_event(1)57        x_9 = sb.set_stream(0, 2)58        x_10 = sb.atan(x)59        if 1 not in removed_events:60            x_11 = sb.wait_event(1)61        if 0 not in removed_events:62            x_12 = sb.wait_event(0)63        x_13 = sb.make_tuple([x_10, x_7, x_3])64        x_14 = sb.concatenate(x_13, 0)65        return tvm.relay.Function([x], sb.ret(x_14))66    func = scheduled_func()67    # Please uncomment the following line to draw the scheduled dataflow graph68    # draw_dataflow_graph(func, f"./graphs/simple_branches_remove_events_{removed_events}.png",69    #                     draw_event_nodes=True)70    if len(removed_events) > 0:71        with pytest.raises(ExecutionOrderError):72            verify_schedule(func)73    else:74        verify_schedule(func)75@pytest.mark.parametrize("removed_events", [[], [0], [1], [2], [3], [4], [0, 1, 2, 3, 4]])76def test_stacked_blocks_event(removed_events):77    def scheduled_func():78        sb = ANFBuilder()79        x = extended_var("x", shape=[2, 2])80        x_0 = sb.set_stream(0, 0)81        x_1 = sb.atan(x)  # atan 282        x_2 = sb.atan(x_1)  # atan 383        if 0 not in removed_events:84            x_3 = sb.add_event(0)85        x_4 = sb.set_stream(0, 1)86        x_5 = sb.atan(x)  # atan 187        if 1 not in removed_events:88            x_6 = sb.add_event(1)89        x_7 = sb.set_stream(0, 2)90        x_8 = sb.atan(x)  # atan 091        if 1 not in removed_events:92            x_9 = sb.wait_event(1)93        if 0 not in removed_events:94            x_10 = sb.wait_event(0)95        x_11 = sb.make_tuple([x_8, x_5, x_2])96        x_12 = sb.concatenate(x_11, 0)  # concat 097        if 2 not in removed_events:98            x_13 = sb.add_event(2)99        x_14 = sb.atan(x_12)  # atan 6100        x_15 = sb.atan(x_14)  # atan 7101        if 3 not in removed_events:102            x_16 = sb.add_event(3)103        x_17 = sb.set_stream(0, 1)104        if 2 not in removed_events:105            x_18 = sb.wait_event(2)106        x_19 = sb.atan(x_12)  # atan 5107        if 4 not in removed_events:108            x_20 = sb.add_event(4)109        x_21 = sb.set_stream(0, 0)110        if 2 not in removed_events:111            x_22 = sb.wait_event(2)112        x_23 = sb.atan(x_12)  # atan 4113        if 4 not in removed_events:114            x_24 = sb.wait_event(4)115        if 3 not in removed_events:116            x_25 = sb.wait_event(3)117        x_26 = sb.make_tuple([x_23, x_19, x_15])118        x_27 = sb.concatenate(x_26, 0)  # concat 1119        return tvm.relay.Function([x], sb.ret(x_27))120    func = scheduled_func()121    # Please uncomment the following line to draw the scheduled dataflow graph122    # draw_dataflow_graph(func, f"./graphs/stacked_blocks_remove_events_{removed_events}.png",123    #                     draw_event_nodes=True)124    if len(removed_events) > 0:125        with pytest.raises(ExecutionOrderError):126            verify_schedule(func)127    else:128        verify_schedule(func)129@pytest.mark.parametrize("removed_barriers", [[], [0]])130def test_simple_branches_barrier(removed_barriers):131    def scheduled_func():132        sb = ANFBuilder()133        x = extended_var("x", shape=[2, 2])134        x_0 = sb.set_stream(0, 2)135        x_1 = sb.atan(x)136        x_2 = sb.atan(x_1)137        x_3 = sb.atan(x_2)138        x_4 = sb.set_stream(0, 1)139        x_5 = sb.atan(x)140        x_6 = sb.atan(x_5)141        x_7 = sb.set_stream(0, 0)142        x_8 = sb.atan(x)143        if 0 not in removed_barriers:144            x_9 = sb.stream_barrier()145        x_10 = sb.make_tuple([x_8, x_6, x_3])146        x_11 = sb.concatenate(x_10, 0)147        return tvm.relay.Function([x], sb.ret(x_11))148    func = scheduled_func()149    if len(removed_barriers) > 0:150        with pytest.raises(ExecutionOrderError):151            verify_schedule(func)152    else:153        verify_schedule(func)154@pytest.mark.parametrize("removed_barriers", [[], [0], [1], [2], [0, 1, 2]])155def test_stacked_blocks_barrier(removed_barriers):156    def scheduled_func():157        sb = ANFBuilder()158        x = extended_var("x", shape=[2, 2])159        x_0 = sb.set_stream(0, 2)160        x_1 = sb.atan(x)161        x_2 = sb.atan(x_1)162        x_3 = sb.set_stream(0, 1)163        x_4 = sb.atan(x)164        x_5 = sb.set_stream(0, 0)165        x_6 = sb.atan(x)166        if 0 not in removed_barriers:167            x_7 = sb.stream_barrier()168        x_8 = sb.make_tuple([x_6, x_4, x_2])169        x_9 = sb.concatenate(x_8, 0)170        if 1 not in removed_barriers:171            x_10 = sb.stream_barrier()172        x_11 = sb.set_stream(0, 2)173        x_12 = sb.atan(x_9)174        x_13 = sb.atan(x_12)175        x_14 = sb.set_stream(0, 1)176        x_15 = sb.atan(x_9)177        x_16 = sb.set_stream(0, 0)178        x_17 = sb.atan(x_9)179        if 2 not in removed_barriers:180            x_18 = sb.stream_barrier()181        x_19 = sb.make_tuple([x_17, x_15, x_13])182        x_20 = sb.concatenate(x_19, 0)183        return tvm.relay.Function([x], sb.ret(x_20))184    func = scheduled_func()185    if len(removed_barriers) > 0:186        with pytest.raises(ExecutionOrderError):187            verify_schedule(func)188    else:189        verify_schedule(func)190@pytest.mark.parametrize("removed_barriers", [[], [0]])191def test_chain_to_another_chain_barrier(removed_barriers):192    def scheduled_func():193        sb = ANFBuilder()194        x = extended_var("x", shape=[2, 2])195        x_0 = sb.set_stream(0, 0)196        x_1 = sb.atan(x)197        x_2 = sb.atan(x_1)198        if 0 not in removed_barriers:199            x_3 = sb.stream_barrier()200        x_4 = sb.set_stream(0, 1)201        x_5 = sb.atan(x_2)202        return tvm.relay.Function([x], sb.ret(x_5))203    func = scheduled_func()204    if len(removed_barriers) > 0:205        with pytest.raises(ExecutionOrderError):206            verify_schedule(func)207    else:208        verify_schedule(func)209if __name__ == "__main__":...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
