How to use _mock_create method in autotest

Best Python code snippet using autotest_python

runner_test.py

Source:runner_test.py Github

copy

Full Screen

1# Lint as: python2, python32# Copyright 2019 Google LLC. All Rights Reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and14# limitations under the License.15"""Tests for tfx.extensions.google_cloud_ai_platform.runner."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19import copy20import os21import sys22from typing import Any, Dict, Text23# Standard Imports24import mock25import tensorflow as tf26from tfx import version27from tfx.extensions.google_cloud_ai_platform import runner28from tfx.extensions.google_cloud_ai_platform.trainer import executor29from tfx.utils import json_utils30from tfx.utils import telemetry_utils31class RunnerTest(tf.test.TestCase):32 def setUp(self):33 super(RunnerTest, self).setUp()34 self._output_data_dir = os.path.join(35 os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),36 self._testMethodName)37 self._project_id = '12345'38 self._mock_api_client = mock.Mock()39 self._inputs = {}40 self._outputs = {}41 self._training_inputs = {42 'project': self._project_id,43 }44 self._job_id = 'my_jobid'45 # Dict format of exec_properties. custom_config needs to be serialized46 # before being passed into start_aip_training function.47 self._exec_properties = {48 'custom_config': {49 executor.TRAINING_ARGS_KEY: self._training_inputs,50 },51 }52 self._model_name = 'model_name'53 self._ai_platform_serving_args = {54 'model_name': self._model_name,55 'project_id': self._project_id,56 }57 self._executor_class_path = 'my.executor.Executor'58 def _setUpTrainingMocks(self):59 self._mock_create = mock.Mock()60 self._mock_api_client.projects().jobs().create = self._mock_create61 self._mock_get = mock.Mock()62 self._mock_api_client.projects().jobs().get.return_value = self._mock_get63 self._mock_get.execute.return_value = {64 'state': 'SUCCEEDED',65 }66 def _serialize_custom_config_under_test(self) -> Dict[Text, Any]:67 """Converts self._exec_properties['custom_config'] to string."""68 result = copy.deepcopy(self._exec_properties)69 result['custom_config'] = json_utils.dumps(result['custom_config'])70 return result71 @mock.patch(72 'tfx.extensions.google_cloud_ai_platform.runner.discovery'73 )74 def testStartAIPTraining(self, mock_discovery):75 mock_discovery.build.return_value = self._mock_api_client76 self._setUpTrainingMocks()77 class_path = 'foo.bar.class'78 runner.start_aip_training(self._inputs, self._outputs,79 self._serialize_custom_config_under_test(),80 class_path,81 self._training_inputs, None)82 self._mock_create.assert_called_with(83 body=mock.ANY, parent='projects/{}'.format(self._project_id))84 (_, kwargs) = self._mock_create.call_args85 body = kwargs['body']86 default_image = 'gcr.io/tfx-oss-public/tfx:{}'.format(version.__version__)87 self.assertDictContainsSubset(88 {89 'masterConfig': {90 'imageUri': default_image,91 },92 'args': [93 '--executor_class_path', class_path, '--inputs', '{}',94 '--outputs', '{}', '--exec-properties', '{"custom_config": '95 '"{\\"ai_platform_training_args\\": {\\"project\\": \\"12345\\"'96 '}}"}'97 ],98 }, body['trainingInput'])99 self.assertStartsWith(body['jobId'], 'tfx_')100 self._mock_get.execute.assert_called_with()101 @mock.patch(102 'tfx.extensions.google_cloud_ai_platform.runner.discovery'103 )104 def testStartAIPTrainingWithUserContainer(self, mock_discovery):105 mock_discovery.build.return_value = self._mock_api_client106 self._setUpTrainingMocks()107 class_path = 'foo.bar.class'108 self._training_inputs['masterConfig'] = {'imageUri': 'my-custom-image'}109 self._exec_properties['custom_config'][executor.JOB_ID_KEY] = self._job_id110 runner.start_aip_training(self._inputs, self._outputs,111 self._serialize_custom_config_under_test(),112 class_path,113 self._training_inputs, self._job_id)114 self._mock_create.assert_called_with(115 body=mock.ANY, parent='projects/{}'.format(self._project_id))116 (_, kwargs) = self._mock_create.call_args117 body = kwargs['body']118 self.assertDictContainsSubset(119 {120 'masterConfig': {121 'imageUri': 'my-custom-image',122 },123 'args': [124 '--executor_class_path', class_path, '--inputs', '{}',125 '--outputs', '{}', '--exec-properties', '{"custom_config": '126 '"{\\"ai_platform_training_args\\": '127 '{\\"masterConfig\\": {\\"imageUri\\": \\"my-custom-image\\"}, '128 '\\"project\\": \\"12345\\"}, '129 '\\"ai_platform_training_job_id\\": \\"my_jobid\\"}"}'130 ],131 }, body['trainingInput'])132 self.assertEqual(body['jobId'], 'my_jobid')133 self._mock_get.execute.assert_called_with()134 def _setUpPredictionMocks(self):135 self._serving_path = os.path.join(self._output_data_dir, 'serving_path')136 self._model_version = 'model_version'137 self._mock_models_create = mock.Mock()138 self._mock_api_client.projects().models().create = self._mock_models_create139 self._mock_versions_create = mock.Mock()140 self._mock_versions_create.return_value.execute.return_value = {141 'name': 'versions_create_op_name'142 }143 self._mock_api_client.projects().models().versions(144 ).create = self._mock_versions_create145 self._mock_get = mock.Mock()146 self._mock_api_client.projects().operations().get = self._mock_get147 self._mock_set_default = mock.Mock()148 self._mock_api_client.projects().models().versions(149 ).setDefault = self._mock_set_default150 self._mock_set_default_execute = mock.Mock()151 self._mock_api_client.projects().models().versions().setDefault(152 ).execute = self._mock_set_default_execute153 self._mock_get.return_value.execute.return_value = {154 'done': True,155 'response': {156 'name': self._model_version,157 },158 }159 def _assertDeployModelMockCalls(self,160 expected_models_create_body=None,161 expected_versions_create_body=None,162 expect_set_default=True):163 if not expected_models_create_body:164 expected_models_create_body = {165 'name':166 self._model_name,167 'regions':168 [],169 }170 if not expected_versions_create_body:171 with telemetry_utils.scoped_labels(172 {telemetry_utils.LABEL_TFX_EXECUTOR: self._executor_class_path}):173 labels = telemetry_utils.get_labels_dict()174 expected_versions_create_body = {175 'name':176 self._model_version,177 'deployment_uri':178 self._serving_path,179 'runtime_version':180 runner._get_tf_runtime_version(tf.__version__),181 'python_version':182 runner._get_caip_python_version(183 runner._get_tf_runtime_version(tf.__version__)),184 'labels': labels185 }186 self._mock_models_create.assert_called_with(187 body=mock.ANY,188 parent='projects/{}'.format(self._project_id),189 )190 (_, models_create_kwargs) = self._mock_models_create.call_args191 self.assertDictEqual(expected_models_create_body,192 models_create_kwargs['body'])193 self._mock_versions_create.assert_called_with(194 body=mock.ANY,195 parent='projects/{}/models/{}'.format(self._project_id,196 self._model_name))197 (_, versions_create_kwargs) = self._mock_versions_create.call_args198 self.assertDictEqual(expected_versions_create_body,199 versions_create_kwargs['body'])200 if not expect_set_default:201 return202 self._mock_set_default.assert_called_with(203 name='projects/{}/models/{}/versions/{}'.format(204 self._project_id, self._model_name, self._model_version))205 self._mock_set_default_execute.assert_called_with()206 @mock.patch(207 'tfx.extensions.google_cloud_ai_platform.runner.discovery'208 )209 def testDeployModelForAIPPrediction(self, mock_discovery):210 mock_discovery.build.return_value = self._mock_api_client211 self._setUpPredictionMocks()212 runner.deploy_model_for_aip_prediction(self._serving_path,213 self._model_version,214 self._ai_platform_serving_args,215 self._executor_class_path)216 expected_models_create_body = {217 'name': self._model_name,218 'regions': []219 }220 self._assertDeployModelMockCalls(221 expected_models_create_body=expected_models_create_body)222 @mock.patch(223 'tfx.extensions.google_cloud_ai_platform.runner.discovery'224 )225 def testDeployModelForAIPPredictionError(self, mock_discovery):226 mock_discovery.build.return_value = self._mock_api_client227 self._setUpPredictionMocks()228 self._mock_get.return_value.execute.return_value = {229 'done': True,230 'error': {231 'code': 999,232 'message': 'it was an error.'233 },234 }235 with self.assertRaises(RuntimeError):236 runner.deploy_model_for_aip_prediction(self._serving_path,237 self._model_version,238 self._ai_platform_serving_args,239 self._executor_class_path)240 expected_models_create_body = {241 'name': self._model_name,242 'regions': []243 }244 self._assertDeployModelMockCalls(245 expected_models_create_body=expected_models_create_body,246 expect_set_default=False)247 @mock.patch(248 'tfx.extensions.google_cloud_ai_platform.runner.discovery'249 )250 def testDeployModelForAIPPredictionWithCustomRegion(self, mock_discovery):251 mock_discovery.build.return_value = self._mock_api_client252 self._setUpPredictionMocks()253 self._ai_platform_serving_args['regions'] = ['custom-region']254 runner.deploy_model_for_aip_prediction(self._serving_path,255 self._model_version,256 self._ai_platform_serving_args,257 self._executor_class_path)258 expected_models_create_body = {259 'name': self._model_name,260 'regions': ['custom-region'],261 }262 self._assertDeployModelMockCalls(263 expected_models_create_body=expected_models_create_body)264 @mock.patch(265 'tfx.extensions.google_cloud_ai_platform.runner.discovery'266 )267 def testDeployModelForAIPPredictionWithCustomRuntime(self, mock_discovery):268 mock_discovery.build.return_value = self._mock_api_client269 self._setUpPredictionMocks()270 self._ai_platform_serving_args['runtime_version'] = '1.23.45'271 runner.deploy_model_for_aip_prediction(self._serving_path,272 self._model_version,273 self._ai_platform_serving_args,274 self._executor_class_path)275 with telemetry_utils.scoped_labels(276 {telemetry_utils.LABEL_TFX_EXECUTOR: self._executor_class_path}):277 labels = telemetry_utils.get_labels_dict()278 expected_versions_create_body = {279 'name': self._model_version,280 'deployment_uri': self._serving_path,281 'runtime_version': '1.23.45',282 'python_version': runner._get_caip_python_version('1.23.45'),283 'labels': labels,284 }285 self._assertDeployModelMockCalls(286 expected_versions_create_body=expected_versions_create_body)287 def testGetTensorflowRuntime(self):288 self.assertEqual('1.14', runner._get_tf_runtime_version('1.14'))289 self.assertEqual('1.15', runner._get_tf_runtime_version('1.15.0'))290 self.assertEqual('1.15', runner._get_tf_runtime_version('1.15.1'))291 self.assertEqual('1.15', runner._get_tf_runtime_version('2.0.0'))292 self.assertEqual('1.15', runner._get_tf_runtime_version('2.0.1'))293 self.assertEqual('2.1', runner._get_tf_runtime_version('2.1.0'))294 # TODO(b/157039850) Remove this once CAIP model support TF 2.2 runtime.295 self.assertEqual('2.1', runner._get_tf_runtime_version('2.2.0'))296 def testGetCaipPythonVersion(self):297 if sys.version_info.major == 2:298 self.assertEqual('2.7', runner._get_caip_python_version('1.14'))299 self.assertEqual('2.7', runner._get_caip_python_version('1.15'))300 else: # 3.x301 self.assertEqual('3.5', runner._get_caip_python_version('1.14'))302 self.assertEqual('3.7', runner._get_caip_python_version('1.15'))303 self.assertEqual('3.7', runner._get_caip_python_version('2.1'))304if __name__ == '__main__':...

Full Screen

Full Screen

client_test.py

Source:client_test.py Github

copy

Full Screen

1# Copyright 2019 Google LLC. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tests for cloud_fit.client."""15import os16import tempfile17from unittest import mock18import cloudpickle19from googleapiclient import discovery20import tensorflow as tf21import tensorflow_datasets as tfds22from tensorflow_enterprise_addons.cloud_fit import client23from tensorflow_enterprise_addons.cloud_fit import utils24# Can only export Datasets which were created executing eagerly25utils.enable_eager_for_tf_1()26MIRRORED_STRATEGY_NAME = utils.MIRRORED_STRATEGY_NAME27MULTI_WORKER_MIRRORED_STRATEGY_NAME = utils.MULTI_WORKER_MIRRORED_STRATEGY_NAME28class CloudFitClientTest(tf.test.TestCase):29 def setUp(self):30 super(CloudFitClientTest, self).setUp()31 self._image_uri = 'gcr.io/some_test_image:latest'32 self._project_id = 'test_project_id'33 self._region = 'test_region'34 self._mock_apiclient = mock.Mock()35 self._remote_dir = tempfile.mkdtemp()36 self._job_spec = client._default_job_spec(self._region, self._image_uri, [37 '--remote_dir', self._remote_dir, '--distribution_strategy',38 MULTI_WORKER_MIRRORED_STRATEGY_NAME39 ])40 self._model = self._model()41 self._x = [[9.], [10.], [11.]] * 1042 self._y = [[xi[0] / 2. + 6] for xi in self._x]43 self._dataset = tf.data.Dataset.from_tensor_slices((self._x, self._y))44 self._scalar_fit_kwargs = {'batch_size': 1, 'epochs': 2, 'verbose': 3}45 def _set_up_training_mocks(self):46 self._mock_create = mock.Mock()47 self._mock_apiclient.projects().jobs().create = self._mock_create48 self._mock_get = mock.Mock()49 self._mock_create.return_value = self._mock_get50 self._mock_get.execute.return_value = {51 'state': 'SUCCEEDED',52 }53 def _model(self):54 """Writes SavedModel to compute y = wx + 1, with w trainable."""55 inp = tf.keras.layers.Input(shape=(1,), dtype=tf.float32)56 times_w = tf.keras.layers.Dense(57 units=1,58 kernel_initializer=tf.keras.initializers.Constant([[0.5]]),59 kernel_regularizer=tf.keras.regularizers.l2(0.01),60 use_bias=False)61 plus_1 = tf.keras.layers.Dense(62 units=1,63 kernel_initializer=tf.keras.initializers.Constant([[1.0]]),64 bias_initializer=tf.keras.initializers.Constant([1.0]),65 trainable=False)66 outp = plus_1(times_w(inp))67 model = tf.keras.Model(inp, outp)68 model.compile(69 tf.keras.optimizers.SGD(0.002), 'mean_squared_error', run_eagerly=True)70 return model71 def test_default_job_spec(self):72 self.assertStartsWith(self._job_spec['job_id'], 'cloud_fit_')73 self.assertDictContainsSubset(74 {75 'masterConfig': {76 'imageUri': self._image_uri,77 },78 'args': [79 '--remote_dir', self._remote_dir, '--distribution_strategy',80 MULTI_WORKER_MIRRORED_STRATEGY_NAME81 ],82 }, self._job_spec['trainingInput'])83 @mock.patch.object(discovery, 'build', autospec=True)84 def test_submit_job(self, mock_discovery_build):85 mock_discovery_build.return_value = self._mock_apiclient86 self._set_up_training_mocks()87 client._submit_job(self._job_spec, self._project_id)88 mock_discovery_build.assert_called_once_with(89 'ml', 'v1', cache_discovery=False)90 self._mock_create.assert_called_with(91 body=mock.ANY, parent='projects/{}'.format(self._project_id))92 _, fit_kwargs = list(self._mock_create.call_args)93 body = fit_kwargs['body']94 self.assertDictContainsSubset(95 {96 'masterConfig': {97 'imageUri': self._image_uri,98 },99 'args': [100 '--remote_dir', self._remote_dir, '--distribution_strategy',101 MULTI_WORKER_MIRRORED_STRATEGY_NAME102 ],103 }, body['trainingInput'])104 self.assertStartsWith(body['job_id'], 'cloud_fit_')105 self._mock_get.execute.assert_called_with()106 def test_serialize_assets(self):107 # TF 1.x is not supported108 if utils.is_tf_v1():109 with self.assertRaises(RuntimeError):110 client.cloud_fit(111 self._model,112 x=self._dataset,113 validation_data=self._dataset,114 remote_dir=self._remote_dir,115 job_spec=self._job_spec,116 batch_size=1,117 epochs=2,118 verbose=3)119 return120 tensorboard_callback = tf.keras.callbacks.TensorBoard(121 log_dir=self._remote_dir)122 args = self._scalar_fit_kwargs123 args['callbacks'] = [tensorboard_callback]124 client._serialize_assets(self._remote_dir, self._model, **args)125 self.assertGreaterEqual(126 len(127 tf.io.gfile.listdir(128 os.path.join(self._remote_dir, 'training_assets'))), 1)129 self.assertGreaterEqual(130 len(tf.io.gfile.listdir(os.path.join(self._remote_dir, 'model'))), 1)131 training_assets_graph = tf.saved_model.load(132 os.path.join(self._remote_dir, 'training_assets'))133 pickled_callbacks = tfds.as_numpy(training_assets_graph.callbacks_fn())134 unpickled_callbacks = cloudpickle.loads(pickled_callbacks)135 self.assertIsInstance(unpickled_callbacks[0],136 tf.keras.callbacks.TensorBoard)137 @mock.patch.object(client, '_submit_job', autospec=True)138 def test_fit_kwargs(self, mock_submit_job):139 # TF 1.x is not supported140 if utils.is_tf_v1():141 with self.assertRaises(RuntimeError):142 client.cloud_fit(143 self._model,144 x=self._dataset,145 validation_data=self._dataset,146 remote_dir=self._remote_dir,147 job_spec=self._job_spec,148 batch_size=1,149 epochs=2,150 verbose=3)151 return152 job_id = client.cloud_fit(153 self._model,154 x=self._dataset,155 validation_data=self._dataset,156 remote_dir=self._remote_dir,157 region=self._region,158 project_id=self._project_id,159 image_uri=self._image_uri,160 batch_size=1,161 epochs=2,162 verbose=3)163 kargs, _ = mock_submit_job.call_args164 body, _ = kargs165 self.assertEqual(body['job_id'], job_id)166 remote_dir = body['trainingInput']['args'][1]167 training_assets_graph = tf.saved_model.load(168 os.path.join(remote_dir, 'training_assets'))169 elements = training_assets_graph.fit_kwargs_fn()170 self.assertDictContainsSubset(171 tfds.as_numpy(elements), {172 'batch_size': 1,173 'epochs': 2,174 'verbose': 3175 })176 @mock.patch.object(client, '_submit_job', autospec=True)177 def test_custom_job_spec(self, mock_submit_job):178 # TF 1.x is not supported179 if utils.is_tf_v1():180 with self.assertRaises(RuntimeError):181 client.cloud_fit(182 self._model,183 x=self._dataset,184 validation_data=self._dataset,185 remote_dir=self._remote_dir,186 job_spec=self._job_spec,187 batch_size=1,188 epochs=2,189 verbose=3)190 return191 client.cloud_fit(192 self._model,193 x=self._dataset,194 validation_data=self._dataset,195 remote_dir=self._remote_dir,196 job_spec=self._job_spec,197 batch_size=1,198 epochs=2,199 verbose=3)200 kargs, _ = mock_submit_job.call_args201 body, _ = kargs202 self.assertDictContainsSubset(203 {204 'masterConfig': {205 'imageUri': self._image_uri,206 },207 'args': [208 '--remote_dir', self._remote_dir, '--distribution_strategy',209 MULTI_WORKER_MIRRORED_STRATEGY_NAME210 ],211 }, body['trainingInput'])212 @mock.patch.object(client, '_submit_job', autospec=True)213 @mock.patch.object(client, '_serialize_assets', autospec=True)214 def test_distribution_strategy(self, mock_serialize_assets, mock_submit_job):215 # TF 1.x is not supported216 if utils.is_tf_v1():217 with self.assertRaises(RuntimeError):218 client.cloud_fit(219 self._model, x=self._dataset, remote_dir=self._remote_dir)220 return221 client.cloud_fit(self._model, x=self._dataset, remote_dir=self._remote_dir)222 kargs, _ = mock_submit_job.call_args223 body, _ = kargs224 self.assertDictContainsSubset(225 {226 'args': [227 '--remote_dir', self._remote_dir, '--distribution_strategy',228 MULTI_WORKER_MIRRORED_STRATEGY_NAME229 ],230 }, body['trainingInput'])231 client.cloud_fit(232 self._model,233 x=self._dataset,234 remote_dir=self._remote_dir,235 distribution_strategy=MIRRORED_STRATEGY_NAME,236 job_spec=self._job_spec)237 kargs, _ = mock_submit_job.call_args238 body, _ = kargs239 self.assertDictContainsSubset(240 {241 'args': [242 '--remote_dir', self._remote_dir, '--distribution_strategy',243 MIRRORED_STRATEGY_NAME244 ],245 }, body['trainingInput'])246 with self.assertRaises(ValueError):247 client.cloud_fit(248 self._model,249 x=self._dataset,250 remote_dir=self._remote_dir,251 distribution_strategy='not_implemented_strategy',252 job_spec=self._job_spec)253 @mock.patch.object(client, '_submit_job', autospec=True)254 @mock.patch.object(client, '_serialize_assets', autospec=True)255 def test_job_id(self, mock_serialize_assets, mock_submit_job):256 # TF 1.x is not supported257 if utils.is_tf_v1():258 with self.assertRaises(RuntimeError):259 client.cloud_fit(260 self._model,261 x=self._dataset,262 validation_data=self._dataset,263 remote_dir=self._remote_dir,264 job_spec=self._job_spec,265 batch_size=1,266 epochs=2,267 verbose=3)268 return269 test_job_id = 'test_job_id'270 client.cloud_fit(271 self._model,272 x=self._dataset,273 validation_data=self._dataset,274 remote_dir=self._remote_dir,275 job_spec=self._job_spec,276 job_id=test_job_id,277 batch_size=1,278 epochs=2,279 verbose=3)280 kargs, _ = mock_submit_job.call_args281 body, _ = kargs282 self.assertDictContainsSubset({283 'job_id': test_job_id,284 }, body)285if __name__ == '__main__':...

Full Screen

Full Screen

ai_platform_training_executor_test.py

Source:ai_platform_training_executor_test.py Github

copy

Full Screen

1# Copyright 2020 Google LLC. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tests for AI Platform Training component executor."""15import copy16from unittest import mock17from googleapiclient import discovery18import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import19from tfx.dsl.component.experimental import placeholders20from tfx.orchestration.kubeflow.v2.components.experimental import ai_platform_training_executor21from tfx.types import artifact_utils22from tfx.types import standard_artifacts23from tfx.utils import json_utils24from tfx.utils import test_case_utils25_EXAMPLE_LOCATION = 'root/ExampleGen/1/examples/'26_MODEL_LOCATION = 'root/Training/2/model/'27class AiPlatformTrainingExecutorTest(test_case_utils.TfxTest):28 def setUp(self):29 super().setUp()30 self._project_id = 'my-project'31 self._job_id = 'my-job-123'32 self._labels = ['label1', 'label2']33 examples_artifact = standard_artifacts.Examples()34 examples_artifact.split_names = artifact_utils.encode_split_names(35 ['train', 'eval'])36 examples_artifact.uri = _EXAMPLE_LOCATION37 self._inputs = {'examples': [examples_artifact]}38 model_artifact = standard_artifacts.Model()39 model_artifact.uri = _MODEL_LOCATION40 self._outputs = {'model': [model_artifact]}41 training_job = {42 'training_input': {43 'scaleTier':44 'CUSTOM',45 'region':46 'us-central1',47 'masterType':48 'n1-standard-8',49 'masterConfig': {50 'imageUri': 'gcr.io/my-project/caip-training-test:latest'51 },52 'workerType':53 'n1-standard-8',54 'workerCount':55 8,56 'workerConfig': {57 'imageUri': 'gcr.io/my-project/caip-training-test:latest'58 },59 'args': [60 '--examples',61 placeholders.InputUriPlaceholder('examples'), '--n-steps',62 placeholders.InputValuePlaceholder('n_step'), '--model-dir',63 placeholders.OutputUriPlaceholder('model')64 ]65 },66 'labels': self._labels,67 }68 aip_training_config = {69 ai_platform_training_executor.PROJECT_CONFIG_KEY: self._project_id,70 ai_platform_training_executor.TRAINING_JOB_CONFIG_KEY: training_job,71 ai_platform_training_executor.JOB_ID_CONFIG_KEY: self._job_id,72 ai_platform_training_executor.LABELS_CONFIG_KEY: self._labels,73 }74 self._exec_properties = {75 ai_platform_training_executor.CONFIG_KEY:76 json_utils.dumps(aip_training_config),77 'n_step':78 10079 }80 resolved_training_input = copy.deepcopy(training_job['training_input'])81 resolved_training_input['args'] = [82 '--examples', _EXAMPLE_LOCATION, '--n-steps', '100', '--model-dir',83 _MODEL_LOCATION84 ]85 self._expected_job_spec = {86 'job_id': self._job_id,87 'training_input': resolved_training_input,88 'labels': self._labels,89 }90 self._mock_api_client = mock.Mock()91 mock_discovery = self.enter_context(92 mock.patch.object(93 discovery,94 'build',95 autospec=True))96 mock_discovery.return_value = self._mock_api_client97 self._setUpTrainingMocks()98 def _setUpTrainingMocks(self):99 self._mock_create = mock.Mock()100 self._mock_api_client.projects().jobs().create = self._mock_create101 self._mock_get = mock.Mock()102 self._mock_api_client.projects().jobs().get.return_value = self._mock_get103 self._mock_get.execute.return_value = {104 'state': 'SUCCEEDED',105 }106 def testRunAipTraining(self):107 aip_executor = ai_platform_training_executor.AiPlatformTrainingExecutor()108 aip_executor.Do(109 input_dict=self._inputs,110 output_dict=self._outputs,111 exec_properties=self._exec_properties)112 self._mock_create.assert_called_once_with(113 body=self._expected_job_spec,114 parent='projects/{}'.format(self._project_id))115 def testRunAipTrainingWithDefaultJobId(self):116 aip_executor = ai_platform_training_executor.AiPlatformTrainingExecutor()117 # Delete job_id in the exec_properties.118 training_config = json_utils.loads(119 self._exec_properties[ai_platform_training_executor.CONFIG_KEY])120 training_config[ai_platform_training_executor.JOB_ID_CONFIG_KEY] = None121 self._exec_properties[ai_platform_training_executor.CONFIG_KEY] = (122 json_utils.dumps(training_config))123 aip_executor.Do(124 input_dict=self._inputs,125 output_dict=self._outputs,126 exec_properties=self._exec_properties)127 self._mock_create.assert_called_once()128 print(self._mock_create.call_args[1])129 print(self._mock_create.call_args[1]['body'])130 self.assertEqual('tfx_',131 self._mock_create.call_args[1]['body']['job_id'][:4])132if __name__ == '__main__':...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run autotest automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful