Best Python code snippet using avocado_python
saver_test.py
Source:saver_test.py  
...90              "v0": v0,91              "v1": v1,92              "v2": v2.saveable93          }, restore_sequentially=True)94      val = save.save(sess, save_path)95      self.assertTrue(isinstance(val, six.string_types))96      self.assertEqual(save_path, val)97    # Start a second session.  In that session the parameter nodes98    # have not been initialized either.99    with self.test_session(graph=ops_lib.Graph()) as sess:100      v0 = variable_op(-1.0, name="v0")101      v1 = variable_op(-1.0, name="v1")102      v2 = saver_test_utils.CheckpointedOp(name="v2")103      # Assert that the variables are not initialized.104      if context.in_graph_mode():105        self.assertEqual(106            len(variables.report_uninitialized_variables().eval()), 2)107        self.assertEqual(0, len(v2.keys().eval()))108        self.assertEqual(0, len(v2.values().eval()))109      # Restore the saved values in the parameter nodes.110      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})111      save.restore(sess, save_path)112      # Check that the parameter nodes have been restored.113      self.assertEqual(10.0, self.evaluate(v0))114      self.assertEqual(20.0, self.evaluate(v1))115      self.assertEqual(b"k1", self.evaluate(v2.keys()))116      self.assertEqual(30.0, self.evaluate(v2.values()))117    # Build another graph with 2 nodes, initialized118    # differently, and a Restore node for them.119    with self.test_session(graph=ops_lib.Graph()) as sess:120      v0_2 = variable_op(1000.0, name="v0")121      v1_2 = variable_op(2000.0, name="v1")122      v2_2 = saver_test_utils.CheckpointedOp(name="v2")123      v2_init = v2_2.insert("k1000", 3000.0)124      # Check that the parameter nodes have been initialized.125      if context.in_graph_mode():126        init_all_op = [variables.global_variables_initializer(), v2_init]127        self.evaluate(init_all_op)128        # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty129        # table as it claims in eager mode?130        self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))131        self.assertEqual(3000.0, self.evaluate(v2_2.values()))132      self.assertEqual(1000.0, self.evaluate(v0_2))133      self.assertEqual(2000.0, self.evaluate(v1_2))134      # Restore the values saved earlier in the parameter nodes.135      save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})136      save2.restore(sess, save_path)137      # Check that the parameter nodes have been restored.138      self.assertEqual(10.0, self.evaluate(v0_2))139      self.assertEqual(20.0, self.evaluate(v1_2))140      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))141      self.assertEqual(30.0, self.evaluate(v2_2.values()))142  def testBasic(self):143    self.basicSaveRestore(variables.Variable)144  @test_util.run_in_graph_and_eager_modes()145  def testResourceBasic(self):146    self.basicSaveRestore(resource_variable_ops.ResourceVariable)147  def testEagerBasic(self):148    with context.eager_mode():149      ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")150      v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")151      v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")152      save = saver_module.Saver([v1, v2])153      save.save(None, ckpt_prefix)154      v1.assign(0.0)155      v2.assign([0, 0])156      self.assertNear(0.0, self.evaluate(v1), 1e-5)157      self.assertAllEqual([0, 0], self.evaluate(v2))158      save.restore(None, ckpt_prefix)159      self.assertNear(3.14, self.evaluate(v1), 1e-5)160      self.assertAllEqual([1, 2], self.evaluate(v2))161  def testEagerGraphCompatibility(self):162    # Save from graph mode and restore from eager mode.163    graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")164    with context.graph_mode():165      with self.test_session(graph=ops_lib.Graph()) as sess:166        # Create a graph model and save the checkpoint.167        w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")168        w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")169        graph_saver = saver_module.Saver([w1, w2])170        sess.run(variables.global_variables_initializer())171        graph_saver.save(sess, graph_ckpt_prefix)172    with context.eager_mode():173      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access174      ops_lib.reset_default_graph()175      w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")176      w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")177      graph_saver = saver_module.Saver([w1, w2])178      graph_saver.restore(None, graph_ckpt_prefix)179      self.assertAllEqual(self.evaluate(w1), 1.0)180      self.assertAllEqual(self.evaluate(w2), 2.0)181    # Save from eager mode and restore from graph mode.182    eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")183    with context.eager_mode():184      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access185      ops_lib.reset_default_graph()186      w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")187      w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")188      graph_saver = saver_module.Saver([w3, w4])189      graph_saver.save(None, eager_ckpt_prefix)190    with context.graph_mode():191      with self.test_session(graph=ops_lib.Graph()) as sess:192        w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")193        w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")194        graph_saver = saver_module.Saver([w3, w4])195        sess.run(variables.global_variables_initializer())196        graph_saver.restore(sess, eager_ckpt_prefix)197        self.assertAllEqual(w3.eval(), 3.0)198        self.assertAllEqual(w4.eval(), 4.0)199  @test_util.run_in_graph_and_eager_modes()200  def testResourceSaveRestoreCachingDevice(self):201    save_path = os.path.join(self.get_temp_dir(), "resource_cache")202    with self.test_session(graph=ops_lib.Graph()) as sess:203      v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",204                                                 name="v")205      if context.in_graph_mode():206        self.evaluate(variables.global_variables_initializer())207      else:208        sess = None209      save = saver_module.Saver([v])210      save.save(sess, save_path)211      save2 = saver_module.Saver([v])212      save2.restore(sess, save_path)213      self.assertEquals(self.evaluate(v), [1])214  def testSaveCopyRestoreWithSaveRelativePaths(self):215    """Save, copy checkpoint dir and restore from copied dir.216    This only works for save_relative_paths=True.217    """218    save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1")219    os.mkdir(save_dir1)220    save_path1 = os.path.join(save_dir1, "save_copy_restore")221    # Build a graph with 2 parameter nodes, and Save and222    # Restore nodes for them.223    v0 = variables.Variable(10.0, name="v0")224    v1 = variables.Variable(20.0, name="v1")225    v2 = saver_test_utils.CheckpointedOp(name="v2")226    v2_init = v2.insert("k1", 30.0)227    save = saver_module.Saver(228        var_list={229            "v0": v0,230            "v1": v1,231            "v2": v2.saveable},232        restore_sequentially=True,233        save_relative_paths=True)234    init_all_op = [variables.global_variables_initializer(), v2_init]235    with self.test_session() as sess:236      # Initialize all variables237      sess.run(init_all_op)238      # Check that the parameter nodes have been initialized.239      self.assertEqual(10.0, v0.eval())240      self.assertEqual(20.0, v1.eval())241      self.assertEqual(b"k1", v2.keys().eval())242      self.assertEqual(30.0, v2.values().eval())243      # Save the initialized values in the file at "save_path"244      val = save.save(sess, save_path1)245      self.assertTrue(isinstance(val, six.string_types))246      self.assertEqual(save_path1, val)247    self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1)248    save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")249    os.renames(save_dir1, save_dir2)250    save_path2 = os.path.join(save_dir2, "save_copy_restore")251    self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2)252    # Start a second session.  In that session the parameter nodes253    # have not been initialized either.254    with self.test_session() as sess:255      v0 = variables.Variable(-1.0, name="v0")256      v1 = variables.Variable(-1.0, name="v1")257      v2 = saver_test_utils.CheckpointedOp(name="v2")258      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})259      # Assert that the variables are not initialized.260      self.assertEqual(261          len(variables.report_uninitialized_variables().eval()), 2)262      self.assertEqual(0, len(v2.keys().eval()))263      self.assertEqual(0, len(v2.values().eval()))264      # Restore the saved values in the parameter nodes.265      save.restore(sess, save_path2)266      # Check that the parameter nodes have been restored.267      self.assertEqual(10.0, v0.eval())268      self.assertEqual(20.0, v1.eval())269      self.assertEqual(b"k1", v2.keys().eval())270      self.assertEqual(30.0, v2.values().eval())271  def testFilenameTensor(self):272    v0 = variables.Variable(0, name="v0")273    filename = b"somerandomfilename"274    save = saver_module.Saver({"v0": v0}, filename=filename)275    with self.test_session() as sess:276      tensor = sess.graph.get_tensor_by_name(277          save.saver_def.filename_tensor_name)278      self.assertEqual(sess.run(tensor), filename)279  def testInvalidPath(self):280    v0 = variables.Variable(0, name="v0")281    for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):282      with self.test_session() as sess:283        save = saver_module.Saver({"v0": v0}, write_version=ver)284        with self.assertRaisesRegexp(errors.NotFoundError,285                                     "Failed to find any matching files for"):286          save.restore(sess, "invalid path")287  def testInt64(self):288    save_path = os.path.join(self.get_temp_dir(), "int64")289    with self.test_session() as sess:290      # Build a graph with 1 node, and save and restore for them.291      v = variables.Variable(np.int64(15), name="v")292      save = saver_module.Saver({"v": v}, restore_sequentially=True)293      variables.global_variables_initializer().run()294      # Save the initialized values in the file at "save_path"295      val = save.save(sess, save_path)296      self.assertTrue(isinstance(val, six.string_types))297      self.assertEqual(save_path, val)298      with self.test_session() as sess:299        v = variables.Variable(np.int64(-1), name="v")300        save = saver_module.Saver({"v": v})301      with self.assertRaisesWithPredicateMatch(302          errors_impl.OpError, lambda e: "uninitialized value v" in e.message):303        sess.run(v)304      # Restore the saved values in the parameter nodes.305      save.restore(sess, save_path)306      # Check that the parameter nodes have been restored.307      self.assertEqual(np.int64(15), v.eval())308  def testSomeErrors(self):309    with ops_lib.Graph().as_default():310      v0 = variables.Variable([10.0], name="v0")311      v1 = variables.Variable([20.0], name="v1")312      v2 = variables.Variable([20.0], name="v2")313      v2._set_save_slice_info(314          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))315      # By default the name used for "v2" will be "v1" and raise an error.316      with self.assertRaisesRegexp(ValueError, "same name: v1"):317        saver_module.Saver([v0, v1, v2])318      # The names are different and will work.319      saver_module.Saver({"vee1": v1, "other": [v2]})320      # Partitioned variables also cause name conflicts.321      p_v1 = variable_scope.get_variable(322          "p_v1",323          shape=[4, 5],324          partitioner=partitioned_variables.fixed_size_partitioner(325              num_shards=2))326      p_v2 = variable_scope.get_variable(327          "p_v2",328          shape=[4, 5],329          partitioner=partitioned_variables.fixed_size_partitioner(330              num_shards=2))331      p_v2._name = "p_v1"332      with self.assertRaisesRegexp(ValueError, "same name: p_v1"):333        saver_module.Saver([p_v1, p_v2])334  def testSameName(self):335    with ops_lib.Graph().as_default():336      v0 = variables.Variable([10.0], name="v0")337      v2 = saver_test_utils.CheckpointedOp(name="v2")338      # Saving one variable under two names raises an error.339      with self.assertRaisesRegexp(340          ValueError, "The same saveable will be restored with two names: v0"):341        saver_module.Saver({"v0": v0, "v0too": v0})342      # Ditto for custom saveables.343      with self.assertRaisesRegexp(344          ValueError, "The same saveable will be restored with two names: v2"):345        saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable})346      # Verify non-duplicate names work.347      saver_module.Saver({"v0": v0, "v2": v2.saveable})348  def testBasicsWithListOfVariables(self):349    save_path = os.path.join(self.get_temp_dir(), "basics_with_list")350    with self.test_session(graph=ops_lib.Graph()) as sess:351      # Build a graph with 2 parameter nodes, and Save and352      # Restore nodes for them.353      v0 = variables.Variable(10.0, name="v0")354      v1 = variables.Variable(20.0, name="v1")355      v2 = saver_test_utils.CheckpointedOp(name="v2")356      v2_init = v2.insert("k1", 30.0)357      save = saver_module.Saver([v0, v1, v2.saveable])358      variables.global_variables_initializer().run()359      v2_init.run()360      # Check that the parameter nodes have been initialized.361      self.assertEqual(10.0, v0.eval())362      self.assertEqual(20.0, v1.eval())363      self.assertEqual(b"k1", v2.keys().eval())364      self.assertEqual(30.0, v2.values().eval())365      # Save the initialized values in the file at "save_path"366      val = save.save(sess, save_path)367      self.assertTrue(isinstance(val, six.string_types))368      self.assertEqual(save_path, val)369    # Start a second session.  In that session the variables370    # have not been initialized either.371    with self.test_session(graph=ops_lib.Graph()) as sess:372      v0 = variables.Variable(-1.0, name="v0")373      v1 = variables.Variable(-1.0, name="v1")374      v2 = saver_test_utils.CheckpointedOp(name="v2")375      save = saver_module.Saver([v0, v1, v2.saveable])376      with self.assertRaisesWithPredicateMatch(377          errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):378        sess.run(v0)379      with self.assertRaisesWithPredicateMatch(380          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):381        sess.run(v1)382      self.assertEqual(0, len(v2.keys().eval()))383      self.assertEqual(0, len(v2.values().eval()))384      # Restore the saved values in the parameter nodes.385      save.restore(sess, save_path)386      # Check that the parameter nodes have been restored.387      self.assertEqual(10.0, v0.eval())388      self.assertEqual(20.0, v1.eval())389      self.assertEqual(b"k1", v2.keys().eval())390      self.assertEqual(30.0, v2.values().eval())391    # Build another graph with 2 nodes, initialized392    # differently, and a Restore node for them.393    with self.test_session(graph=ops_lib.Graph()) as sess:394      v0_2 = variables.Variable(1000.0, name="v0")395      v1_2 = variables.Variable(2000.0, name="v1")396      v2_2 = saver_test_utils.CheckpointedOp(name="v2")397      save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])398      v2_2.insert("k1000", 3000.0).run()399      variables.global_variables_initializer().run()400      # Check that the parameter nodes have been initialized.401      self.assertEqual(1000.0, v0_2.eval())402      self.assertEqual(2000.0, v1_2.eval())403      self.assertEqual(b"k1000", v2_2.keys().eval())404      self.assertEqual(3000.0, v2_2.values().eval())405      # Restore the values saved earlier in the parameter nodes.406      save2.restore(sess, save_path)407      # Check that the parameter nodes have been restored.408      self.assertEqual(10.0, v0_2.eval())409      self.assertEqual(20.0, v1_2.eval())410      self.assertEqual(b"k1", v2_2.keys().eval())411      self.assertEqual(30.0, v2_2.values().eval())412  def _SaveAndLoad(self, var_name, var_value, other_value, save_path):413    with self.test_session(graph=ops_lib.Graph()) as sess:414      var = resource_variable_ops.ResourceVariable(var_value, name=var_name)415      save = saver_module.Saver({var_name: var})416      if context.in_graph_mode():417        self.evaluate(var.initializer)418      val = save.save(sess, save_path)419      self.assertEqual(save_path, val)420    with self.test_session(graph=ops_lib.Graph()) as sess:421      var = resource_variable_ops.ResourceVariable(other_value, name=var_name)422      save = saver_module.Saver({var_name: var})423      save.restore(sess, save_path)424      self.assertAllClose(var_value, self.evaluate(var))425  def testCacheRereadsFile(self):426    save_path = os.path.join(self.get_temp_dir(), "cache_rereads")427    # Save and reload one Variable named "var0".428    self._SaveAndLoad("var0", 0.0, 1.0, save_path)429    # Save and reload one Variable named "var1" in the same file.430    # The cached readers should know to re-read the file.431    self._SaveAndLoad("var1", 1.1, 2.2, save_path)432  def testAllowEmpty(self):433    save_path = os.path.join(self.get_temp_dir(), "allow_empty")434    with self.test_session() as sess:435      _ = constant_op.constant(1)436      save = saver_module.Saver(allow_empty=True)437      val = save.save(sess, save_path)438      self.assertIsNone(val)439    with self.test_session() as sess:440      save = saver_module.Saver(allow_empty=True)441      save.restore(sess, save_path)442  def testGPU(self):443    if not test.is_gpu_available():444      return445    save_path = os.path.join(self.get_temp_dir(), "gpu")446    with session.Session("", graph=ops_lib.Graph()) as sess:447      with sess.graph.device(test.gpu_device_name()):448        v0_1 = variables.Variable(123.45)449      save = saver_module.Saver({"v0": v0_1})450      variables.global_variables_initializer().run()451      save.save(sess, save_path)452    with session.Session("", graph=ops_lib.Graph()) as sess:453      with sess.graph.device(test.gpu_device_name()):454        v0_2 = variables.Variable(543.21)455      save = saver_module.Saver({"v0": v0_2})456      variables.global_variables_initializer().run()457  def testVariables(self):458    save_path = os.path.join(self.get_temp_dir(), "variables")459    with session.Session("", graph=ops_lib.Graph()) as sess:460      one = variables.Variable(1.0)461      twos = variables.Variable([2.0, 2.0, 2.0])462      v2 = saver_test_utils.CheckpointedOp(name="v2")463      init = variables.global_variables_initializer()464      save = saver_module.Saver()465      init.run()466      v2.insert("k1", 3.0).run()467      save.save(sess, save_path)468    with session.Session("", graph=ops_lib.Graph()) as sess:469      one = variables.Variable(0.0)470      twos = variables.Variable([0.0, 0.0, 0.0])471      v2 = saver_test_utils.CheckpointedOp(name="v2")472      # Saver with no arg, defaults to 'all variables'.473      save = saver_module.Saver()474      save.restore(sess, save_path)475      self.assertAllClose(1.0, one.eval())476      self.assertAllClose([2.0, 2.0, 2.0], twos.eval())477      self.assertEqual(b"k1", v2.keys().eval())478      self.assertEqual(3.0, v2.values().eval())479  def testVarListShouldBeEmptyInDeferredBuild(self):480    with ops_lib.Graph().as_default():481      v = variables.Variable(1.0)482      with self.assertRaisesRegexp(ValueError, "defer_build"):483        saver_module.Saver([v], defer_build=True)484  def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):485    save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")486    with ops_lib.Graph().as_default(), session.Session() as sess:487      variables.Variable(1.0)488      saver = saver_module.Saver(defer_build=True)489      with self.assertRaisesRegexp(RuntimeError, "build"):490        saver.save(sess, save_path)491  def testDeferredBuild(self):492    save_path = os.path.join(self.get_temp_dir(), "deferred_build")493    with session.Session("", graph=ops_lib.Graph()) as sess:494      one = variables.Variable(1.0)495      save = saver_module.Saver(defer_build=True)496      # if build is not deferred, saver cannot save the `twos`.497      twos = variables.Variable([2.0, 2.0, 2.0])498      init = variables.global_variables_initializer()499      save.build()500      init.run()501      save.save(sess, save_path)502    with session.Session("", graph=ops_lib.Graph()) as sess:503      one = variables.Variable(0.0)504      twos = variables.Variable([0.0, 0.0, 0.0])505      # Saver with no arg, defaults to 'all variables'.506      save = saver_module.Saver()507      save.restore(sess, save_path)508      self.assertAllClose(1.0, one.eval())509      self.assertAllClose([2.0, 2.0, 2.0], twos.eval())510  def testReshape(self):511    save_path = os.path.join(self.get_temp_dir(), "variables_reshape")512    with session.Session("", graph=ops_lib.Graph()) as sess:513      var = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])514      init = variables.global_variables_initializer()515      save = saver_module.Saver()516      init.run()517      save.save(sess, save_path)518    # Error when restoring with default reshape=False519    with session.Session("", graph=ops_lib.Graph()) as sess:520      var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])521      save = saver_module.Saver()522      with self.assertRaisesRegexp(523          errors_impl.InvalidArgumentError,524          "Assign requires shapes of both tensors to match."):525        save.restore(sess, save_path)526    # Restored to new shape with reshape=True527    with session.Session("", graph=ops_lib.Graph()) as sess:528      var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])529      save = saver_module.Saver(reshape=True)530      save.restore(sess, save_path)531      self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())532  @test_util.run_in_graph_and_eager_modes()533  def testSaveWithGlobalStep(self, pad_step_number=False):534    save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")535    global_step_int = 5536    # Save and reload one Variable named "var0".537    self._SaveAndLoad("var0", 0.0, 1.0, save_path)538    for use_tensor in [True, False]:539      with self.test_session(graph=ops_lib.Graph()):540        var = resource_variable_ops.ResourceVariable(1.0, name="var0")541        save = saver_module.Saver(542            {543                var._shared_name: var544            }, pad_step_number=pad_step_number)545        if context.in_graph_mode():546          self.evaluate(var.initializer)547          sess = ops_lib.get_default_session()548        else:549          sess = None550        if use_tensor:551          global_step = constant_op.constant(global_step_int)552          val = save.save(sess, save_path, global_step=global_step)553        else:554          val = save.save(sess, save_path, global_step=global_step_int)555        if pad_step_number:556          expected_save_path = "%s-%s" % (save_path,557                                          "{:08d}".format(global_step_int))558        else:559          expected_save_path = "%s-%d" % (save_path, global_step_int)560        self.assertEqual(expected_save_path, val)561  def testSaveWithGlobalStepWithPadding(self):562    self.testSaveWithGlobalStep(pad_step_number=True)563  def testSaveToNonexistingPath(self):564    file_io.write_string_to_file(565        os.path.join(self.get_temp_dir(), "actually_a_file"), "")566    paths = [567        os.path.join(self.get_temp_dir(), "nonexisting_dir/path"),568        os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"),569        os.path.join(self.get_temp_dir(), "actually_a_file/path"),570    ]571    for save_path in paths:572      # Build a graph with 2 parameter nodes, and Save and573      # Restore nodes for them.574      v0 = variables.Variable(10.0, name="v0")575      v1 = variables.Variable(20.0, name="v1")576      save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)577      init_all_op = variables.global_variables_initializer()578      # In the case where the parent directory doesn't exist, whether or not the579      # save succeeds or fails is implementation dependent.  Therefore we allow580      # both cases.581      try:582        with self.test_session() as sess:583          # Initialize all variables584          sess.run(init_all_op)585          # Check that the parameter nodes have been initialized.586          self.assertEqual(10.0, v0.eval())587          self.assertEqual(20.0, v1.eval())588          # Save the graph.589          save.save(sess, save_path)590        with self.test_session() as sess:591          # Restore the saved values in the parameter nodes.592          save.restore(sess, save_path)593          # Check that the parameter nodes have been restored.594          self.assertEqual(10.0, v0.eval())595          self.assertEqual(20.0, v1.eval())596      except ValueError as exc:597        error_msg_template = "Parent directory of {} doesn't exist, can't save."598        self.assertEqual(error_msg_template.format(save_path), str(exc))599  def testSaveToURI(self):600    # ParseURI functions don't work on Windows yet.601    # TODO(jhseu): Remove this check when it works.602    if os.name == "nt":603      self.skipTest("Local URI support doesn't work on Windows")604    save_path = "file://" + os.path.join(self.get_temp_dir(), "uri")605    # Build a graph with 2 parameter nodes, and Save and606    # Restore nodes for them.607    v0 = variables.Variable(10.0, name="v0")608    v1 = variables.Variable(20.0, name="v1")609    save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)610    init_all_op = variables.global_variables_initializer()611    with self.test_session() as sess:612      # Initialize all variables613      sess.run(init_all_op)614      # Check that the parameter nodes have been initialized.615      self.assertEqual(10.0, v0.eval())616      self.assertEqual(20.0, v1.eval())617      save.save(sess, save_path)618class SaveRestoreShardedTest(test.TestCase):619  def _get_test_dir(self, dirname):620    test_dir = os.path.join(self.get_temp_dir(), dirname)621    gfile.MakeDirs(test_dir)622    return test_dir623  def testBasics(self):624    save_path = os.path.join(self.get_temp_dir(), "sharded_basics")625    # Build a graph with 2 parameter nodes on different devices.626    with session.Session(627        target="",628        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:629      with sess.graph.device("/cpu:0"):630        v0 = variables.Variable(10, name="v0")631        t0 = saver_test_utils.CheckpointedOp(name="t0")632      with sess.graph.device("/cpu:1"):633        v1 = variables.Variable(20, name="v1")634        t1 = saver_test_utils.CheckpointedOp(name="t1")635      save = saver_module.Saver(636          {637              "v0": v0,638              "v1": v1,639              "t0": t0.saveable,640              "t1": t1.saveable641          },642          sharded=True)643      variables.global_variables_initializer().run()644      t0.insert("k1", 30.0).run()645      t1.insert("k2", 40.0).run()646      val = save.save(sess, save_path)647      if save._write_version is saver_pb2.SaverDef.V1:648        self.assertEqual(save_path + "-?????-of-00002", val)649      else:650        self.assertEqual(save_path, val)651      meta_graph_filename = save._MetaGraphFilename(val)652      self.assertEqual(save_path + ".meta", meta_graph_filename)653    if save._write_version is saver_pb2.SaverDef.V1:654      # Restore different ops from shard 0 of the saved files.655      with session.Session(656          target="",657          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:658        with sess.graph.device("/cpu:0"):659          v0 = variables.Variable(111, name="v0")660          t0 = saver_test_utils.CheckpointedOp(name="t0")661        save = saver_module.Saver({"v0": v0, "t0": t0.saveable}, sharded=True)662        variables.global_variables_initializer().run()663        t0.insert("k11", 33.0).run()664        self.assertEqual(111, v0.eval())665        self.assertEqual(b"k11", t0.keys().eval())666        self.assertEqual(33.0, t0.values().eval())667        save.restore(sess, save_path + "-00000-of-00002")668        self.assertEqual(10, v0.eval())669        self.assertEqual(b"k1", t0.keys().eval())670        self.assertEqual(30.0, t0.values().eval())671      # Restore different ops from shard 1 of the saved files.672      with session.Session(673          target="",674          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:675        with sess.graph.device("/cpu:0"):676          v1 = variables.Variable(222)677          t1 = saver_test_utils.CheckpointedOp(name="t1")678        save = saver_module.Saver({"v1": v1, "t1": t1.saveable}, sharded=True)679        variables.global_variables_initializer().run()680        t1.insert("k22", 44.0).run()681        self.assertEqual(222, v1.eval())682        self.assertEqual(b"k22", t1.keys().eval())683        self.assertEqual(44.0, t1.values().eval())684        save.restore(sess, save_path + "-00001-of-00002")685        self.assertEqual(20, v1.eval())686        self.assertEqual(b"k2", t1.keys().eval())687        self.assertEqual(40.0, t1.values().eval())688    # Now try a restore with the sharded filename.689    with session.Session(690        target="",691        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:692      with sess.graph.device("/cpu:0"):693        v0 = variables.Variable(111, name="v0")694        t0 = saver_test_utils.CheckpointedOp(name="t0")695      with sess.graph.device("/cpu:1"):696        v1 = variables.Variable(222, name="v1")697        t1 = saver_test_utils.CheckpointedOp(name="t1")698      save = saver_module.Saver(699          {700              "v0": v0,701              "v1": v1,702              "t0": t0.saveable,703              "t1": t1.saveable704          },705          sharded=True)706      variables.global_variables_initializer().run()707      t0.insert("k11", 33.0).run()708      t1.insert("k22", 44.0).run()709      self.assertEqual(111, v0.eval())710      self.assertEqual(222, v1.eval())711      self.assertEqual(b"k11", t0.keys().eval())712      self.assertEqual(33.0, t0.values().eval())713      self.assertEqual(b"k22", t1.keys().eval())714      self.assertEqual(44.0, t1.values().eval())715      save_path = os.path.join(self.get_temp_dir(), "sharded_basics")716      if save._write_version is saver_pb2.SaverDef.V1:717        save.restore(sess, save_path + "-?????-of-?????")718      else:719        save.restore(sess, save_path)720      self.assertEqual(10, v0.eval())721      self.assertEqual(20, v1.eval())722      self.assertEqual(b"k1", t0.keys().eval())723      self.assertEqual(30.0, t0.values().eval())724      self.assertEqual(b"k2", t1.keys().eval())725      self.assertEqual(40.0, t1.values().eval())726    if save._write_version is saver_pb2.SaverDef.V1:727      self.assertEqual(728          saver_module.latest_checkpoint(self.get_temp_dir()),729          os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))730    else:731      self.assertEqual(732          saver_module.latest_checkpoint(self.get_temp_dir()),733          os.path.join(self.get_temp_dir(), "sharded_basics"))734  def testSaverDef(self):735    with self.test_session():736      v0 = variables.Variable(123, name="v0")737      save = saver_module.Saver({"v0": v0}, sharded=True)738      sd = save.as_saver_def()739      self.assertTrue(sd.sharded)740  def _testPartitionedVariables(self, use_resource):741    var_full_shape = [10, 3]742    # Allows save/restore mechanism to work w/ different slicings.743    var_name = "my_var"744    saved_dir = self._get_test_dir("partitioned_variables")745    saved_path = os.path.join(saved_dir, "ckpt")746    call_saver_with_dict = False  # updated by test loop below747    def _save(slices=None, partitioner=None):748      with self.test_session(graph=ops_lib.Graph()) as sess:749        # Calls .eval() to return the ndarray that makes up the full variable.750        rnd = random_ops.random_uniform(var_full_shape).eval()751        if slices:752          assert not partitioner753          # TODO(apassos): make create_partitioned_variables take use_resource754          # option to make this test passable without creating a named755          # variable_scope.756          vs = partitioned_variables.create_partitioned_variables(757              var_full_shape, slices, rnd, name=var_name)758        elif partitioner:759          vs = [760              variable_scope.get_variable(761                  var_name,762                  shape=var_full_shape,763                  initializer=rnd,764                  partitioner=partitioner,765                  use_resource=use_resource)766          ]767        else:768          if use_resource:769            vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]770          else:771            vs = [variables.Variable(rnd, name=var_name)]772        variables.global_variables_initializer().run()773        if call_saver_with_dict:774          saver = saver_module.Saver({var_name: (vs if slices else vs[0])})775        else:776          saver = saver_module.Saver(vs)777        actual_path = saver.save(sess, saved_path)778        self.assertEqual(saved_path, actual_path)779        return rnd780    def _restore(slices=None, partitioner=None):781      with self.test_session(graph=ops_lib.Graph()) as sess:782        if slices:783          assert not partitioner784          new_vs = partitioned_variables.create_partitioned_variables(785              var_full_shape,786              slices,787              array_ops.zeros(var_full_shape),  # != original contents.788              name=var_name)789        elif partitioner:790          new_vs = [791              variable_scope.get_variable(792                  var_name,793                  shape=var_full_shape,794                  initializer=array_ops.zeros(var_full_shape),795                  partitioner=partitioner)796          ]797        else:798          new_vs = [799              variables.Variable(800                  array_ops.zeros(801                      shape=var_full_shape),  # != original contents.802                  name=var_name)803          ]804        variables.global_variables_initializer().run()805        if call_saver_with_dict:806          saver = saver_module.Saver({807              var_name: (new_vs if slices else new_vs[0])808          })809        else:810          saver = saver_module.Saver(new_vs)811        saver.restore(sess, saved_path)812        if partitioner:813          return new_vs[0].as_tensor().eval()814        elif slices and slices[0] != 1:815          return array_ops.concat(new_vs, 0).eval()816        elif slices and slices[1] != 1:817          return array_ops.concat(new_vs, 1).eval()818        else:  # Non-sliced.819          return new_vs[0].eval()820    for call_saver_with_dict in {False, True}:821      # Save PartitionedVariable and restore into full variable.822      saved_full = _save(823          partitioner=partitioned_variables.fixed_size_partitioner(824              num_shards=2))825      restored_full = _restore()826      self.assertAllEqual(saved_full, restored_full)827      # Saves 10 horizontal parts of a partitioned variable.828      # Restores into a full variable, non-sliced.829      saved_full = _save(slices=[10, 1])830      restored_full = _restore()831      self.assertAllEqual(saved_full, restored_full)832      # Restores into a different number/orientation of slices.833      restored_full = _restore(slices=[2, 1])  # 2 horizon parts.834      self.assertAllEqual(saved_full, restored_full)835      restored_full = _restore(slices=[1, 3])  # 3 vertical parts.836      self.assertAllEqual(saved_full, restored_full)837      # Restores into a PartitionedVariable838      restored_full = _restore(839          partitioner=partitioned_variables.fixed_size_partitioner(840              num_shards=2))841      self.assertAllEqual(saved_full, restored_full)842      # Now, saves a full variable and restores in slices.843      saved_full = _save()844      restored_full = _restore(slices=[1, 3])845      self.assertAllEqual(saved_full, restored_full)846  def testPartitionedVariable(self):847    self._testPartitionedVariables(use_resource=False)848  def testPartitionedResourceVariable(self):849    self._testPartitionedVariables(use_resource=True)850class MaxToKeepTest(test.TestCase):851  def _get_test_dir(self, dirname):852    test_dir = os.path.join(self.get_temp_dir(), dirname)853    gfile.MakeDirs(test_dir)854    return test_dir855  def assertCheckpointState(self, model_checkpoint_path,856                            all_model_checkpoint_paths, save_dir):857    checkpoint_state = saver_module.get_checkpoint_state(save_dir)858    self.assertEqual(checkpoint_state.model_checkpoint_path,859                     model_checkpoint_path)860    self.assertEqual(checkpoint_state.all_model_checkpoint_paths,861                     all_model_checkpoint_paths)862  def testNonSharded(self):863    save_dir = self._get_test_dir("max_to_keep_non_sharded")864    with self.test_session() as sess:865      v = variables.Variable(10.0, name="v")866      save = saver_module.Saver({"v": v}, max_to_keep=2)867      variables.global_variables_initializer().run()868      self.assertEqual([], save.last_checkpoints)869      s1 = save.save(sess, os.path.join(save_dir, "s1"))870      self.assertEqual([s1], save.last_checkpoints)871      self.assertTrue(saver_module.checkpoint_exists(s1))872      self.assertCheckpointState(873          model_checkpoint_path=s1,874          all_model_checkpoint_paths=[s1],875          save_dir=save_dir)876      s2 = save.save(sess, os.path.join(save_dir, "s2"))877      self.assertEqual([s1, s2], save.last_checkpoints)878      self.assertTrue(saver_module.checkpoint_exists(s1))879      self.assertTrue(saver_module.checkpoint_exists(s2))880      self.assertCheckpointState(881          model_checkpoint_path=s2,882          all_model_checkpoint_paths=[s1, s2],883          save_dir=save_dir)884      s3 = save.save(sess, os.path.join(save_dir, "s3"))885      self.assertEqual([s2, s3], save.last_checkpoints)886      self.assertFalse(saver_module.checkpoint_exists(s1))887      self.assertTrue(saver_module.checkpoint_exists(s2))888      self.assertTrue(saver_module.checkpoint_exists(s3))889      self.assertCheckpointState(890          model_checkpoint_path=s3,891          all_model_checkpoint_paths=[s2, s3],892          save_dir=save_dir)893      # Create a second helper, identical to the first.894      save2 = saver_module.Saver(saver_def=save.as_saver_def())895      save2.set_last_checkpoints(save.last_checkpoints)896      # Create a third helper, with the same configuration but no knowledge of897      # previous checkpoints.898      save3 = saver_module.Saver(saver_def=save.as_saver_def())899      # Exercise the first helper.900      # Adding s2 again (old s2 is removed first, then new s2 appended)901      s2 = save.save(sess, os.path.join(save_dir, "s2"))902      self.assertEqual([s3, s2], save.last_checkpoints)903      self.assertFalse(saver_module.checkpoint_exists(s1))904      self.assertFalse(905          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))906      self.assertTrue(saver_module.checkpoint_exists(s3))907      self.assertTrue(908          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))909      self.assertTrue(saver_module.checkpoint_exists(s2))910      self.assertTrue(911          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))912      self.assertCheckpointState(913          model_checkpoint_path=s2,914          all_model_checkpoint_paths=[s3, s2],915          save_dir=save_dir)916      # Adding s1 (s3 should now be deleted as oldest in list)917      s1 = save.save(sess, os.path.join(save_dir, "s1"))918      self.assertEqual([s2, s1], save.last_checkpoints)919      self.assertFalse(saver_module.checkpoint_exists(s3))920      self.assertFalse(921          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))922      self.assertTrue(saver_module.checkpoint_exists(s2))923      self.assertTrue(924          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))925      self.assertTrue(saver_module.checkpoint_exists(s1))926      self.assertTrue(927          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))928      self.assertCheckpointState(929          model_checkpoint_path=s1,930          all_model_checkpoint_paths=[s2, s1],931          save_dir=save_dir)932      # Exercise the second helper.933      # Adding s2 again (old s2 is removed first, then new s2 appended)934      s2 = save2.save(sess, os.path.join(save_dir, "s2"))935      self.assertEqual([s3, s2], save2.last_checkpoints)936      # Created by the first helper.937      self.assertTrue(saver_module.checkpoint_exists(s1))938      self.assertTrue(939          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))940      # Deleted by the first helper.941      self.assertFalse(saver_module.checkpoint_exists(s3))942      self.assertFalse(943          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))944      self.assertTrue(saver_module.checkpoint_exists(s2))945      self.assertTrue(946          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))947      self.assertCheckpointState(948          model_checkpoint_path=s2,949          all_model_checkpoint_paths=[s3, s2],950          save_dir=save_dir)951      # Adding s1 (s3 should now be deleted as oldest in list)952      s1 = save2.save(sess, os.path.join(save_dir, "s1"))953      self.assertEqual([s2, s1], save2.last_checkpoints)954      self.assertFalse(saver_module.checkpoint_exists(s3))955      self.assertFalse(956          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))957      self.assertTrue(saver_module.checkpoint_exists(s2))958      self.assertTrue(959          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))960      self.assertTrue(saver_module.checkpoint_exists(s1))961      self.assertTrue(962          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))963      self.assertCheckpointState(964          model_checkpoint_path=s1,965          all_model_checkpoint_paths=[s2, s1],966          save_dir=save_dir)967      # Exercise the third helper.968      # Adding s2 again (but helper is unaware of previous s2)969      s2 = save3.save(sess, os.path.join(save_dir, "s2"))970      self.assertEqual([s2], save3.last_checkpoints)971      # Created by the first helper.972      self.assertTrue(saver_module.checkpoint_exists(s1))973      self.assertTrue(974          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))975      # Deleted by the first helper.976      self.assertFalse(saver_module.checkpoint_exists(s3))977      self.assertFalse(978          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))979      self.assertTrue(saver_module.checkpoint_exists(s2))980      self.assertTrue(981          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))982      # Even though the file for s1 exists, this saver isn't aware of it, which983      # is why it doesn't end up in the checkpoint state.984      self.assertCheckpointState(985          model_checkpoint_path=s2,986          all_model_checkpoint_paths=[s2],987          save_dir=save_dir)988      # Adding s1 (s3 should not be deleted because helper is unaware of it)989      s1 = save3.save(sess, os.path.join(save_dir, "s1"))990      self.assertEqual([s2, s1], save3.last_checkpoints)991      self.assertFalse(saver_module.checkpoint_exists(s3))992      self.assertFalse(993          saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))994      self.assertTrue(saver_module.checkpoint_exists(s2))995      self.assertTrue(996          saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))997      self.assertTrue(saver_module.checkpoint_exists(s1))998      self.assertTrue(999          saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))1000      self.assertCheckpointState(1001          model_checkpoint_path=s1,1002          all_model_checkpoint_paths=[s2, s1],1003          save_dir=save_dir)1004  def testSharded(self):1005    save_dir = self._get_test_dir("max_to_keep_sharded")1006    with session.Session(1007        target="",1008        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:1009      with sess.graph.device("/cpu:0"):1010        v0 = variables.Variable(111, name="v0")1011      with sess.graph.device("/cpu:1"):1012        v1 = variables.Variable(222, name="v1")1013      save = saver_module.Saver(1014          {1015              "v0": v0,1016              "v1": v11017          }, sharded=True, max_to_keep=2)1018      variables.global_variables_initializer().run()1019      self.assertEqual([], save.last_checkpoints)1020      s1 = save.save(sess, os.path.join(save_dir, "s1"))1021      self.assertEqual([s1], save.last_checkpoints)1022      if save._write_version is saver_pb2.SaverDef.V1:1023        self.assertEqual(2, len(gfile.Glob(s1)))1024      else:1025        self.assertEqual(4, len(gfile.Glob(s1 + "*")))1026      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))1027      s2 = save.save(sess, os.path.join(save_dir, "s2"))1028      self.assertEqual([s1, s2], save.last_checkpoints)1029      if save._write_version is saver_pb2.SaverDef.V1:1030        self.assertEqual(2, len(gfile.Glob(s1)))1031      else:1032        self.assertEqual(4, len(gfile.Glob(s1 + "*")))1033      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))1034      if save._write_version is saver_pb2.SaverDef.V1:1035        self.assertEqual(2, len(gfile.Glob(s2)))1036      else:1037        self.assertEqual(4, len(gfile.Glob(s2 + "*")))1038      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))1039      s3 = save.save(sess, os.path.join(save_dir, "s3"))1040      self.assertEqual([s2, s3], save.last_checkpoints)1041      self.assertEqual(0, len(gfile.Glob(s1 + "*")))1042      self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))1043      if save._write_version is saver_pb2.SaverDef.V1:1044        self.assertEqual(2, len(gfile.Glob(s2)))1045      else:1046        self.assertEqual(4, len(gfile.Glob(s2 + "*")))1047      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))1048      if save._write_version is saver_pb2.SaverDef.V1:1049        self.assertEqual(2, len(gfile.Glob(s3)))1050      else:1051        self.assertEqual(4, len(gfile.Glob(s3 + "*")))1052      self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))1053  def testNoMaxToKeep(self):1054    save_dir = self._get_test_dir("no_max_to_keep")1055    save_dir2 = self._get_test_dir("max_to_keep_0")1056    with self.test_session() as sess:1057      v = variables.Variable(10.0, name="v")1058      variables.global_variables_initializer().run()1059      # Test max_to_keep being None.1060      save = saver_module.Saver({"v": v}, max_to_keep=None)1061      self.assertEqual([], save.last_checkpoints)1062      s1 = save.save(sess, os.path.join(save_dir, "s1"))1063      self.assertEqual([], save.last_checkpoints)1064      self.assertTrue(saver_module.checkpoint_exists(s1))1065      s2 = save.save(sess, os.path.join(save_dir, "s2"))1066      self.assertEqual([], save.last_checkpoints)1067      self.assertTrue(saver_module.checkpoint_exists(s2))1068      # Test max_to_keep being 0.1069      save2 = saver_module.Saver({"v": v}, max_to_keep=0)1070      self.assertEqual([], save2.last_checkpoints)1071      s1 = save2.save(sess, os.path.join(save_dir2, "s1"))1072      self.assertEqual([], save2.last_checkpoints)1073      self.assertTrue(saver_module.checkpoint_exists(s1))1074      s2 = save2.save(sess, os.path.join(save_dir2, "s2"))1075      self.assertEqual([], save2.last_checkpoints)1076      self.assertTrue(saver_module.checkpoint_exists(s2))1077  def testNoMetaGraph(self):1078    save_dir = self._get_test_dir("no_meta_graph")1079    with self.test_session() as sess:1080      v = variables.Variable(10.0, name="v")1081      save = saver_module.Saver({"v": v})1082      variables.global_variables_initializer().run()1083      s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)1084      self.assertTrue(saver_module.checkpoint_exists(s1))1085      self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))1086class KeepCheckpointEveryNHoursTest(test.TestCase):1087  def _get_test_dir(self, dirname):1088    test_dir = os.path.join(self.get_temp_dir(), dirname)1089    gfile.MakeDirs(test_dir)1090    return test_dir1091  @test.mock.patch.object(saver_module, "time")1092  def testNonSharded(self, mock_time):1093    save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")1094    with self.test_session() as sess:1095      v = variables.Variable([10.0], name="v")1096      # Run the initializer NOW to avoid the 0.5s overhead of the first Run()1097      # call, which throws the test timing off in fastbuild mode.1098      variables.global_variables_initializer().run()1099      # Create a saver that will keep the last 2 checkpoints plus one every 0.71100      # seconds.1101      start_time = time.time()1102      mock_time.time.return_value = start_time1103      save = saver_module.Saver(1104          {1105              "v": v1106          }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600)1107      self.assertEqual([], save.last_checkpoints)1108      # Wait till 1 seconds have elapsed so s1 will be old enough to keep.1109      # sleep may return early, don't trust it.1110      mock_time.time.return_value = start_time + 1.01111      s1 = save.save(sess, os.path.join(save_dir, "s1"))1112      self.assertEqual([s1], save.last_checkpoints)1113      s2 = save.save(sess, os.path.join(save_dir, "s2"))1114      self.assertEqual([s1, s2], save.last_checkpoints)1115      # We now have 2 'last_checkpoints': [s1, s2].  The next call to Save(),1116      # would normally delete s1, because max_to_keep is 2.  However, s1 is1117      # older than 0.7s so we must keep it.1118      s3 = save.save(sess, os.path.join(save_dir, "s3"))1119      self.assertEqual([s2, s3], save.last_checkpoints)1120      # s1 should still be here, we are Not checking now to reduce time1121      # variance in the test.1122      # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk.  The next1123      # call to Save(), will delete s2, because max_to_keep is 2, and because1124      # we already kept the old s1. s2 is very close in time to s1 so it gets1125      # deleted.1126      s4 = save.save(sess, os.path.join(save_dir, "s4"))1127      self.assertEqual([s3, s4], save.last_checkpoints)1128      # Check that s1 is still here, but s2 is gone.1129      self.assertTrue(saver_module.checkpoint_exists(s1))1130      self.assertFalse(saver_module.checkpoint_exists(s2))1131      self.assertTrue(saver_module.checkpoint_exists(s3))1132      self.assertTrue(saver_module.checkpoint_exists(s4))1133class SaveRestoreWithVariableNameMap(test.TestCase):1134  def _testNonReshape(self, variable_op):1135    save_path = os.path.join(self.get_temp_dir(), "non_reshape")1136    with self.test_session(graph=ops_lib.Graph()) as sess:1137      # Build a graph with 2 parameter nodes, and Save and1138      # Restore nodes for them.1139      v0 = variable_op(10.0, name="v0")1140      v1 = variable_op(20.0, name="v1")1141      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})1142      self.evaluate(variables.global_variables_initializer())1143      # Check that the parameter nodes have been initialized.1144      self.assertEqual(10.0, self.evaluate(v0))1145      self.assertEqual(20.0, self.evaluate(v1))1146      # Save the initialized values in the file at "save_path"1147      # Use a variable name map to set the saved tensor names1148      val = save.save(sess, save_path)1149      self.assertTrue(isinstance(val, six.string_types))1150      self.assertEqual(save_path, val)1151      # Verify that the original names are not in the Saved file1152      save = saver_module.Saver({"v0": v0, "v1": v1})1153      with self.assertRaisesOpError("not found in checkpoint"):1154        save.restore(sess, save_path)1155    # Verify that the mapped names are present in the Saved file and can be1156    # Restored using remapped names.1157    with self.test_session(graph=ops_lib.Graph()) as sess:1158      v0 = variable_op(-1.0, name="v0")1159      v1 = variable_op(-1.0, name="v1")1160      if context.in_graph_mode():1161        with self.assertRaisesOpError("uninitialized"):1162          self.evaluate(v0)1163        with self.assertRaisesOpError("uninitialized"):1164          self.evaluate(v1)1165      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})1166      save.restore(sess, save_path)1167      # Check that the parameter nodes have been restored.1168      if context.in_graph_mode():1169        self.assertEqual(10.0, self.evaluate(v0))1170        self.assertEqual(20.0, self.evaluate(v1))1171    # Add a prefix to the node names in the current graph and Restore using1172    # remapped names.1173    with self.test_session(graph=ops_lib.Graph()) as sess:1174      v0 = variable_op(-1.0, name="restore_prefix/v0")1175      v1 = variable_op(-1.0, name="restore_prefix/v1")1176      if context.in_graph_mode():1177        with self.assertRaisesOpError("uninitialized"):1178          self.evaluate(v0)1179        with self.assertRaisesOpError("uninitialized"):1180          self.evaluate(v1)1181      # Restore the saved values in the parameter nodes.1182      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})1183      save.restore(sess, save_path)1184      # Check that the parameter nodes have been restored.1185      self.assertEqual(10.0, self.evaluate(v0))1186      self.assertEqual(20.0, self.evaluate(v1))1187  @test_util.run_in_graph_and_eager_modes()1188  def testNonReshapeResourceVariable(self):1189    self._testNonReshape(resource_variable_ops.ResourceVariable)1190  def testNonReshapeVariable(self):1191    self._testNonReshape(variables.Variable)1192class LatestCheckpointWithRelativePaths(test.TestCase):1193  @staticmethod1194  @contextlib.contextmanager1195  def tempWorkingDir(temppath):1196    cwd = os.getcwd()1197    os.chdir(temppath)1198    try:1199      yield1200    finally:1201      os.chdir(cwd)1202  @staticmethod1203  @contextlib.contextmanager1204  def tempDir():1205    tempdir = tempfile.mkdtemp()1206    try:1207      yield tempdir1208    finally:1209      shutil.rmtree(tempdir)1210  def testNameCollision(self):1211    # Make sure we have a clean directory to work in.1212    with self.tempDir() as tempdir:1213      # Jump to that directory until this test is done.1214      with self.tempWorkingDir(tempdir):1215        # Save training snapshots to a relative path.1216        traindir = "train/"1217        os.mkdir(traindir)1218        # Collides with the default name of the checkpoint state file.1219        filepath = os.path.join(traindir, "checkpoint")1220        with self.test_session() as sess:1221          unused_a = variables.Variable(0.0)  # So that Saver saves something.1222          variables.global_variables_initializer().run()1223          # Should fail.1224          saver = saver_module.Saver(sharded=False)1225          with self.assertRaisesRegexp(ValueError, "collides with"):1226            saver.save(sess, filepath)1227          # Succeeds: the file will be named "checkpoint-<step>".1228          saver.save(sess, filepath, global_step=1)1229          self.assertIsNotNone(saver_module.latest_checkpoint(traindir))1230          # Succeeds: the file will be named "checkpoint-<i>-of-<n>".1231          saver = saver_module.Saver(sharded=True)1232          saver.save(sess, filepath)1233          self.assertIsNotNone(saver_module.latest_checkpoint(traindir))1234          # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".1235          saver = saver_module.Saver(sharded=True)1236          saver.save(sess, filepath, global_step=1)1237          self.assertIsNotNone(saver_module.latest_checkpoint(traindir))1238  def testRelativePath(self):1239    # Make sure we have a clean directory to work in.1240    with self.tempDir() as tempdir:1241      # Jump to that directory until this test is done.1242      with self.tempWorkingDir(tempdir):1243        # Save training snapshots to a relative path.1244        traindir = "train/"1245        os.mkdir(traindir)1246        filename = "snapshot"1247        filepath = os.path.join(traindir, filename)1248        with self.test_session() as sess:1249          # Build a simple graph.1250          v0 = variables.Variable(0.0)1251          inc = v0.assign_add(1.0)1252          save = saver_module.Saver({"v0": v0})1253          # Record a short training history.1254          variables.global_variables_initializer().run()1255          save.save(sess, filepath, global_step=0)1256          inc.eval()1257          save.save(sess, filepath, global_step=1)1258          inc.eval()1259          save.save(sess, filepath, global_step=2)1260        with self.test_session() as sess:1261          # Build a new graph with different initialization.1262          v0 = variables.Variable(-1.0)1263          # Create a new saver.1264          save = saver_module.Saver({"v0": v0})1265          variables.global_variables_initializer().run()1266          # Get the most recent checkpoint name from the training history file.1267          name = saver_module.latest_checkpoint(traindir)1268          self.assertIsNotNone(name)1269          # Restore "v0" from that checkpoint.1270          save.restore(sess, name)1271          self.assertEqual(v0.eval(), 2.0)1272class CheckpointStateTest(test.TestCase):1273  def _get_test_dir(self, dirname):1274    test_dir = os.path.join(self.get_temp_dir(), dirname)1275    gfile.MakeDirs(test_dir)1276    return test_dir1277  def testAbsPath(self):1278    save_dir = self._get_test_dir("abs_paths")1279    abs_path = os.path.join(save_dir, "model-0")1280    ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)1281    self.assertEqual(ckpt.model_checkpoint_path, abs_path)1282    self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))1283    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)1284    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)1285  def testRelPath(self):1286    train_dir = "train"1287    model = os.path.join(train_dir, "model-0")1288    # model_checkpoint_path should have no "train" directory part.1289    new_rel_path = "model-0"1290    ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model)1291    self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)1292    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)1293    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)1294  def testAllModelCheckpointPaths(self):1295    save_dir = self._get_test_dir("all_models_test")1296    abs_path = os.path.join(save_dir, "model-0")1297    for paths in [None, [], ["model-2"]]:1298      ckpt = saver_module.generate_checkpoint_state_proto(1299          save_dir, abs_path, all_model_checkpoint_paths=paths)1300      self.assertEqual(ckpt.model_checkpoint_path, abs_path)1301      self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))1302      self.assertEqual(1303          len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)1304      self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)1305  def testUpdateCheckpointState(self):1306    save_dir = self._get_test_dir("update_checkpoint_state")1307    os.chdir(save_dir)1308    # Make a temporary train directory.1309    train_dir = "train"1310    os.mkdir(train_dir)1311    abs_path = os.path.join(save_dir, "model-0")1312    rel_path = os.path.join("train", "model-2")1313    saver_module.update_checkpoint_state(1314        train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])1315    ckpt = saver_module.get_checkpoint_state(train_dir)1316    self.assertEqual(ckpt.model_checkpoint_path, rel_path)1317    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)1318    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)1319    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)1320  def testUpdateCheckpointStateSaveRelativePaths(self):1321    save_dir = self._get_test_dir("update_checkpoint_state")1322    os.chdir(save_dir)1323    abs_path2 = os.path.join(save_dir, "model-2")1324    rel_path2 = "model-2"1325    abs_path0 = os.path.join(save_dir, "model-0")1326    rel_path0 = "model-0"1327    saver_module._update_checkpoint_state(  # pylint: disable=protected-access1328        save_dir=save_dir,1329        model_checkpoint_path=abs_path2,1330        all_model_checkpoint_paths=[rel_path0, abs_path2],1331        save_relative_paths=True)1332    # File should contain relative paths.1333    file_content = file_io.read_file_to_string(1334        os.path.join(save_dir, "checkpoint"))1335    ckpt = CheckpointState()1336    text_format.Merge(file_content, ckpt)1337    self.assertEqual(ckpt.model_checkpoint_path, rel_path2)1338    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)1339    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)1340    self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)1341    # get_checkpoint_state should return absolute paths.1342    ckpt = saver_module.get_checkpoint_state(save_dir)1343    self.assertEqual(ckpt.model_checkpoint_path, abs_path2)1344    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)1345    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)1346    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)1347  def testCheckPointStateFailsWhenIncomplete(self):1348    save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")1349    os.chdir(save_dir)1350    ckpt_path = os.path.join(save_dir, "checkpoint")1351    ckpt_file = open(ckpt_path, "w")1352    ckpt_file.write("")1353    ckpt_file.close()1354    with self.assertRaises(ValueError):1355      saver_module.get_checkpoint_state(save_dir)1356  def testCheckPointCompletesRelativePaths(self):1357    save_dir = self._get_test_dir("checkpoint_completes_relative_paths")1358    os.chdir(save_dir)1359    ckpt_path = os.path.join(save_dir, "checkpoint")1360    ckpt_file = open(ckpt_path, "w")1361    ckpt_file.write("""1362        model_checkpoint_path: "./model.ckpt-687529"1363        all_model_checkpoint_paths: "./model.ckpt-687500"1364        all_model_checkpoint_paths: "./model.ckpt-687529"1365        """)1366    ckpt_file.close()1367    ckpt = saver_module.get_checkpoint_state(save_dir)1368    self.assertEqual(ckpt.model_checkpoint_path,1369                     os.path.join(save_dir, "./model.ckpt-687529"))1370    self.assertEqual(ckpt.all_model_checkpoint_paths[0],1371                     os.path.join(save_dir, "./model.ckpt-687500"))1372    self.assertEqual(ckpt.all_model_checkpoint_paths[1],1373                     os.path.join(save_dir, "./model.ckpt-687529"))1374class MetaGraphTest(test.TestCase):1375  def _get_test_dir(self, dirname):1376    test_dir = os.path.join(self.get_temp_dir(), dirname)1377    gfile.MakeDirs(test_dir)1378    return test_dir1379  def testAddCollectionDef(self):1380    test_dir = self._get_test_dir("good_collection")1381    filename = os.path.join(test_dir, "metafile")1382    with self.test_session():1383      # Creates a graph.1384      v0 = variables.Variable(1.0, name="v0")1385      control_flow_ops.cond(1386          math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),1387          lambda: math_ops.subtract(v0, 1))1388      control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),1389                                  lambda i: math_ops.add(i, 1), [v0])1390      var = variables.Variable(constant_op.constant(0, dtype=dtypes.int64))1391      count_up_to = var.count_up_to(3)1392      input_queue = data_flow_ops.FIFOQueue(1393          30, dtypes.float32, shared_name="collection_queue")1394      qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to])1395      variables.global_variables_initializer()1396      # Creates a saver.1397      save = saver_module.Saver({"v0": v0})1398      # Adds a set of collections.1399      ops_lib.add_to_collection("int_collection", 3)1400      ops_lib.add_to_collection("float_collection", 3.5)1401      ops_lib.add_to_collection("string_collection", "hello")1402      ops_lib.add_to_collection("variable_collection", v0)1403      # Add QueueRunners.1404      queue_runner_impl.add_queue_runner(qr)1405      # Adds user_defined proto in three formats: string, bytes and Any.1406      queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")1407      ops_lib.add_to_collection("user_defined_string_collection",1408                                str(queue_runner))1409      ops_lib.add_to_collection("user_defined_bytes_collection",1410                                queue_runner.SerializeToString())1411      any_buf = Any()1412      any_buf.Pack(queue_runner)1413      ops_lib.add_to_collection("user_defined_any_collection", any_buf)1414      # Generates MetaGraphDef.1415      meta_graph_def = save.export_meta_graph(filename)1416      self.assertTrue(meta_graph_def.HasField("saver_def"))1417      self.assertTrue(meta_graph_def.HasField("graph_def"))1418      self.assertTrue(meta_graph_def.HasField("meta_info_def"))1419      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")1420      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,1421                          "")1422      collection_def = meta_graph_def.collection_def1423      self.assertEqual(len(collection_def), 12)1424    with ops_lib.Graph().as_default():1425      # Restores from MetaGraphDef.1426      new_saver = saver_module.import_meta_graph(filename)1427      # Generates a new MetaGraphDef.1428      new_meta_graph_def = new_saver.export_meta_graph()1429      # It should be the same as the original.1430    test_util.assert_meta_graph_protos_equal(1431        self, meta_graph_def, new_meta_graph_def)1432  def testAddCollectionDefFails(self):1433    with self.test_session():1434      # Creates a graph.1435      v0 = variables.Variable(10.0, name="v0")1436      # Creates a saver.1437      save = saver_module.Saver({"v0": v0})1438      # Generates MetaGraphDef.1439      meta_graph_def = meta_graph_pb2.MetaGraphDef()1440      # Verifies that collection with unsupported key will not be added.1441      ops_lib.add_to_collection(save, 3)1442      save._add_collection_def(meta_graph_def, save)1443      self.assertEqual(len(meta_graph_def.collection_def), 0)1444      # Verifies that collection where item type does not match expected1445      # type will not be added.1446      ops_lib.add_to_collection("int_collection", 3)1447      ops_lib.add_to_collection("int_collection", 3.5)1448      save._add_collection_def(meta_graph_def, "int_collection")1449      self.assertEqual(len(meta_graph_def.collection_def), 0)1450  def _testMultiSaverCollectionSave(self, test_dir):1451    filename = os.path.join(test_dir, "metafile")1452    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1453    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")1454    with self.test_session(graph=ops_lib.Graph()) as sess:1455      # Creates a graph.1456      v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")1457      v1 = variables.Variable(11.0, name="v1")1458      # Creates 2 savers.1459      saver0 = saver_module.Saver({"v0": v0}, name="saver0")1460      saver1 = saver_module.Saver({"v1": v1}, name="saver1")1461      ops_lib.add_to_collection("savers", saver0)1462      ops_lib.add_to_collection("savers", saver1)1463      variables.global_variables_initializer().run()1464      # Saves to different checkpoints.1465      saver0.save(sess, saver0_ckpt)1466      saver1.save(sess, saver1_ckpt)1467      # Generates MetaGraphDef.1468      meta_graph_def = saver_module.export_meta_graph(filename)1469      meta_graph_def0 = saver0.export_meta_graph()1470      meta_graph_def1 = saver1.export_meta_graph()1471      # Verifies that there is no saver_def in meta_graph_def.1472      self.assertFalse(meta_graph_def.HasField("saver_def"))1473      # Verifies that there is saver_def in meta_graph_def0 and 1.1474      self.assertTrue(meta_graph_def0.HasField("saver_def"))1475      self.assertTrue(meta_graph_def1.HasField("saver_def"))1476      # Verifies SAVERS is saved as bytes_list for meta_graph_def.1477      collection_def = meta_graph_def.collection_def["savers"]1478      kind = collection_def.WhichOneof("kind")1479      self.assertEqual(kind, "bytes_list")1480      # Verifies that there are 2 entries in SAVERS collection.1481      savers = getattr(collection_def, kind)1482      self.assertEqual(2, len(savers.value))1483      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0.1484      collection_def = meta_graph_def0.collection_def["savers"]1485      kind = collection_def.WhichOneof("kind")1486      self.assertEqual(kind, "bytes_list")1487      # Verifies that there are 2 entries in SAVERS collection.1488      savers = getattr(collection_def, kind)1489      self.assertEqual(2, len(savers.value))1490  def _testMultiSaverCollectionRestore(self, test_dir):1491    filename = os.path.join(test_dir, "metafile")1492    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1493    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")1494    with self.test_session(graph=ops_lib.Graph()) as sess:1495      # Imports from meta_graph.1496      saver_module.import_meta_graph(filename)1497      # Retrieves SAVERS collection. Verifies there are 2 entries.1498      savers = ops_lib.get_collection("savers")1499      self.assertEqual(2, len(savers))1500      # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1.1501      new_saver0 = savers[0]1502      new_saver0.restore(sess, saver0_ckpt)1503      v0 = sess.graph.get_tensor_by_name("v0:0")1504      v1 = sess.graph.get_tensor_by_name("v1:0")1505      self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], v0.eval())1506      self.assertEqual([3, 2], v0.get_shape())1507      self.assertEqual([], v1.get_shape())1508      with self.assertRaisesWithPredicateMatch(1509          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):1510        sess.run(v1)1511      # Retrieves saver1. Verifies that new_saver1 can restore v1.1512      new_saver1 = savers[1]1513      new_saver1.restore(sess, saver1_ckpt)1514      v1 = sess.graph.get_tensor_by_name("v1:0")1515      self.assertEqual(11.0, v1.eval())1516  def testMultiSaverCollection(self):1517    test_dir = self._get_test_dir("saver_collection")1518    self._testMultiSaverCollectionSave(test_dir)1519    self._testMultiSaverCollectionRestore(test_dir)1520  def testClearExtraneousSavers(self):1521    test_dir = self._get_test_dir("clear_extraneous_savers")1522    filename = os.path.join(test_dir, "metafile")1523    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1524    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")1525    with self.test_session(graph=ops_lib.Graph()) as sess:1526      # Creates a graph.1527      v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")1528      v1 = variables.Variable(11.0, name="v1")1529      # Creates 2 savers.1530      saver0 = saver_module.Saver({"v0": v0}, name="saver0")1531      saver1 = saver_module.Saver({"v1": v1}, name="saver1")1532      ops_lib.add_to_collection("savers", saver0)1533      ops_lib.add_to_collection("savers", saver1)1534      variables.global_variables_initializer().run()1535      # Saves to different checkpoints.1536      saver0.save(sess, saver0_ckpt)1537      saver1.save(sess, saver1_ckpt)1538      # Generates MetaGraphDef.1539      meta_graph_def = saver_module.export_meta_graph(filename)1540      meta_graph_def0 = saver0.export_meta_graph()1541      meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True)1542      # Verifies that there is no saver_def in meta_graph_def.1543      self.assertFalse(meta_graph_def.HasField("saver_def"))1544      # Verifies that there is saver_def in meta_graph_def0 and 1.1545      self.assertTrue(meta_graph_def0.HasField("saver_def"))1546      self.assertTrue(meta_graph_def1.HasField("saver_def"))1547      # Verifies SAVERS is saved as bytes_list for meta_graph_def.1548      collection_def = meta_graph_def.collection_def["savers"]1549      kind = collection_def.WhichOneof("kind")1550      self.assertEqual(kind, "bytes_list")1551      # Verifies that there are 2 entries in SAVERS collection.1552      savers = getattr(collection_def, kind)1553      self.assertEqual(2, len(savers.value))1554      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1.1555      collection_def = meta_graph_def1.collection_def["savers"]1556      kind = collection_def.WhichOneof("kind")1557      self.assertEqual(kind, "bytes_list")1558      # Verifies that there is 1 entry in SAVERS collection.1559      savers = getattr(collection_def, kind)1560      self.assertEqual(1, len(savers.value))1561      # Verifies that saver0 graph nodes are omitted from the saver1 export1562      self.assertEqual(29, len(meta_graph_def0.graph_def.node))1563      self.assertEqual(19, len(meta_graph_def1.graph_def.node))1564  def testBinaryAndTextFormat(self):1565    test_dir = self._get_test_dir("binary_and_text")1566    filename = os.path.join(test_dir, "metafile")1567    with self.test_session(graph=ops_lib.Graph()):1568      # Creates a graph.1569      variables.Variable(10.0, name="v0")1570      # Exports the graph as binary format.1571      saver_module.export_meta_graph(filename, as_text=False)1572    with self.test_session(graph=ops_lib.Graph()):1573      # Imports the binary format graph.1574      saver = saver_module.import_meta_graph(filename)1575      self.assertIsNotNone(saver)1576      # Exports the graph as text format.1577      saver.export_meta_graph(filename, as_text=True)1578    with self.test_session(graph=ops_lib.Graph()):1579      # Imports the text format graph.1580      saver_module.import_meta_graph(filename)1581      # Writes wrong contents to the file.1582      graph_io.write_graph(saver.as_saver_def(),1583                           os.path.dirname(filename),1584                           os.path.basename(filename))1585    with self.test_session(graph=ops_lib.Graph()):1586      # Import should fail.1587      with self.assertRaisesWithPredicateMatch(IOError,1588                                               lambda e: "Cannot parse file"):1589        saver_module.import_meta_graph(filename)1590      # Deletes the file1591      gfile.Remove(filename)1592      with self.assertRaisesWithPredicateMatch(IOError,1593                                               lambda e: "does not exist"):1594        saver_module.import_meta_graph(filename)1595  def testSliceVariable(self):1596    test_dir = self._get_test_dir("slice_saver")1597    filename = os.path.join(test_dir, "metafile")1598    with self.test_session():1599      v1 = variables.Variable([20.0], name="v1")1600      v2 = variables.Variable([20.0], name="v2")1601      v2._set_save_slice_info(1602          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))1603      # The names are different and will work.1604      slice_saver = saver_module.Saver({"first": v1, "second": v2})1605      variables.global_variables_initializer().run()1606      # Exports to meta_graph1607      meta_graph_def = slice_saver.export_meta_graph(filename)1608    with ops_lib.Graph().as_default():1609      # Restores from MetaGraphDef.1610      new_saver = saver_module.import_meta_graph(filename)1611      self.assertIsNotNone(new_saver)1612      # Generates a new MetaGraphDef.1613      new_meta_graph_def = new_saver.export_meta_graph()1614      # It should be the same as the original.1615      self.assertProtoEquals(meta_graph_def, new_meta_graph_def)1616  def _testGraphExtensionSave(self, test_dir):1617    filename = os.path.join(test_dir, "metafile")1618    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1619    # Creates an inference graph.1620    # Hidden 11621    images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])1622    with ops_lib.name_scope("hidden1"):1623      weights = variables.Variable(1624          random_ops.truncated_normal(1625              [28, 128], stddev=1.0 / math.sqrt(float(28))),1626          name="weights")1627      # The use of control_flow_ops.cond here is purely for adding test coverage1628      # the save and restore of control flow context (which doesn't make any1629      # sense here from a machine learning perspective).  The typical biases is1630      # a simple Variable without the conditions.1631      biases = variables.Variable(1632          control_flow_ops.cond(1633              math_ops.less(random.random(), 0.5),1634              lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),1635          name="biases")1636      hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)1637    # Hidden 21638    with ops_lib.name_scope("hidden2"):1639      weights = variables.Variable(1640          random_ops.truncated_normal(1641              [128, 32], stddev=1.0 / math.sqrt(float(128))),1642          name="weights")1643      # The use of control_flow_ops.while_loop here is purely for adding test1644      # coverage the save and restore of control flow context (which doesn't1645      # make any sense here from a machine learning perspective).  The typical1646      # biases is a simple Variable without the conditions.1647      def loop_cond(it, _):1648        return it < 21649      def loop_body(it, biases):1650        biases += constant_op.constant(0.1, shape=[32])1651        return it + 1, biases1652      _, biases = control_flow_ops.while_loop(1653          loop_cond, loop_body,1654          [constant_op.constant(0), variables.Variable(array_ops.zeros([32]))])1655      hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)1656    # Linear1657    with ops_lib.name_scope("softmax_linear"):1658      weights = variables.Variable(1659          random_ops.truncated_normal(1660              [32, 10], stddev=1.0 / math.sqrt(float(32))),1661          name="weights")1662      biases = variables.Variable(array_ops.zeros([10]), name="biases")1663      logits = math_ops.matmul(hidden2, weights) + biases1664      ops_lib.add_to_collection("logits", logits)1665    init_all_op = variables.global_variables_initializer()1666    with self.test_session() as sess:1667      # Initializes all the variables.1668      sess.run(init_all_op)1669      # Runs to logit.1670      sess.run(logits)1671      # Creates a saver.1672      saver0 = saver_module.Saver()1673      saver0.save(sess, saver0_ckpt)1674      # Generates MetaGraphDef.1675      saver0.export_meta_graph(filename)1676  def _testGraphExtensionRestore(self, test_dir):1677    filename = os.path.join(test_dir, "metafile")1678    train_filename = os.path.join(test_dir, "train_metafile")1679    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1680    with self.test_session(graph=ops_lib.Graph()) as sess:1681      # Restores from MetaGraphDef.1682      new_saver = saver_module.import_meta_graph(filename)1683      # Generates a new MetaGraphDef.1684      new_saver.export_meta_graph()1685      # Restores from checkpoint.1686      new_saver.restore(sess, saver0_ckpt)1687      # Adds loss and train.1688      labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels")1689      batch_size = array_ops.size(labels)1690      labels = array_ops.expand_dims(labels, 1)1691      indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1)1692      concated = array_ops.concat([indices, labels], 1)1693      onehot_labels = sparse_ops.sparse_to_dense(1694          concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)1695      logits = ops_lib.get_collection("logits")[0]1696      cross_entropy = nn_ops.softmax_cross_entropy_with_logits(1697          labels=onehot_labels, logits=logits, name="xentropy")1698      loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")1699      summary.scalar("loss", loss)1700      # Creates the gradient descent optimizer with the given learning rate.1701      optimizer = gradient_descent.GradientDescentOptimizer(0.01)1702      # Runs train_op.1703      train_op = optimizer.minimize(loss)1704      ops_lib.add_to_collection("train_op", train_op)1705      # Runs train_op.1706      sess.run(train_op)1707      # Generates MetaGraphDef.1708      saver_module.export_meta_graph(train_filename)1709  def _testRestoreFromTrainGraphWithControlContext(self, test_dir):1710    train_filename = os.path.join(test_dir, "train_metafile")1711    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1712    with self.test_session(graph=ops_lib.Graph()) as sess:1713      # Restores from MetaGraphDef.1714      new_saver = saver_module.import_meta_graph(train_filename)1715      # Restores from checkpoint.1716      new_saver.restore(sess, saver0_ckpt)1717      train_op = ops_lib.get_collection("train_op")[0]1718      sess.run(train_op)1719  def testGraphExtension(self):1720    test_dir = self._get_test_dir("graph_extension")1721    self._testGraphExtensionSave(test_dir)1722    self._testGraphExtensionRestore(test_dir)1723    self._testRestoreFromTrainGraphWithControlContext(test_dir)1724  def testStrippedOpListDef(self):1725    with self.test_session():1726      # Creates a graph.1727      v0 = variables.Variable(0.0)1728      var = variables.Variable(10.0)1729      math_ops.add(v0, var)1730      @function.Defun(dtypes.float32)1731      def minus_one(x):1732        return x - 11733      minus_one(array_ops.identity(v0))1734      save = saver_module.Saver({"v0": v0})1735      variables.global_variables_initializer()1736      # Generates MetaGraphDef.1737      meta_graph_def = save.export_meta_graph()1738      ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]1739      if save._write_version is saver_pb2.SaverDef.V1:1740        self.assertEqual(ops, [1741            "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2",1742            "SaveSlices", "Sub", "VariableV2"1743        ])1744      else:1745        self.assertEqual(ops, [1746            "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "SaveV2",1747            "Sub", "VariableV2"1748        ])1749      # Test calling stripped_op_list_for_graph directly1750      op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)1751      self.assertEqual(ops, [o.name for o in op_list.op])1752      for o in op_list.op:1753        self.assertEqual(o.summary, "")1754        self.assertEqual(o.description, "")1755  def testImportIntoNamescope(self):1756    # Test that we can import a meta graph into a namescope.1757    test_dir = self._get_test_dir("import_into_namescope")1758    filename = os.path.join(test_dir, "ckpt")1759    image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")1760    label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")1761    with session.Session() as sess:1762      weights = variables.Variable(1763          random_ops.random_uniform([784, 10]), name="weights")1764      bias = variables.Variable(array_ops.zeros([10]), name="bias")1765      logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")1766      nn_ops.softmax(logit, name="prediction")1767      cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,1768                                                      logits=logit, name="cost")1769      adam.AdamOptimizer().minimize(cost, name="optimize")1770      saver = saver_module.Saver()1771      sess.run(variables.global_variables_initializer())1772      saver.save(sess, filename)1773    graph = ops_lib.Graph()1774    with session.Session(graph=graph) as sess:1775      new_saver = saver_module.import_meta_graph(1776          filename + ".meta", graph=graph, import_scope="new_model")1777      new_saver.restore(sess, filename)1778      sess.run(["new_model/optimize"], {1779          "new_model/image:0": np.random.random([1, 784]),1780          "new_model/label:0": np.random.randint(1781              10, size=[1, 10])1782      })1783  def testClearDevicesOnImport(self):1784    # Test that we import a graph without its devices and run successfully.1785    with ops_lib.Graph().as_default():1786      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):1787        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")1788        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")1789        weights = variables.Variable(1790            random_ops.random_uniform([784, 10]), name="weights")1791        bias = variables.Variable(array_ops.zeros([10]), name="bias")1792        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)1793        nn_ops.softmax(logit, name="prediction")1794        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,1795                                                        logits=logit)1796        adam.AdamOptimizer().minimize(cost, name="optimize")1797      meta_graph_def = saver_module.export_meta_graph()1798    with session.Session(graph=ops_lib.Graph()) as sess:1799      saver_module.import_meta_graph(1800          meta_graph_def, clear_devices=False, import_scope="new_model")1801      # Device refers to GPU, which is not available here.1802      with self.assertRaises(errors_impl.InvalidArgumentError):1803        sess.run(variables.global_variables_initializer())1804    with session.Session(graph=ops_lib.Graph()) as sess:1805      saver_module.import_meta_graph(1806          meta_graph_def, clear_devices=True, import_scope="new_model")1807      sess.run(variables.global_variables_initializer())1808      sess.run(["new_model/optimize"], {1809          "new_model/image:0": np.random.random([1, 784]),1810          "new_model/label:0": np.random.randint(1811              10, size=[1, 10])1812      })1813  def testClearDevicesOnExport(self):1814    # Test that we export a graph without its devices and run successfully.1815    with ops_lib.Graph().as_default():1816      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):1817        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")1818        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")1819        weights = variables.Variable(1820            random_ops.random_uniform([784, 10]), name="weights")1821        bias = variables.Variable(array_ops.zeros([10]), name="bias")1822        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)1823        nn_ops.softmax(logit, name="prediction")1824        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,1825                                                        logits=logit)1826        adam.AdamOptimizer().minimize(cost, name="optimize")1827      meta_graph_def = saver_module.export_meta_graph(clear_devices=True)1828      graph_io.write_graph(meta_graph_def, self.get_temp_dir(),1829                           "meta_graph.pbtxt")1830    with session.Session(graph=ops_lib.Graph()) as sess:1831      saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")1832      sess.run(variables.global_variables_initializer())1833      sess.run(["new_model/optimize"], {1834          "new_model/image:0": np.random.random([1, 784]),1835          "new_model/label:0": np.random.randint(1836              10, size=[1, 10])1837      })1838class CheckpointReaderTest(test.TestCase):1839  _WRITE_VERSION = saver_pb2.SaverDef.V11840  def testDebugString(self):1841    # Builds a graph.1842    v0 = variables.Variable(1843        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")1844    v1 = variables.Variable(1845        [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")1846    init_all_op = variables.global_variables_initializer()1847    save = saver_module.Saver(1848        {1849            "v0": v0,1850            "v1": v11851        }, write_version=self._WRITE_VERSION)1852    save_path = os.path.join(self.get_temp_dir(),1853                             "ckpt_for_debug_string" + str(self._WRITE_VERSION))1854    with self.test_session() as sess:1855      sess.run(init_all_op)1856      # Saves a checkpoint.1857      save.save(sess, save_path)1858      # Creates a reader.1859      reader = pywrap_tensorflow.NewCheckpointReader(save_path)1860      # Verifies that the tensors exist.1861      self.assertTrue(reader.has_tensor("v0"))1862      self.assertTrue(reader.has_tensor("v1"))1863      debug_string = reader.debug_string()1864      # Verifies that debug string contains the right strings.1865      self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string)1866      self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string)1867      # Verifies get_variable_to_shape_map() returns the correct information.1868      var_map = reader.get_variable_to_shape_map()1869      self.assertEqual([2, 3], var_map["v0"])1870      self.assertEqual([3, 2, 1], var_map["v1"])1871      # Verifies get_tensor() returns the tensor value.1872      v0_tensor = reader.get_tensor("v0")1873      v1_tensor = reader.get_tensor("v1")1874      self.assertAllEqual(v0.eval(), v0_tensor)1875      self.assertAllEqual(v1.eval(), v1_tensor)1876      # Verifies get_tensor() fails for non-existent tensors.1877      with self.assertRaisesRegexp(errors.NotFoundError,1878                                   "v3 not found in checkpoint"):1879        reader.get_tensor("v3")1880  def testNonexistentPath(self):1881    with self.assertRaisesRegexp(errors.NotFoundError,1882                                 "Unsuccessful TensorSliceReader"):1883      pywrap_tensorflow.NewCheckpointReader("non-existent")1884class CheckpointReaderForV2Test(CheckpointReaderTest):1885  _WRITE_VERSION = saver_pb2.SaverDef.V21886class WriteGraphTest(test.TestCase):1887  def _get_test_dir(self, dirname):1888    test_dir = os.path.join(self.get_temp_dir(), dirname)1889    gfile.MakeDirs(test_dir)1890    return test_dir1891  def testWriteGraph(self):1892    test_dir = self._get_test_dir("write_graph_dir")1893    variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")1894    path = graph_io.write_graph(ops_lib.get_default_graph(),1895                                os.path.join(test_dir, "l1"), "graph.pbtxt")1896    truth = os.path.join(test_dir, "l1", "graph.pbtxt")1897    self.assertEqual(path, truth)1898    self.assertTrue(os.path.exists(path))1899  def testRecursiveCreate(self):1900    test_dir = self._get_test_dir("deep_dir")1901    variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")1902    path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),1903                                os.path.join(test_dir, "l1", "l2", "l3"),1904                                "graph.pbtxt")1905    truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt")1906    self.assertEqual(path, truth)1907    self.assertTrue(os.path.exists(path))1908class SaverUtilsTest(test.TestCase):1909  def setUp(self):1910    self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")1911    gfile.MakeDirs(self._base_dir)1912  def tearDown(self):1913    gfile.DeleteRecursively(self._base_dir)1914  def testCheckpointExists(self):1915    for sharded in (False, True):1916      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):1917        with self.test_session(graph=ops_lib.Graph()) as sess:1918          unused_v = variables.Variable(1.0, name="v")1919          variables.global_variables_initializer().run()1920          saver = saver_module.Saver(sharded=sharded, write_version=version)1921          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))1922          self.assertFalse(1923              saver_module.checkpoint_exists(path))  # Not saved yet.1924          ckpt_prefix = saver.save(sess, path)1925          self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))1926          ckpt_prefix = saver_module.latest_checkpoint(self._base_dir)1927          self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))1928  def testGetCheckpointMtimes(self):1929    prefixes = []1930    for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):1931      with self.test_session(graph=ops_lib.Graph()) as sess:1932        unused_v = variables.Variable(1.0, name="v")1933        variables.global_variables_initializer().run()1934        saver = saver_module.Saver(write_version=version)1935        prefixes.append(1936            saver.save(sess, os.path.join(self._base_dir, str(version))))1937    mtimes = saver_module.get_checkpoint_mtimes(prefixes)1938    self.assertEqual(2, len(mtimes))1939    self.assertTrue(mtimes[1] >= mtimes[0])1940class ScopedGraphTest(test.TestCase):1941  def _get_test_dir(self, dirname):1942    test_dir = os.path.join(self.get_temp_dir(), dirname)1943    gfile.MakeDirs(test_dir)1944    return test_dir1945  def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):1946    graph = ops_lib.Graph()1947    with graph.as_default():1948      # Creates an inference graph.1949      # Hidden 11950      images = constant_op.constant(1951          1.2, dtypes.float32, shape=[100, 28], name="images")1952      with ops_lib.name_scope("hidden1"):1953        weights1 = variables.Variable(1954            random_ops.truncated_normal(1955                [28, 128], stddev=1.0 / math.sqrt(float(28))),1956            name="weights")1957        # The use of control_flow_ops.cond here is purely for adding test1958        # coverage the save and restore of control flow context (which doesn't1959        # make any sense here from a machine learning perspective).  The typical1960        # biases is a simple Variable without the conditions.1961        biases1 = variables.Variable(1962            control_flow_ops.cond(1963                math_ops.less(random.random(), 0.5),1964                lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),1965            name="biases")1966        hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)1967      # Hidden 21968      with ops_lib.name_scope("hidden2"):1969        weights2 = variables.Variable(1970            random_ops.truncated_normal(1971                [128, 32], stddev=1.0 / math.sqrt(float(128))),1972            name="weights")1973        # The use of control_flow_ops.while_loop here is purely for adding test1974        # coverage the save and restore of control flow context (which doesn't1975        # make any sense here from a machine learning perspective).  The typical1976        # biases is a simple Variable without the conditions.1977        def loop_cond(it, _):1978          return it < 21979        def loop_body(it, biases2):1980          biases2 += constant_op.constant(0.1, shape=[32])1981          return it + 1, biases21982        _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [1983            constant_op.constant(0), variables.Variable(array_ops.zeros([32]))1984        ])1985        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)1986      # Linear1987      with ops_lib.name_scope("softmax_linear"):1988        weights3 = variables.Variable(1989            random_ops.truncated_normal(1990                [32, 10], stddev=1.0 / math.sqrt(float(32))),1991            name="weights")1992        biases3 = variables.Variable(array_ops.zeros([10]), name="biases")1993        logits = math_ops.matmul(hidden2, weights3) + biases31994        ops_lib.add_to_collection("logits", logits)1995        # Adds user_defined proto in three formats: string, bytes and Any.1996        # Any proto should just pass through.1997        queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")1998        ops_lib.add_to_collection("user_defined_string_collection",1999                                  str(queue_runner))2000        ops_lib.add_to_collection("user_defined_bytes_collection",2001                                  queue_runner.SerializeToString())2002        any_buf = Any()2003        any_buf.Pack(queue_runner)2004        ops_lib.add_to_collection("user_defined_any_collection", any_buf)2005      _, var_list = meta_graph.export_scoped_meta_graph(2006          filename=os.path.join(test_dir, exported_filename),2007          graph=ops_lib.get_default_graph(),2008          export_scope="hidden1")2009      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))2010    with self.test_session(graph=graph) as sess:2011      sess.run(variables.global_variables_initializer())2012      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)2013      saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)2014  def _testScopedRestore(self, test_dir, exported_filename,2015                         new_exported_filename, ckpt_filename):2016    graph = ops_lib.Graph()2017    # Create all the missing inputs.2018    with graph.as_default():2019      new_image = constant_op.constant(2020          1.2, dtypes.float32, shape=[100, 28], name="images")2021    var_list = meta_graph.import_scoped_meta_graph(2022        os.path.join(test_dir, exported_filename),2023        graph=graph,2024        input_map={"$unbound_inputs_images": new_image},2025        import_scope="new_hidden1")2026    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))2027    hidden1 = graph.as_graph_element("new_hidden1/Relu:0")2028    weights1 = graph.as_graph_element("new_hidden1/weights:0")2029    biases1 = graph.as_graph_element("new_hidden1/biases:0")2030    with graph.as_default():2031      # Hidden 22032      with ops_lib.name_scope("hidden2"):2033        weights = variables.Variable(2034            random_ops.truncated_normal(2035                [128, 32], stddev=1.0 / math.sqrt(float(128))),2036            name="weights")2037        # The use of control_flow_ops.while_loop here is purely for adding test2038        # coverage the save and restore of control flow context (which doesn't2039        # make any sense here from a machine learning perspective).  The typical2040        # biases is a simple Variable without the conditions.2041        def loop_cond(it, _):2042          return it < 22043        def loop_body(it, biases):2044          biases += constant_op.constant(0.1, shape=[32])2045          return it + 1, biases2046        _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [2047            constant_op.constant(0), variables.Variable(array_ops.zeros([32]))2048        ])2049        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)2050      # Linear2051      with ops_lib.name_scope("softmax_linear"):2052        weights = variables.Variable(2053            random_ops.truncated_normal(2054                [32, 10], stddev=1.0 / math.sqrt(float(32))),2055            name="weights")2056        biases = variables.Variable(array_ops.zeros([10]), name="biases")2057        logits = math_ops.matmul(hidden2, weights) + biases2058        ops_lib.add_to_collection("logits", logits)2059      # The rest of the variables.2060      rest_variables = list(2061          set(variables.global_variables()) - set(var_list.keys()))2062      init_rest_op = variables.initialize_variables(rest_variables)2063    with self.test_session(graph=graph) as sess:2064      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)2065      saver.restore(sess, os.path.join(test_dir, ckpt_filename))2066      # Verify that we have restored weights1 and biases1.2067      sess.run([weights1, biases1])2068      # Initialize the rest of the variables and run logits.2069      sess.run(init_rest_op)2070      sess.run(logits)2071  # Verifies that we can save the subgraph under "hidden1" and restore it2072  # into "new_hidden1" in the new graph.2073  def testScopedSaveAndRestore(self):2074    test_dir = self._get_test_dir("scoped_export_import")2075    ckpt_filename = "ckpt"2076    self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)2077    self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",2078                            "exported_new_hidden1.pbtxt", ckpt_filename)2079  # Verifies that we can copy the subgraph under "hidden1" and copy it2080  # to different name scope in the same graph or different graph.2081  def testCopyScopedGraph(self):2082    test_dir = self._get_test_dir("scoped_copy")2083    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")2084    graph1 = ops_lib.Graph()2085    with graph1.as_default():2086      with ops_lib.name_scope("hidden1"):2087        images = constant_op.constant(2088            1.0, dtypes.float32, shape=[3, 2], name="images")2089        weights1 = variables.Variable(2090            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")2091        biases1 = variables.Variable([0.1] * 3, name="biases")2092        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")2093    # Run the graph and save scoped checkpoint.2094    with self.test_session(graph=graph1) as sess:2095      sess.run(variables.global_variables_initializer())2096      _, var_list_1 = meta_graph.export_scoped_meta_graph(2097          export_scope="hidden1")2098      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)2099      saver.save(sess, saver0_ckpt, write_state=False)2100    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))2101    # Verifies copy to the same graph with the same name fails.2102    with graph1.as_default():2103      with self.assertRaisesWithPredicateMatch(2104          ValueError, lambda e: "need to be different" in str(e)):2105        meta_graph.copy_scoped_meta_graph(2106            from_scope="hidden1", to_scope="hidden1")2107    # Verifies copy to the same graph.2108    with graph1.as_default():2109      var_list_2 = meta_graph.copy_scoped_meta_graph(2110          from_scope="hidden1", to_scope="hidden2")2111    with self.test_session(graph=graph1) as sess:2112      saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)2113      saver1.restore(sess, saver0_ckpt)2114      saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)2115      saver2.restore(sess, saver0_ckpt)2116      self.assertAllClose(expected, sess.run("hidden1/relu:0"))2117      self.assertAllClose(expected, sess.run("hidden2/relu:0"))2118    # Verifies copy to differen graph.2119    graph2 = ops_lib.Graph()2120    new_var_list_1 = meta_graph.copy_scoped_meta_graph(2121        from_scope="hidden1",2122        to_scope="new_hidden1",2123        from_graph=graph1,2124        to_graph=graph2)2125    with self.test_session(graph=graph2) as sess:2126      saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)2127      saver3.restore(sess, saver0_ckpt)2128      self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))2129  def testExportGraphDefWithScope(self):2130    test_dir = self._get_test_dir("export_graph_def")2131    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")2132    graph1 = ops_lib.Graph()2133    with graph1.as_default():2134      with ops_lib.name_scope("hidden1"):2135        images = constant_op.constant(2136            1.0, dtypes.float32, shape=[3, 2], name="images")2137        weights1 = variables.Variable(2138            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")2139        biases1 = variables.Variable([0.1] * 3, name="biases")2140        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")2141    # Run the graph and save scoped checkpoint.2142    with self.test_session(graph=graph1) as sess:2143      sess.run(variables.global_variables_initializer())2144      _, var_list_1 = meta_graph.export_scoped_meta_graph(2145          graph_def=graph1.as_graph_def(), export_scope="hidden1")2146      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)2147      saver.save(sess, saver0_ckpt, write_state=False)2148    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))2149    # Verifies that we can run successfully after restoring.2150    graph2 = ops_lib.Graph()2151    new_var_list_1 = meta_graph.copy_scoped_meta_graph(2152        from_scope="hidden1",2153        to_scope="new_hidden1",2154        from_graph=graph1,2155        to_graph=graph2)2156    with self.test_session(graph=graph2) as sess:2157      saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)2158      saver3.restore(sess, saver0_ckpt)2159      self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))2160  def testSerializeSaverWithScope(self):2161    test_dir = self._get_test_dir("export_graph_def")2162    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")2163    saver2_ckpt = os.path.join(test_dir, "saver2.ckpt")2164    graph = ops_lib.Graph()2165    with graph.as_default():2166      with ops_lib.name_scope("hidden1"):2167        variable1 = variables.Variable([1.0], name="variable1")2168        saver1 = saver_module.Saver(var_list=[variable1])2169        graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)2170      with ops_lib.name_scope("hidden2"):2171        variable2 = variables.Variable([2.0], name="variable2")2172      saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")2173      graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)2174    with self.test_session(graph=graph) as sess:2175      variables.global_variables_initializer().run()2176      saver1.save(sess, saver1_ckpt, write_state=False)2177      saver2.save(sess, saver2_ckpt, write_state=False)2178    graph1 = ops_lib.Graph()2179    var_dict1 = meta_graph.copy_scoped_meta_graph(2180        from_scope="hidden1",2181        to_scope="new_hidden1",2182        from_graph=graph,2183        to_graph=graph1)2184    self.assertEqual(1, len(var_dict1))2185    saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)2186    self.assertEqual(1, len(saver_list1))2187    with self.test_session(graph=graph1) as sess:2188      saver_list1[0].restore(sess, saver1_ckpt)2189      self.assertEqual(1.0, var_dict1["variable1:0"].eval())2190    graph2 = ops_lib.Graph()2191    var_dict2 = meta_graph.copy_scoped_meta_graph(2192        from_scope="hidden2",2193        to_scope="new_hidden2",2194        from_graph=graph,2195        to_graph=graph2)2196    self.assertEqual(1, len(var_dict2))2197    saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)2198    self.assertEqual(1, len(saver_list2))2199    with self.test_session(graph=graph2) as sess:2200      saver_list2[0].restore(sess, saver2_ckpt)2201      self.assertEqual(2.0, var_dict2["variable2:0"].eval())2202# TODO(b/64763924): Remove after Jan 1st 2018.2203class LenientNamesTest(test.TestCase):2204  def setUp(self):2205    super(LenientNamesTest, self).setUp()2206    os.putenv("TF_SAVER_LENIENT_NAMES", "True")2207  def tearDown(self):2208    os.putenv("TF_SAVER_LENIENT_NAMES", "")2209    super(LenientNamesTest, self).tearDown()2210  def testSaveRestore(self):2211    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")2212    # Build a graph with 2 parameter nodes, and Save and2213    # Restore nodes for them.2214    v0 = variables.Variable(10.0, name="v0")2215    v1 = variables.Variable(20.0, name="v1")2216    v2 = saver_test_utils.CheckpointedOp(name="v2")2217    v2_init = v2.insert("k1", 30.0)2218    save = saver_module.Saver(2219        {2220            "v0:0": v0,2221            "v1": v1,2222            "v2": v2.saveable2223        }, restore_sequentially=True)2224    init_all_op = [variables.global_variables_initializer(), v2_init]2225    with self.test_session() as sess:2226      sess.run(init_all_op)2227      save.save(sess, save_path)2228    with self.test_session() as sess:2229      v0 = variables.Variable(-1.0, name="v0")2230      v1 = variables.Variable(-1.0, name="v1")2231      v2 = saver_test_utils.CheckpointedOp(name="v2")2232      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})2233      save.restore(sess, save_path)2234      # Check that the parameter nodes have been restored.2235      self.assertEqual(10.0, v0.eval())2236      self.assertEqual(20.0, v1.eval())2237      self.assertEqual(b"k1", v2.keys().eval())2238      self.assertEqual(30.0, v2.values().eval())2239if __name__ == "__main__":...whole_game_work_in_progress.py
Source:whole_game_work_in_progress.py  
1import pygame2from buildings import Factory3from buildings import Windturbine4from buildings import Cleaning_station56pygame.init()78screen = pygame.display.set_mode((800, 600))9background_color = (101, 56, 24)10screen.fill(background_color)11clock = pygame.time.Clock()12counter = 013income = 0141516def button(screen, position, text):17    font = pygame.font.SysFont("Arial", 50)18    text_render = font.render(text, 1, (0, 0, 0))19    x, y, w, h = text_render.get_rect()20    x, y = position21    pygame.draw.line(screen, (50, 50, 50), (x, y + h), (x + w, y + h), 5)22    pygame.draw.line(screen, (50, 50, 50), (x + w, y + h), [x + w, y], 5)23    pygame.draw.rect(screen, (128, 128, 128), (x, y, w, h))24    return screen.blit(text_render, (x, y))252627def button2(screen, position, text, color):28    global text_render, Button_Img29    font = pygame.font.SysFont("Arial", 24)30    if color == "blue":31        text_render = font.render(text, 1, (0, 0, 255))32    elif color == "green":33        text_render = font.render(text, 1, (0, 255, 0))34    elif color == "red":35        text_render = font.render(text, 1, (255, 0, 0))36    elif color == "pink":37        text_render = font.render(text, 1, (255, 182, 193))38    x, y, w, h = text_render.get_rect()39    x, y = position40    pygame.draw.line(screen, (50, 50, 50), (x, y + h), (x + w, y + h), 5)41    pygame.draw.line(screen, (50, 50, 50), (x + w, y + h), [x + w, y], 5)42    pygame.draw.rect(screen, (100, 100, 100), (x, y, w, h))43    return screen.blit(text_render, (x, y))444546def buildings_menu():47    global buy_factory, buy_cleaning, buy_windturbine, upgrade48    buy_factory = button2(screen, (500, 100), "     BUY FACTORY: 1000   ", "red")49    buy_cleaning = button2(screen, (500, 150), " BUY CL. STATION: 3500  ", "blue")50    buy_windturbine = button2(screen, (500, 200), "     BUY WINDMILL: 2000  ", "green")51    upgrade = button2(screen, (500, 300), "       UPGRADE: 10000       ", "red")525354def buy_menu():55    global buy_land56    buy_land = button2(screen, (500, 250), "               BUY: 1500             ", "pink")575859def menu():60    global buy_factory, buy_cleaning, buy_windturbine, buy_land, counter, upgrade, clock, income, save_i, save_j, Button_Img61    factory = Factory(1000, 0.03, 0, 1.5)62    windturbine = Windturbine(2000, 0, 0, 0.5)63    cleaning_station = Cleaning_station(3500, 0.055, 0, 0)64    color = (101, 56, 24)65    money = 7000066    counter_factory = 067    counter_windturbine = 068    counter_cleaning_station = 069    counter_factory_upgrade = 070    counter_cleaning_station_upgrade = 071    counter_windturbine_upgrade = 072    polution = 073    b = [[11, 12, 5, 2], [15, 6, 10, 3], [10, 8, 12, 6], [12, 15, 8, 69], [12, 15, 8, 69]]74    bought = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]7576    def display_text():77        font = pygame.font.SysFont("Stencil", 40)7879        bar_color = (119, 136, 153)80        pygame.draw.rect(screen, color, pygame.Rect(610, 35, 170, 45))81        pygame.display.flip()82        score_display = font.render(f"MONEY : {money}", 1, (8, 255, 8))83        screen.blit(score_display, (450, 35))84        pygame.draw.rect(screen, bar_color, [80, 30, polution * 3, 30])8586        pygame.display.update()8788    for i in range(4):  # i e redove89        for j in range(4):  # j e koloni90            if i == 1:91                b[i][j] = button(screen, ((90 * (j + 1)), 100), "      ")  # first one x,second one y92            elif i == 2:93                b[i][j] = button(screen, ((90 * (j + 1)), 200), "      ")94            elif i == 3:95                b[i][j] = button(screen, ((90 * (j + 1)), 300), "      ")96            else:97                b[i][j] = button(screen, ((90 * (j + 1)), 400), "      ")98    running = 099    while running == 0:100        mx, my = pygame.mouse.get_pos()101        money += income102        if counter_cleaning_station > 0 or counter_cleaning_station_upgrade > 0 and polution > 0:103            polution = polution + factory.polution * (counter_factory + counter_factory_upgrade) - (cleaning_station.polution * counter_cleaning_station + 3 * cleaning_station.polution * counter_cleaning_station_upgrade)104            polution = round(polution, 2)105        elif counter_cleaning_station == 0 and counter_cleaning_station_upgrade == 0:106            polution = polution + factory.polution * (counter_factory + counter_factory_upgrade)107        elif polution <= 0:108            polution = polution + 0109110        display_text()111112        if polution > 100000000:113            income = 0114            money = 0115            running = 1116117        for event in pygame.event.get():118            if event.type == pygame.QUIT:119                pygame.quit()120            if event.type == pygame.MOUSEBUTTONDOWN:121                for i in range(4):122                    for j in range(4):123                        if b[i][j].collidepoint(mx, my):124                            buy_menu()125                            buildings_menu()126                            save_i = i127                            save_j = j128                            pass129130                if buy_land.collidepoint(mx, my) and money >= 1500 and bought[save_i][save_j] == 0:131                    pygame.draw.rect(screen, (255, 182, 193), b[save_i][save_j])132                    money = money - 1500133                    bought[save_i][save_j] = 10134                if buy_factory.collidepoint(mx, my) and money >= factory.cost and bought[save_i][save_j] == 10 and bought[save_i][save_j] != 11 and bought[save_i][save_j] != 12 and bought[save_i][save_j] != 13 and bought[save_i][save_j] != 21 and bought[save_i][save_j] != 22 and bought[save_i][save_j] != 23:135                    pygame.draw.rect(screen, (255, 0, 0), b[save_i][save_j])136                    money = money - factory.cost137                    counter_factory = counter_factory + 1138                    bought[save_i][save_j] = 11139                    Button_Img = pygame.image.load("buildings/factory.png")140                    if save_j <= 3:141                        screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))142                    elif save_i == 3:143                        screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))144                if buy_cleaning.collidepoint(mx, my) and money >= cleaning_station.cost and bought[save_i][save_j] == 10 and bought[save_i][save_j] != 11 and bought[save_i][save_j] != 12 and bought[save_i][save_j] != 13 and bought[save_i][save_j] != 21 and bought[save_i][save_j] != 22 and bought[save_i][save_j] != 23:145                    pygame.draw.rect(screen, (173, 216, 230), b[save_i][save_j])146                    money = money - cleaning_station.cost147                    counter_cleaning_station = counter_cleaning_station + 1148                    bought[save_i][save_j] = 12149                    Button_Img = pygame.image.load("buildings/cleaning.png")150                    if save_j <= 2:151                        screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))152                    elif save_i == 3:153                        screen.blit(Button_Img, (90 * (save_j + 1), 400))154                if buy_windturbine.collidepoint(mx, my) and money >= windturbine.cost and bought[save_i][save_j] == 10 and bought[save_i][save_j] != 11 and bought[save_i][save_j] != 12 and bought[save_i][save_j] != 13 and bought[save_i][save_j] != 21 and bought[save_i][save_j] != 22 and bought[save_i][save_j] != 23:155                    pygame.draw.rect(screen, (0, 255, 0), b[save_i][save_j])156                    money = money - windturbine.cost157                    counter_windturbine = counter_windturbine + 1158                    bought[save_i][save_j] = 13159                    Button_Img = pygame.image.load("buildings/windturbine.jpeg")160                    if save_j <= 2:161                        screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))162                    elif save_j == 3:163                        screen.blit(Button_Img, (90 * (save_j + 1), 400))164                if upgrade.collidepoint(mx, my) and money >= 10000 and bought[save_i][save_j] == 11:165                    pygame.draw.rect(screen, (139, 0, 0), b[save_i][save_j])166                    money = money - 10000167                    counter_factory_upgrade = counter_factory_upgrade + 1168                    counter_factory = counter_factory - 1169                    bought[save_i][save_j] = 21170                    Button_Img = pygame.image.load("buildings/factory_upgrade.png")171                    if save_j <= 2:172                        screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))173                    elif save_j == 3:174                        screen.blit(Button_Img, (90 * (save_j + 1), 400))175                if upgrade.collidepoint(mx, my) and money >= 10000 and bought[save_i][save_j] == 12:176                    pygame.draw.rect(screen, (0, 0, 255), b[save_i][save_j])177                    money = money - 10000178                    counter_cleaning_station_upgrade = counter_cleaning_station_upgrade + 1179                    counter_cleaning_station = counter_cleaning_station - 1180                    bought[save_i][save_j] = 22181                if upgrade.collidepoint(mx, my) and money >= 10000 and bought[save_i][save_j] == 13:182                    pygame.draw.rect(screen, (149, 255, 128), b[save_i][save_j])183                    money = money - 10000184                    counter_windturbine_upgrade = counter_windturbine_upgrade + 1185                    counter_windturbine = counter_windturbine - 1186                    bought[save_i][save_j] = 23187                    Button_Img = pygame.image.load("buildings/windturbine_upgrade.jpg")188                    if save_j <= 2:189                        screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))190                    elif save_i == 3:191                        screen.blit(Button_Img, (90 * (save_j + 1), 400))192                else:193                    pass194195        income = factory.income * counter_factory + windturbine.income * counter_windturbine + 1.5 * counter_factory_upgrade * factory.income + 1.5 * counter_windturbine_upgrade * windturbine.income196        money = round(money, 0)197        clock.tick(60)198199
...populate_database.py
Source:populate_database.py  
...33                 python_module='tools.Setup_workdir.setupWorkDir',34                 hidden=True,35                 tool_folder_name="Setup_workdir"36                 )37module1.save()38module2 = Module(name="Untar",39                 type='0',40                 filter='.*(\.tar)',41                 form=[{"type":"checkbox", "label":"Verbose", "identifier":"verbose", "value": "-v"}],42                 description="Untar all .tar fieles into the current folder",43                 tool_folder_name="Untar_archive",44                 command="tar xf #file -C #workdir #verbose ",45                 hidden=True46                 )47module2.save()48module3 = Module(name="ClamAV",49                 type='0',50                 form=[{"type":"checkbox", "label":"Only show infected files", "identifier":"only_found", "value":"-i"},{"type":"checkbox", "label":"Remove infected files", "identifier":"remove", "value":"--remove"}],51                 # command='[{"value":"clamscan","type":"text"},{"type":"text","value":"-r"},{"value":"-i","type":"var","name":"only_found"},{"value":"--remove","type":"var","name":"remove"},{"type":"var","name":"workdir"}]',52                 command="clamscan -r #only_found #remove #workdir",53                 tool_folder_name="ClamAV"54                 )55module3.save()56module8 = Module(name="DROID",57                 type='3',58                 form=[],59                command="/run.sh \"#file\"",60                filter='.*',61                tool_folder_name="SMART_DROID",62                docker_mount_point="/workdir",63                resultFilter=[{"type": "Containing","value": "[\\w\\W]*Missmatch: \"false\"[\\w\\W]*"},{"type": "Not containing","value": "[\\w\\W]*Missmatch: \"true\"[\\w\\W]*"}]64                 )65module8.save()66module11 = Module(name="Unoconv",67                 type='3',68                 form=[],69                command="UNOPATH=/usr/lib/Libreoffice /usr/bin/python3 /usr/bin/unoconv -f pdf -e SelectPdfVersion=1 \"#file\"",70                # command="seq -s= 100000|tr -d '[:digit:]'",71                filter='.*(\\.(doc|docx))$',72                tool_folder_name="SMART_UNOCONV",73                docker_mount_point="/workdir",74                parallell_jobs=675                 )76module11.save()77module12 = Module(name="Verapdf",78                 type='3',79                 form=[],80                command="verapdf -f 1a \"#file\"",81                filter='.*(\\.pdf)',82                tool_folder_name="SMART_VERAPDF",83                docker_mount_point="/workdir",84                resultFilter=[{"type":"Containing", "value": "[\\w\\W]*compliant=\"1\"[\\w\\W]*"}],85                parallell_jobs=686                 )87module12.save()88# setup default templates89template1 = Template(name="Default Start", hidden=True)90template1.save()91process2 = Process(order=1,92                   template=template1,93                   module=module2,94                   value={"verbose": True}95                   )96process2.save()97template2 = Template(name="Default Done", hidden=True)98template2.save()99template3 = Template(name="Empty template")100template3.save()101template4 = Template(name="Convert pdf")102template4.save()103process = Process(order=0,104                   template=template4,105                   module=module11) # unoconv106process.save()107process = Process(order=1,108                   template=template4,109                   module=module12) # verapdf110process.save()111# create default variables112var = Variable(name="total_number_of_files", data="0")113var.save()114var = Variable(name="total_size", data="0")115var.save()116var = Variable(name="total_number_of_packages", data="0")117var.save()118var = Variable(name="total_number_of_errors", data="0")119var.save()120# system variables121var = Variable(name="work_dir_path", data="/code/workdir")122var.save()123var = Variable(name="packages_path", data="/code/packages")124var.save()125var = Variable(name="tools_path", data="/code/tools")126var.save()127var = Variable(name="premis_file_name", data="log/app_log.xml")128var.save()129var = Variable(name="work_dir_path_host", data="/Users/axenu/Projects/Sydarkivera/APP/workdir")130var.save()131var = Variable(name="premis_template_path", data="/code/templates/premis.json")132var.save()133var = Variable(name="premis_event_template_path", data="/code/templates/premisEvent.json")134var.save()135# default test data, TODO remove in production136# ftype = FileType(name="PDF", errors=3, total=100, size=1203000)137# ftype.save()138# ftype = FileType(name="JPG", errors=33, total=10, size=12033400)139# ftype.save()140# ftype = FileType(name="XML", errors=0, total=43, size=120340)141# ftype.save()142# ftype = FileType(name="XSD", errors=0, total=12, size=120300123)143# ftype.save()144# ftype = FileType(name="sdf", errors=0, total=12, size=120300123)145# ftype.save()146# ftype = FileType(name="Xdh4eSD", errors=0, total=12, size=120300123)147# ftype.save()148# ftype = FileType(name="fgdh", errors=0, total=12, size=120300123)149# ftype.save()150# ftype = FileType(name="wert", errors=0, total=12, size=120300123)151# ftype.save()152# graph = GraphData(date=(datetime.date.today() - datetime.timedelta(days=21)), size=1110000000, count=37954)153# graph.save()154# graph = GraphData(date=(datetime.date.today() - datetime.timedelta(days=14)), size=5834400000, count=9754)155# graph.save()156# graph = GraphData(date=(datetime.date.today() - datetime.timedelta(days=7)), size=2340000000, count=751)157# graph.save()158# graph = GraphData(date=datetime.date.today(), size=300000000, count=3452)159# graph.save()160#create default docker images161# image = DockerImage(name="vera_pdf", mountpoint="/workdir", label="verapdf")162# image.save()163# module12.dockerImage = image164# module12.save()165image = DockerImage(name="axenu/app-worker-droid", mountpoint="/workdir", label="Droid")166image.save()167module8.dockerImage = image168module8.save()169image = DockerImage(name="axenu/app-worker-unoconv", mountpoint="/workdir", label="Unoconv")170image.save()171module11.dockerImage = image172module11.save()173image = DockerImage(name="axenu/app-worker-verapdf", mountpoint="/workdir", label="VeraPDF")174image.save()175module12.dockerImage = image176module12.save()177# create default admin users178User.objects.all().delete()179user = User.objects.create_user('admin', 'simon@axenu.com', 'admin')180user.is_superuser = True181user.is_staff = True182# user.role = 2183user.save()184print('populate_databse finished')185#create new package.186# package1 = Package(name="demo paket 1", path="/Users/axenu/Sydarkivera/toolbox/paket/af268c33-5ba8-4af5-9a44-039b10126835.tar", file_name="af268c33-5ba8-4af5-9a44-039b10126835.tar", status=0)187# package1.save()188# create some processes189# process1 = Process(order=1, package=package1, module=module1, value='{}')190# process1.save()191# process2 = Process(order=2, package=package1, module=module2, value='{}')...urls.py
Source:urls.py  
1from django.urls import path, include2from . import views3from .import HodViews, StaffViews, StudentViews4urlpatterns = [5    path('', views.loginPage, name="login"),6    # path('accounts/', include('django.contrib.auth.urls')),7    path('doLogin/', views.doLogin, name="doLogin"),8    path('get_user_details/', views.get_user_details, name="get_user_details"),9    path('logout_user/', views.logout_user, name="logout_user"),10    path('admin_home/', HodViews.admin_home, name="admin_home"),11    path('add_staff/', HodViews.add_staff, name="add_staff"),12    path('add_staff_save/', HodViews.add_staff_save, name="add_staff_save"),13    path('manage_staff/', HodViews.manage_staff, name="manage_staff"),14    path('edit_staff/<staff_id>/', HodViews.edit_staff, name="edit_staff"),15    path('edit_staff_save/', HodViews.edit_staff_save, name="edit_staff_save"),16    path('delete_staff/<staff_id>/', HodViews.delete_staff, name="delete_staff"),17    path('add_course/', HodViews.add_course, name="add_course"),18    path('add_course_save/', HodViews.add_course_save, name="add_course_save"),19    path('manage_course/', HodViews.manage_course, name="manage_course"),20    path('edit_course/<course_id>/', HodViews.edit_course, name="edit_course"),21    path('edit_course_save/', HodViews.edit_course_save, name="edit_course_save"),22    path('delete_course/<course_id>/', HodViews.delete_course, name="delete_course"),23    path('manage_session/', HodViews.manage_session, name="manage_session"),24    path('add_session/', HodViews.add_session, name="add_session"),25    path('add_session_save/', HodViews.add_session_save, name="add_session_save"),26    path('edit_session/<session_id>', HodViews.edit_session, name="edit_session"),27    path('edit_session_save/', HodViews.edit_session_save, name="edit_session_save"),28    path('delete_session/<session_id>/', HodViews.delete_session, name="delete_session"),29    path('add_student/', HodViews.add_student, name="add_student"),30    path('add_student_save/', HodViews.add_student_save, name="add_student_save"),31    path('edit_student/<student_id>', HodViews.edit_student, name="edit_student"),32    path('edit_student_save/', HodViews.edit_student_save, name="edit_student_save"),33    path('manage_student/', HodViews.manage_student, name="manage_student"),34    path('delete_student/<student_id>/', HodViews.delete_student, name="delete_student"),35    path('add_subject/', HodViews.add_subject, name="add_subject"),36    path('add_subject_save/', HodViews.add_subject_save, name="add_subject_save"),37    path('manage_subject/', HodViews.manage_subject, name="manage_subject"),38    path('edit_subject/<subject_id>/', HodViews.edit_subject, name="edit_subject"),39    path('edit_subject_save/', HodViews.edit_subject_save, name="edit_subject_save"),40    path('delete_subject/<subject_id>/', HodViews.delete_subject, name="delete_subject"),41    path('check_email_exist/', HodViews.check_email_exist, name="check_email_exist"),42    path('check_username_exist/', HodViews.check_username_exist, name="check_username_exist"),43    path('student_feedback_message/', HodViews.student_feedback_message, name="student_feedback_message"),44    path('student_feedback_message_reply/', HodViews.student_feedback_message_reply, name="student_feedback_message_reply"),45    path('staff_feedback_message/', HodViews.staff_feedback_message, name="staff_feedback_message"),46    path('staff_feedback_message_reply/', HodViews.staff_feedback_message_reply, name="staff_feedback_message_reply"),47    path('student_leave_view/', HodViews.student_leave_view, name="student_leave_view"),48    path('student_leave_approve/<leave_id>/', HodViews.student_leave_approve, name="student_leave_approve"),49    path('student_leave_reject/<leave_id>/', HodViews.student_leave_reject, name="student_leave_reject"),50    path('staff_leave_view/', HodViews.staff_leave_view, name="staff_leave_view"),51    path('staff_leave_approve/<leave_id>/', HodViews.staff_leave_approve, name="staff_leave_approve"),52    path('staff_leave_reject/<leave_id>/', HodViews.staff_leave_reject, name="staff_leave_reject"),53    path('admin_view_attendance/', HodViews.admin_view_attendance, name="admin_view_attendance"),54    path('admin_get_attendance_dates/', HodViews.admin_get_attendance_dates, name="admin_get_attendance_dates"),55    path('admin_get_attendance_student/', HodViews.admin_get_attendance_student, name="admin_get_attendance_student"),56    path('admin_profile/', HodViews.admin_profile, name="admin_profile"),57    path('admin_profile_update/', HodViews.admin_profile_update, name="admin_profile_update"),58    path('add_project/', HodViews.add_project, name="add_project"),59    path('add_project_save/', HodViews.add_project_save, name="add_project_save"),60    path('manage_project/', HodViews.manage_project, name="manage_project"),61    path('edit_project/<project_id>/', HodViews.edit_project, name="edit_project"),62    path('edit_project_save/', HodViews.edit_project_save, name="edit_project_save"),63    path('delete_project/<project_id>/', HodViews.delete_project, name="delete_project"),64    65    # URLS for Staff66    path('staff_home/', StaffViews.staff_home, name="staff_home"),67    path('staff_take_attendance/', StaffViews.staff_take_attendance, name="staff_take_attendance"),68    path('get_students/', StaffViews.get_students, name="get_students"),69    path('save_attendance_data/', StaffViews.save_attendance_data, name="save_attendance_data"),70    path('staff_update_attendance/', StaffViews.staff_update_attendance, name="staff_update_attendance"),71    path('get_attendance_dates/', StaffViews.get_attendance_dates, name="get_attendance_dates"),72    path('get_attendance_student/', StaffViews.get_attendance_student, name="get_attendance_student"),73    path('update_attendance_data/', StaffViews.update_attendance_data, name="update_attendance_data"),74    path('staff_apply_leave/', StaffViews.staff_apply_leave, name="staff_apply_leave"),75    path('staff_apply_leave_save/', StaffViews.staff_apply_leave_save, name="staff_apply_leave_save"),76    path('staff_feedback/', StaffViews.staff_feedback, name="staff_feedback"),77    path('staff_feedback_save/', StaffViews.staff_feedback_save, name="staff_feedback_save"),78    path('staff_profile/', StaffViews.staff_profile, name="staff_profile"),79    path('staff_profile_update/', StaffViews.staff_profile_update, name="staff_profile_update"),80    path('staff_add_result/', StaffViews.staff_add_result, name="staff_add_result"),81    path('staff_add_result_save/', StaffViews.staff_add_result_save, name="staff_add_result_save"),82    path('staff_project_view/', StaffViews.project_view, name="project_view"),83    path('staff_eachproject_view/<subject_id>/', StaffViews.eachproject_view, name="eachproject_view"),84    # URSL for Student85    path('student_home/', StudentViews.student_home, name="student_home"),86    path('student_view_attendance/', StudentViews.student_view_attendance, name="student_view_attendance"),87    path('student_view_attendance_post/', StudentViews.student_view_attendance_post, name="student_view_attendance_post"),88    path('student_apply_leave/', StudentViews.student_apply_leave, name="student_apply_leave"),89    path('student_apply_leave_save/', StudentViews.student_apply_leave_save, name="student_apply_leave_save"),90    path('student_feedback/', StudentViews.student_feedback, name="student_feedback"),91    path('student_feedback_save/', StudentViews.student_feedback_save, name="student_feedback_save"),92    path('student_profile/', StudentViews.student_profile, name="student_profile"),93    path('student_profile_update/', StudentViews.student_profile_update, name="student_profile_update"),94    path('student_view_result/', StudentViews.student_view_result, name="student_view_result"),...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
