Best Python code snippet using lemoncheesecake
test_tag_matcher.py
Source:test_tag_matcher.py  
...62#          ["foo", traits.category1_not_enabled_tag]),63#         ("case: Two normal tags", 0,64#          ["foo", "bar"]),65#     ])66#     def test_select_active_tags__with_two_tags(self, case, expected_len, tags):67#         tag_matcher = self.make_tag_matcher()68#         selected = tag_matcher.select_active_tags(tags)69#         selected = list(selected)70#         assert len(selected) == expected_len, case71#72#     @pytest.mark.parametrize("case, expected, tags", [73#         # -- GROUP: With positive logic (non-negated tags)74#         ("case P00: 2 disabled tags", True,75#          [ traits.category1_disabled_tag, traits.category2_disabled_tag]),76#         ("case P01: disabled and enabled tag", True,77#          [ traits.category1_disabled_tag, traits.category2_enabled_tag]),78#         ("case P10: enabled and disabled tag", True,79#          [ traits.category1_enabled_tag, traits.category2_disabled_tag]),80#         ("case P11: 2 enabled tags", False,  # -- SHOULD-RUN81#          [ traits.category1_enabled_tag, traits.category2_enabled_tag]),82#         # -- GROUP: With negated tag83#         ("case N00: not-enabled and disabled tag", True,84#          [ traits.category1_not_enabled_tag, traits.category2_disabled_tag]),85#         ("case N01: not-enabled and enabled tag", True,86#          [ traits.category1_not_enabled_tag, traits.category2_enabled_tag]),87#         ("case N10: not-disabled and disabled tag", True,88#          [ traits.category1_not_disabled_tag, traits.category2_disabled_tag]),89#         ("case N11: not-disabled and enabled tag", False, # -- SHOULD-RUN90#          [ traits.category1_not_disabled_tag, traits.category2_enabled_tag]),91#         # -- GROUP: With unknown category92#         ("case U0x: disabled and unknown tag", True,93#          [ traits.category1_disabled_tag, traits.unknown_category_tag]),94#         ("case U1x: enabled and unknown tag", False,  # -- SHOULD-RUN95#          [ traits.category1_enabled_tag, traits.unknown_category_tag]),96#     ])97#     def test_should_exclude_with__combinations_of_2_categories(self, case, expected, tags):98#         tag_matcher = self.make_tag_matcher()99#         actual_result = tag_matcher.should_exclude_with(tags)100#         assert expected == actual_result, case101#102#     @pytest.mark.parametrize("case, expected, tags", [103#         # -- GROUP: With positive logic (non-negated tags)104#         ("case P00: 2 disabled tags", True,105#          [ traits.category1_disabled_tag, traits.category1_disabled_tag2]),106#         ("case P01: disabled and enabled tag", True,107#          [ traits.category1_disabled_tag, traits.category1_enabled_tag]),108#         ("case P10: enabled and disabled tag", True,109#          [ traits.category1_enabled_tag, traits.category1_disabled_tag]),110#         ("case P11: 2 enabled tags (same)", False,  # -- SHOULD-RUN111#          [ traits.category1_enabled_tag, traits.category1_enabled_tag]),112#         # -- GROUP: With negated tag113#         ("case N00: not-enabled and disabled tag", True,114#          [ traits.category1_not_enabled_tag, traits.category1_disabled_tag]),115#         ("case N01: not-enabled and enabled tag", True,116#          [ traits.category1_not_enabled_tag, traits.category1_enabled_tag]),117#         ("case N10: not-disabled and disabled tag", True,118#          [ traits.category1_not_disabled_tag, traits.category1_disabled_tag]),119#         ("case N11: not-disabled and enabled tag", False, # -- SHOULD-RUN120#          [ traits.category1_not_disabled_tag, traits.category1_enabled_tag]),121#     ])122#     def test_should_exclude_with__combinations_with_same_category(self,123#                                                         case, expected, tags):124#         tag_matcher = self.make_tag_matcher()125#         actual_result = tag_matcher.should_exclude_with(tags)126#         assert expected == actual_result, case127class TestActiveTagMatcher1(TestCase):128    TagMatcher = ActiveTagMatcher129    traits = Traits4ActiveTagMatcher130    @classmethod131    def make_tag_matcher(cls):132        tag_matcher = cls.TagMatcher(cls.traits.value_provider)133        return tag_matcher134    def setUp(self):135        self.tag_matcher = self.make_tag_matcher()136    def test_select_active_tags__basics(self):137        active_tag = "active.with_CATEGORY=VALUE"138        tags = ["foo", active_tag, "bar"]139        selected = list(self.tag_matcher.select_active_tags(tags))140        self.assertEqual(len(selected), 1)141        selected_tag, selected_match = selected[0]142        self.assertEqual(selected_tag, active_tag)143    def test_select_active_tags__matches_tag_parts(self):144        tags = ["active.with_CATEGORY=VALUE"]145        selected = list(self.tag_matcher.select_active_tags(tags))146        self.assertEqual(len(selected), 1)147        selected_tag, selected_match = selected[0]148        self.assertEqual(selected_match.group("prefix"), "active")149        self.assertEqual(selected_match.group("category"), "CATEGORY")150        self.assertEqual(selected_match.group("value"), "VALUE")151    def test_select_active_tags__finds_tag_with_any_valid_tag_prefix(self):152        TagMatcher = self.TagMatcher153        for tag_prefix in TagMatcher.tag_prefixes:154            tag = TagMatcher.make_category_tag("foo", "alice", tag_prefix)155            tags = [ tag ]156            selected = self.tag_matcher.select_active_tags(tags)157            selected = list(selected)158            self.assertEqual(len(selected), 1)159            selected_tag0 = selected[0][0]160            self.assertEqual(selected_tag0, tag)161            self.assertTrue(selected_tag0.startswith(tag_prefix))162    def test_select_active_tags__ignores_invalid_active_tags(self):163        invalid_active_tags = [164            ("foo.alice",               "case: Normal tag"),165            ("with_foo=alice",          "case: Subset of an active tag"),166            ("ACTIVE.with_foo.alice",   "case: Wrong tag_prefix (uppercase)"),167            ("only.with_foo.alice",     "case: Wrong value_separator"),168        ]169        for invalid_tag, case in invalid_active_tags:170            tags = [ invalid_tag ]171            selected = self.tag_matcher.select_active_tags(tags)172            selected = list(selected)173            self.assertEqual(len(selected), 0, case)174    def test_select_active_tags__with_two_tags(self):175        # XXX-JE-DUPLICATED:176        traits = self.traits177        test_patterns = [178            ("case: Two enabled tags",179             [traits.category1_enabled_tag, traits.category2_enabled_tag]),180            ("case: Active enabled and normal tag",181             [traits.category1_enabled_tag,  "foo"]),182            ("case: Active disabled and normal tag",183             [traits.category1_disabled_tag, "foo"]),184            ("case: Active negated and normal tag",185             [traits.category1_not_enabled_tag, "foo"]),186        ]187        for case, tags in test_patterns:188            selected = self.tag_matcher.select_active_tags(tags)189            selected = list(selected)190            self.assertTrue(len(selected) >= 1, case)191    def test_should_exclude_with__returns_false_with_enabled_tag(self):192        traits = self.traits193        tags1 = [ traits.category1_enabled_tag ]194        tags2 = [ traits.category2_enabled_tag ]195        self.assertEqual(False, self.tag_matcher.should_exclude_with(tags1))196        self.assertEqual(False, self.tag_matcher.should_exclude_with(tags2))197    def test_should_exclude_with__returns_false_with_disabled_tag_and_more(self):198        # -- NOTE: Need 1+ enabled active-tags of same category => ENABLED199        traits = self.traits200        test_patterns = [201            ([ traits.category1_enabled_tag, traits.category1_disabled_tag ], "case: first"),202            ([ traits.category1_disabled_tag, traits.category1_enabled_tag ], "case: last"),203            ([ "foo", traits.category1_enabled_tag, traits.category1_disabled_tag, "bar" ], "case: middle"),204        ]205        enabled = True  # EXPECTED206        for tags, case in test_patterns:207            self.assertEqual(not enabled, self.tag_matcher.should_exclude_with(tags),208                             "%s: tags=%s" % (case, tags))209    def test_should_exclude_with__returns_true_with_other_tag(self):210        traits = self.traits211        tags = [ traits.category1_disabled_tag ]212        self.assertEqual(True, self.tag_matcher.should_exclude_with(tags))213    def test_should_exclude_with__returns_true_with_other_tag_and_more(self):214        traits = self.traits215        test_patterns = [216            ([ traits.category1_disabled_tag, "foo" ], "case: first"),217            ([ "foo", traits.category1_disabled_tag ], "case: last"),218            ([ "foo", traits.category1_disabled_tag, "bar" ], "case: middle"),219        ]220        for tags, case in test_patterns:221            self.assertEqual(True, self.tag_matcher.should_exclude_with(tags),222                             "%s: tags=%s" % (case, tags))223    def test_should_exclude_with__returns_true_with_similar_tag(self):224        traits = self.traits225        tags = [ traits.category1_similar_tag ]226        self.assertEqual(True, self.tag_matcher.should_exclude_with(tags))227    def test_should_exclude_with__returns_true_with_similar_and_more(self):228        traits = self.traits229        test_patterns = [230            ([ traits.category1_similar_tag, "foo" ], "case: first"),231            ([ "foo", traits.category1_similar_tag ], "case: last"),232            ([ "foo", traits.category1_similar_tag, "bar" ], "case: middle"),233        ]234        for tags, case in test_patterns:235            self.assertEqual(True, self.tag_matcher.should_exclude_with(tags),236                             "%s: tags=%s" % (case, tags))237    def test_should_exclude_with__returns_false_without_category_tag(self):238        test_patterns = [239            ([ ],           "case: No tags"),240            ([ "foo" ],     "case: One tag"),241            ([ "foo", "bar" ], "case: Two tags"),242        ]243        for tags, case in test_patterns:244            self.assertEqual(False, self.tag_matcher.should_exclude_with(tags),245                             "%s: tags=%s" % (case, tags))246    def test_should_exclude_with__returns_false_with_unknown_category_tag(self):247        """Tags from unknown categories, not supported by value_provider,248        should not be excluded.249        """250        traits = self.traits251        tags = [ traits.unknown_category_tag ]252        self.assertEqual("active.with_UNKNOWN=one", traits.unknown_category_tag)253        self.assertEqual(None, self.tag_matcher.value_provider.get("UNKNOWN"))254        self.assertEqual(False, self.tag_matcher.should_exclude_with(tags))255    def test_should_exclude_with__combinations_of_2_categories(self):256        # XXX-JE-DUPLICATED:257        traits = self.traits258        test_patterns = [259            ("case P00: 2 disabled category tags", True,260             [ traits.category1_disabled_tag, traits.category2_disabled_tag]),261            ("case P01: disabled and enabled category tags", True,262             [ traits.category1_disabled_tag, traits.category2_enabled_tag]),263            ("case P10: enabled and disabled category tags", True,264             [ traits.category1_enabled_tag, traits.category2_disabled_tag]),265            ("case P11: 2 enabled category tags", False,  # -- SHOULD-RUN266             [ traits.category1_enabled_tag, traits.category2_enabled_tag]),267            # -- SPECIAL CASE: With negated category268            ("case N00: not-enabled and disabled category tags", True,269             [ traits.category1_not_enabled_tag, traits.category2_disabled_tag]),270            ("case N01: not-enabled and enabled category tags", True,271             [ traits.category1_not_enabled_tag, traits.category2_enabled_tag]),272            ("case N10: not-disabled and disabled category tags", True,273             [ traits.category1_not_disabled_tag, traits.category2_disabled_tag]),274            ("case N11: not-enabled and enabled category tags", False,  # -- SHOULD-RUN275             [ traits.category1_not_disabled_tag, traits.category2_enabled_tag]),276            # -- SPECIAL CASE: With unknown category277            ("case 0x: disabled and unknown category tags", True,278             [ traits.category1_disabled_tag, traits.unknown_category_tag]),279            ("case 1x: enabled and unknown category tags", False,  # SHOULD-RUN280             [ traits.category1_enabled_tag, traits.unknown_category_tag]),281        ]282        for case, expected, tags in test_patterns:283            actual_result = self.tag_matcher.should_exclude_with(tags)284            self.assertEqual(expected, actual_result,285                             "%s: tags=%s" % (case, tags))286    def test_should_run_with__negates_result_of_should_exclude_with(self):287        traits = self.traits288        test_patterns = [289            ([ ],                   "case: No tags"),290            ([ "foo" ],             "case: One non-category tag"),291            ([ "foo", "bar" ],      "case: Two non-category tags"),292            ([ traits.category1_enabled_tag ],   "case: enabled tag"),293            ([ traits.category1_enabled_tag, traits.category1_disabled_tag ],  "case: enabled and other tag"),294            ([ traits.category1_enabled_tag, "foo" ],    "case: enabled and foo tag"),295            ([ traits.category1_disabled_tag ],            "case: other tag"),296            ([ traits.category1_disabled_tag, "foo" ],     "case: other and foo tag"),297            ([ traits.category1_similar_tag ],          "case: similar tag"),298            ([ "foo", traits.category1_similar_tag ],   "case: foo and similar tag"),299        ]300        for tags, case in test_patterns:301            result1 = self.tag_matcher.should_run_with(tags)302            result2 = self.tag_matcher.should_exclude_with(tags)303            self.assertEqual(result1, not result2, "%s: tags=%s" % (case, tags))304            self.assertEqual(not result1, result2, "%s: tags=%s" % (case, tags))305class TestPredicateTagMatcher(TestCase):306    def test_exclude_with__mechanics(self):307        predicate_function_blueprint = lambda tags: False308        predicate_function = Mock(predicate_function_blueprint)309        predicate_function.return_value = True310        tag_matcher = PredicateTagMatcher(predicate_function)311        tags = [ "foo", "bar" ]312        self.assertEqual(True, tag_matcher.should_exclude_with(tags))313        predicate_function.assert_called_once_with(tags)314        self.assertEqual(True, predicate_function(tags))315    def test_should_exclude_with__returns_true_when_predicate_is_true(self):316        predicate_always_true = lambda tags: True317        tag_matcher1 = PredicateTagMatcher(predicate_always_true)318        tags = [ "foo", "bar" ]319        self.assertEqual(True, tag_matcher1.should_exclude_with(tags))320        self.assertEqual(True, predicate_always_true(tags))321    def test_should_exclude_with__returns_true_when_predicate_is_true2(self):322        # -- CASE: Use predicate function instead of lambda.323        def predicate_contains_foo(tags):324            return any(x == "foo" for x in tags)325        tag_matcher2 = PredicateTagMatcher(predicate_contains_foo)326        tags = [ "foo", "bar" ]327        self.assertEqual(True, tag_matcher2.should_exclude_with(tags))328        self.assertEqual(True, predicate_contains_foo(tags))329    def test_should_exclude_with__returns_false_when_predicate_is_false(self):330        predicate_always_false = lambda tags: False331        tag_matcher1 = PredicateTagMatcher(predicate_always_false)332        tags = [ "foo", "bar" ]333        self.assertEqual(False, tag_matcher1.should_exclude_with(tags))334        self.assertEqual(False, predicate_always_false(tags))335class TestPredicateTagMatcher(TestCase):336    def test_exclude_with__mechanics(self):337        predicate_function_blueprint = lambda tags: False338        predicate_function = Mock(predicate_function_blueprint)339        predicate_function.return_value = True340        tag_matcher = PredicateTagMatcher(predicate_function)341        tags = [ "foo", "bar" ]342        self.assertEqual(True, tag_matcher.should_exclude_with(tags))343        predicate_function.assert_called_once_with(tags)344        self.assertEqual(True, predicate_function(tags))345    def test_should_exclude_with__returns_true_when_predicate_is_true(self):346        predicate_always_true = lambda tags: True347        tag_matcher1 = PredicateTagMatcher(predicate_always_true)348        tags = [ "foo", "bar" ]349        self.assertEqual(True, tag_matcher1.should_exclude_with(tags))350        self.assertEqual(True, predicate_always_true(tags))351    def test_should_exclude_with__returns_true_when_predicate_is_true2(self):352        # -- CASE: Use predicate function instead of lambda.353        def predicate_contains_foo(tags):354            return any(x == "foo" for x in tags)355        tag_matcher2 = PredicateTagMatcher(predicate_contains_foo)356        tags = [ "foo", "bar" ]357        self.assertEqual(True, tag_matcher2.should_exclude_with(tags))358        self.assertEqual(True, predicate_contains_foo(tags))359    def test_should_exclude_with__returns_false_when_predicate_is_false(self):360        predicate_always_false = lambda tags: False361        tag_matcher1 = PredicateTagMatcher(predicate_always_false)362        tags = [ "foo", "bar" ]363        self.assertEqual(False, tag_matcher1.should_exclude_with(tags))364        self.assertEqual(False, predicate_always_false(tags))365class TestCompositeTagMatcher(TestCase):366    @staticmethod367    def count_tag_matcher_with_result(tag_matchers, tags, result_value):368        count = 0369        for tag_matcher in tag_matchers:370            current_result = tag_matcher.should_exclude_with(tags)371            if current_result == result_value:372                count += 1373        return count374    def setUp(self):375        predicate_false = lambda tags: False376        predicate_contains_foo = lambda tags: any(x == "foo" for x in tags)377        self.tag_matcher_false = PredicateTagMatcher(predicate_false)378        self.tag_matcher_foo = PredicateTagMatcher(predicate_contains_foo)379        tag_matchers = [380            self.tag_matcher_foo,381            self.tag_matcher_false382        ]383        self.ctag_matcher = CompositeTagMatcher(tag_matchers)384    def test_should_exclude_with__returns_true_when_any_tag_matcher_returns_true(self):385        test_patterns = [386            ("case: with foo",  ["foo", "bar"]),387            ("case: with foo2", ["foozy", "foo", "bar"]),388        ]389        for case, tags in test_patterns:390            actual_result = self.ctag_matcher.should_exclude_with(tags)391            self.assertEqual(True, actual_result,392                             "%s: tags=%s" % (case, tags))393            actual_true_count = self.count_tag_matcher_with_result(394                                self.ctag_matcher.tag_matchers, tags, True)395            self.assertEqual(1, actual_true_count)396    def test_should_exclude_with__returns_false_when_no_tag_matcher_return_true(self):397        test_patterns = [398            ("case: without foo",   ["fool", "bar"]),399            ("case: without foo2",  ["foozy", "bar"]),400        ]401        for case, tags in test_patterns:402            actual_result = self.ctag_matcher.should_exclude_with(tags)403            self.assertEqual(False, actual_result,404                             "%s: tags=%s" % (case, tags))405            actual_true_count = self.count_tag_matcher_with_result(406                                    self.ctag_matcher.tag_matchers, tags, True)407            self.assertEqual(0, actual_true_count)408# -----------------------------------------------------------------------------409# PROTOTYPING CLASSES (deprecating)410# -----------------------------------------------------------------------------411class TestOnlyWithCategoryTagMatcher(TestCase):412    TagMatcher = OnlyWithCategoryTagMatcher413    def setUp(self):414        category = "xxx"415        with warnings.catch_warnings():416            warnings.simplefilter("ignore", DeprecationWarning)417            self.tag_matcher = OnlyWithCategoryTagMatcher(category, "alice")418        self.enabled_tag = self.TagMatcher.make_category_tag(category, "alice")419        self.similar_tag = self.TagMatcher.make_category_tag(category, "alice2")420        self.other_tag = self.TagMatcher.make_category_tag(category, "other")421        self.category = category422    def test_should_exclude_with__returns_false_with_enabled_tag(self):423        tags = [ self.enabled_tag ]424        self.assertEqual(False, self.tag_matcher.should_exclude_with(tags))425    def test_should_exclude_with__returns_false_with_enabled_tag_and_more(self):426        test_patterns = [427            ([ self.enabled_tag, self.other_tag ], "case: first"),428            ([ self.other_tag, self.enabled_tag ], "case: last"),429            ([ "foo", self.enabled_tag, self.other_tag, "bar" ], "case: middle"),430        ]431        for tags, case in test_patterns:432            self.assertEqual(False, self.tag_matcher.should_exclude_with(tags),433                             "%s: tags=%s" % (case, tags))434    def test_should_exclude_with__returns_true_with_other_tag(self):435        tags = [ self.other_tag ]436        self.assertEqual(True, self.tag_matcher.should_exclude_with(tags))437    def test_should_exclude_with__returns_true_with_other_tag_and_more(self):438        test_patterns = [439            ([ self.other_tag, "foo" ], "case: first"),440            ([ "foo", self.other_tag ], "case: last"),441            ([ "foo", self.other_tag, "bar" ], "case: middle"),442        ]443        for tags, case in test_patterns:444            self.assertEqual(True, self.tag_matcher.should_exclude_with(tags),445                             "%s: tags=%s" % (case, tags))446    def test_should_exclude_with__returns_true_with_similar_tag(self):447        tags = [ self.similar_tag ]448        self.assertEqual(True, self.tag_matcher.should_exclude_with(tags))449    def test_should_exclude_with__returns_true_with_similar_and_more(self):450        test_patterns = [451            ([ self.similar_tag, "foo" ], "case: first"),452            ([ "foo", self.similar_tag ], "case: last"),453            ([ "foo", self.similar_tag, "bar" ], "case: middle"),454        ]455        for tags, case in test_patterns:456            self.assertEqual(True, self.tag_matcher.should_exclude_with(tags),457                             "%s: tags=%s" % (case, tags))458    def test_should_exclude_with__returns_false_without_category_tag(self):459        test_patterns = [460            ([ ],           "case: No tags"),461            ([ "foo" ],     "case: One tag"),462            ([ "foo", "bar" ], "case: Two tags"),463        ]464        for tags, case in test_patterns:465            self.assertEqual(False, self.tag_matcher.should_exclude_with(tags),466                             "%s: tags=%s" % (case, tags))467    def test_should_run_with__negates_result_of_should_exclude_with(self):468        test_patterns = [469            ([ ],                   "case: No tags"),470            ([ "foo" ],             "case: One non-category tag"),471            ([ "foo", "bar" ],      "case: Two non-category tags"),472            ([ self.enabled_tag ],   "case: enabled tag"),473            ([ self.enabled_tag, self.other_tag ],  "case: enabled and other tag"),474            ([ self.enabled_tag, "foo" ],    "case: enabled and foo tag"),475            ([ self.other_tag ],            "case: other tag"),476            ([ self.other_tag, "foo" ],     "case: other and foo tag"),477            ([ self.similar_tag ],          "case: similar tag"),478            ([ "foo", self.similar_tag ],   "case: foo and similar tag"),479        ]480        for tags, case in test_patterns:481            result1 = self.tag_matcher.should_run_with(tags)482            result2 = self.tag_matcher.should_exclude_with(tags)483            self.assertEqual(result1, not result2, "%s: tags=%s" % (case, tags))484            self.assertEqual(not result1, result2, "%s: tags=%s" % (case, tags))485    def test_make_category_tag__returns_category_tag_prefix_without_value(self):486        category = "xxx"487        tag1 = OnlyWithCategoryTagMatcher.make_category_tag(category)488        tag2 = OnlyWithCategoryTagMatcher.make_category_tag(category, None)489        tag3 = OnlyWithCategoryTagMatcher.make_category_tag(category, value=None)490        self.assertEqual("only.with_xxx=", tag1)491        self.assertEqual("only.with_xxx=", tag2)492        self.assertEqual("only.with_xxx=", tag3)493        self.assertTrue(tag1.startswith(OnlyWithCategoryTagMatcher.tag_prefix))494    def test_make_category_tag__returns_category_tag_with_value(self):495        category = "xxx"496        tag1 = OnlyWithCategoryTagMatcher.make_category_tag(category, "alice")497        tag2 = OnlyWithCategoryTagMatcher.make_category_tag(category, "bob")498        self.assertEqual("only.with_xxx=alice", tag1)499        self.assertEqual("only.with_xxx=bob", tag2)500    def test_make_category_tag__returns_category_tag_with_tag_prefix(self):501        my_tag_prefix = "ONLY_WITH."502        category = "xxx"503        TagMatcher = OnlyWithCategoryTagMatcher504        tag0 = TagMatcher.make_category_tag(category, tag_prefix=my_tag_prefix)505        tag1 = TagMatcher.make_category_tag(category, "alice", my_tag_prefix)506        tag2 = TagMatcher.make_category_tag(category, "bob", tag_prefix=my_tag_prefix)507        self.assertEqual("ONLY_WITH.xxx=", tag0)508        self.assertEqual("ONLY_WITH.xxx=alice", tag1)509        self.assertEqual("ONLY_WITH.xxx=bob", tag2)510        self.assertTrue(tag1.startswith(my_tag_prefix))511    def test_ctor__with_tag_prefix(self):512        tag_prefix = "ONLY_WITH."513        tag_matcher = OnlyWithCategoryTagMatcher("xxx", "alice", tag_prefix)514        tags = ["foo", "ONLY_WITH.xxx=foo", "only.with_xxx=bar", "bar"]515        actual_tags = tag_matcher.select_category_tags(tags)516        self.assertEqual(["ONLY_WITH.xxx=foo"], actual_tags)517class Traits4OnlyWithAnyCategoryTagMatcher(object):518    """Test data for OnlyWithAnyCategoryTagMatcher."""519    TagMatcher0 = OnlyWithCategoryTagMatcher520    TagMatcher = OnlyWithAnyCategoryTagMatcher521    category1_enabled_tag = TagMatcher0.make_category_tag("foo", "alice")522    category1_similar_tag = TagMatcher0.make_category_tag("foo", "alice2")523    category1_disabled_tag = TagMatcher0.make_category_tag("foo", "bob")524    category2_enabled_tag = TagMatcher0.make_category_tag("bar", "BOB")525    category2_similar_tag = TagMatcher0.make_category_tag("bar", "BOB2")526    category2_disabled_tag = TagMatcher0.make_category_tag("bar", "CHARLY")527    unknown_category_tag = TagMatcher0.make_category_tag("UNKNOWN", "one")528class TestOnlyWithAnyCategoryTagMatcher(TestCase):529    TagMatcher = OnlyWithAnyCategoryTagMatcher...test_tag_expression2.py
Source:test_tag_expression2.py  
1# -*- coding: utf-8 -*-2"""3Alternative approach to test TagExpression by testing all possible combinations.4REQUIRES: Python >= 2.6, because itertools.combinations() is used.5"""6from __future__ import absolute_import7from behave.tag_expression import TagExpression8from nose import tools9import itertools10from six.moves import range11has_combinations = hasattr(itertools, "combinations")12if has_combinations:13    # -- REQUIRE: itertools.combinations14    # SINCE: Python 2.615    def all_combinations(items):16        variants = []17        for n in range(len(items)+1):18            variants.extend(itertools.combinations(items, n))19        return variants20    NO_TAGS = "__NO_TAGS__"21    def make_tags_line(tags):22        """23        Convert into tags-line as in feature file.24        """25        if tags:26            return "@" + " @".join(tags)27        return NO_TAGS28    TestCase = object29    # ----------------------------------------------------------------------------30    # TEST: all_combinations() test helper31    # ----------------------------------------------------------------------------32    class TestAllCombinations(TestCase):33        def test_all_combinations_with_2values(self):34            items = "@one @two".split()35            expected = [36                (),37                ('@one',),38                ('@two',),39                ('@one', '@two'),40            ]41            actual = all_combinations(items)42            tools.eq_(actual, expected)43            tools.eq_(len(actual), 4)44        def test_all_combinations_with_3values(self):45            items = "@one @two @three".split()46            expected = [47                (),48                ('@one',),49                ('@two',),50                ('@three',),51                ('@one', '@two'),52                ('@one', '@three'),53                ('@two', '@three'),54                ('@one', '@two', '@three'),55            ]56            actual = all_combinations(items)57            tools.eq_(actual, expected)58            tools.eq_(len(actual), 8)59    # ----------------------------------------------------------------------------60    # COMPLICATED TESTS FOR: TagExpression logic61    # ----------------------------------------------------------------------------62    class TagExpressionTestCase(TestCase):63        def assert_tag_expression_matches(self, tag_expression,64                                          tag_combinations, expected):65            matched = [ make_tags_line(c) for c in tag_combinations66                                if tag_expression.check(c) ]67            tools.eq_(matched, expected)68        def assert_tag_expression_mismatches(self, tag_expression,69                                            tag_combinations, expected):70            mismatched = [ make_tags_line(c) for c in tag_combinations71                                if not tag_expression.check(c) ]72            tools.eq_(mismatched, expected)73    class TestTagExpressionWith1Term(TagExpressionTestCase):74        """75        ALL_COMBINATIONS[4] with: @foo @other76            self.NO_TAGS,77            "@foo", "@other",78            "@foo @other",79        """80        tags = ("foo", "other")81        tag_combinations = all_combinations(tags)82        def test_matches__foo(self):83            tag_expression = TagExpression(["@foo"])84            expected = [85                # -- WITH 0 tags: None86                "@foo",87                "@foo @other",88            ]89            self.assert_tag_expression_matches(tag_expression,90                                               self.tag_combinations, expected)91        def test_matches__not_foo(self):92            tag_expression = TagExpression(["-@foo"])93            expected = [94                NO_TAGS,95                "@other",96            ]97            self.assert_tag_expression_matches(tag_expression,98                                               self.tag_combinations, expected)99    class TestTagExpressionWith2Terms(TagExpressionTestCase):100        """101        ALL_COMBINATIONS[8] with: @foo @bar @other102            self.NO_TAGS,103            "@foo", "@bar", "@other",104            "@foo @bar", "@foo @other", "@bar @other",105            "@foo @bar @other",106        """107        tags = ("foo", "bar", "other")108        tag_combinations = all_combinations(tags)109        # -- LOGICAL-OR CASES:110        def test_matches__foo_or_bar(self):111            tag_expression = TagExpression(["@foo,@bar"])112            expected = [113                # -- WITH 0 tags: None114                "@foo", "@bar",115                "@foo @bar", "@foo @other", "@bar @other",116                "@foo @bar @other",117            ]118            self.assert_tag_expression_matches(tag_expression,119                                               self.tag_combinations, expected)120        def test_matches__foo_or_not_bar(self):121            tag_expression = TagExpression(["@foo,-@bar"])122            expected = [123                NO_TAGS,124                "@foo", "@other",125                "@foo @bar", "@foo @other",126                "@foo @bar @other",127            ]128            self.assert_tag_expression_matches(tag_expression,129                                               self.tag_combinations, expected)130        def test_matches__not_foo_or_not_bar(self):131            tag_expression = TagExpression(["-@foo,-@bar"])132            expected = [133                NO_TAGS,134                "@foo", "@bar", "@other",135                "@foo @other", "@bar @other",136            ]137            self.assert_tag_expression_matches(tag_expression,138                                               self.tag_combinations, expected)139        # -- LOGICAL-AND CASES:140        def test_matches__foo_and_bar(self):141            tag_expression = TagExpression(["@foo", "@bar"])142            expected = [143                # -- WITH 0 tags: None144                # -- WITH 1 tag:  None145                "@foo @bar",146                "@foo @bar @other",147            ]148            self.assert_tag_expression_matches(tag_expression,149                                               self.tag_combinations, expected)150        def test_matches__foo_and_not_bar(self):151            tag_expression = TagExpression(["@foo", "-@bar"])152            expected = [153                # -- WITH 0 tags: None154                # -- WITH 1 tag:  None155                "@foo",156                "@foo @other",157                # -- WITH 3 tag:  None158            ]159            self.assert_tag_expression_matches(tag_expression,160                                               self.tag_combinations, expected)161        def test_matches__not_foo_and_not_bar(self):162            tag_expression = TagExpression(["-@foo", "-@bar"])163            expected = [164                NO_TAGS,165                "@other",166                # -- WITH 2 tag:  None167                # -- WITH 3 tag:  None168            ]169            self.assert_tag_expression_matches(tag_expression,170                                               self.tag_combinations, expected)171    class TestTagExpressionWith3Terms(TagExpressionTestCase):172        """173        ALL_COMBINATIONS[16] with: @foo @bar @zap @other174            self.NO_TAGS,175            "@foo", "@bar", "@zap", "@other",176            "@foo @bar", "@foo @zap", "@foo @other",177            "@bar @zap", "@bar @other",178            "@zap @other",179            "@foo @bar @zap", "@foo @bar @other", "@foo @zap @other",180            "@bar @zap @other",181            "@foo @bar @zap @other",182        """183        tags = ("foo", "bar", "zap", "other")184        tag_combinations = all_combinations(tags)185        # -- LOGICAL-OR CASES:186        def test_matches__foo_or_bar_or_zap(self):187            tag_expression = TagExpression(["@foo,@bar,@zap"])188            matched = [189                # -- WITH 0 tags: None190                # -- WITH 1 tag:191                "@foo", "@bar", "@zap",192                # -- WITH 2 tags:193                "@foo @bar", "@foo @zap", "@foo @other",194                "@bar @zap", "@bar @other",195                "@zap @other",196                # -- WITH 3 tags:197                "@foo @bar @zap", "@foo @bar @other", "@foo @zap @other",198                "@bar @zap @other",199                # -- WITH 4 tags:200                "@foo @bar @zap @other",201            ]202            self.assert_tag_expression_matches(tag_expression,203                                               self.tag_combinations, matched)204            mismatched = [205                # -- WITH 0 tags:206                NO_TAGS,207                # -- WITH 1 tag:208                "@other",209                # -- WITH 2 tags: None210                # -- WITH 3 tags: None211                # -- WITH 4 tags: None212            ]213            self.assert_tag_expression_mismatches(tag_expression,214                                               self.tag_combinations, mismatched)215        def test_matches__foo_or_not_bar_or_zap(self):216            tag_expression = TagExpression(["@foo,-@bar,@zap"])217            matched = [218                # -- WITH 0 tags:219                NO_TAGS,220                # -- WITH 1 tag:221                "@foo", "@zap", "@other",222                # -- WITH 2 tags:223                "@foo @bar", "@foo @zap", "@foo @other",224                "@bar @zap",225                "@zap @other",226                # -- WITH 3 tags:227                "@foo @bar @zap", "@foo @bar @other", "@foo @zap @other",228                "@bar @zap @other",229                # -- WITH 4 tags:230                "@foo @bar @zap @other",231            ]232            self.assert_tag_expression_matches(tag_expression,233                                               self.tag_combinations, matched)234            mismatched = [235                # -- WITH 0 tags: None236                # -- WITH 1 tag:237                "@bar",238                # -- WITH 2 tags:239                "@bar @other",240                # -- WITH 3 tags: None241                # -- WITH 4 tags: None242            ]243            self.assert_tag_expression_mismatches(tag_expression,244                                               self.tag_combinations, mismatched)245        def test_matches__foo_or_not_bar_or_not_zap(self):246            tag_expression = TagExpression(["foo,-@bar,-@zap"])247            matched = [248                # -- WITH 0 tags:249                NO_TAGS,250                # -- WITH 1 tag:251                "@foo", "@bar", "@zap", "@other",252                # -- WITH 2 tags:253                "@foo @bar", "@foo @zap", "@foo @other",254                "@bar @other",255                "@zap @other",256                # -- WITH 3 tags:257                "@foo @bar @zap", "@foo @bar @other", "@foo @zap @other",258                # -- WITH 4 tags:259                "@foo @bar @zap @other",260            ]261            self.assert_tag_expression_matches(tag_expression,262                                               self.tag_combinations, matched)263            mismatched = [264                # -- WITH 0 tags: None265                # -- WITH 1 tag: None266                # -- WITH 2 tags:267                "@bar @zap",268                # -- WITH 3 tags: None269                "@bar @zap @other",270                # -- WITH 4 tags: None271            ]272            self.assert_tag_expression_mismatches(tag_expression,273                                               self.tag_combinations, mismatched)274        def test_matches__not_foo_or_not_bar_or_not_zap(self):275            tag_expression = TagExpression(["-@foo,-@bar,-@zap"])276            matched = [277                # -- WITH 0 tags:278                NO_TAGS,279                # -- WITH 1 tag:280                "@foo", "@bar", "@zap", "@other",281                # -- WITH 2 tags:282                "@foo @bar", "@foo @zap", "@foo @other",283                "@bar @zap", "@bar @other",284                "@zap @other",285                # -- WITH 3 tags:286                "@foo @bar @other", "@foo @zap @other",287                "@bar @zap @other",288                # -- WITH 4 tags: None289            ]290            self.assert_tag_expression_matches(tag_expression,291                                               self.tag_combinations, matched)292            mismatched = [293                # -- WITH 0 tags: None294                # -- WITH 1 tag: None295                # -- WITH 2 tags:296                # -- WITH 3 tags:297                "@foo @bar @zap",298                # -- WITH 4 tags:299                "@foo @bar @zap @other",300            ]301            self.assert_tag_expression_mismatches(tag_expression,302                                               self.tag_combinations, mismatched)303        def test_matches__foo_and_bar_or_zap(self):304            tag_expression = TagExpression(["@foo", "@bar,@zap"])305            matched = [306                # -- WITH 0 tags:307                # -- WITH 1 tag:308                # -- WITH 2 tags:309                "@foo @bar", "@foo @zap",310                # -- WITH 3 tags:311                "@foo @bar @zap", "@foo @bar @other", "@foo @zap @other",312                # -- WITH 4 tags: None313                "@foo @bar @zap @other",314            ]315            self.assert_tag_expression_matches(tag_expression,316                                               self.tag_combinations, matched)317            mismatched = [318                # -- WITH 0 tags:319                NO_TAGS,320                # -- WITH 1 tag:321                "@foo", "@bar", "@zap", "@other",322                # -- WITH 2 tags:323                "@foo @other",324                "@bar @zap", "@bar @other",325                "@zap @other",326                # -- WITH 3 tags:327                "@bar @zap @other",328                # -- WITH 4 tags: None329            ]330            self.assert_tag_expression_mismatches(tag_expression,331                                               self.tag_combinations, mismatched)332        def test_matches__foo_and_bar_or_not_zap(self):333            tag_expression = TagExpression(["@foo", "@bar,-@zap"])334            matched = [335                # -- WITH 0 tags:336                # -- WITH 1 tag:337                "@foo",338                # -- WITH 2 tags:339                "@foo @bar", "@foo @other",340                # -- WITH 3 tags:341                "@foo @bar @zap", "@foo @bar @other",342                # -- WITH 4 tags: None343                "@foo @bar @zap @other",344            ]345            self.assert_tag_expression_matches(tag_expression,346                                               self.tag_combinations, matched)347            mismatched = [348                # -- WITH 0 tags:349                NO_TAGS,350                # -- WITH 1 tag:351                "@bar", "@zap", "@other",352                # -- WITH 2 tags:353                "@foo @zap",354                "@bar @zap", "@bar @other",355                "@zap @other",356                # -- WITH 3 tags:357                "@foo @zap @other",358                "@bar @zap @other",359                # -- WITH 4 tags: None360            ]361            self.assert_tag_expression_mismatches(tag_expression,362                                               self.tag_combinations, mismatched)363        def test_matches__foo_and_bar_and_zap(self):364            tag_expression = TagExpression(["@foo", "@bar", "@zap"])365            matched = [366                # -- WITH 0 tags:367                # -- WITH 1 tag:368                # -- WITH 2 tags:369                # -- WITH 3 tags:370                "@foo @bar @zap",371                # -- WITH 4 tags: None372                "@foo @bar @zap @other",373            ]374            self.assert_tag_expression_matches(tag_expression,375                                               self.tag_combinations, matched)376            mismatched = [377                # -- WITH 0 tags:378                NO_TAGS,379                # -- WITH 1 tag:380                "@foo", "@bar", "@zap", "@other",381                # -- WITH 2 tags:382                "@foo @bar", "@foo @zap", "@foo @other",383                "@bar @zap", "@bar @other",384                "@zap @other",385                # -- WITH 3 tags:386                "@foo @bar @other", "@foo @zap @other",387                "@bar @zap @other",388                # -- WITH 4 tags: None389            ]390            self.assert_tag_expression_mismatches(tag_expression,391                                               self.tag_combinations, mismatched)392        def test_matches__not_foo_and_not_bar_and_not_zap(self):393            tag_expression = TagExpression(["-@foo", "-@bar", "-@zap"])394            matched = [395                # -- WITH 0 tags:396                NO_TAGS,397                # -- WITH 1 tag:398                "@other",399                # -- WITH 2 tags:400                # -- WITH 3 tags:401                # -- WITH 4 tags: None402            ]403            self.assert_tag_expression_matches(tag_expression,404                                               self.tag_combinations, matched)405            mismatched = [406                # -- WITH 0 tags:407                # -- WITH 1 tag:408                "@foo", "@bar", "@zap",409                # -- WITH 2 tags:410                "@foo @bar", "@foo @zap", "@foo @other",411                "@bar @zap", "@bar @other",412                "@zap @other",413                # -- WITH 3 tags:414                "@foo @bar @zap",415                "@foo @bar @other", "@foo @zap @other",416                "@bar @zap @other",417                # -- WITH 4 tags: None418                "@foo @bar @zap @other",419            ]420            self.assert_tag_expression_mismatches(tag_expression,...crf.py
Source:crf.py  
1import torch2from torch import nn3from typing import Optional, List4class CRF(nn.Module):5    def __init__(self, num_tags: int,6                 batch_first: Optional[bool] = False) -> None:7        """åå§å CRF å±8        Args:9            num_tags (int): tags çæ°é(ä¸ç® start å end)10            batch_first (bool, optional): batch_first. Defaults to False.11        Raises:12            ValueError: if num_tags <= 013        """14        if num_tags <= 0:15            raise ValueError(16                f"invalid number of tags: {num_tags}, " +17                "it should be greater than 0")18        super(CRF, self).__init__()19        self.num_tags = num_tags20        self.batch_first = batch_first21        self.start_transitions = nn.Parameter(torch.empty(self.num_tags))22        # shape:(num_tags)23        self.end_transitions = nn.Parameter(torch.empty(self.num_tags))24        # shape: (num_tags)25        self.transitions = nn.Parameter(26            torch.empty(self.num_tags, self.num_tags))27        # shape: (num_tags, num_tags)28        self._reset_parameters()29    def _reset_parameters(self) -> None:30        nn.init.uniform_(self.start_transitions, -0.1, 0.1)31        nn.init.uniform_(self.end_transitions, -0.1, 0.1)32        nn.init.uniform_(self.transitions, -0.1, 0.1)33    def __repr__(self) -> str:34        return f'{self.__class__.__name__}(num_tags={self.num_tags})'35    def forward(self, emissions: torch.Tensor,36                tags: torch.LongTensor,37                mask: Optional[torch.ByteTensor] = None,38                reduction: str = "mean") -> torch.Tensor:39        """40        CRF å±åå计ç®, æ±å¾å41        P(y|x)=\frac{exp \sum_k w_k f_k(y_{i-1}, y_{i}, x, i)}42                    {\sum_y exp \sum_k w_k f_k(y_{i-1}, y_{i}, x, i)}43        log P(y|x) = \sum_k w_k f_k(y_{i-1}, y_i, x, i) -44                log({\sum_y exp \sum_k w_k f_k(y_{i-1}, y_{i}, x, i)})45        æ±åºæ¥çæ¯æå¤§ä¼¼ç¶åæ°, ä¼åéè¦-1*log likelihood46        """47        if reduction not in ["none", "sum", "mean", "token_mean"]:48            raise ValueError(49                'reduction expected  "none", "sum", "mean", "token_mean",' +50                f'but got {reduction}')51        if mask is None:52            mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)53        if mask.dtype != torch.uint8:54            mask = mask.byte()55        self._validation(emissions, tags=tags, mask=mask)56        if self.batch_first:57            emissions = emissions.transpose(0, 1)58            tags = tags.transpose(0, 1)59            mask = mask.transpose(0, 1)60            # shape: (seq_length, batch_size, num_tags)61        numerator = self._compute_score(emissions, tags, mask)62        denominator = self._compute_normalizer(emissions, mask)63        llh = numerator - denominator  # log likelihood64        if reduction == 'none':65            return llh66        if reduction == 'sum':67            return llh.sum()68        if reduction == 'mean':69            return llh.mean()70        return llh.sum()/mask.float().sum()  # token level mean71    def _validation(self, emissions: torch.Tensor,72                    tags: Optional[torch.LongTensor] = None,73                    mask: Optional[torch.ByteTensor] = None) -> None:74        if emissions.dim() != 3:75            raise ValueError('emissions must have dimension of 3, ' +76                             f'got {emissions.dim()}')77        if emissions.size(2) != self.num_tags:78            raise ValueError(79                f'expected last dimension of emissions is {self.num_tags}, '80                f'got {emissions.size(2)}')81        if tags is not None:82            if emissions.shape[:2] != tags.shape:83                raise ValueError(84                    'the first two dimensions of emissions and' +85                    ' tags must match, ' +86                    f'got {tuple(emissions.shape[:2])}' +87                    f' and {tuple(tags.shape)}')88        if mask is not None:89            if emissions.shape[:2] != mask.shape:90                raise ValueError(91                    'the first two dimensions of emissions and' +92                    ' mask must match, ' +93                    f'got {tuple(emissions.shape[:2])}' +94                    f' and {tuple(mask.shape)}')95            no_empty_seq = not self.batch_first and mask[0].all()96            no_empty_seq_bf = self.batch_first and mask[:, 0].all()97            if not no_empty_seq and not no_empty_seq_bf:98                raise ValueError('mask of the first timestep must all be on')99    def _compute_score(self, emissions: torch.Tensor,100                       tags: torch.LongTensor,101                       mask: torch.ByteTensor) -> torch.Tensor:102        """ 计ç®batch sentences å¨å½åè·¯å¾(tags)ä¸çå¾å103        Args:104            emissions (torch.Tensor): (seq_length, batch_size, num_tags) åå°åæ°105            tags (torch.LongTensor): (seq_length, batch_size)  æ¯ä¸ª token çæ ç¾106            mask (torch.ByteTensor): (seq_length, batch_size) mask ç©éµ107        Returns:108            torch.Tensor: (batch_size,)109        """110        seq_length, batch_size = tags.shape111        mask = mask.float()112        # shape: (batch_size,)113        score = self.start_transitions[tags[0]]  # ä» start è½¬ç§»å° tags[0]çåæ°114        # å ä¸ä» tags[0]åå°å°ç¬¬ä¸ä¸ªè¯çåæ°115        score = score + emissions[0, torch.arange(batch_size), tags[0]]116        for i in range(1, seq_length):117            # ä» tags[i-1]è½¬ç§»å° tags[i]çåæ°118            score += self.transitions[tags[i-1], tags[i]] * mask[i]119            # ä» tags[i]åå°å°ç¬¬ i 个åè¯çåæ°120            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]121        # å¾åºæåä¸ä¸ª tag122        seq_ends = mask.long().sum(dim=0) - 1123        last_tags = tags[seq_ends, torch.arange(batch_size)]124        score += self.end_transitions[last_tags]125        return score126    def _compute_normalizer(self, emissions: torch.Tensor,127                            mask: torch.ByteTensor) -> torch.Tensor:128        """è®¡ç® batch setences å¨å½ååå°åæ°ä¸çæ»å¾å(è§èåå å)129        Args:130            emissions (torch.Tensor): (seq_length, batch_size, num_tags) åå°åæ°131            mask (torch.ByteTensor): (seq_length, batch_size) mask ç©éµ132        Returns:133            torch.Tensor: (batch_size,)134        """135        seq_length = emissions.size()[0]136        # batch_size = emissions.size()[1]137        # shape: (batch_size, num_tags)138        score = self.start_transitions + emissions[0]139        # print("DEBUG")140        # print(score.shape)141        # print(self.start_transitions.shape)142        # print(emissions[0].shape)143        # TODO æ¤å¤æ¯å¦éè¦ logsumexp?144        for i in range(1, seq_length):145            # print("DEBUG")146            # shape: (batch_size, num_tags, 1)147            broadcast_score = score.unsqueeze(2)148            # shape: (batch_size, 1, num_tags)149            broadcast_emissions = emissions[i].unsqueeze(1)150            # print(f"broadcast_score shape:{broadcast_score.shape}")151            # print(f"broadcast_emissions shape:{broadcast_emissions.shape}")152            # print(f"transition shape: {self.transitions.shape}")153            # print(broadcast_score+broadcast_emissions)154            next_score = self.transitions + broadcast_score155            next_score = next_score + broadcast_emissions156            next_score = torch.logsumexp(next_score, dim=1)157            score = torch.where(mask[i].unsqueeze(1), next_score, score)158        score += self.end_transitions159        score = torch.logsumexp(score, dim=1)160        return score161    def decode(self, emissions: torch.Tensor,162               mask: Optional[torch.ByteTensor] = None,163               pad_tag: Optional[int] = None) -> List[List[List[int]]]:164        """[summary]165        Args:166            emissions (torch.Tensor): [description]167            mask (Optional[torch.ByteTensor], optional): Defaults to None.168            pad_tag (Optional[int], optional): [description]. Defaults to None.169        Returns:170            List[List[List[int]]]: [description]171        """172        # TODO æ£æ¥è¿åçæä¼è·¯å¾çshape173        if mask is None:174            mask = torch.ones(emissions.shape[:2], device=emissions.device,175                              detype=torch.uint8)176        if mask.dtype != torch.uint8:177            mask = mask.byte()178        self._validation(emissions, mask=mask)179        if self.batch_first:180            emissions = emissions.transpose(0, 1)181            mask = mask.transpose(0, 1)182        return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)183    def _viterbi_decode(self, emissions: torch.FloatTensor,184                        mask: torch.ByteTensor,185                        pad_tag: Optional[int] = None) \186            -> List[List[int]]:187        """使ç¨viterbiç®æ³è§£ç åºæä¼è·¯å¾188            卿è§åçææ³189        Args:190            emissions (torch.FloatTensor): (seq_length, batch_size, num_tags)191                                            åå°åæ°192            mask (torch.ByteTensor): (seq_length, batch_size)193                                            mask ç©éµ194        Returns:195            List[List[int]]: (batch_size, seq_length)196        """197        if pad_tag is None:198            pad_tag = 0199        device = emissions.device200        seq_length, batch_size = tags.shape201        score = self.start_transitions + emissions[0]202        history_idx = torch.zeros((seq_length, batch_size, self.num_tags),203                                  dtype=torch.long, device=device)204        oor_idx = torch.zeros((batch_size, self.num_tags),205                              dtype=torch.long, device=device)206        oor_tag = torch.full((seq_length, batch_size), fill_value=pad_tag,207                             dtype=torch.long, device=device)208        for i in range(seq_length):209            broadcast_score = score.unsqueeze(2)210            broadcast_emissions = emissions[i].unsqueeze(1)211            next_score = broadcast_score + self.transitions212            next_score = next_score + broadcast_emissions213            # next_score: (batch_size, self.num_tags)214            # indices: (batch_size, self.num_tags)215            next_score, indices = next_score.max(dim=1)216            score = torch.where(mask[i].unsqueeze(-1), next_score, score)217            indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx)218            history_idx[i-1] = indices219        end_score = score + self.end_transitions220        _, end_tag = end_score.max(dim=1)221        # shape: (batch_size,)222        seq_ends = mask.long().sum(dim=0)-1223        # shape: (batch_size, seq_length, num_tags)224        history_idx = history_idx.transpose(0, 1).contiguous()225        history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1,226                                                               self.num_tags),227                             end_tag.view(-1, 1, 1).expand(-1, 1,228                                                           self.num_tags))229        # æç
§ seq_length, å° history_idx ä¸mask æåçä½ç½®å¯¹åºè¿å»å end_tag230        history_idx = history_idx.transpose(0, 1).contiguous()231        # The most probable path for each sequence232        best_tags_arr = torch.zeros((seq_length, batch_size),233                                    dtype=torch.long, device=device)234        best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)235        for idx in range(seq_length - 1, -1, -1):236            best_tags = torch.gather(history_idx[idx], 1, best_tags)237            best_tags_arr[idx] = best_tags.data.view(batch_size)238        return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1)239if __name__ == "__main__":240    batch_size = 128241    seq_length = 64242    num_tags = 7243    crf = CRF(num_tags, batch_first=False)244    tags = torch.randint(0, 6, size=(seq_length, batch_size))245    mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)246    emissions = torch.randn(seq_length, batch_size, num_tags)247    print(crf)248    score = crf._compute_score(emissions, tags, mask)249    # print(score)250    # print(score.shape)251    normalizer = crf._compute_normalizer(emissions, mask)252    # print(normalizer)253    # print(normalizer.shape)254    llh = crf(emissions, tags)255    print(llh)...test_tests_tags.py
Source:test_tests_tags.py  
...27            pass28        fc = FakeClass()29        self.assertEqual(fc.test_tags, {'at_install', 'standard', 'slow'})30        self.assertEqual(fc.test_module, 'base')31    def test_set_tags_multiple_tags(self):32        """Test the set_tags decorator with multiple tags"""33        @tagged('slow', 'nightly')34        class FakeClass(TransactionCase):35            pass36        fc = FakeClass()37        self.assertEqual(fc.test_tags, {'at_install', 'standard', 'slow', 'nightly'})38        self.assertEqual(fc.test_module, 'base')39    def test_inheritance(self):40        """Test inheritance when using the 'tagged' decorator"""41        @tagged('slow')42        class FakeClassA(TransactionCase):43            pass44        @tagged('nightly')45        class FakeClassB(FakeClassA):...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
