Best Python code snippet using autotest_python
runner_test.py
Source:runner_test.py  
1# Lint as: python2, python32# Copyright 2019 Google LLC. All Rights Reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8#     http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and14# limitations under the License.15"""Tests for tfx.extensions.google_cloud_ai_platform.runner."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19import copy20import os21import sys22from typing import Any, Dict, Text23# Standard Imports24import mock25import tensorflow as tf26from tfx import version27from tfx.extensions.google_cloud_ai_platform import runner28from tfx.extensions.google_cloud_ai_platform.trainer import executor29from tfx.utils import json_utils30from tfx.utils import telemetry_utils31class RunnerTest(tf.test.TestCase):32  def setUp(self):33    super(RunnerTest, self).setUp()34    self._output_data_dir = os.path.join(35        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),36        self._testMethodName)37    self._project_id = '12345'38    self._mock_api_client = mock.Mock()39    self._inputs = {}40    self._outputs = {}41    self._training_inputs = {42        'project': self._project_id,43    }44    self._job_id = 'my_jobid'45    # Dict format of exec_properties. custom_config needs to be serialized46    # before being passed into start_aip_training function.47    self._exec_properties = {48        'custom_config': {49            executor.TRAINING_ARGS_KEY: self._training_inputs,50        },51    }52    self._model_name = 'model_name'53    self._ai_platform_serving_args = {54        'model_name': self._model_name,55        'project_id': self._project_id,56    }57    self._executor_class_path = 'my.executor.Executor'58  def _setUpTrainingMocks(self):59    self._mock_create = mock.Mock()60    self._mock_api_client.projects().jobs().create = self._mock_create61    self._mock_get = mock.Mock()62    self._mock_api_client.projects().jobs().get.return_value = self._mock_get63    self._mock_get.execute.return_value = {64        'state': 'SUCCEEDED',65    }66  def _serialize_custom_config_under_test(self) -> Dict[Text, Any]:67    """Converts self._exec_properties['custom_config'] to string."""68    result = copy.deepcopy(self._exec_properties)69    result['custom_config'] = json_utils.dumps(result['custom_config'])70    return result71  @mock.patch(72      'tfx.extensions.google_cloud_ai_platform.runner.discovery'73  )74  def testStartAIPTraining(self, mock_discovery):75    mock_discovery.build.return_value = self._mock_api_client76    self._setUpTrainingMocks()77    class_path = 'foo.bar.class'78    runner.start_aip_training(self._inputs, self._outputs,79                              self._serialize_custom_config_under_test(),80                              class_path,81                              self._training_inputs, None)82    self._mock_create.assert_called_with(83        body=mock.ANY, parent='projects/{}'.format(self._project_id))84    (_, kwargs) = self._mock_create.call_args85    body = kwargs['body']86    default_image = 'gcr.io/tfx-oss-public/tfx:{}'.format(version.__version__)87    self.assertDictContainsSubset(88        {89            'masterConfig': {90                'imageUri': default_image,91            },92            'args': [93                '--executor_class_path', class_path, '--inputs', '{}',94                '--outputs', '{}', '--exec-properties', '{"custom_config": '95                '"{\\"ai_platform_training_args\\": {\\"project\\": \\"12345\\"'96                '}}"}'97            ],98        }, body['trainingInput'])99    self.assertStartsWith(body['jobId'], 'tfx_')100    self._mock_get.execute.assert_called_with()101  @mock.patch(102      'tfx.extensions.google_cloud_ai_platform.runner.discovery'103  )104  def testStartAIPTrainingWithUserContainer(self, mock_discovery):105    mock_discovery.build.return_value = self._mock_api_client106    self._setUpTrainingMocks()107    class_path = 'foo.bar.class'108    self._training_inputs['masterConfig'] = {'imageUri': 'my-custom-image'}109    self._exec_properties['custom_config'][executor.JOB_ID_KEY] = self._job_id110    runner.start_aip_training(self._inputs, self._outputs,111                              self._serialize_custom_config_under_test(),112                              class_path,113                              self._training_inputs, self._job_id)114    self._mock_create.assert_called_with(115        body=mock.ANY, parent='projects/{}'.format(self._project_id))116    (_, kwargs) = self._mock_create.call_args117    body = kwargs['body']118    self.assertDictContainsSubset(119        {120            'masterConfig': {121                'imageUri': 'my-custom-image',122            },123            'args': [124                '--executor_class_path', class_path, '--inputs', '{}',125                '--outputs', '{}', '--exec-properties', '{"custom_config": '126                '"{\\"ai_platform_training_args\\": '127                '{\\"masterConfig\\": {\\"imageUri\\": \\"my-custom-image\\"}, '128                '\\"project\\": \\"12345\\"}, '129                '\\"ai_platform_training_job_id\\": \\"my_jobid\\"}"}'130            ],131        }, body['trainingInput'])132    self.assertEqual(body['jobId'], 'my_jobid')133    self._mock_get.execute.assert_called_with()134  def _setUpPredictionMocks(self):135    self._serving_path = os.path.join(self._output_data_dir, 'serving_path')136    self._model_version = 'model_version'137    self._mock_models_create = mock.Mock()138    self._mock_api_client.projects().models().create = self._mock_models_create139    self._mock_versions_create = mock.Mock()140    self._mock_versions_create.return_value.execute.return_value = {141        'name': 'versions_create_op_name'142    }143    self._mock_api_client.projects().models().versions(144    ).create = self._mock_versions_create145    self._mock_get = mock.Mock()146    self._mock_api_client.projects().operations().get = self._mock_get147    self._mock_set_default = mock.Mock()148    self._mock_api_client.projects().models().versions(149    ).setDefault = self._mock_set_default150    self._mock_set_default_execute = mock.Mock()151    self._mock_api_client.projects().models().versions().setDefault(152    ).execute = self._mock_set_default_execute153    self._mock_get.return_value.execute.return_value = {154        'done': True,155        'response': {156            'name': self._model_version,157        },158    }159  def _assertDeployModelMockCalls(self,160                                  expected_models_create_body=None,161                                  expected_versions_create_body=None,162                                  expect_set_default=True):163    if not expected_models_create_body:164      expected_models_create_body = {165          'name':166              self._model_name,167          'regions':168              [],169      }170    if not expected_versions_create_body:171      with telemetry_utils.scoped_labels(172          {telemetry_utils.LABEL_TFX_EXECUTOR: self._executor_class_path}):173        labels = telemetry_utils.get_labels_dict()174      expected_versions_create_body = {175          'name':176              self._model_version,177          'deployment_uri':178              self._serving_path,179          'runtime_version':180              runner._get_tf_runtime_version(tf.__version__),181          'python_version':182              runner._get_caip_python_version(183                  runner._get_tf_runtime_version(tf.__version__)),184          'labels': labels185      }186    self._mock_models_create.assert_called_with(187        body=mock.ANY,188        parent='projects/{}'.format(self._project_id),189    )190    (_, models_create_kwargs) = self._mock_models_create.call_args191    self.assertDictEqual(expected_models_create_body,192                         models_create_kwargs['body'])193    self._mock_versions_create.assert_called_with(194        body=mock.ANY,195        parent='projects/{}/models/{}'.format(self._project_id,196                                              self._model_name))197    (_, versions_create_kwargs) = self._mock_versions_create.call_args198    self.assertDictEqual(expected_versions_create_body,199                         versions_create_kwargs['body'])200    if not expect_set_default:201      return202    self._mock_set_default.assert_called_with(203        name='projects/{}/models/{}/versions/{}'.format(204            self._project_id, self._model_name, self._model_version))205    self._mock_set_default_execute.assert_called_with()206  @mock.patch(207      'tfx.extensions.google_cloud_ai_platform.runner.discovery'208  )209  def testDeployModelForAIPPrediction(self, mock_discovery):210    mock_discovery.build.return_value = self._mock_api_client211    self._setUpPredictionMocks()212    runner.deploy_model_for_aip_prediction(self._serving_path,213                                           self._model_version,214                                           self._ai_platform_serving_args,215                                           self._executor_class_path)216    expected_models_create_body = {217        'name': self._model_name,218        'regions': []219    }220    self._assertDeployModelMockCalls(221        expected_models_create_body=expected_models_create_body)222  @mock.patch(223      'tfx.extensions.google_cloud_ai_platform.runner.discovery'224  )225  def testDeployModelForAIPPredictionError(self, mock_discovery):226    mock_discovery.build.return_value = self._mock_api_client227    self._setUpPredictionMocks()228    self._mock_get.return_value.execute.return_value = {229        'done': True,230        'error': {231            'code': 999,232            'message': 'it was an error.'233        },234    }235    with self.assertRaises(RuntimeError):236      runner.deploy_model_for_aip_prediction(self._serving_path,237                                             self._model_version,238                                             self._ai_platform_serving_args,239                                             self._executor_class_path)240    expected_models_create_body = {241        'name': self._model_name,242        'regions': []243    }244    self._assertDeployModelMockCalls(245        expected_models_create_body=expected_models_create_body,246        expect_set_default=False)247  @mock.patch(248      'tfx.extensions.google_cloud_ai_platform.runner.discovery'249  )250  def testDeployModelForAIPPredictionWithCustomRegion(self, mock_discovery):251    mock_discovery.build.return_value = self._mock_api_client252    self._setUpPredictionMocks()253    self._ai_platform_serving_args['regions'] = ['custom-region']254    runner.deploy_model_for_aip_prediction(self._serving_path,255                                           self._model_version,256                                           self._ai_platform_serving_args,257                                           self._executor_class_path)258    expected_models_create_body = {259        'name': self._model_name,260        'regions': ['custom-region'],261    }262    self._assertDeployModelMockCalls(263        expected_models_create_body=expected_models_create_body)264  @mock.patch(265      'tfx.extensions.google_cloud_ai_platform.runner.discovery'266  )267  def testDeployModelForAIPPredictionWithCustomRuntime(self, mock_discovery):268    mock_discovery.build.return_value = self._mock_api_client269    self._setUpPredictionMocks()270    self._ai_platform_serving_args['runtime_version'] = '1.23.45'271    runner.deploy_model_for_aip_prediction(self._serving_path,272                                           self._model_version,273                                           self._ai_platform_serving_args,274                                           self._executor_class_path)275    with telemetry_utils.scoped_labels(276        {telemetry_utils.LABEL_TFX_EXECUTOR: self._executor_class_path}):277      labels = telemetry_utils.get_labels_dict()278    expected_versions_create_body = {279        'name': self._model_version,280        'deployment_uri': self._serving_path,281        'runtime_version': '1.23.45',282        'python_version': runner._get_caip_python_version('1.23.45'),283        'labels': labels,284    }285    self._assertDeployModelMockCalls(286        expected_versions_create_body=expected_versions_create_body)287  def testGetTensorflowRuntime(self):288    self.assertEqual('1.14', runner._get_tf_runtime_version('1.14'))289    self.assertEqual('1.15', runner._get_tf_runtime_version('1.15.0'))290    self.assertEqual('1.15', runner._get_tf_runtime_version('1.15.1'))291    self.assertEqual('1.15', runner._get_tf_runtime_version('2.0.0'))292    self.assertEqual('1.15', runner._get_tf_runtime_version('2.0.1'))293    self.assertEqual('2.1', runner._get_tf_runtime_version('2.1.0'))294    # TODO(b/157039850) Remove this once CAIP model support TF 2.2 runtime.295    self.assertEqual('2.1', runner._get_tf_runtime_version('2.2.0'))296  def testGetCaipPythonVersion(self):297    if sys.version_info.major == 2:298      self.assertEqual('2.7', runner._get_caip_python_version('1.14'))299      self.assertEqual('2.7', runner._get_caip_python_version('1.15'))300    else:  # 3.x301      self.assertEqual('3.5', runner._get_caip_python_version('1.14'))302      self.assertEqual('3.7', runner._get_caip_python_version('1.15'))303      self.assertEqual('3.7', runner._get_caip_python_version('2.1'))304if __name__ == '__main__':...client_test.py
Source:client_test.py  
1# Copyright 2019 Google LLC. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7#     http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tests for cloud_fit.client."""15import os16import tempfile17from unittest import mock18import cloudpickle19from googleapiclient import discovery20import tensorflow as tf21import tensorflow_datasets as tfds22from tensorflow_enterprise_addons.cloud_fit import client23from tensorflow_enterprise_addons.cloud_fit import utils24# Can only export Datasets which were created executing eagerly25utils.enable_eager_for_tf_1()26MIRRORED_STRATEGY_NAME = utils.MIRRORED_STRATEGY_NAME27MULTI_WORKER_MIRRORED_STRATEGY_NAME = utils.MULTI_WORKER_MIRRORED_STRATEGY_NAME28class CloudFitClientTest(tf.test.TestCase):29  def setUp(self):30    super(CloudFitClientTest, self).setUp()31    self._image_uri = 'gcr.io/some_test_image:latest'32    self._project_id = 'test_project_id'33    self._region = 'test_region'34    self._mock_apiclient = mock.Mock()35    self._remote_dir = tempfile.mkdtemp()36    self._job_spec = client._default_job_spec(self._region, self._image_uri, [37        '--remote_dir', self._remote_dir, '--distribution_strategy',38        MULTI_WORKER_MIRRORED_STRATEGY_NAME39    ])40    self._model = self._model()41    self._x = [[9.], [10.], [11.]] * 1042    self._y = [[xi[0] / 2. + 6] for xi in self._x]43    self._dataset = tf.data.Dataset.from_tensor_slices((self._x, self._y))44    self._scalar_fit_kwargs = {'batch_size': 1, 'epochs': 2, 'verbose': 3}45  def _set_up_training_mocks(self):46    self._mock_create = mock.Mock()47    self._mock_apiclient.projects().jobs().create = self._mock_create48    self._mock_get = mock.Mock()49    self._mock_create.return_value = self._mock_get50    self._mock_get.execute.return_value = {51        'state': 'SUCCEEDED',52    }53  def _model(self):54    """Writes SavedModel to compute y = wx + 1, with w trainable."""55    inp = tf.keras.layers.Input(shape=(1,), dtype=tf.float32)56    times_w = tf.keras.layers.Dense(57        units=1,58        kernel_initializer=tf.keras.initializers.Constant([[0.5]]),59        kernel_regularizer=tf.keras.regularizers.l2(0.01),60        use_bias=False)61    plus_1 = tf.keras.layers.Dense(62        units=1,63        kernel_initializer=tf.keras.initializers.Constant([[1.0]]),64        bias_initializer=tf.keras.initializers.Constant([1.0]),65        trainable=False)66    outp = plus_1(times_w(inp))67    model = tf.keras.Model(inp, outp)68    model.compile(69        tf.keras.optimizers.SGD(0.002), 'mean_squared_error', run_eagerly=True)70    return model71  def test_default_job_spec(self):72    self.assertStartsWith(self._job_spec['job_id'], 'cloud_fit_')73    self.assertDictContainsSubset(74        {75            'masterConfig': {76                'imageUri': self._image_uri,77            },78            'args': [79                '--remote_dir', self._remote_dir, '--distribution_strategy',80                MULTI_WORKER_MIRRORED_STRATEGY_NAME81            ],82        }, self._job_spec['trainingInput'])83  @mock.patch.object(discovery, 'build', autospec=True)84  def test_submit_job(self, mock_discovery_build):85    mock_discovery_build.return_value = self._mock_apiclient86    self._set_up_training_mocks()87    client._submit_job(self._job_spec, self._project_id)88    mock_discovery_build.assert_called_once_with(89        'ml', 'v1', cache_discovery=False)90    self._mock_create.assert_called_with(91        body=mock.ANY, parent='projects/{}'.format(self._project_id))92    _, fit_kwargs = list(self._mock_create.call_args)93    body = fit_kwargs['body']94    self.assertDictContainsSubset(95        {96            'masterConfig': {97                'imageUri': self._image_uri,98            },99            'args': [100                '--remote_dir', self._remote_dir, '--distribution_strategy',101                MULTI_WORKER_MIRRORED_STRATEGY_NAME102            ],103        }, body['trainingInput'])104    self.assertStartsWith(body['job_id'], 'cloud_fit_')105    self._mock_get.execute.assert_called_with()106  def test_serialize_assets(self):107    # TF 1.x is not supported108    if utils.is_tf_v1():109      with self.assertRaises(RuntimeError):110        client.cloud_fit(111            self._model,112            x=self._dataset,113            validation_data=self._dataset,114            remote_dir=self._remote_dir,115            job_spec=self._job_spec,116            batch_size=1,117            epochs=2,118            verbose=3)119      return120    tensorboard_callback = tf.keras.callbacks.TensorBoard(121        log_dir=self._remote_dir)122    args = self._scalar_fit_kwargs123    args['callbacks'] = [tensorboard_callback]124    client._serialize_assets(self._remote_dir, self._model, **args)125    self.assertGreaterEqual(126        len(127            tf.io.gfile.listdir(128                os.path.join(self._remote_dir, 'training_assets'))), 1)129    self.assertGreaterEqual(130        len(tf.io.gfile.listdir(os.path.join(self._remote_dir, 'model'))), 1)131    training_assets_graph = tf.saved_model.load(132        os.path.join(self._remote_dir, 'training_assets'))133    pickled_callbacks = tfds.as_numpy(training_assets_graph.callbacks_fn())134    unpickled_callbacks = cloudpickle.loads(pickled_callbacks)135    self.assertIsInstance(unpickled_callbacks[0],136                          tf.keras.callbacks.TensorBoard)137  @mock.patch.object(client, '_submit_job', autospec=True)138  def test_fit_kwargs(self, mock_submit_job):139    # TF 1.x is not supported140    if utils.is_tf_v1():141      with self.assertRaises(RuntimeError):142        client.cloud_fit(143            self._model,144            x=self._dataset,145            validation_data=self._dataset,146            remote_dir=self._remote_dir,147            job_spec=self._job_spec,148            batch_size=1,149            epochs=2,150            verbose=3)151      return152    job_id = client.cloud_fit(153        self._model,154        x=self._dataset,155        validation_data=self._dataset,156        remote_dir=self._remote_dir,157        region=self._region,158        project_id=self._project_id,159        image_uri=self._image_uri,160        batch_size=1,161        epochs=2,162        verbose=3)163    kargs, _ = mock_submit_job.call_args164    body, _ = kargs165    self.assertEqual(body['job_id'], job_id)166    remote_dir = body['trainingInput']['args'][1]167    training_assets_graph = tf.saved_model.load(168        os.path.join(remote_dir, 'training_assets'))169    elements = training_assets_graph.fit_kwargs_fn()170    self.assertDictContainsSubset(171        tfds.as_numpy(elements), {172            'batch_size': 1,173            'epochs': 2,174            'verbose': 3175        })176  @mock.patch.object(client, '_submit_job', autospec=True)177  def test_custom_job_spec(self, mock_submit_job):178    # TF 1.x is not supported179    if utils.is_tf_v1():180      with self.assertRaises(RuntimeError):181        client.cloud_fit(182            self._model,183            x=self._dataset,184            validation_data=self._dataset,185            remote_dir=self._remote_dir,186            job_spec=self._job_spec,187            batch_size=1,188            epochs=2,189            verbose=3)190      return191    client.cloud_fit(192        self._model,193        x=self._dataset,194        validation_data=self._dataset,195        remote_dir=self._remote_dir,196        job_spec=self._job_spec,197        batch_size=1,198        epochs=2,199        verbose=3)200    kargs, _ = mock_submit_job.call_args201    body, _ = kargs202    self.assertDictContainsSubset(203        {204            'masterConfig': {205                'imageUri': self._image_uri,206            },207            'args': [208                '--remote_dir', self._remote_dir, '--distribution_strategy',209                MULTI_WORKER_MIRRORED_STRATEGY_NAME210            ],211        }, body['trainingInput'])212  @mock.patch.object(client, '_submit_job', autospec=True)213  @mock.patch.object(client, '_serialize_assets', autospec=True)214  def test_distribution_strategy(self, mock_serialize_assets, mock_submit_job):215    # TF 1.x is not supported216    if utils.is_tf_v1():217      with self.assertRaises(RuntimeError):218        client.cloud_fit(219            self._model, x=self._dataset, remote_dir=self._remote_dir)220      return221    client.cloud_fit(self._model, x=self._dataset, remote_dir=self._remote_dir)222    kargs, _ = mock_submit_job.call_args223    body, _ = kargs224    self.assertDictContainsSubset(225        {226            'args': [227                '--remote_dir', self._remote_dir, '--distribution_strategy',228                MULTI_WORKER_MIRRORED_STRATEGY_NAME229            ],230        }, body['trainingInput'])231    client.cloud_fit(232        self._model,233        x=self._dataset,234        remote_dir=self._remote_dir,235        distribution_strategy=MIRRORED_STRATEGY_NAME,236        job_spec=self._job_spec)237    kargs, _ = mock_submit_job.call_args238    body, _ = kargs239    self.assertDictContainsSubset(240        {241            'args': [242                '--remote_dir', self._remote_dir, '--distribution_strategy',243                MIRRORED_STRATEGY_NAME244            ],245        }, body['trainingInput'])246    with self.assertRaises(ValueError):247      client.cloud_fit(248          self._model,249          x=self._dataset,250          remote_dir=self._remote_dir,251          distribution_strategy='not_implemented_strategy',252          job_spec=self._job_spec)253  @mock.patch.object(client, '_submit_job', autospec=True)254  @mock.patch.object(client, '_serialize_assets', autospec=True)255  def test_job_id(self, mock_serialize_assets, mock_submit_job):256    # TF 1.x is not supported257    if utils.is_tf_v1():258      with self.assertRaises(RuntimeError):259        client.cloud_fit(260            self._model,261            x=self._dataset,262            validation_data=self._dataset,263            remote_dir=self._remote_dir,264            job_spec=self._job_spec,265            batch_size=1,266            epochs=2,267            verbose=3)268      return269    test_job_id = 'test_job_id'270    client.cloud_fit(271        self._model,272        x=self._dataset,273        validation_data=self._dataset,274        remote_dir=self._remote_dir,275        job_spec=self._job_spec,276        job_id=test_job_id,277        batch_size=1,278        epochs=2,279        verbose=3)280    kargs, _ = mock_submit_job.call_args281    body, _ = kargs282    self.assertDictContainsSubset({283        'job_id': test_job_id,284    }, body)285if __name__ == '__main__':...ai_platform_training_executor_test.py
Source:ai_platform_training_executor_test.py  
1# Copyright 2020 Google LLC. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7#     http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tests for AI Platform Training component executor."""15import copy16from unittest import mock17from googleapiclient import discovery18import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import19from tfx.dsl.component.experimental import placeholders20from tfx.orchestration.kubeflow.v2.components.experimental import ai_platform_training_executor21from tfx.types import artifact_utils22from tfx.types import standard_artifacts23from tfx.utils import json_utils24from tfx.utils import test_case_utils25_EXAMPLE_LOCATION = 'root/ExampleGen/1/examples/'26_MODEL_LOCATION = 'root/Training/2/model/'27class AiPlatformTrainingExecutorTest(test_case_utils.TfxTest):28  def setUp(self):29    super().setUp()30    self._project_id = 'my-project'31    self._job_id = 'my-job-123'32    self._labels = ['label1', 'label2']33    examples_artifact = standard_artifacts.Examples()34    examples_artifact.split_names = artifact_utils.encode_split_names(35        ['train', 'eval'])36    examples_artifact.uri = _EXAMPLE_LOCATION37    self._inputs = {'examples': [examples_artifact]}38    model_artifact = standard_artifacts.Model()39    model_artifact.uri = _MODEL_LOCATION40    self._outputs = {'model': [model_artifact]}41    training_job = {42        'training_input': {43            'scaleTier':44                'CUSTOM',45            'region':46                'us-central1',47            'masterType':48                'n1-standard-8',49            'masterConfig': {50                'imageUri': 'gcr.io/my-project/caip-training-test:latest'51            },52            'workerType':53                'n1-standard-8',54            'workerCount':55                8,56            'workerConfig': {57                'imageUri': 'gcr.io/my-project/caip-training-test:latest'58            },59            'args': [60                '--examples',61                placeholders.InputUriPlaceholder('examples'), '--n-steps',62                placeholders.InputValuePlaceholder('n_step'), '--model-dir',63                placeholders.OutputUriPlaceholder('model')64            ]65        },66        'labels': self._labels,67    }68    aip_training_config = {69        ai_platform_training_executor.PROJECT_CONFIG_KEY: self._project_id,70        ai_platform_training_executor.TRAINING_JOB_CONFIG_KEY: training_job,71        ai_platform_training_executor.JOB_ID_CONFIG_KEY: self._job_id,72        ai_platform_training_executor.LABELS_CONFIG_KEY: self._labels,73    }74    self._exec_properties = {75        ai_platform_training_executor.CONFIG_KEY:76            json_utils.dumps(aip_training_config),77        'n_step':78            10079    }80    resolved_training_input = copy.deepcopy(training_job['training_input'])81    resolved_training_input['args'] = [82        '--examples', _EXAMPLE_LOCATION, '--n-steps', '100', '--model-dir',83        _MODEL_LOCATION84    ]85    self._expected_job_spec = {86        'job_id': self._job_id,87        'training_input': resolved_training_input,88        'labels': self._labels,89    }90    self._mock_api_client = mock.Mock()91    mock_discovery = self.enter_context(92        mock.patch.object(93            discovery,94            'build',95            autospec=True))96    mock_discovery.return_value = self._mock_api_client97    self._setUpTrainingMocks()98  def _setUpTrainingMocks(self):99    self._mock_create = mock.Mock()100    self._mock_api_client.projects().jobs().create = self._mock_create101    self._mock_get = mock.Mock()102    self._mock_api_client.projects().jobs().get.return_value = self._mock_get103    self._mock_get.execute.return_value = {104        'state': 'SUCCEEDED',105    }106  def testRunAipTraining(self):107    aip_executor = ai_platform_training_executor.AiPlatformTrainingExecutor()108    aip_executor.Do(109        input_dict=self._inputs,110        output_dict=self._outputs,111        exec_properties=self._exec_properties)112    self._mock_create.assert_called_once_with(113        body=self._expected_job_spec,114        parent='projects/{}'.format(self._project_id))115  def testRunAipTrainingWithDefaultJobId(self):116    aip_executor = ai_platform_training_executor.AiPlatformTrainingExecutor()117    # Delete job_id in the exec_properties.118    training_config = json_utils.loads(119        self._exec_properties[ai_platform_training_executor.CONFIG_KEY])120    training_config[ai_platform_training_executor.JOB_ID_CONFIG_KEY] = None121    self._exec_properties[ai_platform_training_executor.CONFIG_KEY] = (122        json_utils.dumps(training_config))123    aip_executor.Do(124        input_dict=self._inputs,125        output_dict=self._outputs,126        exec_properties=self._exec_properties)127    self._mock_create.assert_called_once()128    print(self._mock_create.call_args[1])129    print(self._mock_create.call_args[1]['body'])130    self.assertEqual('tfx_',131                     self._mock_create.call_args[1]['body']['job_id'][:4])132if __name__ == '__main__':...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
