Best Python code snippet using pandera_python
debugging_primitives_test.py
Source:debugging_primitives_test.py  
...36except ModuleNotFoundError:37  rich = None38config.parse_flags_with_absl()39debug_print = debugging.debug_print40def _format_multiline(text):41  return textwrap.dedent(text).lstrip()42prev_xla_flags = None43def setUpModule():44  global prev_xla_flags45  # This will control the CPU devices. On TPU we always have 2 devices46  prev_xla_flags = jtu.set_host_platform_device_count(2)47# Reset to previous configuration in case other test modules will be run.48def tearDownModule():49  prev_xla_flags()50# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum51#                 version is >= 0.3.1552disabled_backends = []53if jaxlib.version < (0, 3, 15):54  disabled_backends.append("tpu")55class DummyDevice:56  def __init__(self, platform, id):57    self.platform = platform58    self.id = id59class DebugPrintTest(jtu.JaxTestCase):60  def tearDown(self):61    super().tearDown()62    dispatch.runtime_tokens.clear()63  @jtu.skip_on_devices(*disabled_backends)64  def test_simple_debug_print_works_in_eager_mode(self):65    def f(x):66      debug_print('x: {}', x)67    with jtu.capture_stdout() as output:68      f(2)69      jax.effects_barrier()70    self.assertEqual(output(), "x: 2\n")71  @jtu.skip_on_devices(*disabled_backends)72  def test_debug_print_works_with_named_format_strings(self):73    def f(x):74      debug_print('x: {x}', x=x)75    with jtu.capture_stdout() as output:76      f(2)77      jax.effects_barrier()78    self.assertEqual(output(), "x: 2\n")79  @jtu.skip_on_devices(*disabled_backends)80  def test_multiple_debug_prints_should_print_multiple_values(self):81    def f(x):82      debug_print('x: {x}', x=x)83      debug_print('y: {y}', y=x + 1)84    with jtu.capture_stdout() as output:85      f(2)86      jax.effects_barrier()87    self.assertEqual(output(), "x: 2\ny: 3\n")88  @jtu.skip_on_devices(*disabled_backends)89  def test_can_stage_out_debug_print(self):90    @jax.jit91    def f(x):92      debug_print('x: {x}', x=x)93    with jtu.capture_stdout() as output:94      f(2)95      jax.effects_barrier()96    self.assertEqual(output(), "x: 2\n")97  @jtu.skip_on_devices(*disabled_backends)98  def test_can_stage_out_debug_print_with_donate_argnums(self):99    if jax.default_backend() not in {"gpu", "tpu"}:100      raise unittest.SkipTest("Donate argnums not supported.")101    def f(x, y):102      debug_print('x: {x}', x=x)103      return x + y104    f = jax.jit(f, donate_argnums=0)105    with jtu.capture_stdout() as output:106      f(2, 3)107      jax.effects_barrier()108    self.assertEqual(output(), "x: 2\n")109  @jtu.skip_on_devices(*disabled_backends)110  def test_can_stage_out_ordered_print(self):111    @jax.jit112    def f(x):113      debug_print('x: {x}', x=x, ordered=True)114    with jtu.capture_stdout() as output:115      f(2)116      jax.effects_barrier()117    self.assertEqual(output(), "x: 2\n")118  @jtu.skip_on_devices(*disabled_backends)119  def test_can_stage_out_ordered_print_with_donate_argnums(self):120    if jax.default_backend() not in {"gpu", "tpu"}:121      raise unittest.SkipTest("Donate argnums not supported.")122    def f(x, y):123      debug_print('x: {x}', x=x, ordered=True)124      return x + y125    f = jax.jit(f, donate_argnums=0)126    with jtu.capture_stdout() as output:127      f(2, 3)128      jax.effects_barrier()129    self.assertEqual(output(), "x: 2\n")130  @jtu.skip_on_devices(*disabled_backends)131  def test_can_stage_out_prints_with_donate_argnums(self):132    if jax.default_backend() not in {"gpu", "tpu"}:133      raise unittest.SkipTest("Donate argnums not supported.")134    def f(x, y):135      debug_print('x: {x}', x=x, ordered=True)136      debug_print('x: {x}', x=x)137      return x + y138    f = jax.jit(f, donate_argnums=0)139    with jtu.capture_stdout() as output:140      f(2, 3)141      jax.effects_barrier()142    self.assertEqual(output(), "x: 2\nx: 2\n")143  @jtu.skip_on_devices(*disabled_backends)144  def test_can_double_stage_out_ordered_print(self):145    @jax.jit146    @jax.jit147    def f(x):148      debug_print('x: {x}', x=x, ordered=True)149    with jtu.capture_stdout() as output:150      f(2)151      jax.effects_barrier()152    self.assertEqual(output(), "x: 2\n")153  @jtu.skip_on_devices(*disabled_backends)154  def test_can_stage_out_ordered_print_with_pytree(self):155    @jax.jit156    def f(x):157      struct = dict(foo=x)158      debug_print('x: {}', struct, ordered=True)159    with jtu.capture_stdout() as output:160      f(np.array(2, np.int32))161      jax.effects_barrier()162    self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")163class DebugPrintTransformationTest(jtu.JaxTestCase):164  def test_debug_print_batching(self):165    @jax.vmap166    def f(x):167      debug_print('hello: {}', x)168    with jtu.capture_stdout() as output:169      f(jnp.arange(2))170      jax.effects_barrier()171    self.assertEqual(output(), "hello: 0\nhello: 1\n")172  def test_debug_print_batching_with_diff_axes(self):173    @functools.partial(jax.vmap, in_axes=(0, 1))174    def f(x, y):175      debug_print('hello: {} {}', x, y)176    with jtu.capture_stdout() as output:177      f(jnp.arange(2), jnp.arange(2)[None])178      jax.effects_barrier()179    self.assertEqual(output(), "hello: 0 [0]\nhello: 1 [1]\n")180  def tested_debug_print_with_nested_vmap(self):181    def f(x):182      debug_print('hello: {}', x)183    # Call with184    # [[0, 1],185    #  [2, 3],186    #  [4, 5]]187    with jtu.capture_stdout() as output:188      # Should print over 0-axis then 1-axis189      jax.vmap(jax.vmap(f))(jnp.arange(6).reshape((3, 2)))190      jax.effects_barrier()191    self.assertEqual(192        output(),193        "hello: 0\nhello: 2\nhello: 4\nhello: 1\nhello: 3\nhello: 5\n")194    with jtu.capture_stdout() as output:195      # Should print over 1-axis then 0-axis196      jax.vmap(jax.vmap(f, in_axes=0), in_axes=1)(jnp.arange(6).reshape((3, 2)))197      jax.effects_barrier()198    self.assertEqual(199        output(),200        "hello: 0\nhello: 1\nhello: 2\nhello: 3\nhello: 4\nhello: 5\n")201  def test_debug_print_jvp_rule(self):202    def f(x):203      debug_print('x: {}', x)204    with jtu.capture_stdout() as output:205      jax.jvp(f, (1.,), (1.,))206      jax.effects_barrier()207    self.assertEqual(output(), "x: 1.0\n")208  def test_debug_print_vjp_rule(self):209    def f(x):210      debug_print('x: {}', x)211    with jtu.capture_stdout() as output:212      jax.vjp(f, 1.)213      jax.effects_barrier()214    self.assertEqual(output(), "x: 1.0\n")215  def test_debug_print_in_custom_jvp(self):216    @jax.custom_jvp217    def print_tangent(x):218      return x219    @print_tangent.defjvp220    def _(primals, tangents):221      (x,), (t,) = primals, tangents222      debug_print("x_tangent: {}", t)223      return x, t224    def f(x):225      x = jnp.sin(x)226      x = print_tangent(x)227      return x228    with jtu.capture_stdout() as output:229      x = jnp.array(1., jnp.float32)230      jax.jvp(f, (x,), (x,))231      jax.effects_barrier()232    expected = jnp.cos(jnp.array(1., jnp.float32))233    self.assertEqual(output(), f"x_tangent: {expected}\n")234  @unittest.skip("doesn't work yet!")  # TODO(mattjj,sharadmv)235  def test_debug_print_in_custom_jvp_linearize(self):236    @jax.custom_jvp237    def print_tangent(x):238      return x239    @print_tangent.defjvp240    def _(primals, tangents):241      (x,), (t,) = primals, tangents242      debug_print("x_tangent: {}", t)243      return x, t244    def f(x):245      x = jnp.sin(x)246      x = print_tangent(x)247      return x248    with jtu.capture_stdout() as output:249      x = jnp.array(1., jnp.float32)250      y, f_lin = jax.linearize(f, x)251      jax.effects_barrier()252    self.assertEqual(output(), "")253    with jtu.capture_stdout() as output:254      _ = f_lin(x)255      jax.effects_barrier()256    expected = jnp.cos(jnp.array(1., jnp.float32))257    self.assertEqual(output(), f"x_tangent: {expected}\n")258  def test_debug_print_grad_with_custom_vjp_rule(self):259    @jax.custom_vjp260    def print_grad(x):261      return x262    def print_grad_fwd(x):263      return x, None264    def print_grad_bwd(_, x_grad):265      debug_print("x_grad: {}", x_grad)266      return (x_grad,)267    print_grad.defvjp(print_grad_fwd, print_grad_bwd)268    def f(x):269      debug_print("x: {}", x)270      x = print_grad(x)271      return jnp.sin(x)272    with jtu.capture_stdout() as output:273      jax.grad(f)(jnp.array(1., jnp.float32))274      jax.effects_barrier()275    expected = jnp.cos(jnp.array(1., jnp.float32))276    self.assertEqual(output(), f"x: 1.0\nx_grad: {expected}\n")277  def test_debug_print_transpose_rule(self):278    def f(x):279      debug_print('should never be called: {}', x)280      return x281    with jtu.capture_stdout() as output:282      jax.linear_transpose(f, 1.)(1.)283      jax.effects_barrier()284    # `debug_print` should be dropped by `partial_eval` because of no285    # output data-dependence.286    self.assertEqual(output(), "")287  @parameterized.named_parameters(jtu.cases_from_list(288    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)289         for ordered in [False, True]))290  def test_remat_of_debug_print(self, ordered):291    def f_(x):292      y = ad_checkpoint.checkpoint_name(x + 1., "y")293      z = ad_checkpoint.checkpoint_name(y * 2., "z")294      debug_print('y: {}, z: {}', y, z, ordered=ordered)295      return ad_checkpoint.checkpoint_name(jnp.exp(z), "w")296    # Policy that saves everything so the debug callback will be saved297    f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.everything_saveable)298    with jtu.capture_stdout() as output:299      jax.grad(f)(2.)300      jax.effects_barrier()301    # We expect the print to happen once since it gets saved and isn't302    # rematerialized.303    self.assertEqual(output(), "y: 3.0, z: 6.0\n")304    # Policy that saves nothing so everything gets rematerialized, including the305    # debug callback306    f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.nothing_saveable)307    with jtu.capture_stdout() as output:308      jax.grad(f)(2.)309      jax.effects_barrier()310    # We expect the print to happen twice since it is rematerialized.311    self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)312    # Policy that does not save `z` so we will need to rematerialize the print313    f = ad_checkpoint.checkpoint(314        f_, policy=ad_checkpoint.save_any_names_but_these("z"))315    with jtu.capture_stdout() as output:316      jax.grad(f)(2.)317      jax.effects_barrier()318    # We expect the print to happen twice since it is rematerialized.319    self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)320    def save_everything_but_these_names(*names_not_to_save):321      names_not_to_save = frozenset(names_not_to_save)322      def policy(prim, *_, **params):323        if prim is ad_checkpoint.name_p:324          return params['name'] not in names_not_to_save325        return True # Save everything else326      return policy327    # Policy that saves everything but `y`328    f = ad_checkpoint.checkpoint(329        f_, policy=save_everything_but_these_names("y"))330    with jtu.capture_stdout() as output:331      jax.grad(f)(2.)332      jax.effects_barrier()333    # We expect the print to happen once because `y` is not rematerialized and334    # we won't do extra materialization.335    self.assertEqual(output(), "y: 3.0, z: 6.0\n")336    # Policy that saves everything but `y` and `z`337    f = ad_checkpoint.checkpoint(338        f_, policy=save_everything_but_these_names("y", "z"))339    with jtu.capture_stdout() as output:340      jax.grad(f)(2.)341      jax.effects_barrier()342    # We expect the print to happen twice because both `y` and `z` have been343    # rematerialized and we don't have to do any extra rematerialization to344    # print.345    self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)346  @jtu.skip_on_devices(*disabled_backends)347  def test_debug_print_in_staged_out_custom_jvp(self):348    @jax.jit349    def f(x):350      @jax.custom_jvp351      def g(x):352        debug_print("hello: {x}", x=x)353        return x354      def g_jvp(primals, tangents):355        (x,), (t,) = primals, tangents356        debug_print("goodbye: {x} {t}", x=x, t=t)357        return x, t358      g.defjvp(g_jvp)359      return g(x)360    with jtu.capture_stdout() as output:361      f(2.)362      jax.effects_barrier()363    self.assertEqual(output(), "hello: 2.0\n")364    with jtu.capture_stdout() as output:365      jax.jvp(f, (2.,), (3.,))366      jax.effects_barrier()367    self.assertEqual(output(), "goodbye: 2.0 3.0\n")368  @jtu.skip_on_devices(*disabled_backends)369  def test_debug_print_in_staged_out_custom_vjp(self):370    @jax.jit371    def f(x):372      @jax.custom_vjp373      def g(x):374        debug_print("hello: {x}", x=x)375        return x376      def g_fwd(x):377        debug_print("hello fwd: {x}", x=x)378        return x, x379      def g_bwd(x, g):380        debug_print("hello bwd: {x} {g}", x=x, g=g)381        return (g,)382      g.defvjp(fwd=g_fwd, bwd=g_bwd)383      return g(x)384    with jtu.capture_stdout() as output:385      f(2.)386      jax.effects_barrier()387    self.assertEqual(output(), "hello: 2.0\n")388    with jtu.capture_stdout() as output:389      _, f_vjp = jax.vjp(f, 2.)390      jax.effects_barrier()391    self.assertEqual(output(), "hello fwd: 2.0\n")392    with jtu.capture_stdout() as output:393      f_vjp(3.0)394      jax.effects_barrier()395    self.assertEqual(output(), "hello bwd: 2.0 3.0\n")396class DebugPrintControlFlowTest(jtu.JaxTestCase):397  def _assertLinesEqual(self, text1, text2):398    def _count(lines):399      return collections.Counter(lines)400    self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n")))401  @parameterized.named_parameters(jtu.cases_from_list(402    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)403         for ordered in [False, True]))404  @jtu.skip_on_devices(*disabled_backends)405  def test_can_print_inside_scan(self, ordered):406    def f(xs):407      def _body(carry, x):408        debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered)409        return carry + 1, x + 1410      return lax.scan(_body, 2, xs)411    with jtu.capture_stdout() as output:412      f(jnp.arange(2))413      jax.effects_barrier()414    self.assertEqual(415        output(),416        _format_multiline("""417      carry: 2, x: 0418      carry: 3, x: 1419      """))420  @parameterized.named_parameters(jtu.cases_from_list(421    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)422         for ordered in [False, True]))423  @jtu.skip_on_devices(*disabled_backends)424  def test_can_print_inside_for_loop(self, ordered):425    def f(x):426      def _body(i, x):427        debug_print("i: {i}", i=i, ordered=ordered)428        debug_print("x: {x}", x=x, ordered=ordered)429        return x + 1430      return lax.fori_loop(0, 5, _body, x)431    with jtu.capture_stdout() as output:432      f(2)433      jax.effects_barrier()434    expected = _format_multiline("""435      i: 0436      x: 2437      i: 1438      x: 3439      i: 2440      x: 4441      i: 3442      x: 5443      i: 4444      x: 6445      """)446    if ordered:447      self.assertEqual(output(), expected)448    else:449      self._assertLinesEqual(output(), expected)450  @parameterized.named_parameters(jtu.cases_from_list(451    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)452         for ordered in [False, True]))453  @jtu.skip_on_devices(*disabled_backends)454  def test_can_print_inside_while_loop_body(self, ordered):455    def f(x):456      def _cond(x):457        return x < 10458      def _body(x):459        debug_print("x: {x}", x=x, ordered=ordered)460        return x + 1461      return lax.while_loop(_cond, _body, x)462    with jtu.capture_stdout() as output:463      f(5)464      jax.effects_barrier()465    self.assertEqual(output(), _format_multiline("""466      x: 5467      x: 6468      x: 7469      x: 8470      x: 9471      """))472  @parameterized.named_parameters(jtu.cases_from_list(473    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)474         for ordered in [False, True]))475  @jtu.skip_on_devices(*disabled_backends)476  def test_can_print_inside_while_loop_cond(self, ordered):477    def f(x):478      def _cond(x):479        debug_print("x: {x}", x=x, ordered=ordered)480        return x < 10481      def _body(x):482        return x + 1483      return lax.while_loop(_cond, _body, x)484    with jtu.capture_stdout() as output:485      f(5)486      jax.effects_barrier()487    self.assertEqual(output(), _format_multiline("""488      x: 5489      x: 6490      x: 7491      x: 8492      x: 9493      x: 10494      """))495    with jtu.capture_stdout() as output:496      f(10)497      jax.effects_barrier()498    # Should run the cond once499    self.assertEqual(output(), _format_multiline("""500      x: 10501      """))502  @parameterized.named_parameters(jtu.cases_from_list(503    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)504         for ordered in [False, True]))505  @jtu.skip_on_devices(*disabled_backends)506  def test_can_print_in_batched_while_cond(self, ordered):507    def f(x):508      def _cond(x):509        debug_print("x: {x}", x=x, ordered=ordered)510        return x < 5511      def _body(x):512        return x + 1513      return lax.while_loop(_cond, _body, x)514    with jtu.capture_stdout() as output:515      jax.vmap(f)(jnp.arange(2))516      jax.effects_barrier()517    if ordered:518      expected = _format_multiline("""519      x: 0520      x: 1521      x: 1522      x: 2523      x: 2524      x: 3525      x: 3526      x: 4527      x: 4528      x: 5529      x: 5530      x: 6531      """)532      self.assertEqual(output(), expected)533    else:534      # When the print is unordered, the `cond` is called an additional time535      # after the `_body` runs, so we get more prints.536      expected = _format_multiline("""537      x: 0538      x: 1539      x: 0540      x: 1541      x: 1542      x: 2543      x: 1544      x: 2545      x: 2546      x: 3547      x: 2548      x: 3549      x: 3550      x: 4551      x: 3552      x: 4553      x: 4554      x: 5555      x: 4556      x: 5557      x: 5558      x: 5559      """)560      self._assertLinesEqual(output(), expected)561  @parameterized.named_parameters(jtu.cases_from_list(562    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)563         for ordered in [False, True]))564  @jtu.skip_on_devices(*disabled_backends)565  def test_can_print_inside_cond(self, ordered):566    def f(x):567      def true_fun(x):568        debug_print("true: {}", x, ordered=ordered)569        return x570      def false_fun(x):571        debug_print("false: {}", x, ordered=ordered)572        return x573      return lax.cond(x < 5, true_fun, false_fun, x)574    with jtu.capture_stdout() as output:575      f(5)576      jax.effects_barrier()577    self.assertEqual(output(), _format_multiline("""578      false: 5579      """))580    with jtu.capture_stdout() as output:581      f(4)582      jax.effects_barrier()583    self.assertEqual(output(), _format_multiline("""584      true: 4585      """))586  @parameterized.named_parameters(jtu.cases_from_list(587    dict(testcase_name="_ordered" if ordered else "", ordered=ordered)588         for ordered in [False, True]))589  @jtu.skip_on_devices(*disabled_backends)590  def test_can_print_inside_switch(self, ordered):591    def f(x):592      def b1(x):593        debug_print("b1: {}", x, ordered=ordered)594        return x595      def b2(x):596        debug_print("b2: {}", x, ordered=ordered)597        return x598      def b3(x):599        debug_print("b3: {}", x, ordered=ordered)600        return x601      return lax.switch(x, (b1, b2, b3), x)602    with jtu.capture_stdout() as output:603      f(0)604      jax.effects_barrier()605    self.assertEqual(output(), _format_multiline("""606      b1: 0607      """))608    with jtu.capture_stdout() as output:609      f(1)610      jax.effects_barrier()611    self.assertEqual(output(), _format_multiline("""612      b2: 1613      """))614    with jtu.capture_stdout() as output:615      f(2)616      jax.effects_barrier()617    self.assertEqual(output(), _format_multiline("""618      b3: 2619      """))620class DebugPrintParallelTest(jtu.JaxTestCase):621  def _assertLinesEqual(self, text1, text2):622    def _count(lines):623      return collections.Counter(lines)624    self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n")))625  @jtu.skip_on_devices(*disabled_backends)626  def test_ordered_print_not_supported_in_pmap(self):627    @jax.pmap628    def f(x):629      debug_print("{}", x, ordered=True)630    with self.assertRaisesRegex(631        ValueError, "Ordered effects not supported in `pmap`."):632      f(jnp.arange(jax.local_device_count()))633  @jtu.skip_on_devices(*disabled_backends)634  def test_unordered_print_works_in_pmap(self):635    if jax.device_count() < 2:636      raise unittest.SkipTest("Test requires >= 2 devices.")637    @jax.pmap638    def f(x):639      debug_print("hello: {}", x, ordered=False)640    with jtu.capture_stdout() as output:641      f(jnp.arange(jax.local_device_count()))642      jax.effects_barrier()643    lines = [f"hello: {i}\n" for i in range(jax.local_device_count())]644    self._assertLinesEqual(output(), "".join(lines))645    @jax.pmap646    def f2(x):647      debug_print('hello: {}', x)648      debug_print('hello: {}', x + 2)649    with jtu.capture_stdout() as output:650      f2(jnp.arange(2))651      jax.effects_barrier()652    self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")653  @jtu.skip_on_devices(*disabled_backends)654  def test_unordered_print_with_pjit(self):655    if jax.default_backend() in {"cpu", "gpu"} and jaxlib.version < (0, 3, 16):656      raise unittest.SkipTest("`pjit` of callback not supported.")657    def f(x):658      debug_print("{}", x, ordered=False)659      return x660    mesh = maps.Mesh(np.array(jax.devices()), ['dev'])661    if config.jax_array:662      spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))663      out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec())664    else:665      spec = pjit.PartitionSpec('dev')666      out_spec = pjit.PartitionSpec()667    f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)668    with mesh:669      with jtu.capture_stdout() as output:670        f(np.arange(8, dtype=jnp.int32))671        jax.effects_barrier()672      self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n")673    def f2(x):674      y = x.dot(x)675      debug_print("{}", y, ordered=False)676      return y677    f2 = pjit.pjit(f2, in_axis_resources=spec, out_axis_resources=out_spec)678    with maps.Mesh(np.array(jax.devices()), ['dev']):679      with jtu.capture_stdout() as output:680        f2(np.arange(8, dtype=jnp.int32))681        jax.effects_barrier()682      self.assertEqual(output(), "140\n")683  @jtu.skip_on_devices(*disabled_backends)684  def test_unordered_print_of_pjit_of_while(self):685    if (jax.default_backend() in {"cpu", "gpu"}686        and jaxlib.xla_extension_version < 81):687      raise unittest.SkipTest("`pjit` of callback not supported.")688    def f(x):689      def cond(carry):690        i, *_ = carry691        return i < 5692      def body(carry):693        i, x = carry694        debug_print("{}", x, ordered=False)695        x = x + 1696        return (i + 1, x)697      return lax.while_loop(cond, body, (0, x))[1]698    mesh = maps.Mesh(np.array(jax.devices()), ['dev'])699    if config.jax_array:700      spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))701    else:702      spec = pjit.PartitionSpec('dev')703    f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)704    with mesh:705      with jtu.capture_stdout() as output:706        f(np.arange(8, dtype=jnp.int32))707        jax.effects_barrier()708      self.assertEqual(output(),709          "[0 1 2 3 4 5 6 7]\n"710          "[1 2 3 4 5 6 7 8]\n"711          "[2 3 4 5 6 7 8 9]\n"712          "[ 3  4  5  6  7  8  9 10]\n"713          "[ 4  5  6  7  8  9 10 11]\n")714  @jtu.skip_on_devices(*disabled_backends)715  def test_unordered_print_of_pjit_of_xmap(self):716    if (jax.default_backend() in {"cpu", "gpu"}717        and jaxlib.xla_extension_version < 81):718      raise unittest.SkipTest("`pjit` of callback not supported.")719    def f(x):720      def foo(x):721        idx = lax.axis_index('foo')722        debug_print("{idx}: {x}", idx=idx, x=x)723        return jnp.mean(x, axis=['foo'])724      out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x)725      debug_print("Out: {}", out)726      return out727    mesh = maps.Mesh(np.array(jax.devices()), ['dev'])728    if config.jax_array:729      in_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))730      out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec())731    else:732      in_spec = pjit.PartitionSpec('dev')733      out_spec = pjit.PartitionSpec()734    f = pjit.pjit(f, in_axis_resources=in_spec, out_axis_resources=out_spec)735    with mesh:736      with jtu.capture_stdout() as output:737        f(jnp.arange(8, dtype=jnp.int32) * 2)738        lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",739                 "7: 14", "Out: 7.0", ""]740        jax.effects_barrier()741        self._assertLinesEqual(output(), "\n".join(lines))742  @jtu.skip_on_devices(*disabled_backends)743  def test_unordered_print_with_xmap(self):744    def f(x):745      debug_print("{}", x, ordered=False)746    f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',747                  axis_resources={'a': 'dev'})748    with maps.Mesh(np.array(jax.devices()), ['dev']):749      with jtu.capture_stdout() as output:750        f(np.arange(40))751        jax.effects_barrier()752      lines = [f"{i}\n" for i in range(40)]753      self._assertLinesEqual(output(), "".join(lines))754  @jtu.skip_on_devices(*disabled_backends)755  def test_unordered_print_works_in_pmap_of_while(self):756    if jax.device_count() < 2:757      raise unittest.SkipTest("Test requires >= 2 devices.")758    @jax.pmap759    def f(x):760      def cond(x):761        return x < 3762      def body(x):763        debug_print("hello: {}", x, ordered=False)764        return x + 1765      return lax.while_loop(cond, body, x)766    with jtu.capture_stdout() as output:767      f(jnp.arange(2))768      jax.effects_barrier()769    self._assertLinesEqual(770        output(), "hello: 0\nhello: 1\nhello: 2\n"771        "hello: 1\nhello: 2\n")772  @jtu.skip_on_devices(*disabled_backends)773  def test_incorrectly_formatted_string(self):774    @jax.jit775    def f(x):776      debug_print("hello: {x}", x)777      return x778    with self.assertRaises(KeyError):779      f(jnp.arange(2))780      jax.effects_barrier()781    @jax.jit782    def f(x):783      debug_print("hello: {}", x=x)784      return x785    with self.assertRaises(IndexError):786      f(jnp.arange(2))787      jax.effects_barrier()788  @jtu.skip_on_devices(*disabled_backends)789  def test_format_string_errors_with_unused_args(self):790    @jax.jit791    def f(x):792      debug_print("hello: {x}", x=x, y=x)793      return x794    with self.assertRaisesRegex(ValueError, "Unused keyword arguments"):795      f(jnp.arange(2))796      jax.effects_barrier()797    @jax.jit798    def g(x):799      debug_print("hello", x)800      return x801    with self.assertRaisesRegex(ValueError, "Unused positional arguments"):802      g(jnp.arange(2))803      jax.effects_barrier()804  @jtu.skip_on_devices(*disabled_backends)805  def test_accidental_fstring(self):806    @jax.jit807    def f(x):808      debug_print(f"hello: {x}", x=x)809      return x810    with self.assertRaisesRegex(ValueError, "You may be passing an f-string"):811      f(jnp.arange(2))812      jax.effects_barrier()813class VisualizeShardingTest(jtu.JaxTestCase):814  def _create_devices(self, shape):815    num_devices = np.prod(shape)816    devices = [DummyDevice("CPU", i) for i in range(num_devices)]817    return np.array(devices).reshape(shape)818  def test_trivial_sharding(self):819    mesh = maps.Mesh(self._create_devices(1), ['x'])820    pspec = pjit.PartitionSpec('x')821    sd = sharding.MeshPspecSharding(mesh, pspec)822    shape = (5,)823    with jtu.capture_stdout() as output:824      debugging.visualize_sharding(shape, sd)825    self.assertEqual(output(), _format_multiline("""826    âââââââââ827    â CPU 0 â828    âââââââââ829    """))830  def test_trivial_sharding_with_scale(self):831    mesh = maps.Mesh(self._create_devices(1), ['x'])832    pspec = pjit.PartitionSpec('x')833    sd = sharding.MeshPspecSharding(mesh, pspec)834    shape = (5,)835    with jtu.capture_stdout() as output:836      debugging.visualize_sharding(shape, sd, scale=8.)837    self.assertEqual(output(), _format_multiline("""838    ââââââââââââââââ839    â    CPU 0     â840    ââââââââââââââââ841    """))842  def test_full_sharding(self):843    mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])844    pspec = pjit.PartitionSpec('x', 'y')845    sd = sharding.MeshPspecSharding(mesh, pspec)846    shape = (8, 8)847    with jtu.capture_stdout() as output:848      debugging.visualize_sharding(shape, sd)849    expected = _format_multiline("""850    âââââââââ¬ââââââââ¬ââââââââ¬ââââââââ851    â CPU 0 â CPU 1 â CPU 2 â CPU 3 â852    âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤853    â CPU 4 â CPU 5 â CPU 6 â CPU 7 â854    âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤855    â CPU 8 â CPU 9 âCPU 10 âCPU 11 â856    âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤857    âCPU 12 âCPU 13 âCPU 14 âCPU 15 â858    âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤859    âCPU 16 âCPU 17 âCPU 18 âCPU 19 â860    âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤861    âCPU 20 âCPU 21 âCPU 22 âCPU 23 â862    âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤863    âCPU 24 âCPU 25 âCPU 26 âCPU 27 â864    âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤865    âCPU 28 âCPU 29 âCPU 30 âCPU 31 â866    âââââââââ´ââââââââ´ââââââââ´ââââââââ867    """)868    self.assertEqual(output(), expected)869  def test_sharding_with_replication(self):870    shape = (8, 8)871    mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])872    pspec = pjit.PartitionSpec('x', None)873    sd = sharding.MeshPspecSharding(mesh, pspec)874    with jtu.capture_stdout() as output:875      debugging.visualize_sharding(shape, sd)876    expected = _format_multiline("""877    âââââââââââââââââââââââââ878    â      CPU 0,1,2,3      â879    âââââââââââââââââââââââââ¤880    â      CPU 4,5,6,7      â881    âââââââââââââââââââââââââ¤882    â     CPU 8,9,10,11     â883    âââââââââââââââââââââââââ¤884    â    CPU 12,13,14,15    â885    âââââââââââââââââââââââââ¤886    â    CPU 16,17,18,19    â887    âââââââââââââââââââââââââ¤888    â    CPU 20,21,22,23    â889    âââââââââââââââââââââââââ¤890    â    CPU 24,25,26,27    â891    âââââââââââââââââââââââââ¤892    â    CPU 28,29,30,31    â893    âââââââââââââââââââââââââ894    """)895    self.assertEqual(output(), expected)896    mesh = maps.Mesh(self._create_devices((4, 2)), ['x', 'y'])897    pspec = pjit.PartitionSpec(None, 'y')898    sd = sharding.MeshPspecSharding(mesh, pspec)899    with jtu.capture_stdout() as output:900      debugging.visualize_sharding(shape, sd)901    expected = _format_multiline("""902    âââââââââââââ¬ââââââââââââ903    â           â           â904    â           â           â905    â           â           â906    â           â           â907    âCPU 0,2,4,6âCPU 1,3,5,7â908    â           â           â909    â           â           â910    â           â           â911    â           â           â912    âââââââââââââ´ââââââââââââ913    """)914    self.assertEqual(output(), expected)915  def test_visualize_wide_array(self):916    shape = (128, 10000)917    mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])918    pspec = pjit.PartitionSpec('x', None)919    sd = sharding.MeshPspecSharding(mesh, pspec)920    with jtu.capture_stdout() as output:921      debugging.visualize_sharding(shape, sd)922    expected = _format_multiline("""923    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ924    â                                 CPU 0,1,2,3                                  â925    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤926    â                                 CPU 4,5,6,7                                  â927    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤928    â                                CPU 8,9,10,11                                 â929    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤930    â                               CPU 12,13,14,15                                â931    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤932    â                               CPU 16,17,18,19                                â933    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤934    â                               CPU 20,21,22,23                                â935    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤936    â                               CPU 24,25,26,27                                â937    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤938    â                               CPU 28,29,30,31                                â939    ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ940    """)941    self.assertEqual(output(), expected)942  def test_visualize_pmap_sharding(self):943    ss = pxla.ShardingSpec(944        sharding=(pxla.Unstacked(8),),945        mesh_mapping=(pxla.ShardedAxis(0),))946    sd = sharding.PmapSharding(self._create_devices(8), ss)947    shape = (8,)948    with jtu.capture_stdout() as output:949      debugging.visualize_sharding(shape, sd)950    expected = _format_multiline("""951    âââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ952    â CPU 0 â CPU 1 â CPU 2 â CPU 3 â CPU 4 â CPU 5 â CPU 6 â CPU 7 â953    âââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ954    """)955    self.assertEqual(output(), expected)956    ss = pxla.ShardingSpec(957        sharding=(pxla.Unstacked(8), pxla.NoSharding()),958        mesh_mapping=(pxla.ShardedAxis(0),))959    sd = sharding.PmapSharding(self._create_devices(8), ss)960    shape = (8, 2)961    with jtu.capture_stdout() as output:962      debugging.visualize_sharding(shape, sd)963    expected = _format_multiline("""964    âââââââââ965    â CPU 0 â966    âââââââââ¤967    â CPU 1 â968    âââââââââ¤969    â CPU 2 â970    âââââââââ¤971    â CPU 3 â972    âââââââââ¤973    â CPU 4 â974    âââââââââ¤975    â CPU 5 â976    âââââââââ¤977    â CPU 6 â...debugger_test.py
Source:debugger_test.py  
...33  for command in commands:34    fake_stdin.write(command + "\n")35  fake_stdin.seek(0)36  return fake_stdin, io.StringIO()37def _format_multiline(text):38  return textwrap.dedent(text).lstrip()39prev_xla_flags = None40def setUpModule():41  global prev_xla_flags42  # This will control the CPU devices. On TPU we always have 2 devices43  prev_xla_flags = jtu.set_host_platform_device_count(2)44# Reset to previous configuration in case other test modules will be run.45def tearDownModule():46  prev_xla_flags()47# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum48#                 version is >= 0.3.1549disabled_backends = []50if jaxlib.version < (0, 3, 15):51  disabled_backends.append("tpu")52class CliDebuggerTest(jtu.JaxTestCase):53  @jtu.skip_on_devices(*disabled_backends)54  def test_debugger_eof(self):55    stdin, stdout = make_fake_stdin_stdout([])56    def f(x):57      y = jnp.sin(x)58      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")59      return y60    with self.assertRaises(SystemExit):61      f(2.)62      jax.effects_barrier()63  @jtu.skip_on_devices(*disabled_backends)64  def test_debugger_can_continue(self):65    stdin, stdout = make_fake_stdin_stdout(["c"])66    def f(x):67      y = jnp.sin(x)68      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")69      return y70    f(2.)71    jax.effects_barrier()72    expected = _format_multiline(r"""73    Entering jdb:74    (jdb) """)75    self.assertEqual(stdout.getvalue(), expected)76  @jtu.skip_on_devices(*disabled_backends)77  def test_debugger_can_print_value(self):78    stdin, stdout = make_fake_stdin_stdout(["p x", "c"])79    def f(x):80      y = jnp.sin(x)81      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")82      return y83    expected = _format_multiline(r"""84    Entering jdb:85    (jdb) DeviceArray(2., dtype=float32)86    (jdb) """)87    f(jnp.array(2., jnp.float32))88    jax.effects_barrier()89    self.assertEqual(stdout.getvalue(), expected)90  @jtu.skip_on_devices(*disabled_backends)91  def test_debugger_can_print_value_in_jit(self):92    stdin, stdout = make_fake_stdin_stdout(["p x", "c"])93    @jax.jit94    def f(x):95      y = jnp.sin(x)96      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")97      return y98    expected = _format_multiline(r"""99    Entering jdb:100    (jdb) array(2., dtype=float32)101    (jdb) """)102    f(jnp.array(2., jnp.float32))103    jax.effects_barrier()104    self.assertEqual(stdout.getvalue(), expected)105  @jtu.skip_on_devices(*disabled_backends)106  def test_debugger_can_print_multiple_values(self):107    stdin, stdout = make_fake_stdin_stdout(["p x, y", "c"])108    @jax.jit109    def f(x):110      y = x + 1.111      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")112      return y113    expected = _format_multiline(r"""114    Entering jdb:115    (jdb) (array(2., dtype=float32), array(3., dtype=float32))116    (jdb) """)117    f(jnp.array(2., jnp.float32))118    jax.effects_barrier()119    self.assertEqual(stdout.getvalue(), expected)120  @jtu.skip_on_devices(*disabled_backends)121  def test_debugger_can_print_context(self):122    stdin, stdout = make_fake_stdin_stdout(["l", "c"])123    @jax.jit124    def f(x):125      y = jnp.sin(x)126      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")127      return y128    f(2.)129    jax.effects_barrier()130    expected = _format_multiline(r"""131    Entering jdb:132    \(jdb\) > .*debugger_test\.py\([0-9]+\)133            @jax\.jit134            def f\(x\):135              y = jnp\.sin\(x\)136    ->        debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)137              return y138    .*139    \(jdb\) """)140    self.assertRegex(stdout.getvalue(), expected)141  @jtu.skip_on_devices(*disabled_backends)142  def test_debugger_can_print_backtrace(self):143    stdin, stdout = make_fake_stdin_stdout(["bt", "c"])144    @jax.jit145    def f(x):146      y = jnp.sin(x)147      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")148      return y149    expected = _format_multiline(r"""150    Entering jdb:.*151    \(jdb\) Traceback:.*152    """)153    f(2.)154    jax.effects_barrier()155    self.assertRegex(stdout.getvalue(), expected)156  @jtu.skip_on_devices(*disabled_backends)157  def test_debugger_can_work_with_multiple_stack_frames(self):158    stdin, stdout = make_fake_stdin_stdout(["l", "u", "p x", "d", "c"])159    def f(x):160      y = jnp.sin(x)161      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")162      return y163    @jax.jit164    def g(x):165      y = f(x)166      return jnp.exp(y)167    expected = _format_multiline(r"""168    Entering jdb:169    \(jdb\) > .*debugger_test\.py\([0-9]+\)170            def f\(x\):171              y = jnp\.sin\(x\)172    ->        debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)173              return y174    .*175    \(jdb\) > .*debugger_test\.py\([0-9]+\).*176            @jax\.jit177            def g\(x\):178    ->        y = f\(x\)179              return jnp\.exp\(y\)180    .*181    \(jdb\) array\(2\., dtype=float32\)182    \(jdb\) > .*debugger_test\.py\([0-9]+\)183            def f\(x\):184              y = jnp\.sin\(x\)185    ->        debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)186              return y187    .*188    \(jdb\) """)189    g(jnp.array(2., jnp.float32))190    jax.effects_barrier()191    self.assertRegex(stdout.getvalue(), expected)192  @jtu.skip_on_devices(*disabled_backends)193  def test_can_use_multiple_breakpoints(self):194    stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])195    def f(x):196      y = x + 1.197      debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,198          backend="cli")199      return y200    @jax.jit201    def g(x):202      y = f(x) * 2.203      debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,204          backend="cli")205      return jnp.exp(y)206    expected = _format_multiline(r"""207    Entering jdb:208    (jdb) array(3., dtype=float32)209    (jdb) Entering jdb:210    (jdb) array(6., dtype=float32)211    (jdb) """)212    g(jnp.array(2., jnp.float32))213    jax.effects_barrier()214    self.assertEqual(stdout.getvalue(), expected)215  @jtu.skip_on_devices(*disabled_backends)216  def test_debugger_works_with_vmap(self):217    stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])218    # On TPU, the breakpoints can be reordered inside of vmap but can be fixed219    # by ordering sends.220    # TODO(sharadmv): change back to ordered = False when sends are ordered221    ordered = jax.default_backend() == "tpu"222    def f(x):223      y = x + 1.224      debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered,225          backend="cli")226      return 2. * y227    @jax.jit228    @jax.vmap229    def g(x):230      y = f(x)231      return jnp.exp(y)232    expected = _format_multiline(r"""233    Entering jdb:234    (jdb) array(1., dtype=float32)235    (jdb) Entering jdb:236    (jdb) array(2., dtype=float32)237    (jdb) """)238    g(jnp.arange(2., dtype=jnp.float32))239    jax.effects_barrier()240    self.assertEqual(stdout.getvalue(), expected)241  @jtu.skip_on_devices(*disabled_backends)242  def test_debugger_works_with_pmap(self):243    if jax.local_device_count() < 2:244      raise unittest.SkipTest("Test requires >= 2 devices.")245    stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])246    def f(x):247      y = jnp.sin(x)248      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")249      return y250    @jax.pmap251    def g(x):252      y = f(x)253      return jnp.exp(y)254    expected = _format_multiline(r"""255    Entering jdb:256    \(jdb\) array\(.*, dtype=float32\)257    \(jdb\) Entering jdb:258    \(jdb\) array\(.*, dtype=float32\)259    \(jdb\) """)260    g(jnp.arange(2., dtype=jnp.float32))261    jax.effects_barrier()262    self.assertRegex(stdout.getvalue(), expected)263  @jtu.skip_on_devices(*disabled_backends)264  def test_debugger_works_with_pjit(self):265    if jax.default_backend() != "tpu":266      raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")267    stdin, stdout = make_fake_stdin_stdout(["p y", "c"])268    def f(x):269      y = x + 1270      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")271      return y272    def g(x):273      y = f(x)274      return jnp.exp(y)275    g = pjit.pjit(g, in_axis_resources=pjit.PartitionSpec("dev"),276                  out_axis_resources=pjit.PartitionSpec("dev"))277    with maps.Mesh(np.array(jax.devices()), ["dev"]):278      arr = (1 + np.arange(8)).astype(np.int32)279      expected = _format_multiline(r"""280      Entering jdb:281      \(jdb\) {}282      \(jdb\) """.format(re.escape(repr(arr))))283      g(jnp.arange(8, dtype=jnp.int32))284      jax.effects_barrier()285      print(stdout.getvalue())286      print(expected)287      self.assertRegex(stdout.getvalue(), expected)288  @jtu.skip_on_devices(*disabled_backends)289  def test_debugger_uses_local_before_global_scope(self):290    stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])291    foo = "outer"292    def f(x):293      foo = "inner"294      debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")295      del foo296      return x297    del foo298    expected = _format_multiline(r"""299    Entering jdb:300    \(jdb\) 'inner'301    \(jdb\) """)302    f(2.)303    jax.effects_barrier()304    print(stdout.getvalue())305    print(expected)306    self.assertRegex(stdout.getvalue(), expected)307if __name__ == '__main__':...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
