How to use exhausted method in hypothesis

Best Python code snippet using hypothesis

dataset_serialization_test_base.py

Source:dataset_serialization_test_base.py Github

copy

Full Screen

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Base class for testing serializable datasets."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19import os20import numpy as np21from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops22from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions23from tensorflow.python.data.ops import dataset_ops24from tensorflow.python.data.ops import iterator_ops25from tensorflow.python.framework import dtypes26from tensorflow.python.framework import errors27from tensorflow.python.framework import ops28from tensorflow.python.framework import sparse_tensor29from tensorflow.python.ops import lookup_ops30from tensorflow.python.ops import variables31from tensorflow.python.platform import gfile32from tensorflow.python.platform import test33from tensorflow.python.training import checkpoint_management34from tensorflow.python.training import saver as saver_lib35from tensorflow.python.util import nest36def remove_variants(get_next_op):37 # TODO(b/72408568): Remove this once session.run can get38 # variant tensors.39 """Remove variants from a nest structure, so sess.run will execute."""40 def _remove_variant(x):41 if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:42 return ()43 else:44 return x45 return nest.map_structure(_remove_variant, get_next_op)46class DatasetSerializationTestBase(test.TestCase):47 """Base class for testing serializable datasets."""48 def tearDown(self):49 self._delete_ckpt()50 # TODO(b/72657739): Remove sparse_tensor argument, which is to test the51 # (deprecated) saveable `SparseTensorSliceDataset`, once the API52 # `from_sparse_tensor_slices()`and related tests are deleted.53 def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):54 """Runs the core tests.55 Args:56 ds_fn1: 0-argument function that returns a Dataset.57 ds_fn2: 0-argument function that returns a Dataset different from58 ds_fn1. If None, verify_restore_in_modified_graph test is not run.59 num_outputs: Total number of outputs expected from this Dataset.60 sparse_tensors: Whether dataset is built from SparseTensor(s).61 Raises:62 AssertionError if any test fails.63 """64 # NOTE: We disable all default optimizations in serialization tests in order65 # to test the actual dataset in question.66 options = dataset_ops.Options()67 options.experimental_optimization = OptimizationOptions()68 options.experimental_optimization.apply_default_optimizations = False69 def ds_fn1_no_opt():70 return ds_fn1().with_options(options)71 self.verify_unused_iterator(72 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)73 self.verify_fully_used_iterator(74 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)75 self.verify_exhausted_iterator(76 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)77 self.verify_init_before_restore(78 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)79 self.verify_multiple_breaks(80 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)81 self.verify_reset_restored_iterator(82 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)83 self.verify_restore_in_empty_graph(84 ds_fn1_no_opt, num_outputs, sparse_tensors=sparse_tensors)85 if ds_fn2:86 def ds_fn2_no_opt():87 return ds_fn2().with_options(options)88 self.verify_restore_in_modified_graph(89 ds_fn1_no_opt,90 ds_fn2_no_opt,91 num_outputs,92 sparse_tensors=sparse_tensors)93 def verify_unused_iterator(self,94 ds_fn,95 num_outputs,96 sparse_tensors=False,97 verify_exhausted=True):98 """Verifies that saving and restoring an unused iterator works.99 Args:100 ds_fn: See `run_core_tests`.101 num_outputs: See `run_core_tests`.102 sparse_tensors: See `run_core_tests`.103 verify_exhausted: See `gen_outputs`.104 Raises:105 AssertionError if any test fails.106 """107 self.verify_run_with_breaks(108 ds_fn, [0],109 num_outputs,110 sparse_tensors=sparse_tensors,111 verify_exhausted=verify_exhausted)112 def verify_fully_used_iterator(self, ds_fn, num_outputs,113 sparse_tensors=False):114 """Verifies that saving and restoring a fully used iterator works.115 Note that this only checks saving and restoring an iterator from which116 `num_outputs` items have been produced but does not check for an117 exhausted iterator, i.e., one from which an OutOfRange error has been118 returned.119 Args:120 ds_fn: See `run_core_tests`.121 num_outputs: See `run_core_tests`.122 sparse_tensors: See `run_core_tests`.123 Raises:124 AssertionError if test fails.125 """126 self.verify_run_with_breaks(127 ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)128 def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):129 """Verifies that saving and restoring an exhausted iterator works.130 An exhausted iterator is one which has returned an OutOfRange error.131 Args:132 ds_fn: See `run_core_tests`.133 num_outputs: See `run_core_tests`.134 sparse_tensors: See `run_core_tests`.135 Raises:136 AssertionError if any test fails.137 """138 self.gen_outputs(139 ds_fn, [],140 num_outputs,141 verify_exhausted=True,142 sparse_tensors=sparse_tensors)143 actual = self.gen_outputs(144 ds_fn, [],145 0,146 ckpt_saved=True,147 verify_exhausted=True,148 sparse_tensors=sparse_tensors)149 self.assertEqual(len(actual), 0)150 def verify_init_before_restore(self,151 ds_fn,152 num_outputs,153 sparse_tensors=False,154 verify_exhausted=True):155 """Verifies that restoring into an already initialized iterator works.156 Args:157 ds_fn: See `run_core_tests`.158 num_outputs: See `run_core_tests`.159 sparse_tensors: See `run_core_tests`.160 verify_exhausted: See `gen_outputs`.161 Raises:162 AssertionError if any test fails.163 """164 self.verify_run_with_breaks(165 ds_fn,166 self.gen_break_points(num_outputs),167 num_outputs,168 init_before_restore=True,169 sparse_tensors=sparse_tensors,170 verify_exhausted=verify_exhausted)171 def verify_multiple_breaks(self,172 ds_fn,173 num_outputs,174 num_breaks=10,175 sparse_tensors=False,176 verify_exhausted=True):177 """Attempts to save/restore at multiple break points.178 Args:179 ds_fn: See `run_core_tests`.180 num_outputs: See `run_core_tests`.181 num_breaks: The number of break points. These are uniformly spread in182 [0, num_outputs] both inclusive.183 sparse_tensors: See `run_core_tests`.184 verify_exhausted: See `gen_outputs`.185 Raises:186 AssertionError if any test fails.187 """188 self.verify_run_with_breaks(189 ds_fn,190 self.gen_break_points(num_outputs, num_breaks),191 num_outputs,192 sparse_tensors=sparse_tensors,193 verify_exhausted=verify_exhausted)194 def verify_reset_restored_iterator(self,195 ds_fn,196 num_outputs,197 break_point=None,198 sparse_tensors=False,199 verify_exhausted=True):200 """Attempts to re-initialize a restored iterator.201 This is useful when restoring a training checkpoint during validation.202 Args:203 ds_fn: See `run_core_tests`.204 num_outputs: See `run_core_tests`.205 break_point: Break point. Optional. Defaults to num_outputs/2.206 sparse_tensors: See `run_core_tests`.207 verify_exhausted: See `gen_outputs`.208 Raises:209 AssertionError if any test fails.210 """211 break_point = num_outputs // 2 if not break_point else break_point212 # Collect ground truth containing all outputs.213 expected = self.gen_outputs(214 ds_fn, [],215 num_outputs,216 sparse_tensors=sparse_tensors,217 verify_exhausted=verify_exhausted)218 # Skip some items and save checkpoint.219 self.gen_outputs(220 ds_fn, [],221 break_point,222 sparse_tensors=sparse_tensors,223 verify_exhausted=False)224 actual = []225 # Restore from checkpoint and then run init_op.226 with ops.Graph().as_default() as g:227 saver = self._import_meta_graph()228 init_op, get_next_op = self._get_iterator_ops_from_collection(229 ds_fn, sparse_tensors=sparse_tensors)230 get_next_op = remove_variants(get_next_op)231 with self.session(graph=g) as sess:232 self._restore(saver, sess)233 self._initialize(init_op, sess)234 for _ in range(num_outputs):235 actual.append(sess.run(get_next_op))236 if verify_exhausted:237 with self.assertRaises(errors.OutOfRangeError):238 sess.run(get_next_op)239 self.match(expected, actual)240 def verify_restore_in_modified_graph(self,241 ds_fn1,242 ds_fn2,243 num_outputs,244 break_point=None,245 sparse_tensors=False,246 verify_exhausted=True):247 """Attempts to restore an iterator in a modified graph.248 Builds an input pipeline using ds_fn1, runs it for `break_point` steps249 and saves a checkpoint. Then builds a new graph using ds_fn2, restores250 the checkpoint from ds_fn1 and verifies that the restore is successful.251 Args:252 ds_fn1: See `run_core_tests`.253 ds_fn2: See `run_core_tests`.254 num_outputs: See `run_core_tests`.255 break_point: Break point. Optional. Defaults to num_outputs/2.256 sparse_tensors: See `run_core_tests`.257 verify_exhausted: See `gen_outputs`.258 Raises:259 AssertionError if any test fails.260 """261 break_point = num_outputs // 2 if not break_point else break_point262 # Skip `break_point` items and store the remaining produced from ds_fn1263 # in `expected`.264 self.gen_outputs(265 ds_fn1, [],266 break_point,267 sparse_tensors=sparse_tensors,268 verify_exhausted=False)269 expected = self.gen_outputs(270 ds_fn1, [],271 num_outputs - break_point,272 ckpt_saved=True,273 sparse_tensors=sparse_tensors,274 verify_exhausted=verify_exhausted)275 # Generate `break_point` items from ds_fn1 and save checkpoint.276 self.gen_outputs(277 ds_fn1, [],278 break_point,279 sparse_tensors=sparse_tensors,280 verify_exhausted=False)281 actual = []282 # Build graph for ds_fn2 but load checkpoint for ds_fn1.283 with ops.Graph().as_default() as g:284 _, get_next_op, saver = self._build_graph(285 ds_fn2, sparse_tensors=sparse_tensors)286 get_next_op = remove_variants(get_next_op)287 with self.session(graph=g) as sess:288 self._restore(saver, sess)289 for _ in range(num_outputs - break_point):290 actual.append(sess.run(get_next_op))291 if verify_exhausted:292 with self.assertRaises(errors.OutOfRangeError):293 sess.run(get_next_op)294 self.match(expected, actual)295 def verify_restore_in_empty_graph(self,296 ds_fn,297 num_outputs,298 break_point=None,299 sparse_tensors=False,300 verify_exhausted=True):301 """Attempts to restore an iterator in an empty graph.302 Builds an input pipeline using ds_fn, runs it for `break_point` steps303 and saves a checkpoint. Then builds a new empty graph, restores304 the checkpoint from ds_fn and verifies that the restore is successful.305 Args:306 ds_fn: See `run_core_tests`.307 num_outputs: See `run_core_tests`.308 break_point: Break point. Optional. Defaults to num_outputs/2.309 sparse_tensors: See `run_core_tests`.310 verify_exhausted: See `gen_outputs`.311 Raises:312 AssertionError if any test fails.313 """314 break_point = num_outputs // 2 if not break_point else break_point315 # Skip `break_point` items and store the remaining produced from ds_fn316 # in `expected`.317 self.gen_outputs(318 ds_fn, [],319 break_point,320 sparse_tensors=sparse_tensors,321 verify_exhausted=False)322 expected = self.gen_outputs(323 ds_fn, [],324 num_outputs - break_point,325 ckpt_saved=True,326 sparse_tensors=sparse_tensors,327 verify_exhausted=verify_exhausted)328 # Generate `break_point` items from ds_fn and save checkpoint.329 self.gen_outputs(330 ds_fn, [],331 break_point,332 sparse_tensors=sparse_tensors,333 verify_exhausted=False)334 actual = []335 # Build an empty graph but load checkpoint for ds_fn.336 with ops.Graph().as_default() as g:337 get_next_op, saver = self._build_empty_graph(338 ds_fn, sparse_tensors=sparse_tensors)339 get_next_op = remove_variants(get_next_op)340 with self.session(graph=g) as sess:341 self._restore(saver, sess)342 for _ in range(num_outputs - break_point):343 actual.append(sess.run(get_next_op))344 if verify_exhausted:345 with self.assertRaises(errors.OutOfRangeError):346 sess.run(get_next_op)347 self.match(expected, actual)348 def verify_error_on_save(self,349 ds_fn,350 num_outputs,351 error,352 break_point=None,353 sparse_tensors=False):354 """Attempts to save a non-saveable iterator.355 Args:356 ds_fn: See `run_core_tests`.357 num_outputs: See `run_core_tests`.358 error: Declared error when trying to save iterator.359 break_point: Break point. Optional. Defaults to num_outputs/2.360 sparse_tensors: See `run_core_tests`.361 Raises:362 AssertionError if any test fails.363 """364 break_point = num_outputs // 2 if not break_point else break_point365 with ops.Graph().as_default() as g:366 init_op, get_next_op, saver = self._build_graph(367 ds_fn, sparse_tensors=sparse_tensors)368 get_next_op = remove_variants(get_next_op)369 with self.session(graph=g) as sess:370 self._initialize(init_op, sess)371 for _ in range(break_point):372 sess.run(get_next_op)373 with self.assertRaises(error):374 self._save(sess, saver)375 def verify_run_with_breaks(self,376 ds_fn,377 break_points,378 num_outputs,379 init_before_restore=False,380 sparse_tensors=False,381 verify_exhausted=True):382 """Verifies that ds_fn() produces the same outputs with and without breaks.383 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it384 *without* stopping at break points.385 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it386 with stopping at break points.387 Deep matches outputs from 1 and 2.388 Args:389 ds_fn: See `gen_outputs`.390 break_points: See `gen_outputs`.391 num_outputs: See `gen_outputs`.392 init_before_restore: See `gen_outputs`.393 sparse_tensors: See `run_core_tests`.394 verify_exhausted: See `gen_outputs`.395 Raises:396 AssertionError if any test fails.397 """398 expected = self.gen_outputs(399 ds_fn, [],400 num_outputs,401 init_before_restore=init_before_restore,402 sparse_tensors=sparse_tensors,403 verify_exhausted=verify_exhausted)404 actual = self.gen_outputs(405 ds_fn,406 break_points,407 num_outputs,408 init_before_restore=init_before_restore,409 sparse_tensors=sparse_tensors,410 verify_exhausted=verify_exhausted)411 self.match(expected, actual)412 def gen_outputs(self,413 ds_fn,414 break_points,415 num_outputs,416 ckpt_saved=False,417 init_before_restore=False,418 sparse_tensors=False,419 verify_exhausted=True,420 save_checkpoint_at_end=True):421 """Generates elements from input dataset while stopping at break points.422 Produces `num_outputs` outputs and saves the state of the iterator in the423 Saver checkpoint.424 Args:425 ds_fn: 0-argument function that returns the dataset.426 break_points: A list of integers. For each `break_point` in427 `break_points`, we produce outputs till `break_point` number of items428 have been produced and then checkpoint the state. The current graph429 and session are destroyed and a new graph and session are used to430 produce outputs till next checkpoint or till `num_outputs` elements431 have been produced. `break_point` must be <= `num_outputs`.432 num_outputs: The total number of outputs to produce from the iterator.433 ckpt_saved: Whether a checkpoint already exists. If False, we build the434 graph from ds_fn.435 init_before_restore: Whether init should be called before saver.restore.436 This is just so that we can verify that restoring an already initialized437 iterator works.438 sparse_tensors: Whether dataset is built from SparseTensor(s).439 verify_exhausted: Whether to verify that the iterator has been exhausted440 after producing `num_outputs` elements.441 save_checkpoint_at_end: Whether to save a checkpoint after producing all442 outputs. If False, checkpoints are saved each break point but not at the443 end. Note that checkpoints overwrite each other so there is always only444 a single checkpoint available. Defaults to True.445 Returns:446 A list of `num_outputs` items.447 """448 outputs = []449 def get_ops():450 if ckpt_saved:451 saver = self._import_meta_graph()452 init_op, get_next_op = self._get_iterator_ops_from_collection(453 ds_fn, sparse_tensors=sparse_tensors)454 else:455 init_op, get_next_op, saver = self._build_graph(456 ds_fn, sparse_tensors=sparse_tensors)457 return init_op, get_next_op, saver458 for i in range(len(break_points) + 1):459 with ops.Graph().as_default() as g:460 init_op, get_next_op, saver = get_ops()461 get_next_op = remove_variants(get_next_op)462 with self.session(graph=g) as sess:463 if ckpt_saved:464 if init_before_restore:465 self._initialize(init_op, sess)466 self._restore(saver, sess)467 else:468 self._initialize(init_op, sess)469 start = break_points[i - 1] if i > 0 else 0470 end = break_points[i] if i < len(break_points) else num_outputs471 num_iters = end - start472 for _ in range(num_iters):473 outputs.append(sess.run(get_next_op))474 if i == len(break_points) and verify_exhausted:475 with self.assertRaises(errors.OutOfRangeError):476 sess.run(get_next_op)477 if save_checkpoint_at_end or i < len(break_points):478 self._save(sess, saver)479 ckpt_saved = True480 return outputs481 def match(self, expected, actual):482 """Matches nested structures.483 Recursively matches shape and values of `expected` and `actual`.484 Handles scalars, numpy arrays and other python sequence containers485 e.g. list, dict.486 Args:487 expected: Nested structure 1.488 actual: Nested structure 2.489 Raises:490 AssertionError if matching fails.491 """492 if isinstance(expected, np.ndarray):493 expected = expected.tolist()494 if isinstance(actual, np.ndarray):495 actual = actual.tolist()496 self.assertEqual(type(expected), type(actual))497 if nest.is_sequence(expected):498 self.assertEqual(len(expected), len(actual))499 if isinstance(expected, dict):500 for key1, key2 in zip(sorted(expected), sorted(actual)):501 self.assertEqual(key1, key2)502 self.match(expected[key1], actual[key2])503 else:504 for item1, item2 in zip(expected, actual):505 self.match(item1, item2)506 else:507 self.assertEqual(expected, actual)508 def does_not_match(self, expected, actual):509 with self.assertRaises(AssertionError):510 self.match(expected, actual)511 def gen_break_points(self, num_outputs, num_samples=10):512 """Generates `num_samples` breaks points in [0, num_outputs]."""513 return np.linspace(0, num_outputs, num_samples, dtype=int)514 def _build_graph(self, ds_fn, sparse_tensors=False):515 iterator = dataset_ops.make_initializable_iterator(ds_fn())516 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)517 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)518 init_op = iterator.initializer519 if sparse_tensors:520 get_next = sparse_tensor.SparseTensor(*iterator.get_next())521 else:522 get_next = iterator.get_next()523 self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,524 sparse_tensors)525 saver = saver_lib.Saver(allow_empty=True)526 return init_op, get_next, saver527 def _build_empty_graph(self, ds_fn, sparse_tensors=False):528 iterator = iterator_ops.Iterator.from_structure(529 self._get_output_types(ds_fn),530 output_shapes=self._get_output_shapes(ds_fn),531 output_classes=self._get_output_classes(ds_fn))532 saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)533 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)534 if sparse_tensors:535 get_next = sparse_tensor.SparseTensor(*iterator.get_next())536 else:537 get_next = iterator.get_next()538 saver = saver_lib.Saver(allow_empty=True)539 return get_next, saver540 def _add_iterator_ops_to_collection(self,541 init_op,542 get_next,543 ds_fn,544 sparse_tensors=False):545 ops.add_to_collection("iterator_ops", init_op)546 # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections547 # do not support tuples we flatten the tensors and restore the shape in548 # `_get_iterator_ops_from_collection`.549 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.550 ops.add_to_collection("iterator_ops", get_next.indices)551 ops.add_to_collection("iterator_ops", get_next.values)552 ops.add_to_collection("iterator_ops", get_next.dense_shape)553 return554 get_next_list = nest.flatten(get_next)555 for i, output_class in enumerate(556 nest.flatten(self._get_output_classes(ds_fn))):557 if output_class is sparse_tensor.SparseTensor:558 ops.add_to_collection("iterator_ops", get_next_list[i].indices)559 ops.add_to_collection("iterator_ops", get_next_list[i].values)560 ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)561 else:562 ops.add_to_collection("iterator_ops", get_next_list[i])563 def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):564 all_ops = ops.get_collection("iterator_ops")565 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.566 init_op, indices, values, dense_shape = all_ops567 return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)568 get_next_list = []569 i = 1570 for output_class in nest.flatten(self._get_output_classes(ds_fn)):571 if output_class is sparse_tensor.SparseTensor:572 indices, values, dense_shape = all_ops[i:i + 3]573 i += 3574 get_next_list.append(575 sparse_tensor.SparseTensor(indices, values, dense_shape))576 else:577 get_next_list.append(all_ops[i])578 i += 1579 return all_ops[0], nest.pack_sequence_as(580 self._get_output_types(ds_fn), get_next_list)581 def _get_output_types(self, ds_fn):582 with ops.Graph().as_default():583 return ds_fn().output_types584 def _get_output_shapes(self, ds_fn):585 with ops.Graph().as_default():586 return ds_fn().output_shapes587 def _get_output_classes(self, ds_fn):588 with ops.Graph().as_default():589 return ds_fn().output_classes590 def _ckpt_path(self):591 return os.path.join(self.get_temp_dir(), "iterator")592 def _latest_ckpt(self):593 return checkpoint_management.latest_checkpoint(self.get_temp_dir())594 def _save(self, sess, saver):595 saver.save(sess, self._ckpt_path())596 def _restore(self, saver, sess):597 sess.run(lookup_ops.tables_initializer())598 saver.restore(sess, self._latest_ckpt())599 def _initialize(self, init_op, sess):600 sess.run(variables.global_variables_initializer())601 sess.run(lookup_ops.tables_initializer())602 sess.run(init_op)603 def _import_meta_graph(self):604 meta_file_path = self._ckpt_path() + ".meta"605 return saver_lib.import_meta_graph(meta_file_path)606 def _delete_ckpt(self):607 # Remove all checkpoint files.608 prefix = self._ckpt_path()609 pattern = prefix + "*"610 files = gfile.Glob(pattern)...

Full Screen

Full Screen

model_cache_mixin.py

Source:model_cache_mixin.py Github

copy

Full Screen

1import weakref2import itertools3from claripy import errors4class ModelCache(object):5 _defaults = { 0, 0.0, True }6 def __init__(self, model):7 self.model = model8 self.replacements = weakref.WeakKeyDictionary()9 def __hash__(self):10 if not hasattr(self, '_hash'):11 self._hash = hash(frozenset(self.model.items())) #pylint:disable=attribute-defined-outside-init12 return self._hash13 def __eq__(self, other):14 return self.model == other.model15 def __getstate__(self):16 return (self.model,)17 def __setstate__(self, s):18 self.model = s[0]19 self.replacements = weakref.WeakKeyDictionary()20 #21 # Splitting support22 #23 def filter(self, variables):24 return ModelCache({ k:self.model[k] for k in self.model if k in variables })25 @staticmethod26 def combine(*models):27 return ModelCache(dict(itertools.chain.from_iterable(m.model.items() for m in models)))28 #29 # Model-driven evaluation30 #31 def _leaf_op(self, a):32 return (33 all_operations.BVV(self.model.get(a.args[0], 0), a.length) if a.op == 'BVS' else34 all_operations.BoolV(self.model.get(a.args[0], True)) if a.op == 'BoolS' else35 all_operations.FPV(self.model.get(a.args[0], 0.0), a.args[1]) if a.op == 'FPS' else36 a37 )38 def eval_ast(self, ast):39 """Eval the ast, replacing symbols by their last value in the model.40 """41 # If there was no last value, it was not constrained, so we can use42 # anything.43 new_ast = ast._replace(self.replacements, leaf_operation=self._leaf_op)44 return backends.concrete.eval(new_ast, 1)[0]45 def eval_constraints(self, constraints):46 """Returns whether the constraints is satisfied trivially by using the47 last model."""48 # eval_ast is concretizing symbols and evaluating them, this can raise49 # exceptions.50 try:51 return all(self.eval_ast(c) for c in constraints)52 except errors.ClaripyZeroDivisionError:53 return False54 def eval_list(self, asts):55 return tuple(self.eval_ast(c) for c in asts)56class ModelCacheMixin(object):57 def __init__(self, *args, **kwargs):58 super(ModelCacheMixin, self).__init__(*args, **kwargs)59 self._models = set()60 self._exhausted = False61 self._eval_exhausted = weakref.WeakSet()62 self._max_exhausted = weakref.WeakSet()63 self._min_exhausted = weakref.WeakSet()64 def _blank_copy(self, c):65 super(ModelCacheMixin, self)._blank_copy(c)66 c._models = set()67 c._exhausted = False68 c._eval_exhausted = weakref.WeakSet()69 c._max_exhausted = weakref.WeakSet()70 c._min_exhausted = weakref.WeakSet()71 def _copy(self, c):72 super(ModelCacheMixin, self)._copy(c)73 c._models = set(self._models)74 c._exhausted = self._exhausted75 c._eval_exhausted = weakref.WeakSet(self._eval_exhausted)76 c._max_exhausted = weakref.WeakSet(self._max_exhausted)77 c._min_exhausted = weakref.WeakSet(self._min_exhausted)78 def _ana_getstate(self):79 return (80 self._models,81 self._exhausted,82 tuple(self._eval_exhausted),83 tuple(self._max_exhausted),84 tuple(self._min_exhausted),85 super(ModelCacheMixin, self)._ana_getstate()86 )87 def _ana_setstate(self, s):88 (89 self._models,90 self._exhausted,91 _eval_exhausted,92 _max_exhausted,93 _min_exhausted,94 base_state95 ) = s96 super(ModelCacheMixin, self)._ana_setstate(base_state)97 self._eval_exhausted = weakref.WeakSet(_eval_exhausted)98 self._max_exhausted = weakref.WeakSet(_max_exhausted)99 self._min_exhausted = weakref.WeakSet(_min_exhausted)100 #101 # Model cleaning102 #103 def simplify(self, *args, **kwargs):104 results = super(ModelCacheMixin, self).simplify(*args, **kwargs)105 if len(results) > 0 and any(c is false for c in results):106 self._models.clear()107 return results108 def add(self, constraints, invalidate_cache=True, **kwargs):109 if len(constraints) == 0:110 return constraints111 old_vars = frozenset(self.variables)112 added = super(ModelCacheMixin, self).add(constraints, **kwargs)113 if len(added) == 0:114 return added115 new_vars = any(a.variables - old_vars for a in added)116 if new_vars or invalidate_cache:117 # shortcut for unsat118 if any(c is false for c in constraints):119 self._models.clear()120 still_valid = set(self._get_models(extra_constraints=added))121 if len(still_valid) != len(self._models):122 self._exhausted = False123 self._eval_exhausted.clear()124 self._max_exhausted.clear()125 self._min_exhausted.clear()126 self._models = still_valid127 return added128 def split(self):129 results = super(ModelCacheMixin, self).split()130 for r in results:131 r._models = { m.filter(r.variables) for m in self._models }132 return results133 def combine(self, others):134 combined = super(ModelCacheMixin, self).combine(others)135 if any(len(o._models) == 0 for o in others) or len(self._models) == 0:136 # this would need a solve anyways, so screw it137 return combined138 vars_count = len(self.variables) + sum(len(s.variables) for s in others)139 all_vars = self.variables.union(*[s.variables for s in others])140 if vars_count != len(all_vars):141 # this is the case where there are variables missing from the models.142 # We'll need more intelligence here to handle it143 return combined144 model_lists = [ self._models ]145 model_lists.extend(o._models for o in others)146 combined._models.update(147 ModelCache.combine(*product) for product in148 itertools.islice(itertools.product(*model_lists), len(self._models))149 )150 return combined151 def update(self, other):152 """153 Updates this cache mixin with results discovered by the other split off one.154 """155 acceptable_models = [ m for m in other._models if set(m.model.keys()) == self.variables ]156 self._models.update(acceptable_models)157 self._eval_exhausted.update(other._eval_exhausted)158 self._max_exhausted.update(other._max_exhausted)159 self._min_exhausted.update(other._min_exhausted)160 #161 # Cache retrieval162 #163 def _model_hook(self, m):164 self._models.add(ModelCache(m))165 def _get_models(self, extra_constraints=()):166 for m in self._models:167 if m.eval_constraints(extra_constraints):168 yield m169 def _get_batch_solutions(self, asts, n=None, extra_constraints=()):170 results = set()171 for m in self._get_models(extra_constraints):172 try:173 results.add(m.eval_list(asts))174 except ZeroDivisionError:175 continue176 if len(results) == n:177 break178 return results179 def _get_solutions(self, e, n=None, extra_constraints=()):180 return tuple(v[0] for v in self._get_batch_solutions(181 [e], n=n, extra_constraints=extra_constraints182 ))183 #184 # Cached functions185 #186 def satisfiable(self, extra_constraints=(), **kwargs):187 for _ in self._get_models(extra_constraints=extra_constraints):188 return True189 return super(ModelCacheMixin, self).satisfiable(extra_constraints=extra_constraints, **kwargs)190 def batch_eval(self, asts, n, extra_constraints=(), **kwargs):191 results = self._get_batch_solutions(asts, n=n, extra_constraints=extra_constraints)192 if len(results) == n or (len(asts) == 1 and asts[0].cache_key in self._eval_exhausted):193 return results194 remaining = n - len(results)195 # TODO: faster to concat?196 if len(results) != 0:197 constraints = (all_operations.And(*[198 all_operations.Or(*[a!=v for a,v in zip(asts, r)]) for r in results199 ]),) + extra_constraints200 else:201 constraints = extra_constraints202 try:203 results.update(super(ModelCacheMixin, self).batch_eval(204 asts, remaining, extra_constraints=constraints, **kwargs205 ))206 except UnsatError:207 if len(results) == 0:208 raise209 if len(extra_constraints) == 0 and len(results) < n:210 self._eval_exhausted.update(e.cache_key for e in asts)211 return results212 def eval(self, e, n, **kwargs):213 return tuple( r[0] for r in ModelCacheMixin.batch_eval(self, [e], n=n, **kwargs) )214 def min(self, e, extra_constraints=(), **kwargs):215 cached = [ ]216 if e.cache_key in self._eval_exhausted or e.cache_key in self._min_exhausted:217 cached = self._get_solutions(e, extra_constraints=extra_constraints)218 if len(cached) > 0:219 return min(cached)220 else:221 m = super(ModelCacheMixin, self).min(e, extra_constraints=extra_constraints, **kwargs)222 self._min_exhausted.add(e.cache_key)223 return m224 def max(self, e, extra_constraints=(), **kwargs):225 cached = [ ]226 if e.cache_key in self._eval_exhausted or e.cache_key in self._max_exhausted:227 cached = self._get_solutions(e, extra_constraints=extra_constraints)228 if len(cached) > 0:229 return max(cached)230 else:231 m = super(ModelCacheMixin, self).max(e, extra_constraints=extra_constraints, **kwargs)232 self._max_exhausted.add(e.cache_key)233 return m234 def solution(self, e, v, extra_constraints=(), **kwargs):235 if isinstance(v, Base):236 cached = self._get_batch_solutions([e,v], extra_constraints=extra_constraints)237 if any(ec == vc for ec,vc in cached):238 return True239 else:240 cached = self._get_solutions(e, extra_constraints=extra_constraints)241 if v in cached:242 return True243 return super(ModelCacheMixin, self).solution(e, v, extra_constraints=extra_constraints, **kwargs)244from .. import backends, false245from ..errors import UnsatError...

Full Screen

Full Screen

cache_dataset_serialization_test.py

Source:cache_dataset_serialization_test.py Github

copy

Full Screen

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Tests for the CacheDataset serialization."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19import os20from absl.testing import parameterized21from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base22from tensorflow.python.data.ops import dataset_ops23from tensorflow.python.framework import errors24from tensorflow.python.platform import test25class CacheDatasetSerializationTest(26 dataset_serialization_test_base.DatasetSerializationTestBase,27 parameterized.TestCase):28 def setUp(self):29 self.range_size = 1030 self.num_repeats = 331 self.num_outputs = self.range_size * self.num_repeats32 self.cache_file_prefix = 'test'33 def make_dataset_fn(self, is_memory):34 if is_memory:35 filename = ''36 else:37 filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)38 def ds_fn():39 return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(40 self.num_repeats)41 return ds_fn42 def expected_outputs(self):43 return list(range(self.range_size)) * self.num_repeats44 @parameterized.named_parameters(45 ('Memory', True),46 ('File', False),47 )48 def testCheckpointBeforeOneEpoch(self, is_memory):49 ds_fn = self.make_dataset_fn(is_memory)50 # Generate 5 entries from iterator and save checkpoint.51 outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)52 self.assertSequenceEqual(outputs, range(5))53 # Restore from checkpoint and produce the rest of the elements from the54 # iterator.55 outputs.extend(56 self.gen_outputs(57 ds_fn, [],58 self.num_outputs - 5,59 ckpt_saved=True,60 verify_exhausted=False))61 self.assertSequenceEqual(outputs, self.expected_outputs())62 @parameterized.named_parameters(63 ('Memory', True),64 ('File', False),65 )66 def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):67 ds_fn = self.make_dataset_fn(is_memory)68 # Generate 8 entries from iterator but save checkpoint after producing 5.69 outputs = self.gen_outputs(70 ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)71 self.assertSequenceEqual(outputs, range(8))72 if is_memory:73 outputs = outputs[:5]74 outputs.extend(75 self.gen_outputs(76 ds_fn, [],77 self.num_outputs - 5,78 ckpt_saved=True,79 verify_exhausted=False))80 self.assertSequenceEqual(outputs, self.expected_outputs())81 else:82 # Restoring from checkpoint and running GetNext should return83 # `AlreadExistsError` now because the lockfile already exists.84 with self.assertRaises(errors.AlreadyExistsError):85 self.gen_outputs(86 ds_fn, [],87 self.num_outputs - 5,88 ckpt_saved=True,89 verify_exhausted=False)90 @parameterized.named_parameters(91 ('Memory', True),92 ('File', False),93 )94 def testCheckpointAfterOneEpoch(self, is_memory):95 ds_fn = self.make_dataset_fn(is_memory)96 # Generate 15 entries from iterator and save checkpoint.97 outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)98 self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))99 # Restore from checkpoint and produce the rest of the elements from the100 # iterator.101 outputs.extend(102 self.gen_outputs(103 ds_fn, [],104 self.num_outputs - 15,105 ckpt_saved=True,106 verify_exhausted=False))107 self.assertSequenceEqual(outputs, self.expected_outputs())108 @parameterized.named_parameters(109 ('Memory', True),110 ('File', False),111 )112 def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):113 ds_fn = self.make_dataset_fn(is_memory)114 # Generate 18 entries from iterator but save checkpoint after producing 15.115 outputs = self.gen_outputs(116 ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)117 self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))118 outputs = list(range(10)) + list(range(5)) + self.gen_outputs(119 ds_fn, [],120 self.num_outputs - 15,121 ckpt_saved=True,122 verify_exhausted=False)123 self.assertSequenceEqual(outputs, list(range(10)) * 3)124 @parameterized.named_parameters(125 ('Memory', True),126 ('File', False),127 )128 def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):129 ds_fn = self.make_dataset_fn(is_memory)130 # Generate 13 entries from iterator but save checkpoint after producing 5.131 outputs = self.gen_outputs(132 ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)133 self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))134 # Since we ran for more than one epoch, the cache was completely written.135 # The ckpt was saved when the iterator was in cache-write mode. Test that136 # the iterator falls back to read mode after restoring if the cache has137 # been completely written.138 outputs = list(range(5)) + self.gen_outputs(139 ds_fn, [],140 self.num_outputs - 5,141 ckpt_saved=True,142 verify_exhausted=False)143 self.assertSequenceEqual(outputs, list(range(10)) * 3)144 @parameterized.named_parameters(145 ('Memory', True),146 ('File', False),147 )148 def testCheckpointUnusedWriterIterator(self, is_memory):149 ds_fn = self.make_dataset_fn(is_memory)150 # Checkpoint before get_next is called even once.151 outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)152 self.assertSequenceEqual(outputs, [])153 outputs = self.gen_outputs(154 ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)155 self.assertSequenceEqual(outputs, list(range(10)) * 3)156 @parameterized.named_parameters(157 ('Memory', True),158 ('File', False),159 )160 def testCheckpointUnusedMidwayWriterIterator(self, is_memory):161 ds_fn = self.make_dataset_fn(is_memory)162 # Produce 5 elements and checkpoint.163 outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)164 self.assertSequenceEqual(outputs, range(5))165 # Restore from checkpoint, then produce no elements and checkpoint.166 outputs.extend(167 self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))168 self.assertSequenceEqual(outputs, range(5))169 # Restore from checkpoint and produce rest of the elements.170 outputs.extend(171 self.gen_outputs(172 ds_fn, [],173 self.num_outputs - 5,174 ckpt_saved=True,175 verify_exhausted=False))176 self.assertSequenceEqual(outputs, list(range(10)) * 3)177 @parameterized.named_parameters(178 ('Memory', True),179 ('File', False),180 )181 def testUnusedCheckpointError(self, is_memory):182 ds_fn = self.make_dataset_fn(is_memory)183 # Produce 5 elements and save ckpt.184 outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)185 self.assertSequenceEqual(outputs, range(5))186 if is_memory:187 outputs = self.gen_outputs(188 ds_fn, [], self.num_outputs, verify_exhausted=False)189 self.assertSequenceEqual(outputs, self.expected_outputs())190 else:191 # Since the complete cache has not been written, a new iterator which does192 # not restore the checkpoint will throw an error since there is a partial193 # cache shard.194 with self.assertRaises(errors.AlreadyExistsError):195 outputs = self.gen_outputs(196 ds_fn, [], self.num_outputs, verify_exhausted=False)197 @parameterized.named_parameters(198 ('Memory', True),199 ('File', False),200 )201 def testIgnoreCheckpointIfCacheWritten(self, is_memory):202 ds_fn = self.make_dataset_fn(is_memory)203 # Produce 15 elements and save ckpt. This will write the complete cache.204 outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)205 self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))206 # Build the iterator again but do not restore from ckpt. Since the cache207 # has already been written we should be able to use it.208 outputs = self.gen_outputs(209 ds_fn, [], self.num_outputs, verify_exhausted=False)210 self.assertSequenceEqual(outputs, list(range(10)) * 3)211if __name__ == '__main__':...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run hypothesis automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful