Best Python code snippet using pytest
basic_session_run_hooks_test.py
Source:basic_session_run_hooks_test.py  
1# pylint: disable=g-bad-file-header2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8#     http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and14# limitations under the License.15# ==============================================================================16"""Tests for basic_session_run_hooks."""17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20import shutil21import tempfile22import threading23import time24from tensorflow.contrib.framework.python.framework import checkpoint_utils25from tensorflow.contrib.framework.python.ops import variables26from tensorflow.contrib.testing.python.framework import fake_summary_writer27from tensorflow.python.client import session as session_lib28from tensorflow.python.framework import constant_op29from tensorflow.python.framework import dtypes30from tensorflow.python.framework import meta_graph31from tensorflow.python.framework import ops32from tensorflow.python.ops import array_ops33from tensorflow.python.ops import control_flow_ops34from tensorflow.python.ops import state_ops35from tensorflow.python.ops import variable_scope36from tensorflow.python.ops import variables as variables_lib37import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import38from tensorflow.python.platform import test39from tensorflow.python.platform import tf_logging40from tensorflow.python.summary import summary as summary_lib41from tensorflow.python.summary.writer import writer_cache42from tensorflow.python.training import basic_session_run_hooks43from tensorflow.python.training import monitored_session44from tensorflow.python.training import session_run_hook45class MockCheckpointSaverListener(46    basic_session_run_hooks.CheckpointSaverListener):47  def __init__(self):48    self.begin_count = 049    self.before_save_count = 050    self.after_save_count = 051    self.end_count = 052  def begin(self):53    self.begin_count += 154  def before_save(self, session, global_step):55    self.before_save_count += 156  def after_save(self, session, global_step):57    self.after_save_count += 158  def end(self, session, global_step):59    self.end_count += 160  def get_counts(self):61    return {62        'begin': self.begin_count,63        'before_save': self.before_save_count,64        'after_save': self.after_save_count,65        'end': self.end_count66    }67class SecondOrStepTimerTest(test.TestCase):68  def test_raise_in_both_secs_and_steps(self):69    with self.assertRaises(ValueError):70      basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10)71  def test_raise_in_none_secs_and_steps(self):72    with self.assertRaises(ValueError):73      basic_session_run_hooks.SecondOrStepTimer()74  def test_every_secs(self):75    timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)76    self.assertTrue(timer.should_trigger_for_step(1))77    timer.update_last_triggered_step(1)78    self.assertFalse(timer.should_trigger_for_step(1))79    self.assertFalse(timer.should_trigger_for_step(2))80    time.sleep(1.0)81    self.assertFalse(timer.should_trigger_for_step(1))82    self.assertTrue(timer.should_trigger_for_step(2))83  def test_every_steps(self):84    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3)85    self.assertTrue(timer.should_trigger_for_step(1))86    timer.update_last_triggered_step(1)87    self.assertFalse(timer.should_trigger_for_step(1))88    self.assertFalse(timer.should_trigger_for_step(2))89    self.assertFalse(timer.should_trigger_for_step(3))90    self.assertTrue(timer.should_trigger_for_step(4))91  def test_update_last_triggered_step(self):92    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1)93    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1)94    self.assertEqual(None, elapsed_secs)95    self.assertEqual(None, elapsed_steps)96    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5)97    self.assertLess(0, elapsed_secs)98    self.assertEqual(4, elapsed_steps)99    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7)100    self.assertLess(0, elapsed_secs)101    self.assertEqual(2, elapsed_steps)102class StopAtStepTest(test.TestCase):103  def test_raise_in_both_last_step_and_num_steps(self):104    with self.assertRaises(ValueError):105      basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20)106  def test_stop_based_on_last_step(self):107    h = basic_session_run_hooks.StopAtStepHook(last_step=10)108    with ops.Graph().as_default():109      global_step = variables.get_or_create_global_step()110      no_op = control_flow_ops.no_op()111      h.begin()112      with session_lib.Session() as sess:113        mon_sess = monitored_session._HookedSession(sess, [h])114        sess.run(state_ops.assign(global_step, 5))115        h.after_create_session(sess, None)116        mon_sess.run(no_op)117        self.assertFalse(mon_sess.should_stop())118        sess.run(state_ops.assign(global_step, 9))119        mon_sess.run(no_op)120        self.assertFalse(mon_sess.should_stop())121        sess.run(state_ops.assign(global_step, 10))122        mon_sess.run(no_op)123        self.assertTrue(mon_sess.should_stop())124        sess.run(state_ops.assign(global_step, 11))125        mon_sess._should_stop = False126        mon_sess.run(no_op)127        self.assertTrue(mon_sess.should_stop())128  def test_stop_based_on_num_step(self):129    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)130    with ops.Graph().as_default():131      global_step = variables.get_or_create_global_step()132      no_op = control_flow_ops.no_op()133      h.begin()134      with session_lib.Session() as sess:135        mon_sess = monitored_session._HookedSession(sess, [h])136        sess.run(state_ops.assign(global_step, 5))137        h.after_create_session(sess, None)138        mon_sess.run(no_op)139        self.assertFalse(mon_sess.should_stop())140        sess.run(state_ops.assign(global_step, 13))141        mon_sess.run(no_op)142        self.assertFalse(mon_sess.should_stop())143        sess.run(state_ops.assign(global_step, 14))144        mon_sess.run(no_op)145        self.assertFalse(mon_sess.should_stop())146        sess.run(state_ops.assign(global_step, 15))147        mon_sess.run(no_op)148        self.assertTrue(mon_sess.should_stop())149        sess.run(state_ops.assign(global_step, 16))150        mon_sess._should_stop = False151        mon_sess.run(no_op)152        self.assertTrue(mon_sess.should_stop())153  def test_stop_based_with_multiple_steps(self):154    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)155    with ops.Graph().as_default():156      global_step = variables.get_or_create_global_step()157      no_op = control_flow_ops.no_op()158      h.begin()159      with session_lib.Session() as sess:160        mon_sess = monitored_session._HookedSession(sess, [h])161        sess.run(state_ops.assign(global_step, 5))162        h.after_create_session(sess, None)163        mon_sess.run(no_op)164        self.assertFalse(mon_sess.should_stop())165        sess.run(state_ops.assign(global_step, 15))166        mon_sess.run(no_op)167        self.assertTrue(mon_sess.should_stop())168class LoggingTensorHookTest(test.TestCase):169  def setUp(self):170    # Mock out logging calls so we can verify whether correct tensors are being171    # monitored.172    self._actual_log = tf_logging.info173    self.logged_message = None174    def mock_log(*args, **kwargs):175      self.logged_message = args176      self._actual_log(*args, **kwargs)177    tf_logging.info = mock_log178  def tearDown(self):179    tf_logging.info = self._actual_log180  def test_illegal_args(self):181    with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):182      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0)183    with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):184      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10)185    with self.assertRaisesRegexp(ValueError, 'xactly one of'):186      basic_session_run_hooks.LoggingTensorHook(187          tensors=['t'], every_n_iter=5, every_n_secs=5)188    with self.assertRaisesRegexp(ValueError, 'xactly one of'):189      basic_session_run_hooks.LoggingTensorHook(tensors=['t'])190  def test_print_at_end_only(self):191    with ops.Graph().as_default(), session_lib.Session() as sess:192      t = constant_op.constant(42.0, name='foo')193      train_op = constant_op.constant(3)194      hook = basic_session_run_hooks.LoggingTensorHook(195          tensors=[t.name], at_end=True)196      hook.begin()197      mon_sess = monitored_session._HookedSession(sess, [hook])198      sess.run(variables_lib.global_variables_initializer())199      self.logged_message = ''200      for _ in range(3):201        mon_sess.run(train_op)202        # assertNotRegexpMatches is not supported by python 3.1 and later203        self.assertEqual(str(self.logged_message).find(t.name), -1)204      hook.end(sess)205      self.assertRegexpMatches(str(self.logged_message), t.name)206  def _validate_print_every_n_steps(self, sess, at_end):207    t = constant_op.constant(42.0, name='foo')208    train_op = constant_op.constant(3)209    hook = basic_session_run_hooks.LoggingTensorHook(210        tensors=[t.name], every_n_iter=10, at_end=at_end)211    hook.begin()212    mon_sess = monitored_session._HookedSession(sess, [hook])213    sess.run(variables_lib.global_variables_initializer())214    mon_sess.run(train_op)215    self.assertRegexpMatches(str(self.logged_message), t.name)216    for _ in range(3):217      self.logged_message = ''218      for _ in range(9):219        mon_sess.run(train_op)220        # assertNotRegexpMatches is not supported by python 3.1 and later221        self.assertEqual(str(self.logged_message).find(t.name), -1)222      mon_sess.run(train_op)223      self.assertRegexpMatches(str(self.logged_message), t.name)224    # Add additional run to verify proper reset when called multiple times.225    self.logged_message = ''226    mon_sess.run(train_op)227    # assertNotRegexpMatches is not supported by python 3.1 and later228    self.assertEqual(str(self.logged_message).find(t.name), -1)229    self.logged_message = ''230    hook.end(sess)231    if at_end:232      self.assertRegexpMatches(str(self.logged_message), t.name)233    else:234      # assertNotRegexpMatches is not supported by python 3.1 and later235      self.assertEqual(str(self.logged_message).find(t.name), -1)236  def test_print_every_n_steps(self):237    with ops.Graph().as_default(), session_lib.Session() as sess:238      self._validate_print_every_n_steps(sess, at_end=False)239      # Verify proper reset.240      self._validate_print_every_n_steps(sess, at_end=False)241  def test_print_every_n_steps_and_end(self):242    with ops.Graph().as_default(), session_lib.Session() as sess:243      self._validate_print_every_n_steps(sess, at_end=True)244      # Verify proper reset.245      self._validate_print_every_n_steps(sess, at_end=True)246  def test_print_first_step(self):247    # if it runs every iteration, first iteration has None duration.248    with ops.Graph().as_default(), session_lib.Session() as sess:249      t = constant_op.constant(42.0, name='foo')250      train_op = constant_op.constant(3)251      hook = basic_session_run_hooks.LoggingTensorHook(252          tensors={'foo': t}, every_n_iter=1)253      hook.begin()254      mon_sess = monitored_session._HookedSession(sess, [hook])255      sess.run(variables_lib.global_variables_initializer())256      mon_sess.run(train_op)257      self.assertRegexpMatches(str(self.logged_message), 'foo')258      # in first run, elapsed time is None.259      self.assertEqual(str(self.logged_message).find('sec'), -1)260  def _validate_print_every_n_secs(self, sess, at_end):261    t = constant_op.constant(42.0, name='foo')262    train_op = constant_op.constant(3)263    hook = basic_session_run_hooks.LoggingTensorHook(264        tensors=[t.name], every_n_secs=1.0, at_end=at_end)265    hook.begin()266    mon_sess = monitored_session._HookedSession(sess, [hook])267    sess.run(variables_lib.global_variables_initializer())268    mon_sess.run(train_op)269    self.assertRegexpMatches(str(self.logged_message), t.name)270    # assertNotRegexpMatches is not supported by python 3.1 and later271    self.logged_message = ''272    mon_sess.run(train_op)273    self.assertEqual(str(self.logged_message).find(t.name), -1)274    time.sleep(1.0)275    self.logged_message = ''276    mon_sess.run(train_op)277    self.assertRegexpMatches(str(self.logged_message), t.name)278    self.logged_message = ''279    hook.end(sess)280    if at_end:281      self.assertRegexpMatches(str(self.logged_message), t.name)282    else:283      # assertNotRegexpMatches is not supported by python 3.1 and later284      self.assertEqual(str(self.logged_message).find(t.name), -1)285  def test_print_every_n_secs(self):286    with ops.Graph().as_default(), session_lib.Session() as sess:287      self._validate_print_every_n_secs(sess, at_end=False)288      # Verify proper reset.289      self._validate_print_every_n_secs(sess, at_end=False)290  def test_print_every_n_secs_and_end(self):291    with ops.Graph().as_default(), session_lib.Session() as sess:292      self._validate_print_every_n_secs(sess, at_end=True)293      # Verify proper reset.294      self._validate_print_every_n_secs(sess, at_end=True)295  def test_print_formatter(self):296    with ops.Graph().as_default(), session_lib.Session() as sess:297      t = constant_op.constant(42.0, name='foo')298      train_op = constant_op.constant(3)299      hook = basic_session_run_hooks.LoggingTensorHook(300          tensors=[t.name], every_n_iter=10,301          formatter=lambda items: 'qqq=%s' % items[t.name])302      hook.begin()303      mon_sess = monitored_session._HookedSession(sess, [hook])304      sess.run(variables_lib.global_variables_initializer())305      mon_sess.run(train_op)306      self.assertEqual(self.logged_message[0], 'qqq=42.0')307class CheckpointSaverHookTest(test.TestCase):308  def setUp(self):309    self.model_dir = tempfile.mkdtemp()310    self.graph = ops.Graph()311    with self.graph.as_default():312      self.scaffold = monitored_session.Scaffold()313      self.global_step = variables.get_or_create_global_step()314      self.train_op = state_ops.assign_add(self.global_step, 1)315  def tearDown(self):316    shutil.rmtree(self.model_dir, ignore_errors=True)317  def test_raise_when_saver_and_scaffold_both_missing(self):318    with self.assertRaises(ValueError):319      basic_session_run_hooks.CheckpointSaverHook(self.model_dir)320  def test_raise_when_saver_and_scaffold_both_present(self):321    with self.assertRaises(ValueError):322      basic_session_run_hooks.CheckpointSaverHook(323          self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)324  def test_raise_in_both_secs_and_steps(self):325    with self.assertRaises(ValueError):326      basic_session_run_hooks.CheckpointSaverHook(327          self.model_dir, save_secs=10, save_steps=20)328  def test_raise_in_none_secs_and_steps(self):329    with self.assertRaises(ValueError):330      basic_session_run_hooks.CheckpointSaverHook(self.model_dir)331  def test_save_secs_saves_in_first_step(self):332    with self.graph.as_default():333      hook = basic_session_run_hooks.CheckpointSaverHook(334          self.model_dir, save_secs=2, scaffold=self.scaffold)335      hook.begin()336      self.scaffold.finalize()337      with session_lib.Session() as sess:338        sess.run(self.scaffold.init_op)339        mon_sess = monitored_session._HookedSession(sess, [hook])340        mon_sess.run(self.train_op)341        self.assertEqual(1,342                         checkpoint_utils.load_variable(self.model_dir,343                                                        self.global_step.name))344  def test_save_secs_calls_listeners_at_begin_and_end(self):345    with self.graph.as_default():346      listener = MockCheckpointSaverListener()347      hook = basic_session_run_hooks.CheckpointSaverHook(348          self.model_dir,349          save_secs=2,350          scaffold=self.scaffold,351          listeners=[listener])352      hook.begin()353      self.scaffold.finalize()354      with session_lib.Session() as sess:355        sess.run(self.scaffold.init_op)356        mon_sess = monitored_session._HookedSession(sess, [hook])357        mon_sess.run(self.train_op)  # hook runs here358        mon_sess.run(self.train_op)  # hook won't run here, so it does at end359        hook.end(sess)  # hook runs here360      self.assertEqual({361          'begin': 1,362          'before_save': 2,363          'after_save': 2,364          'end': 1365      }, listener.get_counts())366  def test_listener_with_monitored_session(self):367    with ops.Graph().as_default():368      scaffold = monitored_session.Scaffold()369      global_step = variables.get_or_create_global_step()370      train_op = state_ops.assign_add(global_step, 1)371      listener = MockCheckpointSaverListener()372      hook = basic_session_run_hooks.CheckpointSaverHook(373          self.model_dir,374          save_steps=1,375          scaffold=scaffold,376          listeners=[listener])377      with monitored_session.SingularMonitoredSession(378          hooks=[hook],379          scaffold=scaffold,380          checkpoint_dir=self.model_dir) as sess:381        sess.run(train_op)382        sess.run(train_op)383        global_step_val = sess.run(global_step)384      listener_counts = listener.get_counts()385    self.assertEqual(2, global_step_val)386    self.assertEqual({387        'begin': 1,388        'before_save': 2,389        'after_save': 2,390        'end': 1391    }, listener_counts)392  def test_listener_with_default_saver(self):393    with ops.Graph().as_default():394      global_step = variables.get_or_create_global_step()395      train_op = state_ops.assign_add(global_step, 1)396      listener = MockCheckpointSaverListener()397      hook = basic_session_run_hooks.CheckpointSaverHook(398          self.model_dir,399          save_steps=1,400          listeners=[listener])401      with monitored_session.SingularMonitoredSession(402          hooks=[hook],403          checkpoint_dir=self.model_dir) as sess:404        sess.run(train_op)405        sess.run(train_op)406        global_step_val = sess.run(global_step)407      listener_counts = listener.get_counts()408    self.assertEqual(2, global_step_val)409    self.assertEqual({410        'begin': 1,411        'before_save': 2,412        'after_save': 2,413        'end': 1414    }, listener_counts)415    with ops.Graph().as_default():416      global_step = variables.get_or_create_global_step()417      with monitored_session.SingularMonitoredSession(418          checkpoint_dir=self.model_dir) as sess2:419        global_step_saved_val = sess2.run(global_step)420    self.assertEqual(2, global_step_saved_val)421  def test_two_listeners_with_default_saver(self):422    with ops.Graph().as_default():423      global_step = variables.get_or_create_global_step()424      train_op = state_ops.assign_add(global_step, 1)425      listener1 = MockCheckpointSaverListener()426      listener2 = MockCheckpointSaverListener()427      hook = basic_session_run_hooks.CheckpointSaverHook(428          self.model_dir,429          save_steps=1,430          listeners=[listener1, listener2])431      with monitored_session.SingularMonitoredSession(432          hooks=[hook],433          checkpoint_dir=self.model_dir) as sess:434        sess.run(train_op)435        sess.run(train_op)436        global_step_val = sess.run(global_step)437      listener1_counts = listener1.get_counts()438      listener2_counts = listener2.get_counts()439    self.assertEqual(2, global_step_val)440    self.assertEqual({441        'begin': 1,442        'before_save': 2,443        'after_save': 2,444        'end': 1445    }, listener1_counts)446    self.assertEqual(listener1_counts, listener2_counts)447    with ops.Graph().as_default():448      global_step = variables.get_or_create_global_step()449      with monitored_session.SingularMonitoredSession(450          checkpoint_dir=self.model_dir) as sess2:451        global_step_saved_val = sess2.run(global_step)452    self.assertEqual(2, global_step_saved_val)453  @test.mock.patch('time.time')454  def test_save_secs_saves_periodically(self, mock_time):455    # Let's have a realistic start time456    current_time = 1484695987.209386457    with self.graph.as_default():458      mock_time.return_value = current_time459      hook = basic_session_run_hooks.CheckpointSaverHook(460          self.model_dir, save_secs=2, scaffold=self.scaffold)461      hook.begin()462      self.scaffold.finalize()463      with session_lib.Session() as sess:464        sess.run(self.scaffold.init_op)465        mon_sess = monitored_session._HookedSession(sess, [hook])466        mock_time.return_value = current_time467        mon_sess.run(self.train_op)  # Saved.468        mock_time.return_value = current_time + 0.5469        mon_sess.run(self.train_op)  # Not saved.470        self.assertEqual(1,471                         checkpoint_utils.load_variable(self.model_dir,472                                                        self.global_step.name))473        # Simulate 2.5 seconds of sleep.474        mock_time.return_value = current_time + 2.5475        mon_sess.run(self.train_op)  # Saved.476        mock_time.return_value = current_time + 2.6477        mon_sess.run(self.train_op)  # Not saved.478        mock_time.return_value = current_time + 2.7479        mon_sess.run(self.train_op)  # Not saved.480        self.assertEqual(3,481                         checkpoint_utils.load_variable(self.model_dir,482                                                        self.global_step.name))483        # Simulate 7.5 more seconds of sleep (10 seconds from start.484        mock_time.return_value = current_time + 10485        mon_sess.run(self.train_op)  # Saved.486        self.assertEqual(6,487                         checkpoint_utils.load_variable(self.model_dir,488                                                        self.global_step.name))489  # Flaky because of time.sleep()490  def DISABLED_test_save_secs_calls_listeners_periodically(self):491    with self.graph.as_default():492      listener = MockCheckpointSaverListener()493      hook = basic_session_run_hooks.CheckpointSaverHook(494          self.model_dir,495          save_secs=2,496          scaffold=self.scaffold,497          listeners=[listener])498      hook.begin()499      self.scaffold.finalize()500      with session_lib.Session() as sess:501        sess.run(self.scaffold.init_op)502        mon_sess = monitored_session._HookedSession(sess, [hook])503        mon_sess.run(self.train_op)  # hook runs here504        mon_sess.run(self.train_op)505        time.sleep(2.5)506        mon_sess.run(self.train_op)  # hook runs here507        mon_sess.run(self.train_op)508        mon_sess.run(self.train_op)509        time.sleep(2.5)510        mon_sess.run(self.train_op)  # hook runs here511        mon_sess.run(self.train_op)  # hook won't run here, so it does at end512        hook.end(sess)  # hook runs here513      self.assertEqual({514          'begin': 1,515          'before_save': 4,516          'after_save': 4,517          'end': 1518      }, listener.get_counts())519  def test_save_steps_saves_in_first_step(self):520    with self.graph.as_default():521      hook = basic_session_run_hooks.CheckpointSaverHook(522          self.model_dir, save_steps=2, scaffold=self.scaffold)523      hook.begin()524      self.scaffold.finalize()525      with session_lib.Session() as sess:526        sess.run(self.scaffold.init_op)527        mon_sess = monitored_session._HookedSession(sess, [hook])528        mon_sess.run(self.train_op)529        self.assertEqual(1,530                         checkpoint_utils.load_variable(self.model_dir,531                                                        self.global_step.name))532  def test_save_steps_saves_periodically(self):533    with self.graph.as_default():534      hook = basic_session_run_hooks.CheckpointSaverHook(535          self.model_dir, save_steps=2, scaffold=self.scaffold)536      hook.begin()537      self.scaffold.finalize()538      with session_lib.Session() as sess:539        sess.run(self.scaffold.init_op)540        mon_sess = monitored_session._HookedSession(sess, [hook])541        mon_sess.run(self.train_op)542        mon_sess.run(self.train_op)543        # Not saved544        self.assertEqual(1,545                         checkpoint_utils.load_variable(self.model_dir,546                                                        self.global_step.name))547        mon_sess.run(self.train_op)548        # saved549        self.assertEqual(3,550                         checkpoint_utils.load_variable(self.model_dir,551                                                        self.global_step.name))552        mon_sess.run(self.train_op)553        # Not saved554        self.assertEqual(3,555                         checkpoint_utils.load_variable(self.model_dir,556                                                        self.global_step.name))557        mon_sess.run(self.train_op)558        # saved559        self.assertEqual(5,560                         checkpoint_utils.load_variable(self.model_dir,561                                                        self.global_step.name))562  def test_save_saves_at_end(self):563    with self.graph.as_default():564      hook = basic_session_run_hooks.CheckpointSaverHook(565          self.model_dir, save_secs=2, scaffold=self.scaffold)566      hook.begin()567      self.scaffold.finalize()568      with session_lib.Session() as sess:569        sess.run(self.scaffold.init_op)570        mon_sess = monitored_session._HookedSession(sess, [hook])571        mon_sess.run(self.train_op)572        mon_sess.run(self.train_op)573        hook.end(sess)574        self.assertEqual(2,575                         checkpoint_utils.load_variable(self.model_dir,576                                                        self.global_step.name))577  def test_summary_writer_defs(self):578    fake_summary_writer.FakeSummaryWriter.install()579    writer_cache.FileWriterCache.clear()580    summary_writer = writer_cache.FileWriterCache.get(self.model_dir)581    with self.graph.as_default():582      hook = basic_session_run_hooks.CheckpointSaverHook(583          self.model_dir, save_steps=2, scaffold=self.scaffold)584      hook.begin()585      self.scaffold.finalize()586      with session_lib.Session() as sess:587        sess.run(self.scaffold.init_op)588        mon_sess = monitored_session._HookedSession(sess, [hook])589        mon_sess.run(self.train_op)590      summary_writer.assert_summaries(591          test_case=self,592          expected_logdir=self.model_dir,593          expected_added_meta_graphs=[594              meta_graph.create_meta_graph_def(595                  graph_def=self.graph.as_graph_def(add_shapes=True),596                  saver_def=self.scaffold.saver.saver_def)597          ])598    fake_summary_writer.FakeSummaryWriter.uninstall()599class ResourceCheckpointSaverHookTest(test.TestCase):600  def setUp(self):601    self.model_dir = tempfile.mkdtemp()602    self.graph = ops.Graph()603    with self.graph.as_default():604      self.scaffold = monitored_session.Scaffold()605      with variable_scope.variable_scope('foo', use_resource=True):606        self.global_step = variables.get_or_create_global_step()607      self.train_op = state_ops.assign_add(self.global_step, 1)608  def test_save_steps_saves_periodically(self):609    with self.graph.as_default():610      hook = basic_session_run_hooks.CheckpointSaverHook(611          self.model_dir, save_steps=2, scaffold=self.scaffold)612      hook.begin()613      self.scaffold.finalize()614      with session_lib.Session() as sess:615        sess.run(self.scaffold.init_op)616        mon_sess = monitored_session._HookedSession(sess, [hook])617        mon_sess.run(self.train_op)618        mon_sess.run(self.train_op)619        # Not saved620        self.assertEqual(1,621                         checkpoint_utils.load_variable(self.model_dir,622                                                        self.global_step.name))623        mon_sess.run(self.train_op)624        # saved625        self.assertEqual(3,626                         checkpoint_utils.load_variable(self.model_dir,627                                                        self.global_step.name))628        mon_sess.run(self.train_op)629        # Not saved630        self.assertEqual(3,631                         checkpoint_utils.load_variable(self.model_dir,632                                                        self.global_step.name))633        mon_sess.run(self.train_op)634        # saved635        self.assertEqual(5,636                         checkpoint_utils.load_variable(self.model_dir,637                                                        self.global_step.name))638class StepCounterHookTest(test.TestCase):639  def setUp(self):640    self.log_dir = tempfile.mkdtemp()641  def tearDown(self):642    shutil.rmtree(self.log_dir, ignore_errors=True)643  def test_step_counter_every_n_steps(self):644    with ops.Graph().as_default() as g, session_lib.Session() as sess:645      global_step = variables.get_or_create_global_step()646      train_op = state_ops.assign_add(global_step, 1)647      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)648      hook = basic_session_run_hooks.StepCounterHook(649          summary_writer=summary_writer, every_n_steps=10)650      hook.begin()651      sess.run(variables_lib.global_variables_initializer())652      mon_sess = monitored_session._HookedSession(sess, [hook])653      for _ in range(30):654        time.sleep(0.01)655        mon_sess.run(train_op)656      hook.end(sess)657      summary_writer.assert_summaries(658          test_case=self,659          expected_logdir=self.log_dir,660          expected_graph=g,661          expected_summaries={})662      self.assertItemsEqual([11, 21], summary_writer.summaries.keys())663      for step in [11, 21]:664        summary_value = summary_writer.summaries[step][0].value[0]665        self.assertEqual('global_step/sec', summary_value.tag)666        self.assertGreater(summary_value.simple_value, 0)667  def test_step_counter_every_n_secs(self):668    with ops.Graph().as_default() as g, session_lib.Session() as sess:669      global_step = variables.get_or_create_global_step()670      train_op = state_ops.assign_add(global_step, 1)671      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)672      hook = basic_session_run_hooks.StepCounterHook(673          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)674      hook.begin()675      sess.run(variables_lib.global_variables_initializer())676      mon_sess = monitored_session._HookedSession(sess, [hook])677      mon_sess.run(train_op)678      time.sleep(0.2)679      mon_sess.run(train_op)680      time.sleep(0.2)681      mon_sess.run(train_op)682      hook.end(sess)683      summary_writer.assert_summaries(684          test_case=self,685          expected_logdir=self.log_dir,686          expected_graph=g,687          expected_summaries={})688      self.assertTrue(summary_writer.summaries, 'No summaries were created.')689      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())690      for summary in summary_writer.summaries.values():691        summary_value = summary[0].value[0]692        self.assertEqual('global_step/sec', summary_value.tag)693        self.assertGreater(summary_value.simple_value, 0)694  def test_global_step_name(self):695    with ops.Graph().as_default() as g, session_lib.Session() as sess:696      with variable_scope.variable_scope('bar'):697        foo_step = variable_scope.get_variable(698            'foo',699            initializer=0,700            trainable=False,701            collections=[702                ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES703            ])704      train_op = state_ops.assign_add(foo_step, 1)705      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)706      hook = basic_session_run_hooks.StepCounterHook(707          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)708      hook.begin()709      sess.run(variables_lib.global_variables_initializer())710      mon_sess = monitored_session._HookedSession(sess, [hook])711      mon_sess.run(train_op)712      mon_sess.run(train_op)713      hook.end(sess)714      summary_writer.assert_summaries(715          test_case=self,716          expected_logdir=self.log_dir,717          expected_graph=g,718          expected_summaries={})719      self.assertTrue(summary_writer.summaries, 'No summaries were created.')720      self.assertItemsEqual([2], summary_writer.summaries.keys())721      summary_value = summary_writer.summaries[2][0].value[0]722      self.assertEqual('bar/foo/sec', summary_value.tag)723class SummarySaverHookTest(test.TestCase):724  def setUp(self):725    test.TestCase.setUp(self)726    self.log_dir = 'log/dir'727    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)728    var = variables_lib.Variable(0.0)729    tensor = state_ops.assign_add(var, 1.0)730    tensor2 = tensor * 2731    self.summary_op = summary_lib.scalar('my_summary', tensor)732    self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)733    global_step = variables.get_or_create_global_step()734    self.train_op = state_ops.assign_add(global_step, 1)735  def test_raise_when_scaffold_and_summary_op_both_missing(self):736    with self.assertRaises(ValueError):737      basic_session_run_hooks.SummarySaverHook()738  def test_raise_when_scaffold_and_summary_op_both_present(self):739    with self.assertRaises(ValueError):740      basic_session_run_hooks.SummarySaverHook(741          scaffold=monitored_session.Scaffold(), summary_op=self.summary_op)742  def test_raise_in_both_secs_and_steps(self):743    with self.assertRaises(ValueError):744      basic_session_run_hooks.SummarySaverHook(745          save_secs=10, save_steps=20, summary_writer=self.summary_writer)746  def test_raise_in_none_secs_and_steps(self):747    with self.assertRaises(ValueError):748      basic_session_run_hooks.SummarySaverHook(749          save_secs=None, save_steps=None, summary_writer=self.summary_writer)750  def test_save_steps(self):751    hook = basic_session_run_hooks.SummarySaverHook(752        save_steps=8,753        summary_writer=self.summary_writer,754        summary_op=self.summary_op)755    with self.test_session() as sess:756      hook.begin()757      sess.run(variables_lib.global_variables_initializer())758      mon_sess = monitored_session._HookedSession(sess, [hook])759      for _ in range(30):760        mon_sess.run(self.train_op)761      hook.end(sess)762    self.summary_writer.assert_summaries(763        test_case=self,764        expected_logdir=self.log_dir,765        expected_summaries={766            1: {767                'my_summary': 1.0768            },769            9: {770                'my_summary': 2.0771            },772            17: {773                'my_summary': 3.0774            },775            25: {776                'my_summary': 4.0777            },778        })779  def test_multiple_summaries(self):780    hook = basic_session_run_hooks.SummarySaverHook(781        save_steps=8,782        summary_writer=self.summary_writer,783        summary_op=[self.summary_op, self.summary_op2])784    with self.test_session() as sess:785      hook.begin()786      sess.run(variables_lib.global_variables_initializer())787      mon_sess = monitored_session._HookedSession(sess, [hook])788      for _ in range(10):789        mon_sess.run(self.train_op)790      hook.end(sess)791    self.summary_writer.assert_summaries(792        test_case=self,793        expected_logdir=self.log_dir,794        expected_summaries={795            1: {796                'my_summary': 1.0,797                'my_summary2': 2.0798            },799            9: {800                'my_summary': 2.0,801                'my_summary2': 4.0802            },803        })804  def test_save_secs_saving_once_every_step(self):805    hook = basic_session_run_hooks.SummarySaverHook(806        save_secs=0.5,807        summary_writer=self.summary_writer,808        summary_op=self.summary_op)809    with self.test_session() as sess:810      hook.begin()811      sess.run(variables_lib.global_variables_initializer())812      mon_sess = monitored_session._HookedSession(sess, [hook])813      for _ in range(4):814        mon_sess.run(self.train_op)815        time.sleep(0.5)816      hook.end(sess)817    self.summary_writer.assert_summaries(818        test_case=self,819        expected_logdir=self.log_dir,820        expected_summaries={821            1: {822                'my_summary': 1.0823            },824            2: {825                'my_summary': 2.0826            },827            3: {828                'my_summary': 3.0829            },830            4: {831                'my_summary': 4.0832            },833        })834  def test_save_secs_saving_once_every_three_steps(self):835    hook = basic_session_run_hooks.SummarySaverHook(836        save_secs=0.9,837        summary_writer=self.summary_writer,838        summary_op=self.summary_op)839    with self.test_session() as sess:840      hook.begin()841      sess.run(variables_lib.global_variables_initializer())842      mon_sess = monitored_session._HookedSession(sess, [hook])843      for _ in range(8):844        mon_sess.run(self.train_op)845        time.sleep(0.3)846      hook.end(sess)847    self.summary_writer.assert_summaries(848        test_case=self,849        expected_logdir=self.log_dir,850        expected_summaries={851            1: {852                'my_summary': 1.0853            },854            4: {855                'my_summary': 2.0856            },857            7: {858                'my_summary': 3.0859            },860        })861class GlobalStepWaiterHookTest(test.TestCase):862  def test_not_wait_for_step_zero(self):863    with ops.Graph().as_default():864      variables.get_or_create_global_step()865      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)866      hook.begin()867      with session_lib.Session() as sess:868        # Before run should return without waiting gstep increment.869        hook.before_run(870            session_run_hook.SessionRunContext(871                original_args=None, session=sess))872  def test_wait_for_step(self):873    with ops.Graph().as_default():874      gstep = variables.get_or_create_global_step()875      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)876      hook.begin()877      with session_lib.Session() as sess:878        sess.run(variables_lib.global_variables_initializer())879        waiter = threading.Thread(880            target=hook.before_run,881            args=(session_run_hook.SessionRunContext(882                original_args=None, session=sess),))883        waiter.daemon = True884        waiter.start()885        time.sleep(1.0)886        self.assertTrue(waiter.is_alive())887        sess.run(state_ops.assign(gstep, 500))888        time.sleep(1.0)889        self.assertTrue(waiter.is_alive())890        sess.run(state_ops.assign(gstep, 1100))891        time.sleep(1.2)892        self.assertFalse(waiter.is_alive())893class FinalOpsHookTest(test.TestCase):894  def test_final_ops_is_scalar_tensor(self):895    with ops.Graph().as_default():896      expected_value = 4897      final_ops = constant_op.constant(expected_value)898      hook = basic_session_run_hooks.FinalOpsHook(final_ops)899      hook.begin()900      with session_lib.Session() as session:901        hook.end(session)902        self.assertEqual(expected_value,903                         hook.final_ops_values)904  def test_final_ops_is_tensor(self):905    with ops.Graph().as_default():906      expected_values = [1, 6, 3, 5, 2, 4]907      final_ops = constant_op.constant(expected_values)908      hook = basic_session_run_hooks.FinalOpsHook(final_ops)909      hook.begin()910      with session_lib.Session() as session:911        hook.end(session)912        self.assertListEqual(expected_values,913                             hook.final_ops_values.tolist())914  def test_final_ops_with_dictionary(self):915    with ops.Graph().as_default():916      expected_values = [4, -3]917      final_ops = array_ops.placeholder(dtype=dtypes.float32)918      final_ops_feed_dict = {final_ops: expected_values}919      hook = basic_session_run_hooks.FinalOpsHook(920          final_ops, final_ops_feed_dict)921      hook.begin()922      with session_lib.Session() as session:923        hook.end(session)924        self.assertListEqual(expected_values,925                             hook.final_ops_values.tolist())926class ResourceSummarySaverHookTest(test.TestCase):927  def setUp(self):928    test.TestCase.setUp(self)929    self.log_dir = 'log/dir'930    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)931    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)932    tensor = state_ops.assign_add(var, 1.0)933    self.summary_op = summary_lib.scalar('my_summary', tensor)934    with variable_scope.variable_scope('foo', use_resource=True):935      global_step = variables.get_or_create_global_step()936    self.train_op = state_ops.assign_add(global_step, 1)937  def test_save_steps(self):938    hook = basic_session_run_hooks.SummarySaverHook(939        save_steps=8,940        summary_writer=self.summary_writer,941        summary_op=self.summary_op)942    with self.test_session() as sess:943      hook.begin()944      sess.run(variables_lib.global_variables_initializer())945      mon_sess = monitored_session._HookedSession(sess, [hook])946      for _ in range(30):947        mon_sess.run(self.train_op)948      hook.end(sess)949    self.summary_writer.assert_summaries(950        test_case=self,951        expected_logdir=self.log_dir,952        expected_summaries={953            1: {954                'my_summary': 1.0955            },956            9: {957                'my_summary': 2.0958            },959            17: {960                'my_summary': 3.0961            },962            25: {963                'my_summary': 4.0964            },965        })966class FeedFnHookTest(test.TestCase):967  def test_feeding_placeholder(self):968    with ops.Graph().as_default(), session_lib.Session() as sess:969      x = array_ops.placeholder(dtype=dtypes.float32)970      y = x + 1971      hook = basic_session_run_hooks.FeedFnHook(972          feed_fn=lambda: {x: 1.0})973      hook.begin()974      mon_sess = monitored_session._HookedSession(sess, [hook])975      self.assertEqual(mon_sess.run(y), 2)976if __name__ == '__main__':...local_cli_wrapper_test.py
Source:local_cli_wrapper_test.py  
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7#     http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Unit tests for local command-line-interface debug wrapper session."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19import os20import shutil21import tempfile22from tensorflow.core.protobuf import config_pb223from tensorflow.python.client import session24from tensorflow.python.debug.cli import cli_shared25from tensorflow.python.debug.cli import debugger_cli_common26from tensorflow.python.debug.wrappers import local_cli_wrapper27from tensorflow.python.framework import constant_op28from tensorflow.python.framework import dtypes29from tensorflow.python.framework import errors30from tensorflow.python.framework import ops31from tensorflow.python.framework import test_util32from tensorflow.python.ops import array_ops33from tensorflow.python.ops import control_flow_ops34from tensorflow.python.ops import math_ops35# Import resource_variable_ops for the variables-to-tensor implicit conversion.36from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import37from tensorflow.python.ops import state_ops38from tensorflow.python.ops import variables39from tensorflow.python.platform import googletest40class LocalCLIDebuggerWrapperSessionForTest(41    local_cli_wrapper.LocalCLIDebugWrapperSession):42  """Subclasses the wrapper class for testing.43  Overrides its CLI-related methods for headless testing environments.44  Inserts observer variables for assertions.45  """46  def __init__(self,47               command_args_sequence,48               sess,49               dump_root=None):50    """Constructor of the for-test subclass.51    Args:52      command_args_sequence: (list of list of str) A list of arguments for the53        "run" command.54      sess: See the doc string of LocalCLIDebugWrapperSession.__init__.55      dump_root: See the doc string of LocalCLIDebugWrapperSession.__init__.56    """57    local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(58        self, sess, dump_root=dump_root, log_usage=False)59    self._command_args_sequence = command_args_sequence60    self._response_pointer = 061    # Observer variables.62    self.observers = {63        "debug_dumps": [],64        "tf_errors": [],65        "run_start_cli_run_numbers": [],66        "run_end_cli_run_numbers": [],67        "profiler_py_graphs": [],68        "profiler_run_metadata": [],69    }70  def _prep_cli_for_run_start(self):71    pass72  def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):73    self.observers["debug_dumps"].append(debug_dump)74    self.observers["tf_errors"].append(tf_error)75  def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):76    self.observers["profiler_py_graphs"].append(py_graph)77    self.observers["profiler_run_metadata"].append(run_metadata)78  def _launch_cli(self):79    if self._is_run_start:80      self.observers["run_start_cli_run_numbers"].append(self._run_call_count)81    else:82      self.observers["run_end_cli_run_numbers"].append(self._run_call_count)83    command_args = self._command_args_sequence[self._response_pointer]84    self._response_pointer += 185    try:86      self._run_handler(command_args)87    except debugger_cli_common.CommandLineExit as e:88      response = e.exit_token89    return response90class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):91  def setUp(self):92    self._tmp_dir = tempfile.mktemp()93    self.v = variables.Variable(10.0, name="v")94    self.w = variables.Variable(21.0, name="w")95    self.delta = constant_op.constant(1.0, name="delta")96    self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")97    self.w_int = control_flow_ops.with_dependencies(98        [self.inc_v],99        math_ops.cast(self.w, dtypes.int32, name="w_int_inner"),100        name="w_int_outer")101    self.ph = array_ops.placeholder(dtypes.float32, name="ph")102    self.xph = array_ops.transpose(self.ph, name="xph")103    self.m = constant_op.constant(104        [[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=dtypes.float32, name="m")105    self.y = math_ops.matmul(self.m, self.xph, name="y")106    self.sess = session.Session()107    # Initialize variable.108    self.sess.run(variables.global_variables_initializer())109  def tearDown(self):110    ops.reset_default_graph()111    if os.path.isdir(self._tmp_dir):112      shutil.rmtree(self._tmp_dir)113  def testConstructWrapper(self):114    local_cli_wrapper.LocalCLIDebugWrapperSession(115        session.Session(), log_usage=False)116  def testConstructWrapperWithExistingEmptyDumpRoot(self):117    os.mkdir(self._tmp_dir)118    self.assertTrue(os.path.isdir(self._tmp_dir))119    local_cli_wrapper.LocalCLIDebugWrapperSession(120        session.Session(), dump_root=self._tmp_dir, log_usage=False)121  def testConstructWrapperWithExistingNonEmptyDumpRoot(self):122    os.mkdir(self._tmp_dir)123    dir_path = os.path.join(self._tmp_dir, "foo")124    os.mkdir(dir_path)125    self.assertTrue(os.path.isdir(dir_path))126    with self.assertRaisesRegexp(127        ValueError, "dump_root path points to a non-empty directory"):128      local_cli_wrapper.LocalCLIDebugWrapperSession(129          session.Session(), dump_root=self._tmp_dir, log_usage=False)130  def testConstructWrapperWithExistingFileDumpRoot(self):131    os.mkdir(self._tmp_dir)132    file_path = os.path.join(self._tmp_dir, "foo")133    open(file_path, "a").close()  # Create the file134    self.assertTrue(os.path.isfile(file_path))135    with self.assertRaisesRegexp(ValueError, "dump_root path points to a file"):136      local_cli_wrapper.LocalCLIDebugWrapperSession(137          session.Session(), dump_root=file_path, log_usage=False)138  def testRunsUnderDebugMode(self):139    # Test command sequence: run; run; run;140    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(141        [[], [], []], self.sess, dump_root=self._tmp_dir)142    # run under debug mode twice.143    wrapped_sess.run(self.inc_v)144    wrapped_sess.run(self.inc_v)145    # Verify that the assign_add op did take effect.146    self.assertAllClose(12.0, self.sess.run(self.v))147    # Assert correct run call numbers for which the CLI has been launched at148    # run-start and run-end.149    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])150    self.assertEqual([1, 2], wrapped_sess.observers["run_end_cli_run_numbers"])151    # Verify that the dumps have been generated and picked up during run-end.152    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))153    # Verify that the TensorFlow runtime errors are picked up and in this case,154    # they should be both None.155    self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])156  def testRunsWithEmptyStringDumpRootWorks(self):157    # Test command sequence: run, run158    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(159        [[], []], self.sess, dump_root="")160    # run under debug mode.161    wrapped_sess.run(self.inc_v)162    self.assertAllClose(11.0, self.sess.run(self.v))163  def testRunInfoOutputAtRunEndIsCorrect(self):164    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(165        [[], [], []], self.sess, dump_root=self._tmp_dir)166    wrapped_sess.run(self.inc_v)167    run_info_output = wrapped_sess._run_info_handler([])168    tfdbg_logo = cli_shared.get_tfdbg_logo()169    # The run_info output in the first run() call should contain the tfdbg logo.170    self.assertEqual(tfdbg_logo.lines,171                     run_info_output.lines[:len(tfdbg_logo.lines)])172    menu = run_info_output.annotations[debugger_cli_common.MAIN_MENU_KEY]173    self.assertIn("list_tensors", menu.captions())174    wrapped_sess.run(self.inc_v)175    run_info_output = wrapped_sess._run_info_handler([])176    # The run_info output in the second run() call should NOT contain the logo.177    self.assertNotEqual(tfdbg_logo.lines,178                        run_info_output.lines[:len(tfdbg_logo.lines)])179    menu = run_info_output.annotations[debugger_cli_common.MAIN_MENU_KEY]180    self.assertIn("list_tensors", menu.captions())181  def testRunsUnderNonDebugMode(self):182    # Test command sequence: run -n; run -n; run -n;183    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(184        [["-n"], ["-n"], ["-n"]], self.sess, dump_root=self._tmp_dir)185    # run three times.186    wrapped_sess.run(self.inc_v)187    wrapped_sess.run(self.inc_v)188    wrapped_sess.run(self.inc_v)189    self.assertAllClose(13.0, self.sess.run(self.v))190    self.assertEqual([1, 2, 3],191                     wrapped_sess.observers["run_start_cli_run_numbers"])192    self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])193  def testRunsUnderNonDebugThenDebugMode(self):194    # Test command sequence: run -n; run -n; run; run;195    # Do two NON_DEBUG_RUNs, followed by DEBUG_RUNs.196    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(197        [["-n"], ["-n"], [], []], self.sess, dump_root=self._tmp_dir)198    # run three times.199    wrapped_sess.run(self.inc_v)200    wrapped_sess.run(self.inc_v)201    wrapped_sess.run(self.inc_v)202    self.assertAllClose(13.0, self.sess.run(self.v))203    self.assertEqual([1, 2, 3],204                     wrapped_sess.observers["run_start_cli_run_numbers"])205    # Here, the CLI should have been launched only under the third run,206    # because the first and second runs are NON_DEBUG.207    self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])208    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))209    self.assertEqual([None], wrapped_sess.observers["tf_errors"])210  def testRunMultipleTimesWithinLimit(self):211    # Test command sequence: run -t 3; run;212    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(213        [["-t", "3"], []], self.sess, dump_root=self._tmp_dir)214    # run three times.215    wrapped_sess.run(self.inc_v)216    wrapped_sess.run(self.inc_v)217    wrapped_sess.run(self.inc_v)218    self.assertAllClose(13.0, self.sess.run(self.v))219    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])220    self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])221    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))222    self.assertEqual([None], wrapped_sess.observers["tf_errors"])223  def testRunMultipleTimesOverLimit(self):224    # Test command sequence: run -t 3;225    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(226        [["-t", "3"]], self.sess, dump_root=self._tmp_dir)227    # run twice, which is less than the number of times specified by the228    # command.229    wrapped_sess.run(self.inc_v)230    wrapped_sess.run(self.inc_v)231    self.assertAllClose(12.0, self.sess.run(self.v))232    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])233    self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])234    self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))235    self.assertEqual([], wrapped_sess.observers["tf_errors"])236  def testRunMixingDebugModeAndMultpleTimes(self):237    # Test command sequence: run -n; run -t 2; run; run;238    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(239        [["-n"], ["-t", "2"], [], []], self.sess, dump_root=self._tmp_dir)240    # run four times.241    wrapped_sess.run(self.inc_v)242    wrapped_sess.run(self.inc_v)243    wrapped_sess.run(self.inc_v)244    wrapped_sess.run(self.inc_v)245    self.assertAllClose(14.0, self.sess.run(self.v))246    self.assertEqual([1, 2],247                     wrapped_sess.observers["run_start_cli_run_numbers"])248    self.assertEqual([3, 4], wrapped_sess.observers["run_end_cli_run_numbers"])249    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))250    self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])251  def testDebuggingMakeCallableTensorRunnerWorks(self):252    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(253        [[], []], self.sess, dump_root=self._tmp_dir)254    v = variables.Variable(42)255    tensor_runner = wrapped_sess.make_callable(v)256    self.sess.run(v.initializer)257    self.assertAllClose(42, tensor_runner())258    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))259  def testDebuggingMakeCallableTensorRunnerWithCustomRunOptionsWorks(self):260    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(261        [[], []], self.sess, dump_root=self._tmp_dir)262    a = constant_op.constant(42)263    tensor_runner = wrapped_sess.make_callable(a)264    run_options = config_pb2.RunOptions(265        trace_level=config_pb2.RunOptions.FULL_TRACE)266    run_metadata = config_pb2.RunMetadata()267    self.assertAllClose(268        42, tensor_runner(options=run_options, run_metadata=run_metadata))269    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))270    self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)271  def testDebuggingMakeCallableOperationRunnerWorks(self):272    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(273        [[], []], self.sess, dump_root=self._tmp_dir)274    v = variables.Variable(10.0)275    inc_v = state_ops.assign_add(v, 1.0)276    op_runner = wrapped_sess.make_callable(inc_v.op)277    self.sess.run(v.initializer)278    op_runner()279    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))280    self.assertEqual(11.0, self.sess.run(v))281  def testDebuggingMakeCallableRunnerWithFeedListWorks(self):282    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(283        [[], []], self.sess, dump_root=self._tmp_dir)284    ph1 = array_ops.placeholder(dtypes.float32)285    ph2 = array_ops.placeholder(dtypes.float32)286    a = math_ops.add(ph1, ph2)287    tensor_runner = wrapped_sess.make_callable(a, feed_list=[ph1, ph2])288    self.assertAllClose(42.0, tensor_runner(41.0, 1.0))289    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))290  def testRuntimeErrorShouldBeCaught(self):291    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(292        [[], []], self.sess, dump_root=self._tmp_dir)293    # Do a run that should lead to an TensorFlow runtime error.294    wrapped_sess.run(self.y, feed_dict={self.ph: [[0.0], [1.0], [2.0]]})295    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])296    self.assertEqual([1], wrapped_sess.observers["run_end_cli_run_numbers"])297    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))298    # Verify that the runtime error is caught by the wrapped session properly.299    self.assertEqual(1, len(wrapped_sess.observers["tf_errors"]))300    tf_error = wrapped_sess.observers["tf_errors"][0]301    self.assertEqual("y", tf_error.op.name)302  def testRuntimeErrorBeforeGraphExecutionIsRaised(self):303    # Use an impossible device name to cause an error before graph execution.304    with ops.device("/gpu:1337"):305      w = variables.Variable([1.0] * 10, name="w")306    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(307        [[]], self.sess, dump_root=self._tmp_dir)308    with self.assertRaisesRegexp(errors.OpError, r".*[Dd]evice.*1337.*"):309      wrapped_sess.run(w)310  def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self):311    # Test command sequence:312    #   run -f greater_than_twelve; run -f greater_than_twelve; run;313    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(314        [["-f", "v_greater_than_twelve"], ["-f", "v_greater_than_twelve"], []],315        self.sess,316        dump_root=self._tmp_dir)317    def v_greater_than_twelve(datum, tensor):318      return datum.node_name == "v" and tensor > 12.0319    wrapped_sess.add_tensor_filter("v_greater_than_twelve",320                                   v_greater_than_twelve)321    # run five times.322    wrapped_sess.run(self.inc_v)323    wrapped_sess.run(self.inc_v)324    wrapped_sess.run(self.inc_v)325    wrapped_sess.run(self.inc_v)326    wrapped_sess.run(self.inc_v)327    self.assertAllClose(15.0, self.sess.run(self.v))328    self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])329    # run-end CLI should NOT have been launched for run #2 and #3, because only330    # starting from run #4 v becomes greater than 12.0.331    self.assertEqual([4, 5], wrapped_sess.observers["run_end_cli_run_numbers"])332    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))333    self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])334  def testRunsUnderDebugModeWithWatchFnFilteringNodeNames(self):335    # Test command sequence:336    #   run --node_name_filter inc.*337    #   run --node_name_filter delta338    #   run339    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(340        [["--node_name_filter", "inc.*"], ["--node_name_filter", "delta"], []],341        self.sess, dump_root=self._tmp_dir)342    # run under debug mode twice.343    wrapped_sess.run(self.inc_v)344    wrapped_sess.run(self.inc_v)345    # Verify that the assign_add op did take effect.346    self.assertAllClose(12.0, self.sess.run(self.v))347    # Verify that the dumps have been generated and picked up during run-end.348    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))349    dumps = wrapped_sess.observers["debug_dumps"][0]350    self.assertEqual(1, dumps.size)351    self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name)352    dumps = wrapped_sess.observers["debug_dumps"][1]353    self.assertEqual(1, dumps.size)354    self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name)355  def testRunsUnderDebugModeWithWatchFnFilteringOpTypes(self):356    # Test command sequence:357    #   run --node_name_filter delta358    #   run --op_type_filter AssignAdd359    #   run360    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(361        [["--node_name_filter", "delta"],362         ["--op_type_filter", "AssignAdd"],363         []],364        self.sess, dump_root=self._tmp_dir)365    # run under debug mode twice.366    wrapped_sess.run(self.inc_v)367    wrapped_sess.run(self.inc_v)368    # Verify that the assign_add op did take effect.369    self.assertAllClose(12.0, self.sess.run(self.v))370    # Verify that the dumps have been generated and picked up during run-end.371    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))372    dumps = wrapped_sess.observers["debug_dumps"][0]373    self.assertEqual(1, dumps.size)374    self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name)375    dumps = wrapped_sess.observers["debug_dumps"][1]376    self.assertEqual(1, dumps.size)377    self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name)378  def testRunsUnderDebugModeWithWatchFnFilteringTensorDTypes(self):379    # Test command sequence:380    #   run --op_type_filter Variable.*381    #   run --dtype_filter int32382    #   run383    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(384        [["--op_type_filter", "Variable.*"],385         ["--tensor_dtype_filter", "int32"], []],386        self.sess, dump_root=self._tmp_dir)387    # run under debug mode twice.388    wrapped_sess.run(self.w_int)389    wrapped_sess.run(self.w_int)390    # Verify that the dumps have been generated and picked up during run-end.391    self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))392    dumps = wrapped_sess.observers["debug_dumps"][0]393    self.assertEqual(2, dumps.size)394    self.assertItemsEqual(395        ["v", "w"], [dumps.dumped_tensor_data[i].node_name for i in [0, 1]])396    dumps = wrapped_sess.observers["debug_dumps"][1]397    self.assertEqual(2, dumps.size)398    self.assertEqual(399        ["w_int_inner", "w_int_outer"],400        [dumps.dumped_tensor_data[i].node_name for i in [0, 1]])401  def testRunsUnderDebugModeWithWatchFnFilteringOpTypesAndTensorDTypes(self):402    # Test command sequence:403    #   run --op_type_filter Cast --dtype_filter int32404    #   run405    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(406        [["--op_type_filter", "Cast", "--tensor_dtype_filter", "int32"], []],407        self.sess, dump_root=self._tmp_dir)408    # run under debug mode twice.409    wrapped_sess.run(self.w_int)410    # Verify that the dumps have been generated and picked up during run-end.411    self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))412    dumps = wrapped_sess.observers["debug_dumps"][0]413    self.assertEqual(1, dumps.size)414    self.assertEqual("w_int_inner", dumps.dumped_tensor_data[0].node_name)415  def testRunUnderProfilerModeWorks(self):416    wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(417        [["-p"], []], self.sess)418    wrapped_sess.run(self.w_int)419    self.assertEqual(1, len(wrapped_sess.observers["profiler_run_metadata"]))420    self.assertTrue(421        wrapped_sess.observers["profiler_run_metadata"][0].step_stats)422    self.assertEqual(1, len(wrapped_sess.observers["profiler_py_graphs"]))423    self.assertIsInstance(424        wrapped_sess.observers["profiler_py_graphs"][0], ops.Graph)425if __name__ == "__main__":...backend.py
Source:backend.py  
1import asyncio, time2from collections import defaultdict3from enum import IntFlag4from util.misc import gen_uuid, EMPTY_SET, run_loop5from .user import UserService6from .auth import AuthService7from .stats import Stats8from .models import User, Group, Lst, Contact, UserStatus9from . import error, event10class Ack(IntFlag):11	Zero = 012	NAK = 113	ACK = 214	Full = 315class Backend:16	def __init__(self, loop, *, user_service = None, auth_service = None):17		self._loop = loop18		self._user_service = user_service or UserService()19		self._auth_service = auth_service or AuthService()20		self._stats = Stats()21		22		self._sc = _SessionCollection()23		# Dict[User.uuid, User]24		self._user_by_uuid = {}25		# Dict[User, UserDetail]26		self._unsynced_db = {}27		28		# Dict[chatid, Chat]29		self._chats = {}30		31		self._runners = []32		33		loop.create_task(self._sync_db())34		loop.create_task(self._clean_sessions())35		loop.create_task(self._sync_stats())36	37	def add_runner(self, runner):38		self._runners.append(runner)39	40	def run_forever(self):41		run_loop(self._loop, self._runners)42	43	def on_leave(self, sess):44		user = sess.user45		if user is None: return46		self._stats.on_logout()47		self._sc.remove_session(sess)48		if self._sc.get_sessions_by_user(user):49			# There are still other people logged in as this user,50			# so don't send offline notifications.51			return52		# User is offline, send notifications53		user.detail = None54		self._sync_contact_statuses()55		self._generic_notify(sess)56	57	def login_md5_get_salt(self, email):58		return self._user_service.get_md5_salt(email)59	60	def login_md5_verify(self, sess, email, md5_hash):61		uuid = self._user_service.login_md5(email, md5_hash)62		return self._login_common(sess, uuid, email)63	64	def login_twn_start(self, email, password):65		uuid = self._user_service.login(email, password)66		if uuid is None: return None67		return self._auth_service.create_token('nb/login', uuid)68	69	def login_twn_verify(self, sess, email, token):70		uuid = self._auth_service.pop_token('nb/login', token)71		return self._login_common(sess, uuid, email)72	73	def login_IKWIAD(self, sess, email):74		uuid = self.util_get_uuid_from_email(email)75		return self._login_common(sess, uuid, email)76	77	def _login_common(self, sess, uuid, email):78		if uuid is None: return None79		self._user_service.update_date_login(uuid)80		user = self._load_user_record(uuid)81		sess.user = user82		self._stats.on_login()83		self._stats.on_user_active(user, sess.client)84		self._sc.add_session(sess)85		user.detail = self._load_detail(user)86		return user87	88	def _load_user_record(self, uuid):89		if uuid not in self._user_by_uuid:90			user = self._user_service.get(uuid)91			if user is None: return None92			self._user_by_uuid[uuid] = user93		return self._user_by_uuid[uuid]94	95	def _load_detail(self, user):96		if user.detail: return user.detail97		return self._user_service.get_detail(user.uuid)98	99	def _generic_notify(self, sess):100		# Notify relevant `Session`s of status, name, message, media101		user = sess.user102		if user is None: return103		# TODO: This does a lot of work, iterating through _every_ session.104		# If RL is set up properly, could iterate through `user.detail.contacts`.105		for sess_other in self._sc.iter_sessions():106			if sess_other == sess: continue107			user_other = sess_other.user108			if user_other is None: continue109			if user_other.detail is None: continue110			ctc = user_other.detail.contacts.get(user.uuid)111			if ctc is None: continue112			sess_other.send_event(event.PresenceNotificationEvent(ctc))113	114	def _sync_contact_statuses(self):115		# Recompute all `Contact.status`'s116		for user in self._user_by_uuid.values():117			detail = user.detail118			if detail is None: continue119			for ctc in detail.contacts.values():120				ctc.compute_visible_status(user)121	122	def _mark_modified(self, user, *, detail = None):123		ud = user.detail or detail124		if detail: assert ud is detail125		assert ud is not None126		self._unsynced_db[user] = ud127	128	def sb_token_create(self, sess, *, extra_data = None):129		if extra_data is None:130			extra_data = {}131		extra_data['client'] = sess.client132		return self._auth_service.create_token('sb/xfr', { 'uuid': sess.user.uuid, 'extra_data': extra_data })133	134	def me_update(self, sess, fields):135		user = sess.user136		137		if 'message' in fields:138			user.status.message = fields['message']139		if 'media' in fields:140			user.status.media = fields['media']141		if 'name' in fields:142			user.status.name = fields['name']143		if 'gtc' in fields:144			user.detail.settings['gtc'] = fields['gtc']145		if 'blp' in fields:146			user.detail.settings['blp'] = fields['blp']147		if 'substatus' in fields:148			user.status.substatus = fields['substatus']149		150		self._mark_modified(user)151		self._sync_contact_statuses()152		self._generic_notify(sess)153	154	def me_group_add(self, sess, name, *, is_favorite = None):155		if len(name) > MAX_GROUP_NAME_LENGTH:156			raise error.GroupNameTooLong()157		user = sess.user158		group = Group(_gen_group_id(user.detail), name, is_favorite = is_favorite)159		user.detail.groups[group.id] = group160		self._mark_modified(user)161		return group162	163	def me_group_remove(self, sess, group_id):164		if group_id == '0':165			raise error.CannotRemoveSpecialGroup()166		user = sess.user167		try:168			del user.detail.groups[group_id]169		except KeyError:170			raise error.GroupDoesNotExist()171		for ctc in user.detail.contacts.values():172			ctc.groups.discard(group_id)173		self._mark_modified(user)174	175	def me_group_edit(self, sess, group_id, new_name, *, is_favorite = None):176		user = sess.user177		g = user.detail.groups.get(group_id)178		if g is None:179			raise error.GroupDoesNotExist()180		if new_name is not None:181			if len(new_name) > MAX_GROUP_NAME_LENGTH:182				raise error.GroupNameTooLong()183			g.new_name = new_name184		if is_favorite is not None:185			g.is_favorite = is_favorite186		self._mark_modified(user)187	188	def me_group_contact_add(self, sess, group_id, contact_uuid):189		if group_id == '0': return190		user = sess.user191		detail = user.detail192		if group_id not in detail.groups:193			raise error.GroupDoesNotExist()194		ctc = detail.contacts.get(contact_uuid)195		if ctc is None:196			raise error.ContactDoesNotExist()197		if group_id in ctc.groups:198			raise error.ContactAlreadyOnList()199		ctc.groups.add(group_id)200		self._mark_modified(user)201	202	def me_group_contact_remove(self, sess, group_id, contact_uuid):203		user = sess.user204		detail = user.detail205		ctc = detail.contacts.get(contact_uuid)206		if ctc is None:207			raise error.ContactDoesNotExist()208		if group_id not in detail.groups and group_id != '0':209			raise error.GroupDoesNotExist()210		try:211			ctc.groups.remove(group_id)212		except KeyError:213			if group_id == '0':214				raise error.ContactNotOnList()215		self._mark_modified(user)216	217	def me_contact_add(self, sess, contact_uuid, lst, name):218		ctc_head = self._load_user_record(contact_uuid)219		if ctc_head is None:220			raise error.UserDoesNotExist()221		user = sess.user222		ctc = self._add_to_list(user, ctc_head, lst, name)223		if lst is Lst.FL:224			# FL needs a matching RL on the contact225			self._add_to_list(ctc_head, user, Lst.RL, user.status.name)226			self._notify_reverse_add(sess, ctc_head)227		self._sync_contact_statuses()228		self._generic_notify(sess)229		return ctc, ctc_head230	231	def _notify_reverse_add(self, sess, user_added):232		user_adder = sess.user233		# `user_added` was added to `user_adder`'s RL234		for sess_added in self._sc.get_sessions_by_user(user_added):235			if sess_added == sess: continue236			sess_added.send_event(event.AddedToListEvent(Lst.RL, user_adder))237	238	def me_contact_edit(self, sess, contact_uuid, *, is_messenger_user = None):239		user = sess.user240		ctc = user.detail.contacts.get(contact_uuid)241		if ctc is None:242			raise error.ContactDoesNotExist()243		if is_messenger_user is not None:244			ctc.is_messenger_user = is_messenger_user245		self._mark_modified(user)246	247	def me_contact_remove(self, sess, contact_uuid, lst):248		user = sess.user249		ctc = user.detail.contacts.get(contact_uuid)250		if ctc is None:251			raise error.ContactDoesNotExist()252		if lst is Lst.FL:253			# Remove from FL254			self._remove_from_list(user, ctc.head, Lst.FL)255			# Remove matching RL256			self._remove_from_list(ctc.head, user, Lst.RL)257		else:258			assert lst is not Lst.RL259			ctc.lists &= ~lst260		self._mark_modified(user)261		self._sync_contact_statuses()262	263	def _add_to_list(self, user, ctc_head, lst, name):264		# Add `ctc_head` to `user`'s `lst`265		detail = self._load_detail(user)266		contacts = detail.contacts267		if ctc_head.uuid not in contacts:268			contacts[ctc_head.uuid] = Contact(ctc_head, set(), 0, UserStatus(name))269		ctc = contacts.get(ctc_head.uuid)270		if ctc.status.name is None:271			ctc.status.name = name272		ctc.lists |= lst273		self._mark_modified(user, detail = detail)274		return ctc275	276	def _remove_from_list(self, user, ctc_head, lst):277		# Remove `ctc_head` from `user`'s `lst`278		detail = self._load_detail(user)279		contacts = detail.contacts280		ctc = contacts.get(ctc_head.uuid)281		if ctc is None: return282		ctc.lists &= ~lst283		if not ctc.lists:284			del contacts[ctc_head.uuid]285		self._mark_modified(user, detail = detail)286	287	def me_pop_boot_others(self, sess):288		for sess_other in self._sc.get_sessions_by_user(sess.user):289			if sess is sess_other: continue290			sess_other.send_event(event.POPBootEvent())291	292	def me_pop_notify_others(self, sess):293		for sess_other in self._sc.get_sessions_by_user(sess.user):294			if sess is sess_other: continue295			sess_other.send_event(event.POPNotifyEvent())296	297	def login_xfr(self, sess, email, token):298		(user, extra_data) = self._load_user('sb/xfr', token)299		if user is None: return None300		if user.email != email: return None301		sess.user = user302		sess.client = extra_data['client']303		chat = Chat(self._stats)304		self._chats[chat.id] = chat305		chat.add_session(sess)306		return chat, extra_data307	308	def login_cal(self, sess, email, token, chatid):309		(user, extra_data) = self._load_user('sb/cal', token)310		if user is None: return None311		if user.email != email: return None312		sess.user = user313		sess.client = extra_data['client']314		chat = self._chats.get(chatid)315		if chat is None: return None316		chat.add_session(sess)317		return chat, extra_data318	319	def _load_user(self, purpose, token):320		data = self._auth_service.pop_token(purpose, token)321		if data is None: return (None, None)322		return (self._user_service.get(data['uuid']), data['extra_data'])323	324	def util_get_uuid_from_email(self, email):325		return self._user_service.get_uuid(email)326	327	def util_set_sess_token(self, sess, token):328		self._sc.set_nc_by_token(sess, token)329	330	def util_get_sess_by_token(self, token):331		return self._sc.get_nc_by_token(token)332	333	def util_get_sessions_by_user(self, user):334		return self._sc.get_sessions_by_user(user)335	336	def notify_call(self, caller_uuid, callee_email, chatid):337		caller = self._user_by_uuid.get(caller_uuid)338		if caller is None: raise error.ServerError()339		if caller.detail is None: raise error.ServerError()340		callee_uuid = self._user_service.get_uuid(callee_email)341		if callee_uuid is None: raise error.UserDoesNotExist()342		ctc = caller.detail.contacts.get(callee_uuid)343		if ctc is None:344			if callee_uuid != caller_uuid: raise error.ContactDoesNotExist()345			ctc_user = caller346		else:347			if ctc.status.is_offlineish(): raise error.ContactNotOnline()348			ctc_user = ctc.head349		ctc_sessions = self._sc.get_sessions_by_user(ctc_user)350		if not ctc_sessions: raise error.ContactNotOnline()351		352		for ctc_sess in ctc_sessions:353			extra_data = ctc_sess.state.get_sb_extra_data() or {}354			extra_data['client'] = ctc_sess.client355			token = self._auth_service.create_token('sb/cal', { 'uuid': ctc_user.uuid, 'extra_data': extra_data })356			ctc_sess.send_event(event.InvitedToChatEvent(chatid, token, caller))357	358	async def _sync_db(self):359		while True:360			await asyncio.sleep(1)361			self._sync_db_impl()362	363	def _sync_db_impl(self):364		if not self._unsynced_db: return365		try:366			users = list(self._unsynced_db.keys())[:100]367			batch = []368			for user in users:369				detail = self._unsynced_db.pop(user, None)370				if not detail: continue371				batch.append((user, detail))372			self._user_service.save_batch(batch)373		except Exception:374			import traceback375			traceback.print_exc()376	377	async def _clean_sessions(self):378		from .session import PollingSession379		while True:380			await asyncio.sleep(10)381			now = time.time()382			closed = []383			384			try:385				for sess in self._sc.iter_sessions():386					if sess.closed:387						closed.append(sess)388						continue389					if isinstance(sess, PollingSession):390						if now >= sess.time_last_connect + sess.timeout:391							sess.close()392							closed.append(sess)393			except Exception:394				import traceback395				traceback.print_exc()396			397			for sess in closed:398				self._sc.remove_session(sess)399	400	async def _sync_stats(self):401		while True:402			await asyncio.sleep(60)403			try:404				self._stats.flush()405			except Exception:406				import traceback407				traceback.print_exc()408class _SessionCollection:409	def __init__(self):410		# Set[Session]411		self._sessions = set()412		# Dict[User, Set[Session]]413		self._sessions_by_user = defaultdict(set)414		# Dict[str, Session]415		self._sess_by_token = {}416		# Dict[Session, Set[str]]417		self._tokens_by_sess = defaultdict(set)418	419	def get_sessions_by_user(self, user):420		if user not in self._sessions_by_user:421			return EMPTY_SET422		return self._sessions_by_user[user]423	424	def iter_sessions(self):425		yield from self._sessions426	427	def set_nc_by_token(self, sess, token: str):428		self._sess_by_token[token] = sess429		self._tokens_by_sess[sess].add(sess)430		self._sessions.add(sess)431	432	def get_nc_by_token(self, token: str):433		return self._sess_by_token.get(token)434	435	def add_session(self, sess):436		if sess.user:437			self._sessions_by_user[sess.user].add(sess)438		self._sessions.add(sess)439	440	def remove_session(self, sess):441		if sess in self._tokens_by_sess:442			tokens = self._tokens_by_sess.pop(sess)443			for token in tokens:444				self._sess_by_token.pop(token, None)445		self._sessions.discard(sess)446		if sess.user in self._sessions_by_user:447			self._sessions_by_user[sess.user].discard(sess)448class Chat:449	def __init__(self, stats):450		self.id = gen_uuid()451		# Dict[Session, User]452		self._users_by_sess = {}453		self._stats = stats454	455	def add_session(self, sess):456		self._users_by_sess[sess] = sess.user457	458	def send_message_to_everyone(self, sess_sender, data):459		self._stats.on_message_sent(sess_sender.user, sess_sender.client)460		self._stats.on_user_active(sess_sender.user, sess_sender.client)461		su_sender = self._users_by_sess[sess_sender]462		for sess in self._users_by_sess.keys():463			if sess == sess_sender: continue464			sess.send_event(event.ChatMessage(su_sender, data))465			self._stats.on_message_received(sess.user, sess.client)466	467	def get_roster(self, sess):468		roster = []469		for sess1, su1 in self._users_by_sess.items():470			if sess1 == sess: continue471			roster.append((sess1, su1))472		return roster473	474	def send_participant_joined(self, sess):475		for sc, _ in self.get_roster(self):476			sc.send_event(event.ChatParticipantJoined(sess))477	478	def on_leave(self, sess):479		su = self._users_by_sess.pop(sess, None)480		if su is None: return481		# Notify others that `sess` has left482		for sess1, su1 in self._users_by_sess.items():483			if sess1 == sess: continue484			sess1.send_event(event.ChatParticipantLeft(su))485def _gen_group_id(detail):486	id = 1487	s = str(id)488	while s in detail.groups:489		id += 1490		s = str(id)491	return s...gen_dsin_input.py
Source:gen_dsin_input.py  
1import os2import numpy as np3import pandas as pd4from deepctr.utils import SingleFeat5from sklearn.preprocessing import LabelEncoder, StandardScaler6from tensorflow.python.keras.preprocessing.sequence import pad_sequences7from tqdm import tqdm8from config import DSIN_SESS_COUNT, DSIN_SESS_MAX_LEN, FRAC9FRAC = FRAC10SESS_COUNT = DSIN_SESS_COUNT11def gen_sess_feature_dsin(row):12    sess_count = DSIN_SESS_COUNT13    sess_max_len = DSIN_SESS_MAX_LEN14    sess_input_dict = {}15    sess_input_length_dict = {}16    for i in range(sess_count):17        sess_input_dict['sess_' + str(i)] = {'cate_id': [], 'brand': []}18        sess_input_length_dict['sess_' + str(i)] = 019    sess_length = 020    user, time_stamp = row[1]['user'], row[1]['time_stamp']21    # sample_time = pd.to_datetime(timestamp_datetime(time_stamp ))22    if user not in user_hist_session:23        for i in range(sess_count):24            sess_input_dict['sess_' + str(i)]['cate_id'] = [0]25            sess_input_dict['sess_' + str(i)]['brand'] = [0]26            sess_input_length_dict['sess_' + str(i)] = 027        sess_length = 028    else:29        valid_sess_count = 030        last_sess_idx = len(user_hist_session[user]) - 131        for i in reversed(range(len(user_hist_session[user]))):32            cur_sess = user_hist_session[user][i]33            if cur_sess[0][2] < time_stamp:34                in_sess_count = 135                for j in range(1, len(cur_sess)):36                    if cur_sess[j][2] < time_stamp:37                        in_sess_count += 138                if in_sess_count > 2:39                    sess_input_dict['sess_0']['cate_id'] = [e[0] for e in cur_sess[max(0,40                                                                                       in_sess_count - sess_max_len):in_sess_count]]41                    sess_input_dict['sess_0']['brand'] = [e[1] for e in42                                                          cur_sess[max(0, in_sess_count - sess_max_len):in_sess_count]]43                    sess_input_length_dict['sess_0'] = min(44                        sess_max_len, in_sess_count)45                    last_sess_idx = i46                    valid_sess_count += 147                    break48        for i in range(1, sess_count):49            if last_sess_idx - i >= 0:50                cur_sess = user_hist_session[user][last_sess_idx - i]51                sess_input_dict['sess_' + str(i)]['cate_id'] = [e[0]52                                                                for e in cur_sess[-sess_max_len:]]53                sess_input_dict['sess_' + str(i)]['brand'] = [e[1]54                                                              for e in cur_sess[-sess_max_len:]]55                sess_input_length_dict['sess_' +56                                       str(i)] = min(sess_max_len, len(cur_sess))57                valid_sess_count += 158            else:59                sess_input_dict['sess_' + str(i)]['cate_id'] = [0]60                sess_input_dict['sess_' + str(i)]['brand'] = [0]61                sess_input_length_dict['sess_' + str(i)] = 062        sess_length = valid_sess_count63    return sess_input_dict, sess_input_length_dict, sess_length64if __name__ == "__main__":65    user_hist_session = {}66    FILE_NUM = len(67        list(filter(lambda x: x.startswith('user_hist_session_' + str(FRAC) + '_dsin_'),68                    os.listdir('../sampled_data/'))))69    print('total', FILE_NUM, 'files')70    for i in range(FILE_NUM):71        user_hist_session_ = pd.read_pickle(72            '../sampled_data/user_hist_session_' + str(FRAC) + '_dsin_' + str(i) + '.pkl')  # 19,3473        user_hist_session.update(user_hist_session_)74        del user_hist_session_75    sample_sub = pd.read_pickle(76        '../sampled_data/raw_sample_' + str(FRAC) + '.pkl')77    index_list = []78    sess_input_dict = {}79    sess_input_length_dict = {}80    for i in range(SESS_COUNT):81        sess_input_dict['sess_' + str(i)] = {'cate_id': [], 'brand': []}82        sess_input_length_dict['sess_' + str(i)] = []83    sess_length_list = []84    for row in tqdm(sample_sub[['user', 'time_stamp']].iterrows()):85        sess_input_dict_, sess_input_length_dict_, sess_length = gen_sess_feature_dsin(86            row)87        # index_list.append(index)88        for i in range(SESS_COUNT):89            sess_name = 'sess_' + str(i)90            sess_input_dict[sess_name]['cate_id'].append(91                sess_input_dict_[sess_name]['cate_id'])92            sess_input_dict[sess_name]['brand'].append(93                sess_input_dict_[sess_name]['brand'])94            sess_input_length_dict[sess_name].append(95                sess_input_length_dict_[sess_name])96        sess_length_list.append(sess_length)97    print('done')98    user = pd.read_pickle('../sampled_data/user_profile_' + str(FRAC) + '.pkl')99    ad = pd.read_pickle('../sampled_data/ad_feature_enc_' + str(FRAC) + '.pkl')100    user = user.fillna(-1)101    user.rename(102        columns={'new_user_class_level ': 'new_user_class_level'}, inplace=True)103    sample_sub = pd.read_pickle(104        '../sampled_data/raw_sample_' + str(FRAC) + '.pkl')105    sample_sub.rename(columns={'user': 'userid'}, inplace=True)106    data = pd.merge(sample_sub, user, how='left', on='userid', )107    data = pd.merge(data, ad, how='left', on='adgroup_id')108    sparse_features = ['userid', 'adgroup_id', 'pid', 'cms_segid', 'cms_group_id', 'final_gender_code', 'age_level',109                       'pvalue_level', 'shopping_level', 'occupation', 'new_user_class_level', 'campaign_id',110                       'customer'] # sparse feature for user and ads111    dense_features = ['price'] # dense feature for user and ads112    for feat in tqdm(sparse_features):113        lbe = LabelEncoder()  # or Hash114        data[feat] = lbe.fit_transform(data[feat]) # å°ä¸åçåå¼è½¬æ¢ä¸ºå¯¹åºçç¼å·115    mms = StandardScaler()116    data[dense_features] = mms.fit_transform(data[dense_features])117    # class SingleFeat(namedtuple('SingleFeat', ['name', 'dimension', 'hash_flag', 'dtype'])):118    sparse_feature_list = [SingleFeat(feat, data[feat].nunique(119    ) + 1) for feat in sparse_features + ['cate_id', 'brand']]120    dense_feature_list = [SingleFeat(feat, 1) for feat in dense_features]121    sess_feature = ['cate_id', 'brand'] # sess feature for ad122    sess_input = []123    sess_input_length = []124    for i in tqdm(range(SESS_COUNT)):125        sess_name = 'sess_' + str(i)126        for feat in sess_feature:127            sess_input.append(pad_sequences(128                sess_input_dict[sess_name][feat], maxlen=SESS_COUNT, padding='post'))129        sess_input_length.append(sess_input_length_dict[sess_name])130    model_input = [data[feat.name].values for feat in sparse_feature_list] + \131                  [data[feat.name].values for feat in dense_feature_list]132    sess_lists = sess_input + [np.array(sess_length_list)]133    model_input += sess_lists134    if not os.path.exists('../model_input/'):135        os.mkdir('../model_input/')136    pd.to_pickle(model_input, '../model_input/dsin_input_' +137                 str(FRAC) + '_' + str(SESS_COUNT) + '.pkl')138    pd.to_pickle(data['clk'].values, '../model_input/dsin_label_' +139                 str(FRAC) + '_' + str(SESS_COUNT) + '.pkl')140    pd.to_pickle({'sparse': sparse_feature_list, 'dense': dense_feature_list},141                 '../model_input/dsin_fd_' + str(FRAC) + '_' + str(SESS_COUNT) + '.pkl')...Looking for an in-depth tutorial around pytest? LambdaTest covers the detailed pytest tutorial that has everything related to the pytest, from setting up the pytest framework to automation testing. Delve deeper into pytest testing by exploring advanced use cases like parallel testing, pytest fixtures, parameterization, executing multiple test cases from a single file, and more.
Skim our below pytest tutorial playlist to get started with automation testing using the pytest framework.
https://www.youtube.com/playlist?list=PLZMWkkQEwOPlcGgDmHl8KkXKeLF83XlrP
Get 100 minutes of automation test minutes FREE!!
