How to use cond method in hypothesis

Best Python code snippet using hypothesis

cond_v2_test.py

Source:cond_v2_test.py Github

copy

Full Screen

...45 if not feed_dict:46 feed_dict = {}47 with self.session(graph=ops.get_default_graph()) as sess:48 pred = array_ops.placeholder(dtypes.bool, name="pred")49 expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")50 actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")51 expected_grad = gradients_impl.gradients(expected, train_vals)52 actual_grad = gradients_impl.gradients(actual, train_vals)53 sess_run_args = {pred: True}54 sess_run_args.update(feed_dict)55 expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(56 (expected, actual, expected_grad, actual_grad), sess_run_args)57 self.assertEqual(expected_val, actual_val)58 self.assertEqual(expected_grad_val, actual_grad_val)59 sess_run_args = {pred: False}60 sess_run_args.update(feed_dict)61 expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(62 (expected, actual, expected_grad, actual_grad), sess_run_args)63 self.assertEqual(expected_val, actual_val)64 self.assertEqual(expected_grad_val, actual_grad_val)65 @test_util.run_deprecated_v166 def testBasic(self):67 x = constant_op.constant(1.0, name="x")68 y = constant_op.constant(2.0, name="y")69 def true_fn():70 return x * 2.071 def false_fn():72 return y * 3.073 self._testCond(true_fn, false_fn, [x])74 self._testCond(true_fn, false_fn, [x, y])75 self._testCond(true_fn, false_fn, [y])76 def testExternalControlDependencies(self):77 with ops.Graph().as_default(), self.test_session():78 v = variables.Variable(1.0)79 v.initializer.run()80 op = v.assign_add(1.0)81 def true_branch():82 with ops.control_dependencies([op]):83 return 1.084 cond_v2.cond_v2(array_ops.placeholder_with_default(False, None),85 true_branch,86 lambda: 2.0).eval()87 self.assertAllEqual(self.evaluate(v), 2.0)88 @test_util.run_deprecated_v189 def testMultipleOutputs(self):90 x = constant_op.constant(1.0, name="x")91 y = constant_op.constant(3.0, name="y")92 def true_fn():93 return x * y, y94 def false_fn():95 return x, y * 3.096 self._testCond(true_fn, false_fn, [x])97 self._testCond(true_fn, false_fn, [x, y])98 self._testCond(true_fn, false_fn, [y])99 @test_util.run_deprecated_v1100 def testBasic2(self):101 x = constant_op.constant(1.0, name="x")102 y = constant_op.constant(2.0, name="y")103 def true_fn():104 return x * y * 2.0105 def false_fn():106 return 2.0107 self._testCond(true_fn, false_fn, [x])108 self._testCond(true_fn, false_fn, [x, y])109 self._testCond(true_fn, false_fn, [y])110 @test_util.run_deprecated_v1111 def testNoInputs(self):112 with self.cached_session() as sess:113 pred = array_ops.placeholder(dtypes.bool, name="pred")114 def true_fn():115 return constant_op.constant(1.0)116 def false_fn():117 return constant_op.constant(2.0)118 out = cond_v2.cond_v2(pred, true_fn, false_fn)119 self.assertEqual(sess.run(out, {pred: True}), (1.0,))120 self.assertEqual(sess.run(out, {pred: False}), (2.0,))121 def _createCond(self, name):122 """Creates a cond_v2 call and returns the output tensor and the cond op."""123 pred = constant_op.constant(True, name="pred")124 x = constant_op.constant(1.0, name="x")125 def true_fn():126 return x127 def false_fn():128 return x + 1129 output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)130 cond_op = output.op.inputs[0].op131 self.assertEqual(cond_op.type, "If")132 return output, cond_op133 def _createNestedCond(self, name):134 """Like _createCond but creates a nested cond_v2 call as well."""135 pred = constant_op.constant(True, name="pred")136 x = constant_op.constant(1.0, name="x")137 def true_fn():138 return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)139 def false_fn():140 return x + 2141 output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)142 cond_op = output.op.inputs[0].op143 self.assertEqual(cond_op.type, "If")144 return output, cond_op145 def testDefaultName(self):146 with ops.Graph().as_default():147 _, cond_op = self._createCond(None)148 self.assertEqual(cond_op.name, "cond")149 self.assertRegexpMatches(150 cond_op.get_attr("then_branch").name, r"cond_true_\d*")151 self.assertRegexpMatches(152 cond_op.get_attr("else_branch").name, r"cond_false_\d*")153 with ops.Graph().as_default():154 with ops.name_scope("foo"):155 _, cond1_op = self._createCond("")156 self.assertEqual(cond1_op.name, "foo/cond")157 self.assertRegexpMatches(158 cond1_op.get_attr("then_branch").name, r"foo_cond_true_\d*")159 self.assertRegexpMatches(160 cond1_op.get_attr("else_branch").name, r"foo_cond_false_\d*")161 _, cond2_op = self._createCond(None)162 self.assertEqual(cond2_op.name, "foo/cond_1")163 self.assertRegexpMatches(164 cond2_op.get_attr("then_branch").name, r"foo_cond_1_true_\d*")165 self.assertRegexpMatches(166 cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")167 @test_util.run_v1_only("b/120545219")168 def testDefunInCond(self):169 x = constant_op.constant(1.0, name="x")170 y = constant_op.constant(2.0, name="y")171 def true_fn():172 @function.defun173 def fn():174 return x * y * 2.0175 return fn()176 def false_fn():177 return 2.0178 self._testCond(true_fn, false_fn, [x])179 self._testCond(true_fn, false_fn, [x, y])180 self._testCond(true_fn, false_fn, [y])181 @test_util.run_deprecated_v1182 def testNestedDefunInCond(self):183 x = constant_op.constant(1.0, name="x")184 y = constant_op.constant(2.0, name="y")185 def true_fn():186 return 2.0187 def false_fn():188 @function.defun189 def fn():190 @function.defun191 def nested_fn():192 return x * y * 2.0193 return nested_fn()194 return fn()195 self._testCond(true_fn, false_fn, [x])196 self._testCond(true_fn, false_fn, [x, y])197 self._testCond(true_fn, false_fn, [y])198 @test_util.run_deprecated_v1199 def testDoubleNestedDefunInCond(self):200 x = constant_op.constant(1.0, name="x")201 y = constant_op.constant(2.0, name="y")202 def true_fn():203 @function.defun204 def fn():205 @function.defun206 def nested_fn():207 @function.defun208 def nested_nested_fn():209 return x * y * 2.0210 return nested_nested_fn()211 return nested_fn()212 return fn()213 def false_fn():214 return 2.0215 self._testCond(true_fn, false_fn, [x])216 self._testCond(true_fn, false_fn, [x, y])217 self._testCond(true_fn, false_fn, [y])218 def testNestedCond(self):219 def run_test(pred_value):220 def build_graph():221 pred = array_ops.placeholder(dtypes.bool, name="pred")222 x = constant_op.constant(1.0, name="x")223 y = constant_op.constant(2.0, name="y")224 def true_fn():225 return 2.0226 def false_fn():227 def false_true_fn():228 return x * y * 2.0229 def false_false_fn():230 return x * 5.0231 return _cond(pred, false_true_fn, false_false_fn, "inside_false_fn")232 return x, y, pred, true_fn, false_fn233 with ops.Graph().as_default():234 x, y, pred, true_fn, false_fn = build_graph()235 self._testCond(true_fn, false_fn, [x, y], {pred: pred_value})236 self._testCond(true_fn, false_fn, [x], {pred: pred_value})237 self._testCond(true_fn, false_fn, [y], {pred: pred_value})238 run_test(True)239 run_test(False)240 def testNestedCondBothBranches(self):241 def run_test(pred_value):242 def build_graph():243 pred = array_ops.placeholder(dtypes.bool, name="pred")244 x = constant_op.constant(1.0, name="x")245 y = constant_op.constant(2.0, name="y")246 def true_fn():247 return _cond(pred, lambda: x + y, lambda: x * x, name=None)248 def false_fn():249 return _cond(pred, lambda: x - y, lambda: y * y, name=None)250 return x, y, pred, true_fn, false_fn251 with ops.Graph().as_default():252 x, y, pred, true_fn, false_fn = build_graph()253 self._testCond(true_fn, false_fn, [x, y], {pred: pred_value})254 self._testCond(true_fn, false_fn, [x], {pred: pred_value})255 self._testCond(true_fn, false_fn, [y], {pred: pred_value})256 run_test(True)257 run_test(False)258 def testDoubleNestedCond(self):259 def run_test(pred1_value, pred2_value):260 def build_graph():261 pred1 = array_ops.placeholder(dtypes.bool, name="pred1")262 pred2 = array_ops.placeholder(dtypes.bool, name="pred2")263 x = constant_op.constant(1.0, name="x")264 y = constant_op.constant(2.0, name="y")265 def true_fn():266 return 2.0267 def false_fn():268 def false_true_fn():269 def false_true_true_fn():270 return x * y * 2.0271 def false_true_false_fn():272 return x * 10.0273 return _cond(274 pred1,275 false_true_true_fn,276 false_true_false_fn,277 name="inside_false_true_fn")278 def false_false_fn():279 return x * 5.0280 return _cond(281 pred2, false_true_fn, false_false_fn, name="inside_false_fn")282 return x, y, pred1, pred2, true_fn, false_fn283 with ops.Graph().as_default():284 x, y, pred1, pred2, true_fn, false_fn = build_graph()285 self._testCond(true_fn, false_fn, [x, y], {286 pred1: pred1_value,287 pred2: pred2_value288 })289 x, y, pred1, pred2, true_fn, false_fn = build_graph()290 self._testCond(true_fn, false_fn, [x], {291 pred1: pred1_value,292 pred2: pred2_value293 })294 x, y, pred1, pred2, true_fn, false_fn = build_graph()295 self._testCond(true_fn, false_fn, [y], {296 pred1: pred1_value,297 pred2: pred2_value298 })299 run_test(True, True)300 run_test(True, False)301 run_test(False, False)302 run_test(False, True)303 def testGradientFromInsideDefun(self):304 def build_graph():305 pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")306 pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")307 x = constant_op.constant(1.0, name="x")308 y = constant_op.constant(2.0, name="y")309 def true_fn():310 return 2.0311 def false_fn():312 def inner_true_fn():313 return x * y * 2.0314 def inner_false_fn():315 return x * 5.0316 return cond_v2.cond_v2(317 pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")318 cond_outer = cond_v2.cond_v2(319 pred_outer, true_fn, false_fn, name="outer_cond")320 # Compute grads inside a Defun.321 @function.defun322 def nesting_fn():323 return gradients_impl.gradients(cond_outer, [x, y])324 grads = nesting_fn()325 return grads, pred_outer, pred_inner326 with ops.Graph().as_default():327 grads, pred_outer, pred_inner = build_graph()328 with self.session(graph=ops.get_default_graph()) as sess:329 self.assertSequenceEqual(330 sess.run(grads, {331 pred_outer: True,332 pred_inner: True333 }), [0., 0.])334 self.assertSequenceEqual(335 sess.run(grads, {336 pred_outer: True,337 pred_inner: False338 }), [0., 0.])339 self.assertSequenceEqual(340 sess.run(grads, {341 pred_outer: False,342 pred_inner: True343 }), [4., 2.])344 self.assertSequenceEqual(345 sess.run(grads, {346 pred_outer: False,347 pred_inner: False348 }), [5., 0.])349 def testGradientFromInsideNestedDefun(self):350 def build_graph():351 pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")352 pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")353 x = constant_op.constant(1.0, name="x")354 y = constant_op.constant(2.0, name="y")355 def true_fn():356 return 2.0357 def false_fn():358 def inner_true_fn():359 return x * y * 2.0360 def inner_false_fn():361 return x * 5.0362 return cond_v2.cond_v2(363 pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")364 cond_outer = cond_v2.cond_v2(365 pred_outer, true_fn, false_fn, name="outer_cond")366 # Compute grads inside a Defun.367 @function.defun368 def nesting_fn():369 @function.defun370 def inner_nesting_fn():371 return gradients_impl.gradients(cond_outer, [x, y])372 return inner_nesting_fn()373 grads = nesting_fn()374 return grads, pred_outer, pred_inner375 with ops.Graph().as_default():376 grads, pred_outer, pred_inner = build_graph()377 with self.session(graph=ops.get_default_graph()) as sess:378 self.assertSequenceEqual(379 sess.run(grads, {380 pred_outer: True,381 pred_inner: True382 }), [0., 0.])383 self.assertSequenceEqual(384 sess.run(grads, {385 pred_outer: True,386 pred_inner: False387 }), [0., 0.])388 self.assertSequenceEqual(389 sess.run(grads, {390 pred_outer: False,391 pred_inner: True392 }), [4., 2.])393 self.assertSequenceEqual(394 sess.run(grads, {395 pred_outer: False,396 pred_inner: False397 }), [5., 0.])398 def testBuildCondAndGradientInsideDefun(self):399 def build_graph():400 pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")401 pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")402 x = constant_op.constant(1.0, name="x")403 y = constant_op.constant(2.0, name="y")404 # Build cond and its gradient inside a Defun.405 @function.defun406 def fn():407 def true_fn():408 return 2.0409 def false_fn():410 def inner_true_fn():411 return x * y * 2.0412 def inner_false_fn():413 return x * 5.0414 return cond_v2.cond_v2(415 pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")416 cond_outer = cond_v2.cond_v2(417 pred_outer, true_fn, false_fn, name="outer_cond")418 return gradients_impl.gradients(cond_outer, [x, y])419 grads = fn()420 return grads, pred_outer, pred_inner421 with ops.Graph().as_default(), self.session(422 graph=ops.get_default_graph()) as sess:423 grads, pred_outer, pred_inner = build_graph()424 self.assertSequenceEqual(425 sess.run(grads, {426 pred_outer: True,427 pred_inner: True428 }), [0., 0.])429 self.assertSequenceEqual(430 sess.run(grads, {431 pred_outer: True,432 pred_inner: False433 }), [0., 0.])434 self.assertSequenceEqual(435 sess.run(grads, {436 pred_outer: False,437 pred_inner: True438 }), [4., 2.])439 self.assertSequenceEqual(440 sess.run(grads, {441 pred_outer: False,442 pred_inner: False443 }), [5., 0.])444 @test_util.run_deprecated_v1445 def testSecondDerivative(self):446 with self.cached_session() as sess:447 pred = array_ops.placeholder(dtypes.bool, name="pred")448 x = constant_op.constant(3.0, name="x")449 def true_fn():450 return math_ops.pow(x, 3)451 def false_fn():452 return x453 cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")454 cond_grad = gradients_impl.gradients(cond, [x])455 cond_grad_grad = gradients_impl.gradients(cond_grad, [x])456 # d[x^3]/dx = 3x^2457 true_val = sess.run(cond_grad, {pred: True})458 self.assertEqual(true_val, [27.0])459 # d[x]/dx = 1460 false_val = sess.run(cond_grad, {pred: False})461 self.assertEqual(false_val, [1.0])462 true_val = sess.run(cond_grad_grad, {pred: True})463 # d2[x^3]/dx2 = 6x464 self.assertEqual(true_val, [18.0])465 false_val = sess.run(cond_grad_grad, {pred: False})466 # d2[x]/dx2 = 0467 self.assertEqual(false_val, [0.0])468 def testGradientOfDeserializedCond(self):469 with ops.Graph().as_default():470 pred = array_ops.placeholder(dtypes.bool, name="pred")471 x = constant_op.constant(3.0, name="x")472 ops.add_to_collection("x", x)473 def true_fn():474 return math_ops.pow(x, 3)475 def false_fn():476 return x477 ops.add_to_collection("pred", pred)478 cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")479 ops.add_to_collection("cond", cond)480 meta_graph = saver.export_meta_graph()481 with ops.Graph().as_default() as g:482 with self.session(graph=g) as sess:483 saver.import_meta_graph(meta_graph)484 x = ops.get_collection("x")[0]485 pred = ops.get_collection("pred")[0]486 cond = ops.get_collection("cond")487 cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")488 cond_grad_grad = gradients_impl.gradients(489 cond_grad, [x], name="cond_grad_grad")490 # d[x^3]/dx = 3x^2491 true_val = sess.run(cond_grad, {pred: True})492 self.assertEqual(true_val, [27.0])493 # d[x]/dx = 1494 false_val = sess.run(cond_grad, {pred: False})495 self.assertEqual(false_val, [1.0])496 true_val = sess.run(cond_grad_grad, {pred: True})497 # d2[x^3]/dx2 = 6x498 self.assertEqual(true_val, [18.0])499 false_val = sess.run(cond_grad_grad, {pred: False})500 # d2[x]/dx2 = 0501 self.assertEqual(false_val, [0.0])502 def testGradientTapeOfCondWithResourceVariableInFunction(self):503 with context.eager_mode():504 v = variables.Variable(2.)505 @def_function.function506 def fnWithCond(): # pylint: disable=invalid-name507 with backprop.GradientTape() as tape:508 pred = constant_op.constant(True, dtype=dtypes.bool)509 def true_fn():510 return math_ops.pow(v, 3)511 def false_fn():512 return v513 cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")514 return tape.gradient(cond, v)515 self.assertAllEqual(fnWithCond(), 12.0)516 def testLowering(self):517 with ops.Graph().as_default() as g:518 with self.session(graph=g) as sess:519 cond_output, _ = self._createCond("cond")520 run_options = config_pb2.RunOptions(output_partition_graphs=True)521 run_metadata = config_pb2.RunMetadata()522 sess.run(cond_output, options=run_options, run_metadata=run_metadata)523 # If lowering was enabled, there should be a `Switch` node524 switch_found = any(525 any(node.op == "Switch" for node in graph.node)526 for graph in run_metadata.partition_graphs527 )528 self.assertTrue(switch_found,529 "A `Switch` op should exist if the graph was lowered.")530 # If lowering was enabled, there should be no `If` node531 if_found = any(532 any(node.op == "If" for node in graph.node)533 for graph in run_metadata.partition_graphs534 )535 self.assertFalse(if_found,536 "An `If` op was found, but it should be lowered.")537 @test_util.run_deprecated_v1538 def testLoweringDisabledInXLA(self):539 with self.session(graph=ops.Graph()) as sess:540 # Build the cond_v2 in an XLA context541 xla_context = control_flow_ops.XLAControlFlowContext()542 xla_context.Enter()543 cond_output, cond_op = self._createCond("cond")544 xla_context.Exit()545 # Check lowering attr is not set.546 with self.assertRaises(ValueError):547 cond_op.get_attr("_lower_using_switch_merge")548 # Check the actual graph that is run.549 run_options = config_pb2.RunOptions(output_partition_graphs=True)550 run_metadata = config_pb2.RunMetadata()551 sess.run(cond_output, options=run_options, run_metadata=run_metadata)552 # Lowering disabled in XLA, there should be no `Switch` node553 switch_found = any(554 any(node.op == "Switch" for node in graph.node)555 for graph in run_metadata.partition_graphs556 )557 self.assertFalse(558 switch_found,559 "A `Switch` op exists, but the graph should not be lowered.")560 # Lowering disabled in XLA, there should still be an `If` node561 if_found = any(562 any(node.op == "If" for node in graph.node)563 for graph in run_metadata.partition_graphs564 )565 self.assertTrue(566 if_found,567 "An `If` op was not found, but the graph should not be lowered.")568 @test_util.run_deprecated_v1569 def testNestedLoweringDisabledInXLA(self):570 # Build the cond_v2 in an XLA context571 xla_context = control_flow_ops.XLAControlFlowContext()572 xla_context.Enter()573 _, cond_op = self._createNestedCond("cond")574 xla_context.Exit()575 # Check lowering attr is not set for either If node.576 with self.assertRaises(ValueError):577 cond_op.get_attr("_lower_using_switch_merge")578 nested_if_ops = []579 for func in ops.get_default_graph()._functions.values():580 nested_if_ops.extend(op for op in func.graph.get_operations()581 if op.type == "If")582 self.assertEqual(len(nested_if_ops), 1)583 with self.assertRaises(ValueError):584 nested_if_ops[0].get_attr("_lower_using_switch_merge")585 # TODO(skyewm): check the actual graphs that are run once we have a way to586 # programmatically access those graphs.587 # b/131355614588 @test_util.run_deprecated_v1589 def testNoOptionalsInXla(self):590 @def_function.function591 def func_with_cond():592 pred = constant_op.constant(True, name="pred")593 x = constant_op.constant(1.0, name="x")594 def true_fn():595 intermediate = x + 1596 return intermediate * x597 def false_fn():598 return x + 1599 output = cond_v2.cond_v2(pred, true_fn, false_fn)600 grad = gradients_impl.gradients(output, x)[0]601 forward_if_op = output.op.inputs[0].op602 gradient_if_op = grad.op.inputs[0].op603 def verify_no_optional_ops(op, branch_name):604 branch_function = ops.get_default_graph()._get_function(605 op.get_attr(branch_name).name)606 function_def = branch_function.definition607 for node_def in function_def.node_def:608 self.assertNotIn(node_def.op, _OPTIONAL_OPS)609 verify_no_optional_ops(forward_if_op, "then_branch")610 verify_no_optional_ops(forward_if_op, "else_branch")611 verify_no_optional_ops(gradient_if_op, "then_branch")612 verify_no_optional_ops(gradient_if_op, "else_branch")613 return grad614 xla_context = control_flow_ops.XLAControlFlowContext()615 xla_context.Enter()616 func_with_cond()617 xla_context.Exit()618 @test_util.run_deprecated_v1619 def testLoweringDisabledWithSingleThreadedExecutorContext(self):620 with self.session(graph=ops.Graph()) as sess:621 @function.defun622 def _add_cond(x):623 return cond_v2.cond_v2(624 constant_op.constant(True, name="pred"),625 lambda: x,626 lambda: x + 1)627 x = array_ops.placeholder(shape=None, dtype=dtypes.float32)628 with context.function_executor_type("SINGLE_THREADED_EXECUTOR"):629 out_cond = _add_cond(x)630 # The fact that sess.run() succeeds means lowering is disabled, because631 # the single threaded executor does not support cond v1 ops.632 sess.run(out_cond, feed_dict={x: 1.0})633 @test_util.enable_control_flow_v2634 def testStructuredOutputs(self):635 x = constant_op.constant(1.0, name="x")636 y = constant_op.constant(3.0, name="y")637 def true_fn():638 return ((x * y,), y)639 def false_fn():640 return ((x,), y * 3.0)641 output = control_flow_ops.cond(642 constant_op.constant(False), true_fn, false_fn)643 self.assertEqual(self.evaluate(output[0][0]), 1.)644 self.assertEqual(self.evaluate(output[1]), 9.)645 @test_util.enable_control_flow_v2646 @test_util.run_deprecated_v1647 def testRaisesOutputStructuresMismatch(self):648 x = constant_op.constant(1.0, name="x")649 y = constant_op.constant(3.0, name="y")650 def true_fn():651 return x * y, y652 def false_fn():653 return ((x,), y * 3.0)654 with self.assertRaisesRegexp(655 TypeError, "true_fn and false_fn arguments to tf.cond must have the "656 "same number, type, and overall structure of return values."):657 control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)658 @test_util.enable_control_flow_v2659 def testCondAndTensorArray(self):660 x = math_ops.range(-5, 5)661 output = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=x.shape[0])662 def loop_body(i, output):663 def if_true():664 return output.write(i, x[i]**2)665 def if_false():666 return output.write(i, x[i])667 output = control_flow_ops.cond(x[i] > 0, if_true, if_false)668 return i + 1, output669 _, output = control_flow_ops.while_loop(670 lambda i, arr: i < x.shape[0],671 loop_body,672 loop_vars=(constant_op.constant(0), output))673 output_t = output.stack()674 self.assertAllEqual(675 self.evaluate(output_t), [-5, -4, -3, -2, -1, 0, 1, 4, 9, 16])676 @test_util.enable_control_flow_v2677 def testCondAndTensorArrayInDefun(self):678 @function.defun679 def f():680 x = math_ops.range(-5, 5)681 output = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=x.shape[0])682 def loop_body(i, output):683 def if_true():684 return output.write(i, x[i]**2)685 def if_false():686 return output.write(i, x[i])687 output = control_flow_ops.cond(x[i] > 0, if_true, if_false)688 return i + 1, output689 _, output = control_flow_ops.while_loop(690 lambda i, arr: i < x.shape[0],691 loop_body,692 loop_vars=(constant_op.constant(0), output))693 return output.stack()694 output_t = f()695 self.assertAllEqual(696 self.evaluate(output_t), [-5, -4, -3, -2, -1, 0, 1, 4, 9, 16])697 @test_util.run_deprecated_v1698 def testForwardPassRewrite(self):699 x = constant_op.constant(1.0, name="x")700 output = cond_v2.cond_v2(constant_op.constant(True),701 lambda: x * 2.0,702 lambda: x)703 if_op = output.op.inputs[0].op704 self.assertEqual(if_op.type, "If")705 # pylint: disable=g-deprecated-assert706 self.assertEqual(len(if_op.outputs), 1)707 gradients_impl.gradients(output, x)708 # if_op should have been rewritten to output 2.0 intermediate.709 self.assertEqual(len(if_op.outputs), 2)710 gradients_impl.gradients(output, x)711 # Computing the gradient again shouldn't rewrite if_op again.712 self.assertEqual(len(if_op.outputs), 2)713 # pylint: enable=g-deprecated-assert714class CondV2CollectionTest(test.TestCase):715 def testCollectionIntValueAccessInCond(self):716 """Read values from graph collections inside of cond_v2."""717 with ops.Graph().as_default() as g:718 with self.session(graph=g):719 x = 2720 y = 5721 ops.add_to_collection("x", x)722 ops.add_to_collection("y", y)723 def fn():724 x_const = constant_op.constant(ops.get_collection("x")[0])725 y_const = constant_op.constant(ops.get_collection("y")[0])726 return math_ops.add(x_const, y_const)727 cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn)728 self.assertEquals(cnd.eval(), 7)729 def testCollectionTensorValueAccessInCond(self):730 """Read tensors from collections inside of cond_v2 & use them."""731 with ops.Graph().as_default() as g:732 with self.session(graph=g):733 x = constant_op.constant(2)734 y = constant_op.constant(5)735 ops.add_to_collection("x", x)736 ops.add_to_collection("y", y)737 def fn():738 x_read = ops.get_collection("x")[0]739 y_read = ops.get_collection("y")[0]740 return math_ops.add(x_read, y_read)741 cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)742 self.assertEquals(cnd.eval(), 7)743 def testCollectionIntValueWriteInCond(self):744 """Make sure Int writes to collections work inside of cond_v2."""745 with ops.Graph().as_default() as g:746 with self.session(graph=g):747 x = constant_op.constant(2)748 y = constant_op.constant(5)749 def true_fn():750 z = math_ops.add(x, y)751 ops.add_to_collection("z", 7)752 return math_ops.mul(x, z)753 def false_fn():754 z = math_ops.add(x, y)755 return math_ops.mul(x, z)756 cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn)757 self.assertEquals(cnd.eval(), 14)758 read_z_collection = ops.get_collection("z")759 self.assertEquals(read_z_collection, [7])760class CondV2ContainerTest(test.TestCase):761 def testContainer(self):762 """Set containers outside & inside of cond_v2.763 Make sure the containers are set correctly for both variable creation764 (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)765 """766 self.skipTest("b/113048653")767 with ops.Graph().as_default() as g:768 with self.session(graph=g):769 v0 = variables.Variable([0])770 q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)771 def container(node):772 return node.op.get_attr("container")773 self.assertEqual(compat.as_bytes(""), container(v0))774 self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))775 def true_fn():776 # When this branch is created in cond below,777 # the container should begin with 'l1'778 v1 = variables.Variable([1])779 q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)780 with ops.container("l2t"):781 v2 = variables.Variable([2])782 q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)783 v3 = variables.Variable([1])784 q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)785 self.assertEqual(compat.as_bytes("l1"), container(v1))786 self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))787 self.assertEqual(compat.as_bytes("l2t"), container(v2))788 self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref))789 self.assertEqual(compat.as_bytes("l1"), container(v3))790 self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))791 return constant_op.constant(2.0)792 def false_fn():793 # When this branch is created in cond below,794 # the container should begin with 'l1'795 v1 = variables.Variable([1])796 q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)797 with ops.container("l2f"):798 v2 = variables.Variable([2])799 q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)800 v3 = variables.Variable([1])801 q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)802 self.assertEqual(compat.as_bytes("l1"), container(v1))803 self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))804 self.assertEqual(compat.as_bytes("l2f"), container(v2))805 self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))806 self.assertEqual(compat.as_bytes("l1"), container(v3))807 self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))808 return constant_op.constant(6.0)809 with ops.container("l1"):810 cnd_true = cond_v2.cond_v2(811 constant_op.constant(True), true_fn, false_fn)812 self.assertEquals(cnd_true.eval(), 2)813 cnd_false = cond_v2.cond_v2(814 constant_op.constant(False), true_fn, false_fn)815 self.assertEquals(cnd_false.eval(), 6)816 v4 = variables.Variable([3])817 q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)818 v5 = variables.Variable([4])819 q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)820 self.assertEqual(compat.as_bytes("l1"), container(v4))821 self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))822 self.assertEqual(compat.as_bytes(""), container(v5))823 self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))824class CondV2ColocationGroupAndDeviceTest(test.TestCase):825 def testColocateWithBeforeCond(self):826 with ops.Graph().as_default() as g:827 with self.session(graph=g):828 a = constant_op.constant([2.0], name="a")829 b = constant_op.constant([2.0], name="b")830 def fn():831 c = constant_op.constant(3.0)832 self.assertEqual([b"loc:@a"], c.op.colocation_groups())833 return c834 with ops.colocate_with(a.op):835 self.assertEquals(836 cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3)837 def fn2():838 c = constant_op.constant(3.0)839 self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())840 return c841 with ops.colocate_with(a.op):842 with ops.colocate_with(b.op):843 self.assertEquals(844 cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)845 def testColocateWithInAndOutOfCond(self):846 with ops.Graph().as_default() as g:847 with self.session(graph=g):848 a = constant_op.constant([2.0], name="a")849 b = constant_op.constant([2.0], name="b")850 def fn2():851 with ops.colocate_with(b.op):852 c = constant_op.constant(3.0)853 self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())854 return c855 with ops.colocate_with(a.op):856 self.assertEquals(857 cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)858 d = constant_op.constant([2.0], name="d")859 self.assertEqual([b"loc:@a"], d.op.colocation_groups())860 def testColocateWithInCondGraphPartitioning(self):861 with ops.Graph().as_default() as g:862 with self.session(863 graph=g,864 config=config_pb2.ConfigProto(device_count={"CPU": 2})865 ) as sess:866 with ops.device("/device:CPU:0"):867 a = constant_op.constant([2.0], name="a")868 with ops.device("/device:CPU:1"):869 b = constant_op.constant([2.0], name="b")870 def fn():871 with ops.colocate_with(b.op):872 c = math_ops.add(a, a, name="c")873 return c874 out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn)875 run_options = config_pb2.RunOptions(output_partition_graphs=True)876 run_metadata = config_pb2.RunMetadata()877 sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)878 # We expect there to be two partitions because of the879 # colocate_with. We are only running the cond, which has a data880 # dependency on `a` but not on `b`. So, without the colocate_with881 # we would expect execution on just one device.882 self.assertTrue(len(run_metadata.partition_graphs) >= 2)883 def testDeviceBeforeCond(self):884 with ops.Graph().as_default() as g:885 with self.session(graph=g):886 def fn():887 self.assertEqual("", constant_op.constant(3.0).op.device)888 return test_ops.device_placement_op()889 with ops.device("/device:CPU:0"):890 self.assertIn(891 compat.as_bytes("CPU:0"),892 self.evaluate(cond_v2.cond_v2(constant_op.constant(True),893 fn, fn)))894 def fn2():895 self.assertEqual("", constant_op.constant(3.0).op.device)896 return test_ops.device_placement_op()897 if test_util.is_gpu_available():898 with ops.device("/device:GPU:0"):899 self.assertIn(900 compat.as_bytes("GPU:0"),901 self.evaluate(cond_v2.cond_v2(constant_op.constant(True),902 fn2, fn2)))903 else:904 self.skipTest("Test requires a GPU to check GPU device placement.")905 def testDeviceInAndOutOfCond(self):906 with ops.Graph().as_default() as g:907 with self.session(908 graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})):909 def fn2():910 with ops.device("/device:CPU:1"):911 c = constant_op.constant(3.0)912 self.assertEqual("/device:CPU:1", c.op.device)913 return c914 with ops.device("/device:CPU:0"):915 self.assertEquals(916 cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)917 d = constant_op.constant(4.0)918 self.assertEqual("/device:CPU:0", d.op.device)919 def testDeviceInCondGraphPartitioning(self):920 with ops.Graph().as_default() as g:921 with self.session(922 graph=g,923 config=config_pb2.ConfigProto(device_count={"CPU": 2})924 ) as sess:925 def fn():926 with ops.device("/device:CPU:1"):927 c = math_ops.add(a, a, name="c")928 return c929 with ops.device("/device:CPU:0"):930 a = constant_op.constant([2.0], name="a")931 out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn)932 run_options = config_pb2.RunOptions(output_partition_graphs=True)933 run_metadata = config_pb2.RunMetadata()934 sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)935 self.assertTrue(len(run_metadata.partition_graphs) >= 2)936def _cond(pred, true_fn, false_fn, name):937 if _is_old_cond():938 return control_flow_ops.cond(pred, true_fn, false_fn, name=name)939 else:940 return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)941def _is_old_cond():942 return isinstance(ops.get_default_graph()._get_control_flow_context(),943 control_flow_ops.CondContext)944if __name__ == "__main__":...

Full Screen

Full Screen

test_locks.py

Source:test_locks.py Github

copy

Full Screen

...528 self.assertTrue(RGX_REPR.match(repr(cond)))529 def test_context_manager(self):530 cond = asyncio.Condition(loop=self.loop)531 @asyncio.coroutine532 def acquire_cond():533 return (yield from cond)534 with self.loop.run_until_complete(acquire_cond()):535 self.assertTrue(cond.locked())536 self.assertFalse(cond.locked())537 def test_context_manager_no_yield(self):538 cond = asyncio.Condition(loop=self.loop)539 try:540 with cond:541 self.fail('RuntimeError is not raised in with expression')542 except RuntimeError as err:543 self.assertEqual(544 str(err),545 '"yield from" should be used as context manager expression')546 self.assertFalse(cond.locked())547 def test_explicit_lock(self):548 lock = asyncio.Lock(loop=self.loop)...

Full Screen

Full Screen

mk_star_flag.py

Source:mk_star_flag.py Github

copy

Full Screen

1import os2import numpy as np3import astropy.io.fits as fitsio4import matplotlib5import matplotlib.pyplot as plt6import matplotlib.gridspec as gridspec7from useful import calc_sn8def get_mags(catalog):9 flx_b,err_b,cov_b = catalog["FLUX_APER_supcam_b"][:,2], catalog["FLUXERR_APER_supcam_b"][:,2], catalog["COVERAGE_FLAG_supcam_b"]10 flx_z,err_z,cov_z = catalog["FLUX_APER_supcam_z"][:,2], catalog["FLUXERR_APER_supcam_z"][:,2], catalog["COVERAGE_FLAG_supcam_z"]11 flx_k,err_k,cov_k = catalog["FLUX_APER_video_ks"][:,2], catalog["FLUXERR_APER_video_ks"][:,2], catalog["COVERAGE_FLAG_video_ks"]12 sn_b = flx_b/err_b13 sn_z = flx_z/err_z14 sn_k = flx_k/err_k15 cond_cov = (cov_b == 1) & (cov_z == 1) & (cov_k == 1)16 cond_det = (sn_b>1) & (sn_z>3) & (sn_k>3) & cond_cov17 cond_non = (sn_b<1) & (sn_z>3) & (sn_k>3) & cond_cov18 mag_b,mag_z,mag_k = np.zeros((3,len(catalog))) - 99.19 mag_b[cond_det] = -2.5*np.log10(flx_b[cond_det]) + 23.9320 mag_z[cond_det] = -2.5*np.log10(flx_z[cond_det]) + 23.9321 mag_k[cond_det] = -2.5*np.log10(flx_k[cond_det]) + 23.9322 mag_b[cond_non] = -2.5*np.log10(err_b[cond_non]) + 23.9323 mag_z[cond_non] = -2.5*np.log10(flx_z[cond_non]) + 23.9324 mag_k[cond_non] = -2.5*np.log10(flx_k[cond_non]) + 23.9325 return mag_b, mag_z, mag_k, catalog["LPH_Z_BEST"], cond_det, cond_non, cond_cov26def get_star_flag(catalog):27 mag_b, mag_z, mag_k, z, cond_det, cond_non, cond_cov = get_mags(catalog)28 cond_chi = (catalog["LPH_CHI_BEST"] - catalog["LPH_CHI_STAR"]) > 029 cond_bzk = (mag_z - mag_k) < (mag_b - mag_z) * 0.3 - 0.2530 cond_star = cond_cov & (cond_chi & cond_bzk)31 cond_gal = cond_cov & ~(cond_chi & cond_bzk)32 print ("Star/Galaxy classification: %i stars and %i gals (no classification for %i)" % (np.sum(cond_star),np.sum(cond_gal),np.sum(~cond_cov)))33 return cond_star, cond_gal, cond_cov34def add_sg_classification(catalog):35 cond_star, cond_gal, cond_cov = get_star_flag(catalog)36 catalog["STAR_FLAG"][cond_star] = 137 catalog["STAR_FLAG"][cond_gal ] = 038 catalog["STAR_FLAG"][~cond_cov] = -9939 catalog["ZPHOT"][ cond_star] = 0.0040 return catalog41def mk_plot():42 catalog = fitsio.getdata("final_cats/final_catalog_errfix_zphot.fits")43 mag_b, mag_z, mag_k, z, cond_det, cond_non, cond_cov = get_mags(catalog)44 cond_star,cond_gal,cond_cov = get_star_flag(catalog)45 fig = plt.figure(figsize=(10,9),dpi=75)46 fig.subplots_adjust(left=0.09,right=0.91,top=0.98,bottom=0.08,wspace=0.02,hspace=0.14)47 ogs = gridspec.GridSpec(1,2,width_ratios=[25,1])48 ax = fig.add_subplot(ogs[0,0])49 cax = fig.add_subplot(ogs[0,1])50 # ax2 = fig.add_subplot(ogs[1,:])51 vmax = 4.552 cmap = matplotlib.cm.get_cmap('RdYlBu_r')53 norm = matplotlib.colors.Normalize(vmin=0, vmax=vmax)54 zz = np.arange(-1,7,0.1)55 ss = np.linspace(0.5,1.5,len(zz)-1)**256 for z0,z1,s in zip (zz[:-1],zz[1:],ss):57 cond_z = (z0<=z) & (z<z1)58 im = ax.scatter(-99,-99,s=0.5,c=0,cmap=plt.cm.RdYlBu_r,vmin=0,vmax=vmax)59 ax.scatter((mag_b-mag_z)[cond_det&cond_z],(mag_z-mag_k)[cond_det&cond_z],c=cmap(norm(z[cond_det&cond_z])),s=s,alpha=0.4)60 ax.errorbar((mag_b-mag_z)[cond_non&cond_z],(mag_z-mag_k)[cond_non&cond_z],xerr=0.15,xlolims=True,ecolor=cmap(norm(0.5*(z0+z1))),linestyle='',marker='',alpha=0.4)61 ax.scatter((mag_b-mag_z)[cond_det&cond_star],(mag_z-mag_k)[cond_det&cond_star],s=2,c='k',alpha=0.4)62 ax.errorbar((mag_b-mag_z)[cond_non&cond_star],(mag_z-mag_k)[cond_non&cond_star],xerr=0.15,xlolims=True,ecolor='k',linestyle='',marker='',alpha=0.4)63 plt.colorbar(im, cax=cax)64 # ax2.hist(catalog["LPH_Z_BEST"],bins=np.arange(-1,7,0.05),color='k',alpha=0.4)65 # ax2.hist(catalog["LPH_Z_BEST"][cond_star],bins=np.arange(-1,7,0.05),color='k',alpha=0.4)66 ax.set_xlabel("$B-z$",fontsize=24)67 ax.set_ylabel("$z-K$",fontsize=24)68 ax.set_xlim(-0.8,7.8)69 ax.set_ylim(-1.5,4.0)70 cax.set_ylabel("photo-z",fontsize=24)71 # ax2.set_xlabel('z',fontsize=18)72 # ax2.set_xlim(0-0.1,6.1)73 _ = [label.set_fontsize(16) for label in ax.get_yticklabels()+ax.get_xticklabels()+cax.get_yticklabels()]74 # _ = [label.set_visible(False) for label in ax2.get_yticklabels()]75 fig.savefig("final_cats/plots/star_flag_bzk.png")76if __name__ == '__main__':77 mk_plot()...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run hypothesis automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful