Best Python code snippet using autotest_python
test_sqlalchemy_store.py
Source:test_sqlalchemy_store.py  
1import os2import unittest3import mock4import tempfile5import uuid6import mlflow7import mlflow.db8import mlflow.store.db.base_sql_model9from mlflow.entities.model_registry import RegisteredModel, ModelVersion, \10    RegisteredModelTag, ModelVersionTag11from mlflow.exceptions import MlflowException12from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST, \13    INVALID_PARAMETER_VALUE, RESOURCE_ALREADY_EXISTS14from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore15from tests.helper_functions import random_str16DB_URI = 'sqlite:///'17class TestSqlAlchemyStoreSqlite(unittest.TestCase):18    def _get_store(self, db_uri=''):19        return SqlAlchemyStore(db_uri)20    def setUp(self):21        self.maxDiff = None  # print all differences on assert failures22        fd, self.temp_dbfile = tempfile.mkstemp()23        # Close handle immediately so that we can remove the file later on in Windows24        os.close(fd)25        self.db_url = "%s%s" % (DB_URI, self.temp_dbfile)26        self.store = self._get_store(self.db_url)27    def tearDown(self):28        mlflow.store.db.base_sql_model.Base.metadata.drop_all(self.store.engine)29        os.remove(self.temp_dbfile)30    def _rm_maker(self, name, tags=None):31        return self.store.create_registered_model(name, tags)32    def _mv_maker(self, name, source="path/to/source", run_id=uuid.uuid4().hex, tags=None):33        return self.store.create_model_version(name, source, run_id, tags)34    def _extract_latest_by_stage(self, latest_versions):35        return {mvd.current_stage: mvd.version for mvd in latest_versions}36    def test_create_registered_model(self):37        name = random_str() + "abCD"38        rm1 = self._rm_maker(name)39        self.assertEqual(rm1.name, name)40        # error on duplicate41        with self.assertRaises(MlflowException) as exception_context:42            self._rm_maker(name)43        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS)44        # slightly different name is ok45        for name2 in [name + "extra", name.lower(), name.upper(), name + name]:46            rm2 = self._rm_maker(name2)47            self.assertEqual(rm2.name, name2)48        # test create model with tags49        name2 = random_str() + "tags"50        tags = [RegisteredModelTag("key", "value"),51                RegisteredModelTag("anotherKey", "some other value")]52        rm2 = self._rm_maker(name2, tags)53        rmd2 = self.store.get_registered_model(name2)54        self.assertEqual(rm2.name, name2)55        self.assertEqual(rm2.tags, {tag.key: tag.value for tag in tags})56        self.assertEqual(rmd2.name, name2)57        self.assertEqual(rmd2.tags, {tag.key: tag.value for tag in tags})58        # invalid model name will fail59        with self.assertRaises(MlflowException) as exception_context:60            self._rm_maker(None)61        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)62        with self.assertRaises(MlflowException) as exception_context:63            self._rm_maker("")64        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)65    def test_get_registered_model(self):66        name = "model_1"67        tags = [RegisteredModelTag("key", "value"),68                RegisteredModelTag("anotherKey", "some other value")]69        # use fake clock70        with mock.patch("time.time") as mock_time:71            mock_time.return_value = 123472            rm = self._rm_maker(name, tags)73            self.assertEqual(rm.name, name)74        rmd = self.store.get_registered_model(name=name)75        self.assertEqual(rmd.name, name)76        self.assertEqual(rmd.creation_timestamp, 1234000)77        self.assertEqual(rmd.last_updated_timestamp, 1234000)78        self.assertEqual(rmd.description, None)79        self.assertEqual(rmd.latest_versions, [])80        self.assertEqual(rmd.tags, {tag.key: tag.value for tag in tags})81    def test_update_registered_model(self):82        name = "model_for_update_RM"83        rm1 = self._rm_maker(name)84        rmd1 = self.store.get_registered_model(name=name)85        self.assertEqual(rm1.name, name)86        self.assertEqual(rmd1.description, None)87        # update description88        rm2 = self.store.update_registered_model(name=name, description="test model")89        rmd2 = self.store.get_registered_model(name=name)90        self.assertEqual(rm2.name, "model_for_update_RM")91        self.assertEqual(rmd2.name, "model_for_update_RM")92        self.assertEqual(rmd2.description, "test model")93    def test_rename_registered_model(self):94        original_name = "original name"95        new_name = "new name"96        self._rm_maker(original_name)97        self._mv_maker(original_name)98        self._mv_maker(original_name)99        rm = self.store.get_registered_model(original_name)100        mv1 = self.store.get_model_version(original_name, 1)101        mv2 = self.store.get_model_version(original_name, 2)102        self.assertEqual(rm.name, original_name)103        self.assertEqual(mv1.name, original_name)104        self.assertEqual(mv2.name, original_name)105        # test renaming registered model also updates its model versions106        self.store.rename_registered_model(original_name, new_name)107        rm = self.store.get_registered_model(new_name)108        mv1 = self.store.get_model_version(new_name, 1)109        mv2 = self.store.get_model_version(new_name, 2)110        self.assertEqual(rm.name, new_name)111        self.assertEqual(mv1.name, new_name)112        self.assertEqual(mv2.name, new_name)113        # test accessing the model with the old name will fail114        with self.assertRaises(MlflowException) as exception_context:115            self.store.get_registered_model(original_name)116        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)117        # test name another model with the replaced name is ok118        self._rm_maker(original_name)119        # cannot rename model to conflict with an existing model120        with self.assertRaises(MlflowException) as exception_context:121            self.store.rename_registered_model(new_name, original_name)122        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS)123        # invalid model name will fail124        with self.assertRaises(MlflowException) as exception_context:125            self.store.rename_registered_model(original_name, None)126        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)127        with self.assertRaises(MlflowException) as exception_context:128            self.store.rename_registered_model(original_name, "")129        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)130    def test_delete_registered_model(self):131        name = "model_for_delete_RM"132        self._rm_maker(name)133        self._mv_maker(name)134        rm1 = self.store.get_registered_model(name=name)135        mv1 = self.store.get_model_version(name, 1)136        self.assertEqual(rm1.name, name)137        self.assertEqual(mv1.name, name)138        # delete model139        self.store.delete_registered_model(name=name)140        # cannot get model141        with self.assertRaises(MlflowException) as exception_context:142            self.store.get_registered_model(name=name)143        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)144        # cannot update a delete model145        with self.assertRaises(MlflowException) as exception_context:146            self.store.update_registered_model(name=name, description="deleted")147        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)148        # cannot delete it again149        with self.assertRaises(MlflowException) as exception_context:150            self.store.delete_registered_model(name=name)151        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)152        # model versions are cascade deleted with the registered model153        with self.assertRaises(MlflowException) as exception_context:154            self.store.get_model_version(name, 1)155        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)156    def _list_registered_models(self, page_token=None, max_results=10):157        result = self.store.list_registered_models(max_results, page_token)158        for idx in range(len(result)):159            result[idx] = result[idx].name160        return result161    def test_list_registered_model(self):162        self._rm_maker("A")163        registered_models = self.store.list_registered_models(max_results=10, page_token=None)164        self.assertEqual(len(registered_models), 1)165        self.assertEqual(registered_models[0].name, "A")166        self.assertIsInstance(registered_models[0], RegisteredModel)167        self._rm_maker("B")168        self.assertEqual(set(self._list_registered_models()),169                         set(["A", "B"]))170        self._rm_maker("BB")171        self._rm_maker("BA")172        self._rm_maker("AB")173        self._rm_maker("BBC")174        self.assertEqual(set(self._list_registered_models()),175                         set(["A", "B", "BB", "BA", "AB", "BBC"]))176        # list should not return deleted models177        self.store.delete_registered_model(name="BA")178        self.store.delete_registered_model(name="B")179        self.assertEqual(set(self._list_registered_models()),180                         set(["A", "BB", "AB", "BBC"]))181    def test_list_registered_model_paginated_last_page(self):182        rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]183        # test flow with fixed max_results184        returned_rms = []185        result = self._list_registered_models(page_token=None, max_results=25)186        returned_rms.extend(result)187        while result.token:188            result = self._list_registered_models(page_token=result.token, max_results=25)189            self.assertEqual(len(result), 25)190            returned_rms.extend(result)191        self.assertEqual(result.token, None)192        self.assertEqual(set(rms), set(returned_rms))193    def test_list_registered_model_paginated_returns_in_correct_order(self):194        rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]195        # test that pagination will return all valid results in sorted order196        # by name ascending197        result = self._list_registered_models(max_results=5)198        self.assertNotEqual(result.token, None)199        self.assertEqual(result, rms[0:5])200        result = self._list_registered_models(page_token=result.token, max_results=10)201        self.assertNotEqual(result.token, None)202        self.assertEqual(result, rms[5:15])203        result = self._list_registered_models(page_token=result.token, max_results=20)204        self.assertNotEqual(result.token, None)205        self.assertEqual(result, rms[15:35])206        result = self._list_registered_models(page_token=result.token, max_results=100)207        # assert that page token is None208        self.assertEqual(result.token, None)209        self.assertEqual(result, rms[35:])210    def test_list_registered_model_paginated_errors(self):211        rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]212        # test that providing a completely invalid page token throws213        with self.assertRaises(MlflowException) as exception_context:214            self._list_registered_models(page_token="evilhax", max_results=20)215        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)216        # test that providing too large of a max_results throws217        with self.assertRaises(MlflowException) as exception_context:218            self._list_registered_models(page_token="evilhax", max_results=1e15)219        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)220        self.assertIn("Invalid value for request parameter max_results",221                      exception_context.exception.message)222        # list should not return deleted models223        self.store.delete_registered_model(name=f"RM{0:03}")224        self.assertEqual(set(self._list_registered_models(max_results=100)),225                         set(rms[1:]))226    def test_get_latest_versions(self):227        name = "test_for_latest_versions"228        self._rm_maker(name)229        rmd1 = self.store.get_registered_model(name=name)230        self.assertEqual(rmd1.latest_versions, [])231        mv1 = self._mv_maker(name)232        self.assertEqual(mv1.version, 1)233        rmd2 = self.store.get_registered_model(name=name)234        self.assertEqual(self._extract_latest_by_stage(rmd2.latest_versions), {"None": 1})235        # add a bunch more236        mv2 = self._mv_maker(name)237        self.assertEqual(mv2.version, 2)238        self.store.transition_model_version_stage(239            name=mv2.name, version=mv2.version, stage="Production",240            archive_existing_versions=False)241        mv3 = self._mv_maker(name)242        self.assertEqual(mv3.version, 3)243        self.store.transition_model_version_stage(name=mv3.name, version=mv3.version,244                                                  stage="Production",245                                                  archive_existing_versions=False)246        mv4 = self._mv_maker(name)247        self.assertEqual(mv4.version, 4)248        self.store.transition_model_version_stage(249            name=mv4.name, version=mv4.version, stage="Staging",250            archive_existing_versions=False)251        # test that correct latest versions are returned for each stage252        rmd4 = self.store.get_registered_model(name=name)253        self.assertEqual(self._extract_latest_by_stage(rmd4.latest_versions),254                         {"None": 1, "Production": 3, "Staging": 4})255        # delete latest Production, and should point to previous one256        self.store.delete_model_version(name=mv3.name, version=mv3.version)257        rmd5 = self.store.get_registered_model(name=name)258        self.assertEqual(self._extract_latest_by_stage(rmd5.latest_versions),259                         {"None": 1, "Production": 2, "Staging": 4})260    def test_set_registered_model_tag(self):261        name1 = "SetRegisteredModelTag_TestMod"262        name2 = "SetRegisteredModelTag_TestMod 2"263        initial_tags = [RegisteredModelTag("key", "value"),264                        RegisteredModelTag("anotherKey", "some other value")]265        self._rm_maker(name1, initial_tags)266        self._rm_maker(name2, initial_tags)267        new_tag = RegisteredModelTag("randomTag", "not a random value")268        self.store.set_registered_model_tag(name1, new_tag)269        rm1 = self.store.get_registered_model(name=name1)270        all_tags = initial_tags + [new_tag]271        self.assertEqual(rm1.tags, {tag.key: tag.value for tag in all_tags})272        # test overriding a tag with the same key273        overriding_tag = RegisteredModelTag("key", "overriding")274        self.store.set_registered_model_tag(name1, overriding_tag)275        all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]276        rm1 = self.store.get_registered_model(name=name1)277        self.assertEqual(rm1.tags, {tag.key: tag.value for tag in all_tags})278        # does not affect other models with the same key279        rm2 = self.store.get_registered_model(name=name2)280        self.assertEqual(rm2.tags, {tag.key: tag.value for tag in initial_tags})281        # can not set tag on deleted (non-existed) registered model282        self.store.delete_registered_model(name1)283        with self.assertRaises(MlflowException) as exception_context:284            self.store.set_registered_model_tag(name1, overriding_tag)285        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)286        # test cannot set tags that are too long287        long_tag = RegisteredModelTag("longTagKey", "a" * 5001)288        with self.assertRaises(MlflowException) as exception_context:289            self.store.set_registered_model_tag(name2, long_tag)290        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)291        # test can set tags that are somewhat long292        long_tag = RegisteredModelTag("longTagKey", "a" * 4999)293        self.store.set_registered_model_tag(name2, long_tag)294        # can not set invalid tag295        with self.assertRaises(MlflowException) as exception_context:296            self.store.set_registered_model_tag(name2, RegisteredModelTag(key=None, value=""))297        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)298        # can not use invalid model name299        with self.assertRaises(MlflowException) as exception_context:300            self.store.set_registered_model_tag(None, RegisteredModelTag(key="key", value="value"))301        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)302    def test_delete_registered_model_tag(self):303        name1 = "DeleteRegisteredModelTag_TestMod"304        name2 = "DeleteRegisteredModelTag_TestMod 2"305        initial_tags = [RegisteredModelTag("key", "value"),306                        RegisteredModelTag("anotherKey", "some other value")]307        self._rm_maker(name1, initial_tags)308        self._rm_maker(name2, initial_tags)309        new_tag = RegisteredModelTag("randomTag", "not a random value")310        self.store.set_registered_model_tag(name1, new_tag)311        self.store.delete_registered_model_tag(name1, "randomTag")312        rm1 = self.store.get_registered_model(name=name1)313        self.assertEqual(rm1.tags, {tag.key: tag.value for tag in initial_tags})314        # testing deleting a key does not affect other models with the same key315        self.store.delete_registered_model_tag(name1, "key")316        rm1 = self.store.get_registered_model(name=name1)317        rm2 = self.store.get_registered_model(name=name2)318        self.assertEqual(rm1.tags, {"anotherKey": "some other value"})319        self.assertEqual(rm2.tags, {tag.key: tag.value for tag in initial_tags})320        # delete tag that is already deleted does nothing321        self.store.delete_registered_model_tag(name1, "key")322        rm1 = self.store.get_registered_model(name=name1)323        self.assertEqual(rm1.tags, {"anotherKey": "some other value"})324        # can not delete tag on deleted (non-existed) registered model325        self.store.delete_registered_model(name1)326        with self.assertRaises(MlflowException) as exception_context:327            self.store.delete_registered_model_tag(name1, "anotherKey")328        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)329        # can not delete tag with invalid key330        with self.assertRaises(MlflowException) as exception_context:331            self.store.delete_registered_model_tag(name2, None)332        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)333        # can not use invalid model name334        with self.assertRaises(MlflowException) as exception_context:335            self.store.delete_registered_model_tag(None, "key")336        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)337    def test_create_model_version(self):338        name = "test_for_update_MV"339        self._rm_maker(name)340        run_id = uuid.uuid4().hex341        with mock.patch("time.time") as mock_time:342            mock_time.return_value = 456778343            mv1 = self._mv_maker(name, "a/b/CD", run_id)344            self.assertEqual(mv1.name, name)345            self.assertEqual(mv1.version, 1)346        mvd1 = self.store.get_model_version(mv1.name, mv1.version)347        self.assertEqual(mvd1.name, name)348        self.assertEqual(mvd1.version, 1)349        self.assertEqual(mvd1.current_stage, "None")350        self.assertEqual(mvd1.creation_timestamp, 456778000)351        self.assertEqual(mvd1.last_updated_timestamp, 456778000)352        self.assertEqual(mvd1.description, None)353        self.assertEqual(mvd1.source, "a/b/CD")354        self.assertEqual(mvd1.run_id, run_id)355        self.assertEqual(mvd1.status, "READY")356        self.assertEqual(mvd1.status_message, None)357        self.assertEqual(mvd1.tags, {})358        # new model versions for same name autoincrement versions359        mv2 = self._mv_maker(name)360        mvd2 = self.store.get_model_version(name=mv2.name, version=mv2.version)361        self.assertEqual(mv2.version, 2)362        self.assertEqual(mvd2.version, 2)363        # create model version with tags return model version entity with tags364        tags = [ModelVersionTag("key", "value"),365                ModelVersionTag("anotherKey", "some other value")]366        mv3 = self._mv_maker(name, tags=tags)367        mvd3 = self.store.get_model_version(name=mv3.name, version=mv3.version)368        self.assertEqual(mv3.version, 3)369        self.assertEqual(mv3.tags, {tag.key: tag.value for tag in tags})370        self.assertEqual(mvd3.version, 3)371        self.assertEqual(mvd3.tags, {tag.key: tag.value for tag in tags})372    def test_update_model_version(self):373        name = "test_for_update_MV"374        self._rm_maker(name)375        mv1 = self._mv_maker(name)376        mvd1 = self.store.get_model_version(name=mv1.name, version=mv1.version)377        self.assertEqual(mvd1.name, name)378        self.assertEqual(mvd1.version, 1)379        self.assertEqual(mvd1.current_stage, "None")380        # update stage381        self.store.transition_model_version_stage(name=mv1.name, version=mv1.version,382                                                  stage="Production",383                                                  archive_existing_versions=False)384        mvd2 = self.store.get_model_version(name=mv1.name, version=mv1.version)385        self.assertEqual(mvd2.name, name)386        self.assertEqual(mvd2.version, 1)387        self.assertEqual(mvd2.current_stage, "Production")388        self.assertEqual(mvd2.description, None)389        # update description390        self.store.update_model_version(name=mv1.name, version=mv1.version,391                                        description="test model version")392        mvd3 = self.store.get_model_version(name=mv1.name, version=mv1.version)393        self.assertEqual(mvd3.name, name)394        self.assertEqual(mvd3.version, 1)395        self.assertEqual(mvd3.current_stage, "Production")396        self.assertEqual(mvd3.description, "test model version")397        # only valid stages can be set398        with self.assertRaises(MlflowException) as exception_context:399            self.store.transition_model_version_stage(mv1.name, mv1.version,400                                                      stage="unknown",401                                                      archive_existing_versions=False)402        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)403        # stages are case-insensitive and auto-corrected to system stage names404        for stage_name in ["STAGING", "staging", "StAgInG"]:405            self.store.transition_model_version_stage(406                name=mv1.name, version=mv1.version,407                stage=stage_name, archive_existing_versions=False)408            mvd5 = self.store.get_model_version(name=mv1.name, version=mv1.version)409            self.assertEqual(mvd5.current_stage, "Staging")410    def test_transition_model_version_stage_when_archive_existing_versions_is_false(self):411        name = "model"412        self._rm_maker(name)413        mv1 = self._mv_maker(name)414        mv2 = self._mv_maker(name)415        mv3 = self._mv_maker(name)416        # test that when `archive_existing_versions` is False, transitioning a model version417        # to the inactive stages ("Archived" and "None") does not throw.418        for stage in ["Archived", "None"]:419            self.store.transition_model_version_stage(name, mv1.version, stage, False)420        self.store.transition_model_version_stage(name, mv1.version, "Staging", False)421        self.store.transition_model_version_stage(name, mv2.version, "Production", False)422        self.store.transition_model_version_stage(name, mv3.version, "Staging", False)423        mvd1 = self.store.get_model_version(name=name, version=mv1.version)424        mvd2 = self.store.get_model_version(name=name, version=mv2.version)425        mvd3 = self.store.get_model_version(name=name, version=mv3.version)426        self.assertEqual(mvd1.current_stage, "Staging")427        self.assertEqual(mvd2.current_stage, "Production")428        self.assertEqual(mvd3.current_stage, "Staging")429        self.store.transition_model_version_stage(name, mv3.version, "Production", False)430        mvd1 = self.store.get_model_version(name=name, version=mv1.version)431        mvd2 = self.store.get_model_version(name=name, version=mv2.version)432        mvd3 = self.store.get_model_version(name=name, version=mv3.version)433        self.assertEqual(mvd1.current_stage, "Staging")434        self.assertEqual(mvd2.current_stage, "Production")435        self.assertEqual(mvd3.current_stage, "Production")436    def test_transition_model_version_stage_when_archive_existing_versions_is_true(self):437        name = "model"438        self._rm_maker(name)439        mv1 = self._mv_maker(name)440        mv2 = self._mv_maker(name)441        mv3 = self._mv_maker(name)442        msg = (r"Model version transition cannot archive existing model versions "443               r"because .+ is not an Active stage. Valid stages are .+")444        # test that when `archive_existing_versions` is True, transitioning a model version445        # to the inactive stages ("Archived" and "None") throws.446        for stage in ["Archived", "None"]:447            with self.assertRaisesRegex(MlflowException, msg):448                self.store.transition_model_version_stage(name, mv1.version, stage, True)449        self.store.transition_model_version_stage(name, mv1.version, "Staging", False)450        self.store.transition_model_version_stage(name, mv2.version, "Production", False)451        self.store.transition_model_version_stage(name, mv3.version, "Staging", True)452        mvd1 = self.store.get_model_version(name=name, version=mv1.version)453        mvd2 = self.store.get_model_version(name=name, version=mv2.version)454        mvd3 = self.store.get_model_version(name=name, version=mv3.version)455        self.assertEqual(mvd1.current_stage, "Archived")456        self.assertEqual(mvd2.current_stage, "Production")457        self.assertEqual(mvd3.current_stage, "Staging")458        self.assertEqual(mvd1.last_updated_timestamp, mvd3.last_updated_timestamp)459        self.store.transition_model_version_stage(name, mv3.version, "Production", True)460        mvd1 = self.store.get_model_version(name=name, version=mv1.version)461        mvd2 = self.store.get_model_version(name=name, version=mv2.version)462        mvd3 = self.store.get_model_version(name=name, version=mv3.version)463        self.assertEqual(mvd1.current_stage, "Archived")464        self.assertEqual(mvd2.current_stage, "Archived")465        self.assertEqual(mvd3.current_stage, "Production")466        self.assertEqual(mvd2.last_updated_timestamp, mvd3.last_updated_timestamp)467    def test_delete_model_version(self):468        name = "test_for_update_MV"469        initial_tags = [ModelVersionTag("key", "value"),470                        ModelVersionTag("anotherKey", "some other value")]471        self._rm_maker(name)472        mv = self._mv_maker(name, tags=initial_tags)473        mvd = self.store.get_model_version(name=mv.name, version=mv.version)474        self.assertEqual(mvd.name, name)475        self.store.delete_model_version(name=mv.name, version=mv.version)476        # cannot get a deleted model version477        with self.assertRaises(MlflowException) as exception_context:478            self.store.get_model_version(name=mv.name, version=mv.version)479        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)480        # cannot update a delete481        with self.assertRaises(MlflowException) as exception_context:482            self.store.update_model_version(mv.name, mv.version, description="deleted!")483        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)484        # cannot delete it again485        with self.assertRaises(MlflowException) as exception_context:486            self.store.delete_model_version(name=mv.name, version=mv.version)487        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)488    def test_get_model_version_download_uri(self):489        name = "test_for_update_MV"490        self._rm_maker(name)491        source_path = "path/to/source"492        mv = self._mv_maker(name, source=source_path, run_id=uuid.uuid4().hex)493        mvd1 = self.store.get_model_version(name=mv.name, version=mv.version)494        self.assertEqual(mvd1.name, name)495        self.assertEqual(mvd1.source, source_path)496        # download location points to source497        self.assertEqual(self.store.get_model_version_download_uri(name=mv.name,498                                                                   version=mv.version), source_path)499        # download URI does not change even if model version is updated500        self.store.transition_model_version_stage(501            name=mv.name, version=mv.version,502            stage="Production",503            archive_existing_versions=False)504        self.store.update_model_version(name=mv.name, version=mv.version,505                                        description="Test for Path")506        mvd2 = self.store.get_model_version(name=mv.name, version=mv.version)507        self.assertEqual(mvd2.source, source_path)508        self.assertEqual(self.store.get_model_version_download_uri(509            name=mv.name, version=mv.version), source_path)510        # cannot retrieve download URI for deleted model versions511        self.store.delete_model_version(name=mv.name, version=mv.version)512        with self.assertRaises(MlflowException) as exception_context:513            self.store.get_model_version_download_uri(name=mv.name, version=mv.version)514        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)515    def test_search_model_versions(self):516        # create some model versions517        name = "test_for_search_MV"518        self._rm_maker(name)519        run_id_1 = uuid.uuid4().hex520        run_id_2 = uuid.uuid4().hex521        run_id_3 = uuid.uuid4().hex522        mv1 = self._mv_maker(name=name, source="A/B", run_id=run_id_1)523        self.assertEqual(mv1.version, 1)524        mv2 = self._mv_maker(name=name, source="A/C", run_id=run_id_2)525        self.assertEqual(mv2.version, 2)526        mv3 = self._mv_maker(name=name, source="A/D", run_id=run_id_2)527        self.assertEqual(mv3.version, 3)528        mv4 = self._mv_maker(name=name, source="A/D", run_id=run_id_3)529        self.assertEqual(mv4.version, 4)530        def search_versions(filter_string):531            return [mvd.version for mvd in self.store.search_model_versions(filter_string)]532        # search using name should return all 4 versions533        self.assertEqual(set(search_versions("name='%s'" % name)), set([1, 2, 3, 4]))534        # search using run_id_1 should return version 1535        self.assertEqual(set(search_versions("run_id='%s'" % run_id_1)), set([1]))536        # search using run_id_2 should return versions 2 and 3537        self.assertEqual(set(search_versions("run_id='%s'" % run_id_2)), set([2, 3]))538        # search using source_path "A/D" should return version 3 and 4539        self.assertEqual(set(search_versions("source_path = 'A/D'")), set([3, 4]))540        # search using source_path "A" should not return anything541        self.assertEqual(len(search_versions("source_path = 'A'")), 0)542        self.assertEqual(len(search_versions("source_path = 'A/'")), 0)543        self.assertEqual(len(search_versions("source_path = ''")), 0)544        # delete mv4. search should not return version 4545        self.store.delete_model_version(name=mv4.name, version=mv4.version)546        self.assertEqual(set(search_versions("")), set([1, 2, 3]))547        self.assertEqual(set(search_versions(None)), set([1, 2, 3]))548        self.assertEqual(set(search_versions("name='%s'" % name)), set([1, 2, 3]))549        self.assertEqual(set(search_versions("source_path = 'A/D'")), set([3]))550        self.store.transition_model_version_stage(551            name=mv1.name, version=mv1.version, stage="production",552            archive_existing_versions=False553        )554        self.store.update_model_version(555            name=mv1.name, version=mv1.version, description="Online prediction model!")556        mvds = self.store.search_model_versions("run_id = '%s'" % run_id_1)557        assert 1 == len(mvds)558        assert isinstance(mvds[0], ModelVersion)559        assert mvds[0].current_stage == "Production"560        assert mvds[0].run_id == run_id_1561        assert mvds[0].source == "A/B"562        assert mvds[0].description == "Online prediction model!"563    def _search_registered_models(self,564                                  filter_string,565                                  max_results=10,566                                  order_by=None,567                                  page_token=None):568        result = self.store.search_registered_models(filter_string=filter_string,569                                                     max_results=max_results,570                                                     order_by=order_by,571                                                     page_token=page_token)572        return [registered_model.name for registered_model in result], result.token573    def test_search_registered_models(self):574        # create some registered models575        prefix = "test_for_search_"576        names = [prefix + name for name in ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4a"]]577        [self._rm_maker(name) for name in names]578        # search with no filter should return all registered models579        rms, _ = self._search_registered_models(None)580        self.assertEqual(rms, names)581        # equality search using name should return exactly the 1 name582        rms, _ = self._search_registered_models(f"name='{names[0]}'")583        self.assertEqual(rms, [names[0]])584        # equality search using name that is not valid should return nothing585        rms, _ = self._search_registered_models(f"name='{names[0] + 'cats'}'")586        self.assertEqual(rms, [])587        # case-sensitive prefix search using LIKE should return all the RMs588        rms, _ = self._search_registered_models(f"name LIKE '{prefix}%'")589        self.assertEqual(rms, names)590        # case-sensitive prefix search using LIKE with surrounding % should return all the RMs591        rms, _ = self._search_registered_models(f"name LIKE '%RM%'")592        self.assertEqual(rms, names)593        # case-sensitive prefix search using LIKE with surrounding % should return all the RMs594        # _e% matches test_for_search_ , so all RMs should match595        rms, _ = self._search_registered_models(f"name LIKE '_e%'")596        self.assertEqual(rms, names)597        # case-sensitive prefix search using LIKE should return just rm4598        rms, _ = self._search_registered_models(f"name LIKE '{prefix + 'RM4A'}%'")599        self.assertEqual(rms, [names[4]])600        # case-sensitive prefix search using LIKE should return no models if no match601        rms, _ = self._search_registered_models(f"name LIKE '{prefix + 'cats'}%'")602        self.assertEqual(rms, [])603        # confirm that LIKE is not case-sensitive604        rms, _ = self._search_registered_models(f"name lIkE '%blah%'")605        self.assertEqual(rms, [])606        rms, _ = self._search_registered_models(f"name like '{prefix + 'RM4A'}%'")607        self.assertEqual(rms, [names[4]])608        # case-insensitive prefix search using ILIKE should return both rm5 and rm6609        rms, _ = self._search_registered_models(f"name ILIKE '{prefix + 'RM4A'}%'")610        self.assertEqual(rms, names[4:])611        # case-insensitive postfix search with ILIKE612        rms, _ = self._search_registered_models(f"name ILIKE '%RM4a'")613        self.assertEqual(rms, names[4:])614        # case-insensitive prefix search using ILIKE should return both rm5 and rm6615        rms, _ = self._search_registered_models(f"name ILIKE '{prefix + 'cats'}%'")616        self.assertEqual(rms, [])617        # confirm that ILIKE is not case-sensitive618        rms, _ = self._search_registered_models(f"name iLike '%blah%'")619        self.assertEqual(rms, [])620        # confirm that ILIKE works for empty query621        rms, _ = self._search_registered_models(f"name iLike '%%'")622        self.assertEqual(rms, names)623        rms, _ = self._search_registered_models(f"name ilike '%RM4a'")624        self.assertEqual(rms, names[4:])625        # cannot search by invalid comparator types626        with self.assertRaises(MlflowException) as exception_context:627            self._search_registered_models("name!=something")628        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)629        # cannot search by run_id630        with self.assertRaises(MlflowException) as exception_context:631            self._search_registered_models("run_id='%s'" % "somerunID")632        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)633        # cannot search by source_path634        with self.assertRaises(MlflowException) as exception_context:635            self._search_registered_models("source_path = 'A/D'")636        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)637        # cannot search by other params638        with self.assertRaises(MlflowException) as exception_context:639            self._search_registered_models("evilhax = true")640        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)641        # delete last registered model. search should not return the first 5642        self.store.delete_registered_model(name=names[-1])643        self.assertEqual(self._search_registered_models(None, max_results=1000), (names[:-1], None))644        # equality search using name should return no names645        self.assertEqual(self._search_registered_models(f"name='{names[-1]}'"), ([], None))646        # case-sensitive prefix search using LIKE should return all the RMs647        self.assertEqual(self._search_registered_models(f"name LIKE '{prefix}%'"),648                         (names[0:5], None))649        # case-insensitive prefix search using ILIKE should return both rm5 and rm6650        self.assertEqual(self._search_registered_models(f"name ILIKE '{prefix + 'RM4A'}%'"),651                         ([names[4]], None))652    def test_search_registered_model_pagination(self):653        rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]654        # test flow with fixed max_results655        returned_rms = []656        query = "name LIKE 'RM%'"657        result, token = self._search_registered_models(query, page_token=None, max_results=5)658        returned_rms.extend(result)659        while token:660            result, token = self._search_registered_models(query, page_token=token, max_results=5)661            returned_rms.extend(result)662        self.assertEqual(rms, returned_rms)663        # test that pagination will return all valid results in sorted order664        # by name ascending665        result, token1 = self._search_registered_models(query, max_results=5)666        self.assertNotEqual(token1, None)667        self.assertEqual(result, rms[0:5])668        result, token2 = self._search_registered_models(query, page_token=token1, max_results=10)669        self.assertNotEqual(token2, None)670        self.assertEqual(result, rms[5:15])671        result, token3 = self._search_registered_models(query, page_token=token2, max_results=20)672        self.assertNotEqual(token3, None)673        self.assertEqual(result, rms[15:35])674        result, token4 = self._search_registered_models(query, page_token=token3, max_results=100)675        # assert that page token is None676        self.assertEqual(token4, None)677        self.assertEqual(result, rms[35:])678        # test that providing a completely invalid page token throws679        with self.assertRaises(MlflowException) as exception_context:680            self._search_registered_models(query, page_token="evilhax", max_results=20)681        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)682        # test that providing too large of a max_results throws683        with self.assertRaises(MlflowException) as exception_context:684            self._search_registered_models(query, page_token="evilhax", max_results=1e15)685        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)686        self.assertIn("Invalid value for request parameter max_results",687                      exception_context.exception.message)688    def test_search_registered_model_order_by(self):689        rms = []690        # explicitly mock the creation_timestamps because timestamps seem to be unstable in Windows691        for i in range(50):692            with mock.patch("mlflow.store.model_registry.sqlalchemy_store.now", return_value=i):693                rms.append(self._rm_maker(f"RM{i:03}").name)694        # test flow with fixed max_results and order_by (test stable order across pages)695        returned_rms = []696        query = "name LIKE 'RM%'"697        result, token = self._search_registered_models(query,698                                                       page_token=None,699                                                       order_by=['name DESC'],700                                                       max_results=5)701        returned_rms.extend(result)702        while token:703            result, token = self._search_registered_models(query,704                                                           page_token=token,705                                                           order_by=['name DESC'],706                                                           max_results=5)707            returned_rms.extend(result)708        # name descending should be the opposite order of the current order709        self.assertEqual(rms[::-1], returned_rms)710        # last_updated_timestamp descending should have the newest RMs first711        result, _ = self._search_registered_models(query,712                                                   page_token=None,713                                                   order_by=['last_updated_timestamp DESC'],714                                                   max_results=100)715        self.assertEqual(rms[::-1], result)716        # timestamp returns same result as last_updated_timestamp717        result, _ = self._search_registered_models(query,718                                                   page_token=None,719                                                   order_by=['timestamp DESC'],720                                                   max_results=100)721        self.assertEqual(rms[::-1], result)722        # last_updated_timestamp ascending should have the oldest RMs first723        result, _ = self._search_registered_models(query,724                                                   page_token=None,725                                                   order_by=['last_updated_timestamp ASC'],726                                                   max_results=100)727        self.assertEqual(rms, result)728        # timestamp returns same result as last_updated_timestamp729        result, _ = self._search_registered_models(query,730                                                   page_token=None,731                                                   order_by=['timestamp ASC'],732                                                   max_results=100)733        self.assertEqual(rms, result)734        # timestamp returns same result as last_updated_timestamp735        result, _ = self._search_registered_models(query,736                                                   page_token=None,737                                                   order_by=['timestamp'],738                                                   max_results=100)739        self.assertEqual(rms, result)740        # name ascending should have the original order741        result, _ = self._search_registered_models(query,742                                                   page_token=None,743                                                   order_by=['name ASC'],744                                                   max_results=100)745        self.assertEqual(rms, result)746        # test that no ASC/DESC defaults to ASC747        result, _ = self._search_registered_models(query,748                                                   page_token=None,749                                                   order_by=['last_updated_timestamp'],750                                                   max_results=100)751        self.assertEqual(rms, result)752        with mock.patch("mlflow.store.model_registry.sqlalchemy_store.now", return_value=1):753            rm1 = self._rm_maker("MR1").name754            rm2 = self._rm_maker("MR2").name755        with mock.patch("mlflow.store.model_registry.sqlalchemy_store.now", return_value=2):756            rm3 = self._rm_maker("MR3").name757            rm4 = self._rm_maker("MR4").name758        query = "name LIKE 'MR%'"759        # test with multiple clauses760        result, _ = self._search_registered_models(query,761                                                   page_token=None,762                                                   order_by=['last_updated_timestamp ASC',763                                                             'name DESC'],764                                                   max_results=100)765        self.assertEqual([rm2, rm1, rm4, rm3], result)766        result, _ = self._search_registered_models(query,767                                                   page_token=None,768                                                   order_by=['timestamp ASC',769                                                             'name   DESC'],770                                                   max_results=100)771        self.assertEqual([rm2, rm1, rm4, rm3], result)772        # confirm that name ascending is the default, even if ties exist on other fields773        result, _ = self._search_registered_models(query,774                                                   page_token=None,775                                                   order_by=[],776                                                   max_results=100)777        self.assertEqual([rm1, rm2, rm3, rm4], result)778        # test default tiebreak with descending timestamps779        result, _ = self._search_registered_models(query,780                                                   page_token=None,781                                                   order_by=['last_updated_timestamp DESC'],782                                                   max_results=100)783        self.assertEqual([rm3, rm4, rm1, rm2], result)784        # test timestamp parsing785        result, _ = self._search_registered_models(query,786                                                   page_token=None,787                                                   order_by=['timestamp\tASC'],788                                                   max_results=100)789        self.assertEqual([rm1, rm2, rm3, rm4], result)790        result, _ = self._search_registered_models(query,791                                                   page_token=None,792                                                   order_by=['timestamp\r\rASC'],793                                                   max_results=100)794        self.assertEqual([rm1, rm2, rm3, rm4], result)795        result, _ = self._search_registered_models(query,796                                                   page_token=None,797                                                   order_by=['timestamp\nASC'],798                                                   max_results=100)799        self.assertEqual([rm1, rm2, rm3, rm4], result)800        result, _ = self._search_registered_models(query,801                                                   page_token=None,802                                                   order_by=['timestamp  ASC'],803                                                   max_results=100)804        self.assertEqual([rm1, rm2, rm3, rm4], result)805        # validate order by key is case-insensitive806        result, _ = self._search_registered_models(query,807                                                   page_token=None,808                                                   order_by=['timestamp  asc'],809                                                   max_results=100)810        self.assertEqual([rm1, rm2, rm3, rm4], result)811        result, _ = self._search_registered_models(query,812                                                   page_token=None,813                                                   order_by=['timestamp  aSC'],814                                                   max_results=100)815        self.assertEqual([rm1, rm2, rm3, rm4], result)816        result, _ = self._search_registered_models(query,817                                                   page_token=None,818                                                   order_by=['timestamp  desc',819                                                             'name desc'],820                                                   max_results=100)821        self.assertEqual([rm4, rm3, rm2, rm1], result)822        result, _ = self._search_registered_models(query,823                                                   page_token=None,824                                                   order_by=['timestamp  deSc',825                                                             'name deSc'],826                                                   max_results=100)827        self.assertEqual([rm4, rm3, rm2, rm1], result)828    def test_search_registered_model_order_by_errors(self):829        query = "name LIKE 'RM%'"830        # test that invalid columns throw even if they come after valid columns831        with self.assertRaises(MlflowException) as exception_context:832            self._search_registered_models(query,833                                           page_token=None,834                                           order_by=['name ASC', 'creation_timestamp DESC'],835                                           max_results=5)836        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)837        # test that invalid columns with random text throw even if they come after valid columns838        with self.assertRaises(MlflowException) as exception_context:839            self._search_registered_models(query,840                                           page_token=None,841                                           order_by=['name ASC',842                                                     'last_updated_timestamp DESC blah'],843                                           max_results=5)844        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)845    def test_set_model_version_tag(self):846        name1 = "SetModelVersionTag_TestMod"847        name2 = "SetModelVersionTag_TestMod 2"848        initial_tags = [ModelVersionTag("key", "value"),849                        ModelVersionTag("anotherKey", "some other value")]850        self._rm_maker(name1)851        self._rm_maker(name2)852        run_id_1 = uuid.uuid4().hex853        run_id_2 = uuid.uuid4().hex854        run_id_3 = uuid.uuid4().hex855        self._mv_maker(name1, "A/B", run_id_1, initial_tags)856        self._mv_maker(name1, "A/C", run_id_2, initial_tags)857        self._mv_maker(name2, "A/D", run_id_3, initial_tags)858        new_tag = ModelVersionTag("randomTag", "not a random value")859        self.store.set_model_version_tag(name1, 1, new_tag)860        all_tags = initial_tags + [new_tag]861        rm1mv1 = self.store.get_model_version(name1, 1)862        self.assertEqual(rm1mv1.tags, {tag.key: tag.value for tag in all_tags})863        # test overriding a tag with the same key864        overriding_tag = ModelVersionTag("key", "overriding")865        self.store.set_model_version_tag(name1, 1, overriding_tag)866        all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]867        rm1mv1 = self.store.get_model_version(name1, 1)868        self.assertEqual(rm1mv1.tags, {tag.key: tag.value for tag in all_tags})869        # does not affect other model versions with the same key870        rm1mv2 = self.store.get_model_version(name1, 2)871        rm2mv1 = self.store.get_model_version(name2, 1)872        self.assertEqual(rm1mv2.tags, {tag.key: tag.value for tag in initial_tags})873        self.assertEqual(rm2mv1.tags, {tag.key: tag.value for tag in initial_tags})874        # can not set tag on deleted (non-existed) model version875        self.store.delete_model_version(name1, 2)876        with self.assertRaises(MlflowException) as exception_context:877            self.store.set_model_version_tag(name1, 2, overriding_tag)878        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)879        # test cannot set tags that are too long880        long_tag = ModelVersionTag("longTagKey", "a" * 5001)881        with self.assertRaises(MlflowException) as exception_context:882            self.store.set_model_version_tag(name1, 1, long_tag)883        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)884        # test can set tags that are somewhat long885        long_tag = ModelVersionTag("longTagKey", "a" * 4999)886        self.store.set_model_version_tag(name1, 1, long_tag)887        # can not set invalid tag888        with self.assertRaises(MlflowException) as exception_context:889            self.store.set_model_version_tag(name2, 1, ModelVersionTag(key=None, value=""))890        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)891        # can not use invalid model name or version892        with self.assertRaises(MlflowException) as exception_context:893            self.store.set_model_version_tag(None, 1, ModelVersionTag(key="key", value="value"))894        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)895        with self.assertRaises(MlflowException) as exception_context:896            self.store.set_model_version_tag(name2, "I am not a version",897                                             ModelVersionTag(key="key", value="value"))898        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)899    def test_delete_model_version_tag(self):900        name1 = "DeleteModelVersionTag_TestMod"901        name2 = "DeleteModelVersionTag_TestMod 2"902        initial_tags = [ModelVersionTag("key", "value"),903                        ModelVersionTag("anotherKey", "some other value")]904        self._rm_maker(name1)905        self._rm_maker(name2)906        run_id_1 = uuid.uuid4().hex907        run_id_2 = uuid.uuid4().hex908        run_id_3 = uuid.uuid4().hex909        self._mv_maker(name1, "A/B", run_id_1, initial_tags)910        self._mv_maker(name1, "A/C", run_id_2, initial_tags)911        self._mv_maker(name2, "A/D", run_id_3, initial_tags)912        new_tag = ModelVersionTag("randomTag", "not a random value")913        self.store.set_model_version_tag(name1, 1, new_tag)914        self.store.delete_model_version_tag(name1, 1, "randomTag")915        rm1mv1 = self.store.get_model_version(name1, 1)916        self.assertEqual(rm1mv1.tags, {tag.key: tag.value for tag in initial_tags})917        # testing deleting a key does not affect other model versions with the same key918        self.store.delete_model_version_tag(name1, 1, "key")919        rm1mv1 = self.store.get_model_version(name1, 1)920        rm1mv2 = self.store.get_model_version(name1, 2)921        rm2mv1 = self.store.get_model_version(name2, 1)922        self.assertEqual(rm1mv1.tags, {"anotherKey": "some other value"})923        self.assertEqual(rm1mv2.tags, {tag.key: tag.value for tag in initial_tags})924        self.assertEqual(rm2mv1.tags, {tag.key: tag.value for tag in initial_tags})925        # delete tag that is already deleted does nothing926        self.store.delete_model_version_tag(name1, 1, "key")927        rm1mv1 = self.store.get_model_version(name1, 1)928        self.assertEqual(rm1mv1.tags, {"anotherKey": "some other value"})929        # can not delete tag on deleted (non-existed) model version930        self.store.delete_model_version(name2, 1)931        with self.assertRaises(MlflowException) as exception_context:932            self.store.delete_model_version_tag(name2, 1, "key")933        assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)934        # can not delete tag with invalid key935        with self.assertRaises(MlflowException) as exception_context:936            self.store.delete_model_version_tag(name1, 2, None)937        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)938        # can not use invalid model name or version939        with self.assertRaises(MlflowException) as exception_context:940            self.store.delete_model_version_tag(None, 2, "key")941        assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)942        with self.assertRaises(MlflowException) as exception_context:943            self.store.delete_model_version_tag(name1, "I am not a version", "key")..._metadata_code_details_test.py
Source:_metadata_code_details_test.py  
1# Copyright 2016 gRPC authors.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7#     http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tests application-provided metadata, status code, and details."""15import threading16import unittest17import grpc18from tests.unit import test_common19from tests.unit.framework.common import test_constants20from tests.unit.framework.common import test_control21_SERIALIZED_REQUEST = b'\x46\x47\x48'22_SERIALIZED_RESPONSE = b'\x49\x50\x51'23_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST24_REQUEST_DESERIALIZER = lambda unused_serialized_request: object()25_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE26_RESPONSE_DESERIALIZER = lambda unused_serialized_response: object()27_SERVICE = 'test.TestService'28_UNARY_UNARY = 'UnaryUnary'29_UNARY_STREAM = 'UnaryStream'30_STREAM_UNARY = 'StreamUnary'31_STREAM_STREAM = 'StreamStream'32_CLIENT_METADATA = (('client-md-key', 'client-md-key'), ('client-md-key-bin',33                                                         b'\x00\x01'))34_SERVER_INITIAL_METADATA = (('server-initial-md-key',35                             'server-initial-md-value'),36                            ('server-initial-md-key-bin', b'\x00\x02'))37_SERVER_TRAILING_METADATA = (('server-trailing-md-key',38                              'server-trailing-md-value'),39                             ('server-trailing-md-key-bin', b'\x00\x03'))40_NON_OK_CODE = grpc.StatusCode.NOT_FOUND41_DETAILS = 'Test details!'42# calling abort should always fail an RPC, even for "invalid" codes43_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK)44_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN,45                          grpc.StatusCode.UNKNOWN)46_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '')47class _Servicer(object):48    def __init__(self):49        self._lock = threading.Lock()50        self._abort_call = False51        self._code = None52        self._details = None53        self._exception = False54        self._return_none = False55        self._received_client_metadata = None56    def unary_unary(self, request, context):57        with self._lock:58            self._received_client_metadata = context.invocation_metadata()59            context.send_initial_metadata(_SERVER_INITIAL_METADATA)60            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)61            if self._abort_call:62                context.abort(self._code, self._details)63            else:64                if self._code is not None:65                    context.set_code(self._code)66                if self._details is not None:67                    context.set_details(self._details)68            if self._exception:69                raise test_control.Defect()70            else:71                return None if self._return_none else object()72    def unary_stream(self, request, context):73        with self._lock:74            self._received_client_metadata = context.invocation_metadata()75            context.send_initial_metadata(_SERVER_INITIAL_METADATA)76            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)77            if self._abort_call:78                context.abort(self._code, self._details)79            else:80                if self._code is not None:81                    context.set_code(self._code)82                if self._details is not None:83                    context.set_details(self._details)84            for _ in range(test_constants.STREAM_LENGTH // 2):85                yield _SERIALIZED_RESPONSE86            if self._exception:87                raise test_control.Defect()88    def stream_unary(self, request_iterator, context):89        with self._lock:90            self._received_client_metadata = context.invocation_metadata()91            context.send_initial_metadata(_SERVER_INITIAL_METADATA)92            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)93            # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the94            # request iterator.95            list(request_iterator)96            if self._abort_call:97                context.abort(self._code, self._details)98            else:99                if self._code is not None:100                    context.set_code(self._code)101                if self._details is not None:102                    context.set_details(self._details)103            if self._exception:104                raise test_control.Defect()105            else:106                return None if self._return_none else _SERIALIZED_RESPONSE107    def stream_stream(self, request_iterator, context):108        with self._lock:109            self._received_client_metadata = context.invocation_metadata()110            context.send_initial_metadata(_SERVER_INITIAL_METADATA)111            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)112            # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the113            # request iterator.114            list(request_iterator)115            if self._abort_call:116                context.abort(self._code, self._details)117            else:118                if self._code is not None:119                    context.set_code(self._code)120                if self._details is not None:121                    context.set_details(self._details)122            for _ in range(test_constants.STREAM_LENGTH // 3):123                yield object()124            if self._exception:125                raise test_control.Defect()126    def set_abort_call(self):127        with self._lock:128            self._abort_call = True129    def set_code(self, code):130        with self._lock:131            self._code = code132    def set_details(self, details):133        with self._lock:134            self._details = details135    def set_exception(self):136        with self._lock:137            self._exception = True138    def set_return_none(self):139        with self._lock:140            self._return_none = True141    def received_client_metadata(self):142        with self._lock:143            return self._received_client_metadata144def _generic_handler(servicer):145    method_handlers = {146        _UNARY_UNARY:147        grpc.unary_unary_rpc_method_handler(148            servicer.unary_unary,149            request_deserializer=_REQUEST_DESERIALIZER,150            response_serializer=_RESPONSE_SERIALIZER),151        _UNARY_STREAM:152        grpc.unary_stream_rpc_method_handler(servicer.unary_stream),153        _STREAM_UNARY:154        grpc.stream_unary_rpc_method_handler(servicer.stream_unary),155        _STREAM_STREAM:156        grpc.stream_stream_rpc_method_handler(157            servicer.stream_stream,158            request_deserializer=_REQUEST_DESERIALIZER,159            response_serializer=_RESPONSE_SERIALIZER),160    }161    return grpc.method_handlers_generic_handler(_SERVICE, method_handlers)162class MetadataCodeDetailsTest(unittest.TestCase):163    def setUp(self):164        self._servicer = _Servicer()165        self._server = test_common.test_server()166        self._server.add_generic_rpc_handlers(167            (_generic_handler(self._servicer),))168        port = self._server.add_insecure_port('[::]:0')169        self._server.start()170        channel = grpc.insecure_channel('localhost:{}'.format(port))171        self._unary_unary = channel.unary_unary(172            '/'.join((173                '',174                _SERVICE,175                _UNARY_UNARY,176            )),177            request_serializer=_REQUEST_SERIALIZER,178            response_deserializer=_RESPONSE_DESERIALIZER,179        )180        self._unary_stream = channel.unary_stream('/'.join((181            '',182            _SERVICE,183            _UNARY_STREAM,184        )),)185        self._stream_unary = channel.stream_unary('/'.join((186            '',187            _SERVICE,188            _STREAM_UNARY,189        )),)190        self._stream_stream = channel.stream_stream(191            '/'.join((192                '',193                _SERVICE,194                _STREAM_STREAM,195            )),196            request_serializer=_REQUEST_SERIALIZER,197            response_deserializer=_RESPONSE_DESERIALIZER,198        )199    def testSuccessfulUnaryUnary(self):200        self._servicer.set_details(_DETAILS)201        unused_response, call = self._unary_unary.with_call(202            object(), metadata=_CLIENT_METADATA)203        self.assertTrue(204            test_common.metadata_transmitted(205                _CLIENT_METADATA, self._servicer.received_client_metadata()))206        self.assertTrue(207            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,208                                             call.initial_metadata()))209        self.assertTrue(210            test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,211                                             call.trailing_metadata()))212        self.assertIs(grpc.StatusCode.OK, call.code())213        self.assertEqual(_DETAILS, call.details())214    def testSuccessfulUnaryStream(self):215        self._servicer.set_details(_DETAILS)216        response_iterator_call = self._unary_stream(217            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)218        received_initial_metadata = response_iterator_call.initial_metadata()219        list(response_iterator_call)220        self.assertTrue(221            test_common.metadata_transmitted(222                _CLIENT_METADATA, self._servicer.received_client_metadata()))223        self.assertTrue(224            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,225                                             received_initial_metadata))226        self.assertTrue(227            test_common.metadata_transmitted(228                _SERVER_TRAILING_METADATA,229                response_iterator_call.trailing_metadata()))230        self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())231        self.assertEqual(_DETAILS, response_iterator_call.details())232    def testSuccessfulStreamUnary(self):233        self._servicer.set_details(_DETAILS)234        unused_response, call = self._stream_unary.with_call(235            iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),236            metadata=_CLIENT_METADATA)237        self.assertTrue(238            test_common.metadata_transmitted(239                _CLIENT_METADATA, self._servicer.received_client_metadata()))240        self.assertTrue(241            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,242                                             call.initial_metadata()))243        self.assertTrue(244            test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,245                                             call.trailing_metadata()))246        self.assertIs(grpc.StatusCode.OK, call.code())247        self.assertEqual(_DETAILS, call.details())248    def testSuccessfulStreamStream(self):249        self._servicer.set_details(_DETAILS)250        response_iterator_call = self._stream_stream(251            iter([object()] * test_constants.STREAM_LENGTH),252            metadata=_CLIENT_METADATA)253        received_initial_metadata = response_iterator_call.initial_metadata()254        list(response_iterator_call)255        self.assertTrue(256            test_common.metadata_transmitted(257                _CLIENT_METADATA, self._servicer.received_client_metadata()))258        self.assertTrue(259            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,260                                             received_initial_metadata))261        self.assertTrue(262            test_common.metadata_transmitted(263                _SERVER_TRAILING_METADATA,264                response_iterator_call.trailing_metadata()))265        self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())266        self.assertEqual(_DETAILS, response_iterator_call.details())267    def testAbortedUnaryUnary(self):268        test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,269                         _EXPECTED_DETAILS)270        for abort_code, expected_code, expected_details in test_cases:271            self._servicer.set_code(abort_code)272            self._servicer.set_details(_DETAILS)273            self._servicer.set_abort_call()274            with self.assertRaises(grpc.RpcError) as exception_context:275                self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)276            self.assertTrue(277                test_common.metadata_transmitted(278                    _CLIENT_METADATA,279                    self._servicer.received_client_metadata()))280            self.assertTrue(281                test_common.metadata_transmitted(282                    _SERVER_INITIAL_METADATA,283                    exception_context.exception.initial_metadata()))284            self.assertTrue(285                test_common.metadata_transmitted(286                    _SERVER_TRAILING_METADATA,287                    exception_context.exception.trailing_metadata()))288            self.assertIs(expected_code, exception_context.exception.code())289            self.assertEqual(expected_details,290                             exception_context.exception.details())291    def testAbortedUnaryStream(self):292        test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,293                         _EXPECTED_DETAILS)294        for abort_code, expected_code, expected_details in test_cases:295            self._servicer.set_code(abort_code)296            self._servicer.set_details(_DETAILS)297            self._servicer.set_abort_call()298            response_iterator_call = self._unary_stream(299                _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)300            received_initial_metadata = \301                response_iterator_call.initial_metadata()302            with self.assertRaises(grpc.RpcError):303                self.assertEqual(len(list(response_iterator_call)), 0)304            self.assertTrue(305                test_common.metadata_transmitted(306                    _CLIENT_METADATA,307                    self._servicer.received_client_metadata()))308            self.assertTrue(309                test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,310                                                 received_initial_metadata))311            self.assertTrue(312                test_common.metadata_transmitted(313                    _SERVER_TRAILING_METADATA,314                    response_iterator_call.trailing_metadata()))315            self.assertIs(expected_code, response_iterator_call.code())316            self.assertEqual(expected_details, response_iterator_call.details())317    def testAbortedStreamUnary(self):318        test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,319                         _EXPECTED_DETAILS)320        for abort_code, expected_code, expected_details in test_cases:321            self._servicer.set_code(abort_code)322            self._servicer.set_details(_DETAILS)323            self._servicer.set_abort_call()324            with self.assertRaises(grpc.RpcError) as exception_context:325                self._stream_unary.with_call(326                    iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),327                    metadata=_CLIENT_METADATA)328            self.assertTrue(329                test_common.metadata_transmitted(330                    _CLIENT_METADATA,331                    self._servicer.received_client_metadata()))332            self.assertTrue(333                test_common.metadata_transmitted(334                    _SERVER_INITIAL_METADATA,335                    exception_context.exception.initial_metadata()))336            self.assertTrue(337                test_common.metadata_transmitted(338                    _SERVER_TRAILING_METADATA,339                    exception_context.exception.trailing_metadata()))340            self.assertIs(expected_code, exception_context.exception.code())341            self.assertEqual(expected_details,342                             exception_context.exception.details())343    def testAbortedStreamStream(self):344        test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,345                         _EXPECTED_DETAILS)346        for abort_code, expected_code, expected_details in test_cases:347            self._servicer.set_code(abort_code)348            self._servicer.set_details(_DETAILS)349            self._servicer.set_abort_call()350            response_iterator_call = self._stream_stream(351                iter([object()] * test_constants.STREAM_LENGTH),352                metadata=_CLIENT_METADATA)353            received_initial_metadata = \354                response_iterator_call.initial_metadata()355            with self.assertRaises(grpc.RpcError):356                self.assertEqual(len(list(response_iterator_call)), 0)357            self.assertTrue(358                test_common.metadata_transmitted(359                    _CLIENT_METADATA,360                    self._servicer.received_client_metadata()))361            self.assertTrue(362                test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,363                                                 received_initial_metadata))364            self.assertTrue(365                test_common.metadata_transmitted(366                    _SERVER_TRAILING_METADATA,367                    response_iterator_call.trailing_metadata()))368            self.assertIs(expected_code, response_iterator_call.code())369            self.assertEqual(expected_details, response_iterator_call.details())370    def testCustomCodeUnaryUnary(self):371        self._servicer.set_code(_NON_OK_CODE)372        self._servicer.set_details(_DETAILS)373        with self.assertRaises(grpc.RpcError) as exception_context:374            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)375        self.assertTrue(376            test_common.metadata_transmitted(377                _CLIENT_METADATA, self._servicer.received_client_metadata()))378        self.assertTrue(379            test_common.metadata_transmitted(380                _SERVER_INITIAL_METADATA,381                exception_context.exception.initial_metadata()))382        self.assertTrue(383            test_common.metadata_transmitted(384                _SERVER_TRAILING_METADATA,385                exception_context.exception.trailing_metadata()))386        self.assertIs(_NON_OK_CODE, exception_context.exception.code())387        self.assertEqual(_DETAILS, exception_context.exception.details())388    def testCustomCodeUnaryStream(self):389        self._servicer.set_code(_NON_OK_CODE)390        self._servicer.set_details(_DETAILS)391        response_iterator_call = self._unary_stream(392            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)393        received_initial_metadata = response_iterator_call.initial_metadata()394        with self.assertRaises(grpc.RpcError):395            list(response_iterator_call)396        self.assertTrue(397            test_common.metadata_transmitted(398                _CLIENT_METADATA, self._servicer.received_client_metadata()))399        self.assertTrue(400            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,401                                             received_initial_metadata))402        self.assertTrue(403            test_common.metadata_transmitted(404                _SERVER_TRAILING_METADATA,405                response_iterator_call.trailing_metadata()))406        self.assertIs(_NON_OK_CODE, response_iterator_call.code())407        self.assertEqual(_DETAILS, response_iterator_call.details())408    def testCustomCodeStreamUnary(self):409        self._servicer.set_code(_NON_OK_CODE)410        self._servicer.set_details(_DETAILS)411        with self.assertRaises(grpc.RpcError) as exception_context:412            self._stream_unary.with_call(413                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),414                metadata=_CLIENT_METADATA)415        self.assertTrue(416            test_common.metadata_transmitted(417                _CLIENT_METADATA, self._servicer.received_client_metadata()))418        self.assertTrue(419            test_common.metadata_transmitted(420                _SERVER_INITIAL_METADATA,421                exception_context.exception.initial_metadata()))422        self.assertTrue(423            test_common.metadata_transmitted(424                _SERVER_TRAILING_METADATA,425                exception_context.exception.trailing_metadata()))426        self.assertIs(_NON_OK_CODE, exception_context.exception.code())427        self.assertEqual(_DETAILS, exception_context.exception.details())428    def testCustomCodeStreamStream(self):429        self._servicer.set_code(_NON_OK_CODE)430        self._servicer.set_details(_DETAILS)431        response_iterator_call = self._stream_stream(432            iter([object()] * test_constants.STREAM_LENGTH),433            metadata=_CLIENT_METADATA)434        received_initial_metadata = response_iterator_call.initial_metadata()435        with self.assertRaises(grpc.RpcError) as exception_context:436            list(response_iterator_call)437        self.assertTrue(438            test_common.metadata_transmitted(439                _CLIENT_METADATA, self._servicer.received_client_metadata()))440        self.assertTrue(441            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,442                                             received_initial_metadata))443        self.assertTrue(444            test_common.metadata_transmitted(445                _SERVER_TRAILING_METADATA,446                exception_context.exception.trailing_metadata()))447        self.assertIs(_NON_OK_CODE, exception_context.exception.code())448        self.assertEqual(_DETAILS, exception_context.exception.details())449    def testCustomCodeExceptionUnaryUnary(self):450        self._servicer.set_code(_NON_OK_CODE)451        self._servicer.set_details(_DETAILS)452        self._servicer.set_exception()453        with self.assertRaises(grpc.RpcError) as exception_context:454            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)455        self.assertTrue(456            test_common.metadata_transmitted(457                _CLIENT_METADATA, self._servicer.received_client_metadata()))458        self.assertTrue(459            test_common.metadata_transmitted(460                _SERVER_INITIAL_METADATA,461                exception_context.exception.initial_metadata()))462        self.assertTrue(463            test_common.metadata_transmitted(464                _SERVER_TRAILING_METADATA,465                exception_context.exception.trailing_metadata()))466        self.assertIs(_NON_OK_CODE, exception_context.exception.code())467        self.assertEqual(_DETAILS, exception_context.exception.details())468    def testCustomCodeExceptionUnaryStream(self):469        self._servicer.set_code(_NON_OK_CODE)470        self._servicer.set_details(_DETAILS)471        self._servicer.set_exception()472        response_iterator_call = self._unary_stream(473            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)474        received_initial_metadata = response_iterator_call.initial_metadata()475        with self.assertRaises(grpc.RpcError):476            list(response_iterator_call)477        self.assertTrue(478            test_common.metadata_transmitted(479                _CLIENT_METADATA, self._servicer.received_client_metadata()))480        self.assertTrue(481            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,482                                             received_initial_metadata))483        self.assertTrue(484            test_common.metadata_transmitted(485                _SERVER_TRAILING_METADATA,486                response_iterator_call.trailing_metadata()))487        self.assertIs(_NON_OK_CODE, response_iterator_call.code())488        self.assertEqual(_DETAILS, response_iterator_call.details())489    def testCustomCodeExceptionStreamUnary(self):490        self._servicer.set_code(_NON_OK_CODE)491        self._servicer.set_details(_DETAILS)492        self._servicer.set_exception()493        with self.assertRaises(grpc.RpcError) as exception_context:494            self._stream_unary.with_call(495                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),496                metadata=_CLIENT_METADATA)497        self.assertTrue(498            test_common.metadata_transmitted(499                _CLIENT_METADATA, self._servicer.received_client_metadata()))500        self.assertTrue(501            test_common.metadata_transmitted(502                _SERVER_INITIAL_METADATA,503                exception_context.exception.initial_metadata()))504        self.assertTrue(505            test_common.metadata_transmitted(506                _SERVER_TRAILING_METADATA,507                exception_context.exception.trailing_metadata()))508        self.assertIs(_NON_OK_CODE, exception_context.exception.code())509        self.assertEqual(_DETAILS, exception_context.exception.details())510    def testCustomCodeExceptionStreamStream(self):511        self._servicer.set_code(_NON_OK_CODE)512        self._servicer.set_details(_DETAILS)513        self._servicer.set_exception()514        response_iterator_call = self._stream_stream(515            iter([object()] * test_constants.STREAM_LENGTH),516            metadata=_CLIENT_METADATA)517        received_initial_metadata = response_iterator_call.initial_metadata()518        with self.assertRaises(grpc.RpcError):519            list(response_iterator_call)520        self.assertTrue(521            test_common.metadata_transmitted(522                _CLIENT_METADATA, self._servicer.received_client_metadata()))523        self.assertTrue(524            test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,525                                             received_initial_metadata))526        self.assertTrue(527            test_common.metadata_transmitted(528                _SERVER_TRAILING_METADATA,529                response_iterator_call.trailing_metadata()))530        self.assertIs(_NON_OK_CODE, response_iterator_call.code())531        self.assertEqual(_DETAILS, response_iterator_call.details())532    def testCustomCodeReturnNoneUnaryUnary(self):533        self._servicer.set_code(_NON_OK_CODE)534        self._servicer.set_details(_DETAILS)535        self._servicer.set_return_none()536        with self.assertRaises(grpc.RpcError) as exception_context:537            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)538        self.assertTrue(539            test_common.metadata_transmitted(540                _CLIENT_METADATA, self._servicer.received_client_metadata()))541        self.assertTrue(542            test_common.metadata_transmitted(543                _SERVER_INITIAL_METADATA,544                exception_context.exception.initial_metadata()))545        self.assertTrue(546            test_common.metadata_transmitted(547                _SERVER_TRAILING_METADATA,548                exception_context.exception.trailing_metadata()))549        self.assertIs(_NON_OK_CODE, exception_context.exception.code())550        self.assertEqual(_DETAILS, exception_context.exception.details())551    def testCustomCodeReturnNoneStreamUnary(self):552        self._servicer.set_code(_NON_OK_CODE)553        self._servicer.set_details(_DETAILS)554        self._servicer.set_return_none()555        with self.assertRaises(grpc.RpcError) as exception_context:556            self._stream_unary.with_call(557                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),558                metadata=_CLIENT_METADATA)559        self.assertTrue(560            test_common.metadata_transmitted(561                _CLIENT_METADATA, self._servicer.received_client_metadata()))562        self.assertTrue(563            test_common.metadata_transmitted(564                _SERVER_INITIAL_METADATA,565                exception_context.exception.initial_metadata()))566        self.assertTrue(567            test_common.metadata_transmitted(568                _SERVER_TRAILING_METADATA,569                exception_context.exception.trailing_metadata()))570        self.assertIs(_NON_OK_CODE, exception_context.exception.code())571        self.assertEqual(_DETAILS, exception_context.exception.details())572if __name__ == '__main__':..._invalid_metadata_test.py
Source:_invalid_metadata_test.py  
1# Copyright 2016 gRPC authors.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7#     http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Test of RPCs made against gRPC Python's application-layer API."""15import unittest16import grpc17from tests.unit.framework.common import test_constants18_SERIALIZE_REQUEST = lambda bytestring: bytestring * 219_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]20_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 321_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]22_UNARY_UNARY = '/test/UnaryUnary'23_UNARY_STREAM = '/test/UnaryStream'24_STREAM_UNARY = '/test/StreamUnary'25_STREAM_STREAM = '/test/StreamStream'26def _unary_unary_multi_callable(channel):27    return channel.unary_unary(_UNARY_UNARY)28def _unary_stream_multi_callable(channel):29    return channel.unary_stream(30        _UNARY_STREAM,31        request_serializer=_SERIALIZE_REQUEST,32        response_deserializer=_DESERIALIZE_RESPONSE)33def _stream_unary_multi_callable(channel):34    return channel.stream_unary(35        _STREAM_UNARY,36        request_serializer=_SERIALIZE_REQUEST,37        response_deserializer=_DESERIALIZE_RESPONSE)38def _stream_stream_multi_callable(channel):39    return channel.stream_stream(_STREAM_STREAM)40class InvalidMetadataTest(unittest.TestCase):41    def setUp(self):42        self._channel = grpc.insecure_channel('localhost:8080')43        self._unary_unary = _unary_unary_multi_callable(self._channel)44        self._unary_stream = _unary_stream_multi_callable(self._channel)45        self._stream_unary = _stream_unary_multi_callable(self._channel)46        self._stream_stream = _stream_stream_multi_callable(self._channel)47    def testUnaryRequestBlockingUnaryResponse(self):48        request = b'\x07\x08'49        metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),)50        expected_error_details = "metadata was invalid: %s" % metadata51        with self.assertRaises(ValueError) as exception_context:52            self._unary_unary(request, metadata=metadata)53        self.assertIn(expected_error_details, str(exception_context.exception))54    def testUnaryRequestBlockingUnaryResponseWithCall(self):55        request = b'\x07\x08'56        metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponseWithCall'),)57        expected_error_details = "metadata was invalid: %s" % metadata58        with self.assertRaises(ValueError) as exception_context:59            self._unary_unary.with_call(request, metadata=metadata)60        self.assertIn(expected_error_details, str(exception_context.exception))61    def testUnaryRequestFutureUnaryResponse(self):62        request = b'\x07\x08'63        metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),)64        expected_error_details = "metadata was invalid: %s" % metadata65        response_future = self._unary_unary.future(request, metadata=metadata)66        with self.assertRaises(grpc.RpcError) as exception_context:67            response_future.result()68        self.assertEqual(exception_context.exception.details(),69                         expected_error_details)70        self.assertEqual(exception_context.exception.code(),71                         grpc.StatusCode.INTERNAL)72        self.assertEqual(response_future.details(), expected_error_details)73        self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)74    def testUnaryRequestStreamResponse(self):75        request = b'\x37\x58'76        metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),)77        expected_error_details = "metadata was invalid: %s" % metadata78        response_iterator = self._unary_stream(request, metadata=metadata)79        with self.assertRaises(grpc.RpcError) as exception_context:80            next(response_iterator)81        self.assertEqual(exception_context.exception.details(),82                         expected_error_details)83        self.assertEqual(exception_context.exception.code(),84                         grpc.StatusCode.INTERNAL)85        self.assertEqual(response_iterator.details(), expected_error_details)86        self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)87    def testStreamRequestBlockingUnaryResponse(self):88        request_iterator = (89            b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))90        metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),)91        expected_error_details = "metadata was invalid: %s" % metadata92        with self.assertRaises(ValueError) as exception_context:93            self._stream_unary(request_iterator, metadata=metadata)94        self.assertIn(expected_error_details, str(exception_context.exception))95    def testStreamRequestBlockingUnaryResponseWithCall(self):96        request_iterator = (97            b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))98        metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),)99        expected_error_details = "metadata was invalid: %s" % metadata100        multi_callable = _stream_unary_multi_callable(self._channel)101        with self.assertRaises(ValueError) as exception_context:102            multi_callable.with_call(request_iterator, metadata=metadata)103        self.assertIn(expected_error_details, str(exception_context.exception))104    def testStreamRequestFutureUnaryResponse(self):105        request_iterator = (106            b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))107        metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),)108        expected_error_details = "metadata was invalid: %s" % metadata109        response_future = self._stream_unary.future(110            request_iterator, metadata=metadata)111        with self.assertRaises(grpc.RpcError) as exception_context:112            response_future.result()113        self.assertEqual(exception_context.exception.details(),114                         expected_error_details)115        self.assertEqual(exception_context.exception.code(),116                         grpc.StatusCode.INTERNAL)117        self.assertEqual(response_future.details(), expected_error_details)118        self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)119    def testStreamRequestStreamResponse(self):120        request_iterator = (121            b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))122        metadata = (('InVaLiD', 'StreamRequestStreamResponse'),)123        expected_error_details = "metadata was invalid: %s" % metadata124        response_iterator = self._stream_stream(125            request_iterator, metadata=metadata)126        with self.assertRaises(grpc.RpcError) as exception_context:127            next(response_iterator)128        self.assertEqual(exception_context.exception.details(),129                         expected_error_details)130        self.assertEqual(exception_context.exception.code(),131                         grpc.StatusCode.INTERNAL)132        self.assertEqual(response_iterator.details(), expected_error_details)133        self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)134if __name__ == '__main__':...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
