Best JavaScript code snippet using istanbul
op_def_library_test.py
Source:op_def_library_test.py
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Tests for tensorflow.python.ops.op_def_library."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19from google.protobuf import text_format20from tensorflow.core.framework import op_def_pb221from tensorflow.core.framework import tensor_shape_pb222from tensorflow.python.framework import dtypes23from tensorflow.python.framework import ops24from tensorflow.python.framework import tensor_shape25from tensorflow.python.framework import test_util26from tensorflow.python.framework.op_def_library import OpDefLibrary27from tensorflow.python.platform import googletest28def _unknown_shape(op):29 """Shape function for use with ops whose output shapes are unknown."""30 return [tensor_shape.unknown_shape() for _ in op.outputs]31# NOTE(mrry): Dummy shape registrations for ops used in the tests, since they32# don't have C++ op registrations on which to attach C++ shape fns.33ops.RegisterShape("Attr")(_unknown_shape)34ops.RegisterShape("AttrBool")(_unknown_shape)35ops.RegisterShape("AttrBoolList")(_unknown_shape)36ops.RegisterShape("AttrDefault")(_unknown_shape)37ops.RegisterShape("AttrEmptyListDefault")(_unknown_shape)38ops.RegisterShape("AttrEnum")(_unknown_shape)39ops.RegisterShape("AttrEnumList")(_unknown_shape)40ops.RegisterShape("AttrFloat")(_unknown_shape)41ops.RegisterShape("AttrListDefault")(_unknown_shape)42ops.RegisterShape("AttrListMin")(_unknown_shape)43ops.RegisterShape("AttrMin")(_unknown_shape)44ops.RegisterShape("AttrShape")(_unknown_shape)45ops.RegisterShape("AttrShapeList")(_unknown_shape)46ops.RegisterShape("AttrPartialShape")(_unknown_shape)47ops.RegisterShape("AttrPartialShapeList")(_unknown_shape)48ops.RegisterShape("AttrTypeDefault")(_unknown_shape)49ops.RegisterShape("AttrListTypeDefault")(_unknown_shape)50ops.RegisterShape("Binary")(_unknown_shape)51ops.RegisterShape("ComplexStruct")(_unknown_shape)52ops.RegisterShape("InPolymorphicTwice")(_unknown_shape)53ops.RegisterShape("MixedStruct")(_unknown_shape)54ops.RegisterShape("NInPolymorphicTwice")(_unknown_shape)55ops.RegisterShape("NInTwice")(_unknown_shape)56ops.RegisterShape("NInTwoTypeVariables")(_unknown_shape)57ops.RegisterShape("NIntsIn")(_unknown_shape)58ops.RegisterShape("NIntsOut")(_unknown_shape)59ops.RegisterShape("NIntsOutDefault")(_unknown_shape)60ops.RegisterShape("NPolymorphicIn")(_unknown_shape)61ops.RegisterShape("NPolymorphicOut")(_unknown_shape)62ops.RegisterShape("NPolymorphicOutDefault")(_unknown_shape)63ops.RegisterShape("NPolymorphicRestrictIn")(_unknown_shape)64ops.RegisterShape("NPolymorphicRestrictOut")(_unknown_shape)65ops.RegisterShape("OutT")(_unknown_shape)66ops.RegisterShape("OutTypeList")(_unknown_shape)67ops.RegisterShape("OutTypeListRestrict")(_unknown_shape)68ops.RegisterShape("Polymorphic")(_unknown_shape)69ops.RegisterShape("PolymorphicDefaultOut")(_unknown_shape)70ops.RegisterShape("PolymorphicOut")(_unknown_shape)71ops.RegisterShape("RefIn")(_unknown_shape)72ops.RegisterShape("RefOut")(_unknown_shape)73ops.RegisterShape("ReservedAttr")(_unknown_shape)74ops.RegisterShape("ReservedInput")(_unknown_shape)75ops.RegisterShape("Restrict")(_unknown_shape)76ops.RegisterShape("Simple")(_unknown_shape)77ops.RegisterShape("SimpleStruct")(_unknown_shape)78ops.RegisterShape("TwoRefsIn")(_unknown_shape)79ops.RegisterShape("TypeList")(_unknown_shape)80ops.RegisterShape("TypeListRestrict")(_unknown_shape)81ops.RegisterShape("TypeListTwice")(_unknown_shape)82class OpDefLibraryTest(test_util.TensorFlowTestCase):83 def setUp(self):84 self._lib = OpDefLibrary()85 self._g = ops.Graph()86 self._default_graph_controller = self._g.as_default()87 self._default_graph_controller.__enter__()88 self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "89 "output_arg { name: 'out' type: DT_FLOAT }")90 self._add_op("name: 'OutT' output_arg { name: 'a' type_attr: 'T' } "91 "attr { name: 'T' type: 'type' }")92 def tearDown(self):93 self._default_graph_controller.__exit__(None, None, None)94 def _add_op(self, ascii):95 op_def = op_def_pb2.OpDef()96 text_format.Merge(ascii, op_def)97 self._lib.add_op(op_def)98 def Tensor(self, t, name="in"):99 return self._lib.apply_op("OutT", T=t, name=name)100 def testNoRegisteredOpFails(self):101 with self.assertRaises(RuntimeError) as cm:102 self._lib.apply_op("unknown")103 self.assertEqual(str(cm.exception), "Unrecognized Op name unknown")104 def testAddOpValidation(self):105 with self.assertRaises(TypeError) as cm:106 self._add_op("name: 'MissingTypeAttr' "107 "input_arg { name: 'a' type_attr: 'T' } ")108 self.assertEqual(str(cm.exception),109 "Inconsistent OpDef for 'MissingTypeAttr', "110 "missing attr 'T'")111 with self.assertRaises(TypeError) as cm:112 self._add_op("name: 'BadTypeAttr' "113 "output_arg { name: 'a' type_attr: 'T' } "114 "attr { name: 'T' type: 'int' }")115 self.assertEqual(116 str(cm.exception),117 "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int")118 with self.assertRaises(TypeError) as cm:119 self._add_op("name: 'MissingNumberAttr' "120 "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ")121 self.assertEqual(str(cm.exception),122 "Inconsistent OpDef for 'MissingNumberAttr', "123 "missing attr 'N'")124 with self.assertRaises(TypeError) as cm:125 self._add_op("name: 'BadNumberAttr' "126 "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "127 "attr { name: 'N' type: 'type' }")128 self.assertEqual(129 str(cm.exception),130 "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type")131 with self.assertRaises(TypeError) as cm:132 self._add_op("name: 'TwoTypesA' "133 "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } "134 "attr { name: 'T' type: 'type' }")135 self.assertEqual(str(cm.exception),136 "Arg 'a' of 'TwoTypesA' must have one type field not 2")137 with self.assertRaises(TypeError) as cm:138 self._add_op("name: 'TwoTypesB' "139 "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } "140 "attr { name: 'T' type: 'list(type)' }")141 self.assertEqual(str(cm.exception),142 "Arg 'a' of 'TwoTypesB' must have one type field not 2")143 with self.assertRaises(TypeError) as cm:144 self._add_op("name: 'ThreeTypes' "145 "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' "146 "type_list_attr: 'U' } "147 "attr { name: 'T' type: 'type' } "148 "attr { name: 'U' type: 'list(type)' }")149 self.assertEqual(str(cm.exception),150 "Arg 'a' of 'ThreeTypes' must have one type field not 3")151 with self.assertRaises(TypeError) as cm:152 self._add_op("name: 'NoTypes' output_arg { name: 'a' } ")153 self.assertEqual(str(cm.exception),154 "Arg 'a' of 'NoTypes' must have one type field not 0")155 def testSimple(self):156 out = self._lib.apply_op("Simple", a=3)157 self.assertEqual(dtypes.float32, out.dtype)158 self.assertProtoEquals("""159 name: 'Simple' op: 'Simple' input: 'Simple/a'160 """, out.op.node_def)161 out = self._lib.apply_op("Simple", a=4)162 self.assertProtoEquals("""163 name: 'Simple_1' op: 'Simple' input: 'Simple_1/a'164 """, out.op.node_def)165 out = self._lib.apply_op("Simple", a=5, name="named")166 self.assertProtoEquals("""167 name: 'named' op: 'Simple' input: 'named/a'168 """, out.op.node_def)169 out = self._lib.apply_op("Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d")170 self.assertProtoEquals("""171 name: 'two_d' op: 'Simple' input: 'two_d/a'172 """, out.op.node_def)173 def testSimpleFailures(self):174 with self.assertRaises(TypeError) as cm:175 self._lib.apply_op("Simple", a="Bad string")176 self.assertEqual(str(cm.exception),177 "Expected int32 passed to parameter 'a' of op 'Simple', "178 "got 'Bad string' of type 'str' instead.")179 with self.assertRaises(TypeError) as cm:180 self._lib.apply_op("Simple", a=self.Tensor(dtypes.string))181 self.assertEqual(str(cm.exception),182 "Input 'a' of 'Simple' Op has type string "183 "that does not match expected type of int32.")184 with self.assertRaises(TypeError) as cm:185 self._lib.apply_op("Simple", a=6, extra="bogus")186 self.assertEqual(str(cm.exception),187 "apply_op() got unexpected keyword arguments: extra")188 with self.assertRaises(TypeError) as cm:189 self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus")190 self.assertEqual(str(cm.exception),191 "apply_op() got unexpected keyword arguments: extra1, "192 "extra2")193 with self.assertRaises(TypeError) as cm:194 self._lib.apply_op("Simple")195 self.assertEqual(str(cm.exception), "No argument for input a")196 with self.assertRaises(TypeError) as cm:197 self._lib.apply_op("Simple", wrong=7)198 self.assertEqual(str(cm.exception), "No argument for input a")199 with self.assertRaises(TypeError) as cm:200 self._lib.apply_op("Simple", a={"label": 1})201 self.assertEqual(str(cm.exception),202 "Expected int32 passed to parameter 'a' of op 'Simple', "203 "got {'label': 1} of type 'dict' instead.")204 def testReservedInput(self):205 self._add_op("name: 'ReservedInput' "206 "input_arg { name: 'input' type: DT_INT32 } ")207 op = self._lib.apply_op("ReservedInput", input_=7, name="x")208 self.assertProtoEquals("""209 name: 'x' op: 'ReservedInput' input: 'x/input'210 """, op.node_def)211 def testPolymorphic(self):212 self._add_op("name: 'Polymorphic' "213 "input_arg { name: 'a' type_attr: 'T' } "214 "output_arg { name: 'out' type_attr: 'T' } "215 "attr { name: 'T' type: 'type' }")216 out = self._lib.apply_op("Polymorphic", a=7, name="p")217 self.assertEqual(dtypes.int32, out.dtype)218 self.assertProtoEquals("""219 name: 'p' op: 'Polymorphic' input: 'p/a'220 attr { key: 'T' value { type: DT_INT32 } }221 """, out.op.node_def)222 out = self._lib.apply_op("Polymorphic", a="s", name="q")223 self.assertEqual(dtypes.string, out.dtype)224 self.assertProtoEquals("""225 name: 'q' op: 'Polymorphic' input: 'q/a'226 attr { key: 'T' value { type: DT_STRING } }227 """, out.op.node_def)228 out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r")229 self.assertEqual(dtypes.string, out.dtype)230 self.assertProtoEquals("""231 name: 'r' op: 'Polymorphic' input: 'r/a'232 attr { key: 'T' value { type: DT_STRING } }233 """, out.op.node_def)234 with self.assertRaises(TypeError) as cm:235 self._lib.apply_op("Polymorphic", a="s", T=dtypes.string)236 self.assertEqual(str(cm.exception),237 "Should not specify value for inferred attr 'T'.")238 def testPolymorphicOut(self):239 self._add_op("name: 'PolymorphicOut' "240 "output_arg { name: 'out' type_attr: 'T' } "241 "attr { name: 'T' type: 'type' }")242 out = self._lib.apply_op("PolymorphicOut", T=dtypes.int32, name="p")243 self.assertEqual(dtypes.int32, out.dtype)244 self.assertProtoEquals("""245 name: 'p' op: 'PolymorphicOut'246 attr { key: 'T' value { type: DT_INT32 } }247 """, out.op.node_def)248 out = self._lib.apply_op("PolymorphicOut", T=dtypes.bool, name="q")249 self.assertEqual(dtypes.bool, out.dtype)250 self.assertProtoEquals("""251 name: 'q' op: 'PolymorphicOut'252 attr { key: 'T' value { type: DT_BOOL } }253 """, out.op.node_def)254 with self.assertRaises(TypeError) as cm:255 self._lib.apply_op("PolymorphicOut")256 self.assertEqual(str(cm.exception),257 "No argument for attr T")258 with self.assertRaises(TypeError) as cm:259 self._lib.apply_op("PolymorphicOut", T=None)260 self.assertEqual(str(cm.exception),261 "Expected DataType for argument 'T' not None.")262 def testPolymorphicDefaultOut(self):263 self._add_op("name: 'PolymorphicDefaultOut' "264 "output_arg { name: 'out' type_attr: 'T' } "265 "attr { name: 'T' type: 'type' "266 " default_value { type: DT_STRING } }")267 out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p")268 self.assertEqual(dtypes.string, out.dtype)269 self.assertProtoEquals("""270 name: 'p' op: 'PolymorphicDefaultOut'271 attr { key: 'T' value { type: DT_STRING } }272 """, out.op.node_def)273 out = self._lib.apply_op("PolymorphicDefaultOut", T=dtypes.bool, name="q")274 self.assertEqual(dtypes.bool, out.dtype)275 self.assertProtoEquals("""276 name: 'q' op: 'PolymorphicDefaultOut'277 attr { key: 'T' value { type: DT_BOOL } }278 """, out.op.node_def)279 def testBinary(self):280 self._add_op("name: 'Binary' "281 "input_arg { name: 'a' type_attr: 'T' } "282 "input_arg { name: 'b' type_attr: 'T' } "283 "output_arg { name: 'out' type_attr: 'T' } "284 "attr { name: 'T' type: 'type' }")285 out = self._lib.apply_op("Binary", a=8, b=9, name="b")286 self.assertEqual(dtypes.int32, out.dtype)287 self.assertProtoEquals("""288 name: 'b' op: 'Binary' input: 'b/a' input: 'b/b'289 attr { key: 'T' value { type: DT_INT32 } }290 """, out.op.node_def)291 out = self._lib.apply_op("Binary", a="left", b="right", name="c")292 self.assertEqual(dtypes.string, out.dtype)293 self.assertProtoEquals("""294 name: 'c' op: 'Binary' input: 'c/a' input: 'c/b'295 attr { key: 'T' value { type: DT_STRING } }296 """, out.op.node_def)297 with self.assertRaises(TypeError) as cm:298 self._lib.apply_op("Binary", a="left", b=12)299 self.assertEqual(str(cm.exception),300 "Expected string passed to parameter 'b' of op 'Binary', "301 "got 12 of type 'int' instead.")302 with self.assertRaises(TypeError) as cm:303 self._lib.apply_op("Binary",304 a=self.Tensor(dtypes.string),305 b=self.Tensor(dtypes.int32))306 self.assertEqual(str(cm.exception),307 "Input 'b' of 'Binary' Op has type int32 "308 "that does not match type string of argument 'a'.")309 def testRestrict(self):310 self._add_op("name: 'Restrict' "311 "input_arg { name: 'a' type_attr: 'T' } "312 "output_arg { name: 'out' type_attr: 'T' } "313 "attr { name: 'T' type: 'type' allowed_values { list { "314 " type: DT_STRING type: DT_BOOL } } }")315 out = self._lib.apply_op("Restrict", a="foo", name="g")316 self.assertEqual(dtypes.string, out.dtype)317 self.assertProtoEquals("""318 name: 'g' op: 'Restrict' input: 'g/a'319 attr { key: 'T' value { type: DT_STRING } }320 """, out.op.node_def)321 out = self._lib.apply_op("Restrict", a=True, name="h")322 self.assertEqual(dtypes.bool, out.dtype)323 self.assertProtoEquals("""324 name: 'h' op: 'Restrict' input: 'h/a'325 attr { key: 'T' value { type: DT_BOOL } }326 """, out.op.node_def)327 with self.assertRaises(TypeError) as cm:328 self._lib.apply_op("Restrict", a=17)329 self.assertEqual(str(cm.exception),330 "Value passed to parameter 'a' has DataType int32 "331 "not in list of allowed values: string, bool")332 def testTypeList(self):333 self._add_op("name: 'TypeList' "334 "input_arg { name: 'a' type_list_attr: 'T' } "335 "attr { name: 'T' type: 'list(type)' }")336 op = self._lib.apply_op("TypeList", a=["foo"], name="z")337 self.assertProtoEquals("""338 name: 'z' op: 'TypeList' input: 'z/a_0'339 attr { key: 'T' value { list { type: DT_STRING } } }340 """, op.node_def)341 op = self._lib.apply_op("TypeList", a=[True, 12], name="y")342 self.assertProtoEquals("""343 name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1'344 attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } }345 """, op.node_def)346 op = self._lib.apply_op("TypeList", a=[], name="empty")347 self.assertProtoEquals("""348 name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } }349 """, op.node_def)350 with self.assertRaises(TypeError) as cm:351 self._lib.apply_op("TypeList", a=17)352 self.assertStartsWith(str(cm.exception),353 "Expected list for 'a' "354 "argument to 'TypeList' Op, not ")355 with self.assertRaises(TypeError) as cm:356 self._lib.apply_op("TypeList", a=[self.Tensor(dtypes.int32), None])357 self.assertStartsWith(str(cm.exception),358 "Tensors in list passed to 'a' of 'TypeList' Op "359 "have types [int32, <NOT CONVERTIBLE TO TENSOR>]")360 def testTypeListTwice(self):361 self._add_op("name: 'TypeListTwice' "362 "input_arg { name: 'a' type_list_attr: 'T' } "363 "input_arg { name: 'b' type_list_attr: 'T' } "364 "attr { name: 'T' type: 'list(type)' }")365 op = self._lib.apply_op("TypeListTwice",366 a=["foo", True],367 b=["bar", False],368 name="z")369 self.assertProtoEquals("""370 name: 'z' op: 'TypeListTwice'371 input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1'372 attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }373 """, op.node_def)374 op = self._lib.apply_op("TypeListTwice", a=[], b=[], name="empty")375 self.assertProtoEquals("""376 name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } }377 """, op.node_def)378 with self.assertRaises(TypeError) as cm:379 self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6])380 self.assertEqual(str(cm.exception),381 "Input 'b' of 'TypeListTwice' Op has type list of "382 "string, int32 that does not match type list "383 "string, bool of argument 'a'.")384 def testOutTypeList(self):385 self._add_op("name: 'OutTypeList' "386 "output_arg { name: 'out' type_list_attr: 'T' } "387 "attr { name: 'T' type: 'list(type)' }")388 out, = self._lib.apply_op("OutTypeList", T=[dtypes.float32], name="x")389 self.assertEqual(dtypes.float32, out.dtype)390 self.assertProtoEquals("""391 name: 'x' op: 'OutTypeList'392 attr { key: 'T' value { list { type: DT_FLOAT } } }393 """, out.op.node_def)394 out1, out2 = self._lib.apply_op("OutTypeList",395 T=[dtypes.int32, dtypes.bool],396 name="w")397 self.assertEqual(dtypes.int32, out1.dtype)398 self.assertEqual(dtypes.bool, out2.dtype)399 self.assertProtoEquals("""400 name: 'w' op: 'OutTypeList'401 attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } }402 """, out1.op.node_def)403 out = self._lib.apply_op("OutTypeList", T=[], name="empty")404 self.assertEqual([], out)405 with self.assertRaises(TypeError) as cm:406 self._lib.apply_op("OutTypeList", T=dtypes.int32)407 self.assertEqual(str(cm.exception), "Expected list for attr T")408 def testTypeListRestrict(self):409 self._add_op("name: 'TypeListRestrict' "410 "input_arg { name: 'a' type_list_attr: 'T' } "411 "attr { name: 'T' type: 'list(type)' allowed_values { list { "412 " type: DT_STRING type: DT_BOOL } } }")413 op = self._lib.apply_op("TypeListRestrict", a=["foo", False], name="v")414 self.assertProtoEquals("""415 name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1'416 attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }417 """, op.node_def)418 with self.assertRaises(TypeError) as cm:419 self._lib.apply_op("TypeListRestrict", a=[True, 12])420 self.assertEqual(str(cm.exception),421 "Value passed to parameter 'a' has DataType int32 "422 "not in list of allowed values: string, bool")423 def testOutTypeListRestrict(self):424 self._add_op("name: 'OutTypeListRestrict' "425 "output_arg { name: 'out' type_list_attr: 't' } "426 "attr { name: 't' type: 'list(type)' allowed_values { list { "427 " type: DT_STRING type: DT_BOOL } } }")428 out1, out2 = self._lib.apply_op("OutTypeListRestrict",429 t=[dtypes.bool, dtypes.string],430 name="u")431 self.assertEqual(dtypes.bool, out1.dtype)432 self.assertEqual(dtypes.string, out2.dtype)433 self.assertProtoEquals("""434 name: 'u' op: 'OutTypeListRestrict'435 attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } }436 """, out1.op.node_def)437 with self.assertRaises(TypeError) as cm:438 self._lib.apply_op("OutTypeListRestrict", t=[dtypes.string, dtypes.int32])439 self.assertEqual(str(cm.exception),440 "Value passed to parameter 't' has DataType int32 "441 "not in list of allowed values: string, bool")442 def testAttr(self):443 self._add_op("name: 'Attr' attr { name: 'a' type: 'int' }")444 op = self._lib.apply_op("Attr", a=12, name="t")445 self.assertProtoEquals("""446 name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } }447 """, op.node_def)448 op = self._lib.apply_op("Attr", a=tensor_shape.Dimension(13), name="u")449 self.assertProtoEquals("""450 name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } }451 """, op.node_def)452 with self.assertRaises(TypeError) as cm:453 self._lib.apply_op("Attr", a="bad")454 self.assertEqual(str(cm.exception),455 "Expected int for argument 'a' not 'bad'.")456 with self.assertRaises(TypeError) as cm:457 self._lib.apply_op("Attr", a=[12])458 self.assertEqual(str(cm.exception),459 "Expected int for argument 'a' not [12].")460 with self.assertRaises(TypeError) as cm:461 self._lib.apply_op("Attr", a=None)462 self.assertEqual(str(cm.exception),463 "Expected int for argument 'a' not None.")464 with self.assertRaises(TypeError) as cm:465 self._lib.apply_op("Attr")466 self.assertEqual(str(cm.exception), "No argument for attr a")467 def testAttrFloat(self):468 self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }")469 op = self._lib.apply_op("AttrFloat", a=1.2, name="t")470 self.assertProtoEquals("""471 name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } }472 """, op.node_def)473 op = self._lib.apply_op("AttrFloat", a=12, name="u")474 self.assertProtoEquals("""475 name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } }476 """, op.node_def)477 with self.assertRaises(TypeError) as cm:478 self._lib.apply_op("AttrFloat", a="bad")479 self.assertEqual(str(cm.exception),480 "Expected float for argument 'a' not 'bad'.")481 def testAttrBool(self):482 self._add_op("name: 'AttrBool' attr { name: 'a' type: 'bool' }")483 op = self._lib.apply_op("AttrBool", a=True, name="t")484 self.assertProtoEquals("""485 name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } }486 """, op.node_def)487 op = self._lib.apply_op("AttrBool", a=False, name="u")488 self.assertProtoEquals("""489 name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } }490 """, op.node_def)491 with self.assertRaises(TypeError) as cm:492 self._lib.apply_op("AttrBool", a=0)493 self.assertEqual(str(cm.exception),494 "Expected bool for argument 'a' not 0.")495 with self.assertRaises(TypeError) as cm:496 self._lib.apply_op("AttrBool", a=1)497 self.assertEqual(str(cm.exception),498 "Expected bool for argument 'a' not 1.")499 with self.assertRaises(TypeError) as cm:500 self._lib.apply_op("AttrBool", a=[])501 self.assertEqual(str(cm.exception),502 "Expected bool for argument 'a' not [].")503 def testAttrBoolList(self):504 self._add_op("name: 'AttrBoolList' attr { name: 'a' type: 'list(bool)' }")505 op = self._lib.apply_op("AttrBoolList", a=[True, False, True], name="t")506 self.assertProtoEquals("""507 name: 't' op: 'AttrBoolList'508 attr { key: 'a' value { list { b: true b: false b:true } } }509 """, op.node_def)510 op = self._lib.apply_op("AttrBoolList", a=[], name="u")511 self.assertProtoEquals("""512 name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } }513 """, op.node_def)514 with self.assertRaises(TypeError) as cm:515 self._lib.apply_op("AttrBoolList", a=[0])516 self.assertEqual(str(cm.exception),517 "Expected bool for argument 'a' not 0.")518 def testAttrMin(self):519 self._add_op("name: 'AttrMin' attr { name: 'a' type: 'int' "520 "has_minimum: true minimum: 5 }")521 op = self._lib.apply_op("AttrMin", a=12, name="s")522 self.assertProtoEquals("""523 name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } }524 """, op.node_def)525 with self.assertRaises(ValueError) as cm:526 self._lib.apply_op("AttrMin", a=2)527 self.assertEqual(str(cm.exception),528 "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.")529 def testAttrListMin(self):530 self._add_op("name: 'AttrListMin' attr { name: 'a' type: 'list(int)' "531 "has_minimum: true minimum: 2 }")532 op = self._lib.apply_op("AttrListMin", a=[1, 2], name="r")533 self.assertProtoEquals("""534 name: 'r' op: 'AttrListMin'535 attr { key: 'a' value { list { i: 1 i: 2 } } }536 """, op.node_def)537 with self.assertRaises(ValueError) as cm:538 self._lib.apply_op("AttrListMin", a=[17])539 self.assertEqual(str(cm.exception),540 "Attr 'a' of 'AttrListMin' Op "541 "passed list of length 1 less than minimum 2.")542 def testAttrEnum(self):543 self._add_op("name: 'AttrEnum' "544 "attr { name: 'a' type: 'string' "545 " allowed_values { list { s: 'apples' s: 'oranges' } } }")546 op = self._lib.apply_op("AttrEnum", a="oranges", name="e")547 self.assertProtoEquals("""548 name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } }549 """, op.node_def)550 with self.assertRaises(ValueError) as cm:551 self._lib.apply_op("AttrEnum", a="invalid")552 self.assertEqual(str(cm.exception),553 'Attr \'a\' of \'AttrEnum\' Op '554 'passed string \'invalid\' not in: '555 '"apples", "oranges".')556 def testAttrEnumList(self):557 self._add_op("name: 'AttrEnumList' "558 "attr { name: 'a' type: 'list(string)' "559 " allowed_values { list { s: 'apples' s: 'oranges' } } }")560 op = self._lib.apply_op("AttrEnumList", a=["oranges", "apples"], name="f")561 self.assertProtoEquals("""562 name: 'f' op: 'AttrEnumList'563 attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } }564 """, op.node_def)565 with self.assertRaises(ValueError) as cm:566 self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"])567 self.assertEqual(str(cm.exception),568 'Attr \'a\' of \'AttrEnumList\' Op '569 'passed string \'invalid\' not '570 'in: "apples", "oranges".')571 def testAttrShape(self):572 self._add_op("name: 'AttrShape' attr { name: 'a' type: 'shape' }")573 op = self._lib.apply_op("AttrShape", a=[5], name="s1")574 self.assertProtoEquals("""575 name: 's1' op: 'AttrShape'576 attr { key: 'a' value { shape { dim { size: 5 } } } }577 """, op.node_def)578 op = self._lib.apply_op("AttrShape", a=(4, 3, 2), name="s2")579 self.assertProtoEquals("""580 name: 's2' op: 'AttrShape'581 attr { key: 'a' value {582 shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } }583 """, op.node_def)584 op = self._lib.apply_op(585 "AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3")586 self.assertProtoEquals("""587 name: 's3' op: 'AttrShape'588 attr { key: 'a' value {589 shape { dim { size: 3 } dim { size: 2 } } } }590 """, op.node_def)591 op = self._lib.apply_op("AttrShape", a=[], name="s4")592 self.assertProtoEquals("""593 name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } }594 """, op.node_def)595 shape = tensor_shape_pb2.TensorShapeProto()596 shape.dim.add().size = 6597 shape.dim.add().size = 3598 op = self._lib.apply_op("AttrShape", a=shape, name="s5")599 self.assertProtoEquals("""600 name: 's5' op: 'AttrShape'601 attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } }602 """, op.node_def)603 # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes.604 # with self.assertRaises(TypeError) as cm:605 # self._lib.apply_op("AttrShape", a=5)606 # self.assertEqual(str(cm.exception),607 # "Don't know how to convert 5 to a TensorShapeProto for "608 # "argument 'a'")609 with self.assertRaises(TypeError):610 self._lib.apply_op("AttrShape", a="ABC")611 def testAttrShapeList(self):612 self._add_op("name: 'AttrShapeList' attr { name: 'a' type: 'list(shape)' }")613 op = self._lib.apply_op("AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl")614 self.assertProtoEquals("""615 name: 'sl' op: 'AttrShapeList'616 attr { key: 'a' value { list {617 shape { dim { size: 3 } dim { size: 2 } }618 shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } }619 """, op.node_def)620 op = self._lib.apply_op("AttrShapeList", a=[], name="esl")621 self.assertProtoEquals("""622 name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } }623 """, op.node_def)624 def testAttrPartialShape(self):625 self._add_op(626 "name: 'AttrPartialShape' attr { name: 'a' type: 'shape' }")627 op = self._lib.apply_op("AttrPartialShape", a=[5], name="s1")628 self.assertProtoEquals("""629 name: 's1' op: 'AttrPartialShape'630 attr { key: 'a' value { shape { dim { size: 5 } } } }631 """, op.node_def)632 op = self._lib.apply_op("AttrPartialShape", a=(4, None, 2), name="s2")633 self.assertProtoEquals("""634 name: 's2' op: 'AttrPartialShape'635 attr { key: 'a' value {636 shape { dim { size: 4 } dim { size: -1 } dim { size: 2 } } } }637 """, op.node_def)638 op = self._lib.apply_op(639 "AttrPartialShape", a=tensor_shape.TensorShape([3, None]), name="s3")640 self.assertProtoEquals("""641 name: 's3' op: 'AttrPartialShape'642 attr { key: 'a' value {643 shape { dim { size: 3 } dim { size: -1 } } } }644 """, op.node_def)645 op = self._lib.apply_op("AttrPartialShape", a=[], name="s4")646 self.assertProtoEquals("""647 name: 's4' op: 'AttrPartialShape'648 attr { key: 'a' value { shape { } } }649 """, op.node_def)650 shape = tensor_shape_pb2.TensorShapeProto()651 shape.dim.add().size = -1652 shape.dim.add().size = 3653 op = self._lib.apply_op("AttrPartialShape", a=shape, name="s5")654 self.assertProtoEquals("""655 name: 's5' op: 'AttrPartialShape'656 attr { key: 'a' value {657 shape { dim { size: -1 } dim { size: 3 } } } }658 """, op.node_def)659 # TODO(ebrevdo): Re-enable once we stop promoting scalars to shapes.660 # with self.assertRaises(TypeError) as cm:661 # self._lib.apply_op("AttrPartialShape", a=5)662 # self.assertEqual(str(cm.exception),663 # "Don't know how to convert 5 to a TensorShapeProto for "664 # "argument 'a'")665 with self.assertRaises(TypeError):666 self._lib.apply_op("AttrPartialShape", a="ABC")667 def testAttrPartialShapeList(self):668 self._add_op("""669 name: 'AttrPartialShapeList'670 attr { name: 'a' type: 'list(shape)' }671 """)672 op = self._lib.apply_op(673 "AttrPartialShapeList", a=[[3, 2], [6, None, 4]], name="sl")674 self.assertProtoEquals("""675 name: 'sl' op: 'AttrPartialShapeList'676 attr { key: 'a' value { list {677 shape { dim { size: 3 } dim { size: 2 } }678 shape { dim { size: 6 } dim { size: -1 } dim { size: 4 } } } } }679 """, op.node_def)680 op = self._lib.apply_op("AttrPartialShapeList", a=[], name="esl")681 self.assertProtoEquals("""682 name: 'esl' op: 'AttrPartialShapeList' attr {683 key: 'a' value { list { } } }684 """, op.node_def)685 def testAttrDefault(self):686 self._add_op("name: 'AttrDefault' "687 "attr { name: 'a' type: 'string' "688 " default_value { s: 'banana' } }")689 op = self._lib.apply_op("AttrDefault", a=None, name="d")690 self.assertProtoEquals("""691 name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } }692 """, op.node_def)693 op = self._lib.apply_op("AttrDefault", a="kiwi", name="c")694 self.assertProtoEquals("""695 name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } }696 """, op.node_def)697 def testAttrListDefault(self):698 self._add_op("name: 'AttrListDefault' "699 "attr { name: 'a' type: 'list(int)' "700 " default_value { list { i: 5 i: 15 } } }")701 op = self._lib.apply_op("AttrListDefault", a=None, name="b")702 self.assertProtoEquals("""703 name: 'b' op: 'AttrListDefault'704 attr { key: 'a' value { list { i: 5 i: 15 } } }705 """, op.node_def)706 op = self._lib.apply_op("AttrListDefault", a=[3], name="a")707 self.assertProtoEquals("""708 name: 'a' op: 'AttrListDefault'709 attr { key: 'a' value { list { i: 3 } } }710 """, op.node_def)711 op = self._lib.apply_op("AttrListDefault", a=[], name="empty")712 self.assertProtoEquals("""713 name: 'empty' op: 'AttrListDefault'714 attr { key: 'a' value { list { } } }715 """, op.node_def)716 def testAttrEmptyListDefault(self):717 self._add_op("name: 'AttrEmptyListDefault' "718 "attr { name: 'a' type: 'list(float)' "719 " default_value { list { } } }")720 op = self._lib.apply_op("AttrEmptyListDefault", a=None, name="b")721 self.assertProtoEquals("""722 name: 'b' op: 'AttrEmptyListDefault'723 attr { key: 'a' value { list { } } }724 """, op.node_def)725 op = self._lib.apply_op("AttrEmptyListDefault", a=[3], name="a")726 self.assertProtoEquals("""727 name: 'a' op: 'AttrEmptyListDefault'728 attr { key: 'a' value { list { f: 3 } } }729 """, op.node_def)730 op = self._lib.apply_op("AttrEmptyListDefault", a=[], name="empty")731 self.assertProtoEquals("""732 name: 'empty' op: 'AttrEmptyListDefault'733 attr { key: 'a' value { list { } } }734 """, op.node_def)735 def testReservedAttr(self):736 self._add_op("name: 'ReservedAttr' "737 "attr { name: 'range' type: 'int' } ")738 op = self._lib.apply_op("ReservedAttr", range_=7, name="x")739 self.assertProtoEquals("""740 name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }741 """, op.node_def)742 def testDefaultAttrType(self):743 self._add_op("name: 'AttrTypeDefault' "744 "input_arg { name: 'a' type_attr: 'T' } "745 "attr { name: 'T' type: 'type' "746 " default_value { type: DT_INT32 } }")747 # Give an input whose type has no obvious output type.748 op = self._lib.apply_op("AttrTypeDefault", a=[], name="n")749 self.assertProtoEquals("""750 name: 'n' op: 'AttrTypeDefault' input: 'n/a'751 attr { key: 'T' value { type: DT_INT32 } }752 """, op.node_def)753 # Give an input whose type can be inferred as different754 # than the default.755 op = self._lib.apply_op("AttrTypeDefault", a=[1.0], name="f")756 self.assertProtoEquals("""757 name: 'f' op: 'AttrTypeDefault' input: 'f/a'758 attr { key: 'T' value { type: DT_FLOAT } }759 """, op.node_def)760 def testDefaultListAttrType(self):761 self._add_op("name: 'AttrListTypeDefault' "762 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "763 "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "764 "attr { name: 'T' type: 'type' "765 " default_value { type: DT_INT32 } }"766 "attr { name: 'N' type: 'int' }")767 # Give an input whose type can be inferred as different768 # than the default.769 op = self._lib.apply_op("AttrListTypeDefault", a=[1.0], b=[2.0], name="n")770 self.assertProtoEquals("""771 name: 'n' op: 'AttrListTypeDefault' input: 'n/a_0' input: 'n/b_0'772 attr { key: 'T' value { type: DT_FLOAT } }773 attr { key: 'N' value { i: 1 } }774 """, op.node_def)775 def testNIntsIn(self):776 self._add_op("name: 'NIntsIn' "777 "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "778 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")779 op = self._lib.apply_op("NIntsIn", a=[1, 2], name="n")780 self.assertProtoEquals("""781 name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1'782 attr { key: 'N' value { i: 2 } }783 """, op.node_def)784 op = self._lib.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o")785 self.assertProtoEquals("""786 name: 'o' op: 'NIntsIn'787 input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'788 attr { key: 'N' value { i: 5 } }789 """, op.node_def)790 with self.assertRaises(TypeError) as cm:791 self._lib.apply_op("NIntsIn", a=["foo", "bar"])792 self.assertEqual(str(cm.exception),793 "Tensors in list passed to 'a' of 'NIntsIn' Op have types "794 "[string, string] that do not match expected type int32.")795 with self.assertRaises(TypeError) as cm:796 self._lib.apply_op("NIntsIn",797 a=[self.Tensor(dtypes.string),798 self.Tensor(dtypes.string)])799 self.assertEqual(str(cm.exception),800 "Tensors in list passed to 'a' of 'NIntsIn' Op have "801 "types [string, string] that do not match expected type "802 "int32.")803 with self.assertRaises(ValueError) as cm:804 self._lib.apply_op("NIntsIn", a=[99])805 self.assertEqual(str(cm.exception),806 "List argument 'a' to 'NIntsIn' Op "807 "with length 1 shorter than "808 "minimum length 2.")809 with self.assertRaises(TypeError) as cm:810 self._lib.apply_op("NIntsIn", a=[38, "bar"])811 self.assertEqual(str(cm.exception),812 "Tensors in list passed to 'a' of 'NIntsIn' Op have types "813 "[int32, string] that do not match expected type int32.")814 with self.assertRaises(TypeError) as cm:815 self._lib.apply_op("NIntsIn",816 a=[self.Tensor(dtypes.int32),817 self.Tensor(dtypes.string)])818 self.assertEqual(str(cm.exception),819 "Tensors in list passed to 'a' of 'NIntsIn' Op "820 "have types [int32, string] that do not match expected "821 "type int32.")822 with self.assertRaises(TypeError) as cm:823 self._lib.apply_op("NIntsIn", a=17)824 self.assertStartsWith(str(cm.exception),825 "Expected list for 'a' argument "826 "to 'NIntsIn' Op, not ")827 def testNPolymorphicIn(self):828 self._add_op("name: 'NPolymorphicIn' "829 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "830 "attr { name: 'T' type: 'type' } "831 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")832 op = self._lib.apply_op("NPolymorphicIn", a=[1, 2], name="n")833 self.assertProtoEquals("""834 name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1'835 attr { key: 'T' value { type: DT_INT32 } }836 attr { key: 'N' value { i: 2 } }837 """, op.node_def)838 op = self._lib.apply_op("NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o")839 self.assertProtoEquals("""840 name: 'o' op: 'NPolymorphicIn'841 input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'842 attr { key: 'T' value { type: DT_INT32 } }843 attr { key: 'N' value { i: 5 } }844 """, op.node_def)845 op = self._lib.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p")846 self.assertProtoEquals("""847 name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1'848 attr { key: 'T' value { type: DT_STRING } }849 attr { key: 'N' value { i: 2 } }850 """, op.node_def)851 op = self._lib.apply_op("NPolymorphicIn",852 a=[1, self.Tensor(dtypes.float32, name="x")],853 name="q")854 self.assertProtoEquals("""855 name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x'856 attr { key: 'T' value { type: DT_FLOAT } }857 attr { key: 'N' value { i: 2 } }858 """, op.node_def)859 op = self._lib.apply_op("NPolymorphicIn",860 a=[self.Tensor(dtypes.float32, name="y"),861 self.Tensor(dtypes.float32_ref, name="z")],862 name="r")863 self.assertProtoEquals("""864 name: 'r' op: 'NPolymorphicIn' input: 'y' input: 'z'865 attr { key: 'T' value { type: DT_FLOAT } }866 attr { key: 'N' value { i: 2 } }867 """, op.node_def)868 with self.assertRaises(ValueError) as cm:869 self._lib.apply_op("NPolymorphicIn", a=[99])870 self.assertEqual(str(cm.exception),871 "List argument 'a' to 'NPolymorphicIn' Op with length 1 "872 "shorter than minimum length 2.")873 with self.assertRaises(TypeError) as cm:874 self._lib.apply_op("NPolymorphicIn", a=[38, "bar"])875 self.assertEqual(str(cm.exception),876 "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "877 "have types [int32, string] that don't all match.")878 with self.assertRaises(TypeError) as cm:879 self._lib.apply_op("NPolymorphicIn", a=[38, self.Tensor(dtypes.string)])880 self.assertEqual(str(cm.exception),881 "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "882 "have types [int32, string] that don't all match.")883 with self.assertRaises(TypeError) as cm:884 self._lib.apply_op("NPolymorphicIn", a=[38, None])885 self.assertEqual(str(cm.exception),886 "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "887 "have types [int32, <NOT CONVERTIBLE TO TENSOR>] that "888 "don't all match.")889 with self.assertRaises(TypeError) as cm:890 self._lib.apply_op("NPolymorphicIn",891 a=["abcd", self.Tensor(dtypes.int32)])892 self.assertEqual(str(cm.exception),893 "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "894 "have types [string, int32] that don't all match.")895 with self.assertRaises(TypeError) as cm:896 self._lib.apply_op("NPolymorphicIn", a=17)897 self.assertStartsWith(str(cm.exception),898 "Expected list for 'a' argument "899 "to 'NPolymorphicIn' Op, not ")900 def testNPolymorphicRestrictIn(self):901 self._add_op("name: 'NPolymorphicRestrictIn' "902 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "903 "attr { name: 'T' type: 'type' allowed_values { "904 " list { type: DT_STRING type: DT_BOOL } } } "905 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")906 op = self._lib.apply_op("NPolymorphicRestrictIn", a=["foo", "bar"],907 name="p")908 self.assertProtoEquals("""909 name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1'910 attr { key: 'T' value { type: DT_STRING } }911 attr { key: 'N' value { i: 2 } }912 """, op.node_def)913 op = self._lib.apply_op("NPolymorphicRestrictIn",914 a=[False, True, False],915 name="b")916 self.assertProtoEquals("""917 name: 'b' op: 'NPolymorphicRestrictIn'918 input: 'b/a_0' input: 'b/a_1' input: 'b/a_2'919 attr { key: 'T' value { type: DT_BOOL } }920 attr { key: 'N' value { i: 3 } }921 """, op.node_def)922 with self.assertRaises(TypeError) as cm:923 self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2])924 self.assertEqual(str(cm.exception),925 "Value passed to parameter 'a' has DataType int32 not in "926 "list of allowed values: string, bool")927 def testNInTwice(self):928 self._add_op("name: 'NInTwice' "929 "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "930 "input_arg { name: 'b' type: DT_STRING number_attr: 'N' } "931 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")932 op = self._lib.apply_op("NInTwice", a=[1, 2], b=["one", "two"], name="n")933 self.assertProtoEquals("""934 name: 'n' op: 'NInTwice'935 input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'936 attr { key: 'N' value { i: 2 } }937 """, op.node_def)938 op = self._lib.apply_op("NInTwice", a=[], b=[], name="o")939 self.assertProtoEquals("""940 name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } }941 """, op.node_def)942 with self.assertRaises(ValueError) as cm:943 self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"])944 self.assertEqual(str(cm.exception),945 "List argument 'b' to 'NInTwice' Op "946 "with length 1 must match "947 "length 3 of argument 'a'.")948 def testNInPolymorphicTwice(self):949 self._add_op("name: 'NInPolymorphicTwice' "950 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "951 "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "952 "attr { name: 'T' type: 'type' } "953 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")954 op = self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n")955 self.assertProtoEquals("""956 name: 'n' op: 'NInPolymorphicTwice'957 input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'958 attr { key: 'T' value { type: DT_INT32 } }959 attr { key: 'N' value { i: 2 } }960 """, op.node_def)961 with self.assertRaises(ValueError) as cm:962 self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5])963 self.assertEqual(str(cm.exception),964 "List argument 'b' to 'NInPolymorphicTwice' Op "965 "with length 1 "966 "must match length 3 of argument 'a'.")967 with self.assertRaises(TypeError) as cm:968 self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"])969 self.assertEqual(str(cm.exception),970 "Tensors in list passed to 'b' of 'NInPolymorphicTwice' "971 "Op have types [string, string] that do not match type "972 "int32 inferred from earlier arguments.")973 with self.assertRaises(TypeError) as cm:974 self._lib.apply_op("NInPolymorphicTwice",975 a=[self.Tensor(dtypes.int32)],976 b=[self.Tensor(dtypes.string)])977 self.assertEqual(str(cm.exception),978 "Tensors in list passed to 'b' of "979 "'NInPolymorphicTwice' Op have types [string] that do not "980 "match type int32 inferred from earlier arguments.")981 def testNInTwoTypeVariables(self):982 self._add_op("name: 'NInTwoTypeVariables' "983 "input_arg { name: 'a' type_attr: 'S' number_attr: 'N' } "984 "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "985 "attr { name: 'S' type: 'type' } "986 "attr { name: 'T' type: 'type' } "987 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")988 op = self._lib.apply_op("NInTwoTypeVariables",989 a=[1, 2],990 b=[True, False],991 name="n")992 self.assertProtoEquals("""993 name: 'n' op: 'NInTwoTypeVariables'994 input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'995 attr { key: 'S' value { type: DT_INT32 } }996 attr { key: 'T' value { type: DT_BOOL } }997 attr { key: 'N' value { i: 2 } }998 """, op.node_def)999 op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o")1000 self.assertProtoEquals("""1001 name: 'o' op: 'NInTwoTypeVariables'1002 input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1'1003 attr { key: 'S' value { type: DT_INT32 } }1004 attr { key: 'T' value { type: DT_INT32 } }1005 attr { key: 'N' value { i: 2 } }1006 """, op.node_def)1007 op = self._lib.apply_op("NInTwoTypeVariables",1008 a=[self.Tensor(dtypes.int32, name="q")],1009 b=[self.Tensor(dtypes.string, name="r")],1010 name="p")1011 self.assertProtoEquals("""1012 name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r'1013 attr { key: 'S' value { type: DT_INT32 } }1014 attr { key: 'T' value { type: DT_STRING } }1015 attr { key: 'N' value { i: 1 } }1016 """, op.node_def)1017 with self.assertRaises(ValueError) as cm:1018 self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"])1019 self.assertEqual(str(cm.exception),1020 "List argument 'b' to 'NInTwoTypeVariables' Op "1021 "with length 1 "1022 "must match length 3 of argument 'a'.")1023 def testInPolymorphicTwice(self):1024 self._add_op("name: 'InPolymorphicTwice' "1025 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "1026 "input_arg { name: 'b' type_attr: 'T' number_attr: 'M' } "1027 "attr { name: 'T' type: 'type' } "1028 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 } "1029 "attr { name: 'M' type: 'int' has_minimum: true minimum: 0 } ")1030 op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n")1031 self.assertProtoEquals("""1032 name: 'n' op: 'InPolymorphicTwice'1033 input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2'1034 attr { key: 'T' value { type: DT_INT32 } }1035 attr { key: 'N' value { i: 1 } }1036 attr { key: 'M' value { i: 3 } }1037 """, op.node_def)1038 op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[], name="o")1039 self.assertProtoEquals("""1040 name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0'1041 attr { key: 'T' value { type: DT_INT32 } }1042 attr { key: 'N' value { i: 1 } }1043 attr { key: 'M' value { i: 0 } }1044 """, op.node_def)1045 with self.assertRaises(TypeError) as cm:1046 self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])1047 self.assertEqual(str(cm.exception),1048 "Don't know how to infer type variable from empty input "1049 "list passed to input 'a' of 'InPolymorphicTwice' Op.")1050 with self.assertRaises(TypeError) as cm:1051 self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"])1052 self.assertEqual(str(cm.exception),1053 "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op "1054 "have types [string, string] that do not match type int32 "1055 "inferred from earlier arguments.")1056 with self.assertRaises(TypeError) as cm:1057 self._lib.apply_op("InPolymorphicTwice",1058 a=[self.Tensor(dtypes.int32)],1059 b=[self.Tensor(dtypes.string)])1060 self.assertEqual(str(cm.exception),1061 "Tensors in list passed to 'b' of 'InPolymorphicTwice' "1062 "Op have types [string] that do not match type int32 "1063 "inferred from earlier arguments.")1064 def testNIntsOut(self):1065 self._add_op("name: 'NIntsOut' "1066 "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "1067 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")1068 out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n")1069 self.assertEqual(dtypes.int32, out1.dtype)1070 self.assertEqual(dtypes.int32, out2.dtype)1071 self.assertProtoEquals("""1072 name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } }1073 """, out1.op.node_def)1074 out1, out2, out3, out4, out5 = self._lib.apply_op(1075 "NIntsOut", N=5, name="o")1076 self.assertEqual(dtypes.int32, out1.dtype)1077 self.assertEqual(dtypes.int32, out2.dtype)1078 self.assertEqual(dtypes.int32, out3.dtype)1079 self.assertEqual(dtypes.int32, out4.dtype)1080 self.assertEqual(dtypes.int32, out5.dtype)1081 self.assertProtoEquals("""1082 name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } }1083 """, out5.op.node_def)1084 with self.assertRaises(ValueError) as cm:1085 self._lib.apply_op("NIntsOut", N=1)1086 self.assertEqual(str(cm.exception),1087 "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.")1088 with self.assertRaises(TypeError) as cm:1089 self._lib.apply_op("NIntsOut", N=[3])1090 self.assertEqual(str(cm.exception),1091 "Expected int for argument 'N' not [3].")1092 def testNIntsOutDefault(self):1093 self._add_op("name: 'NIntsOutDefault' "1094 "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "1095 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2"1096 " default_value { i:3 } }")1097 out1, out2, out3 = self._lib.apply_op(1098 "NIntsOutDefault", N=None, name="z")1099 self.assertEqual(dtypes.int32, out1.dtype)1100 self.assertEqual(dtypes.int32, out2.dtype)1101 self.assertEqual(dtypes.int32, out3.dtype)1102 self.assertProtoEquals("""1103 name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } }1104 """, out1.op.node_def)1105 out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y")1106 self.assertEqual(dtypes.int32, out1.dtype)1107 self.assertEqual(dtypes.int32, out2.dtype)1108 self.assertProtoEquals("""1109 name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } }1110 """, out2.op.node_def)1111 def testNPolymorphicOut(self):1112 self._add_op("name: 'NPolymorphicOut' "1113 "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "1114 "attr { name: 'T' type: 'type' } "1115 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")1116 out1, out2 = self._lib.apply_op("NPolymorphicOut",1117 N=2,1118 T=dtypes.int32,1119 name="n")1120 self.assertEqual(dtypes.int32, out1.dtype)1121 self.assertEqual(dtypes.int32, out2.dtype)1122 self.assertProtoEquals("""1123 name: 'n' op: 'NPolymorphicOut'1124 attr { key: 'T' value { type: DT_INT32 } }1125 attr { key: 'N' value { i: 2 } }1126 """, out1.op.node_def)1127 out1, out2, out3 = self._lib.apply_op(1128 "NPolymorphicOut", T=dtypes.string, N=3, name="o")1129 self.assertEqual(dtypes.string, out1.dtype)1130 self.assertEqual(dtypes.string, out2.dtype)1131 self.assertEqual(dtypes.string, out3.dtype)1132 self.assertProtoEquals("""1133 name: 'o' op: 'NPolymorphicOut'1134 attr { key: 'T' value { type: DT_STRING } }1135 attr { key: 'N' value { i: 3 } }1136 """, out3.op.node_def)1137 with self.assertRaises(ValueError) as cm:1138 self._lib.apply_op("NPolymorphicOut", N=1, T=dtypes.string)1139 self.assertEqual(str(cm.exception),1140 "Attr 'N' of 'NPolymorphicOut' Op "1141 "passed 1 less than minimum 2.")1142 with self.assertRaises(TypeError) as cm:1143 self._lib.apply_op("NPolymorphicOut", N=3, T=[dtypes.string])1144 self.assertEqual(1145 str(cm.exception),1146 "Expected DataType for argument 'T' not [tf.string].")1147 def testNPolymorphicOutDefault(self):1148 self._add_op("name: 'NPolymorphicOutDefault' "1149 "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "1150 "attr { name: 'T' type: 'type'"1151 " default_value { type: DT_BOOL } } "1152 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 "1153 " default_value { i: 2 } }")1154 out1, out2 = self._lib.apply_op(1155 "NPolymorphicOutDefault", N=None, T=None, name="r")1156 self.assertEqual(dtypes.bool, out1.dtype)1157 self.assertEqual(dtypes.bool, out2.dtype)1158 self.assertProtoEquals("""1159 name: 'r' op: 'NPolymorphicOutDefault'1160 attr { key: 'T' value { type: DT_BOOL } }1161 attr { key: 'N' value { i: 2 } }1162 """, out1.op.node_def)1163 out1, out2, out3 = self._lib.apply_op(1164 "NPolymorphicOutDefault", N=3, T=None, name="s")1165 self.assertEqual(dtypes.bool, out1.dtype)1166 self.assertEqual(dtypes.bool, out2.dtype)1167 self.assertEqual(dtypes.bool, out3.dtype)1168 self.assertProtoEquals("""1169 name: 's' op: 'NPolymorphicOutDefault'1170 attr { key: 'T' value { type: DT_BOOL } }1171 attr { key: 'N' value { i: 3 } }1172 """, out1.op.node_def)1173 out1, out2 = self._lib.apply_op(1174 "NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t")1175 self.assertEqual(dtypes.int32, out1.dtype)1176 self.assertEqual(dtypes.int32, out2.dtype)1177 self.assertProtoEquals("""1178 name: 't' op: 'NPolymorphicOutDefault'1179 attr { key: 'T' value { type: DT_INT32 } }1180 attr { key: 'N' value { i: 2 } }1181 """, out1.op.node_def)1182 out1, out2, out3 = self._lib.apply_op(1183 "NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u")1184 self.assertEqual(dtypes.int32, out1.dtype)1185 self.assertEqual(dtypes.int32, out2.dtype)1186 self.assertEqual(dtypes.int32, out3.dtype)1187 self.assertProtoEquals("""1188 name: 'u' op: 'NPolymorphicOutDefault'1189 attr { key: 'T' value { type: DT_INT32 } }1190 attr { key: 'N' value { i: 3 } }1191 """, out1.op.node_def)1192 def testNPolymorphicRestrictOut(self):1193 self._add_op("name: 'NPolymorphicRestrictOut' "1194 "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "1195 "attr { name: 'T' type: 'type' allowed_values { "1196 " list { type: DT_STRING type: DT_BOOL } } } "1197 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")1198 out1, out2, out3 = self._lib.apply_op(1199 "NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u")1200 self.assertEqual(dtypes.bool, out1.dtype)1201 self.assertEqual(dtypes.bool, out2.dtype)1202 self.assertEqual(dtypes.bool, out3.dtype)1203 self.assertProtoEquals("""1204 name: 'u' op: 'NPolymorphicRestrictOut'1205 attr { key: 'T' value { type: DT_BOOL } }1206 attr { key: 'N' value { i: 3 } }1207 """, out1.op.node_def)1208 with self.assertRaises(TypeError) as cm:1209 self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32)1210 self.assertEqual(str(cm.exception),1211 "Value passed to parameter 'T' has DataType int32 "1212 "not in list of allowed values: string, bool")1213 def testRef(self):1214 self._add_op("name: 'RefIn' "1215 "input_arg { name: 'a' type_attr: 'T' is_ref: true } "1216 "attr { name: 'T' type: 'type' } ")1217 self._add_op("name: 'TwoRefsIn' "1218 "input_arg { name: 'a' type_attr: 'T' is_ref: true } "1219 "input_arg { name: 'b' type_attr: 'T' is_ref: true } "1220 "attr { name: 'T' type: 'type' } ")1221 self._add_op("name: 'RefOut' "1222 "output_arg { name: 'a' type_attr: 'T' is_ref: true } "1223 "attr { name: 'T' type: 'type' } ")1224 out = self._lib.apply_op("RefOut", T=dtypes.bool, name="o")1225 self.assertEqual(dtypes.bool_ref, out.dtype)1226 self.assertProtoEquals("""1227 name: 'o' op: 'RefOut'1228 attr { key: 'T' value { type: DT_BOOL } }1229 """, out.op.node_def)1230 op = self._lib.apply_op("RefIn", a=out, name="i")1231 self.assertProtoEquals("""1232 name: 'i' op: 'RefIn' input: 'o'1233 attr { key: 'T' value { type: DT_BOOL } }1234 attr { key: "_class" value { list { s: "loc:@o" } } }1235 """, op.node_def)1236 # Can pass ref to non-ref input.1237 out = self._lib.apply_op("RefOut", T=dtypes.int32, name="r")1238 out = self._lib.apply_op("Simple", a=out, name="s")1239 self.assertProtoEquals("""1240 name: 's' op: 'Simple' input: 'r'1241 """, out.op.node_def)1242 # Can't pass non-ref to ref input.1243 with self.assertRaises(TypeError) as cm:1244 self._lib.apply_op("RefIn", a=2)1245 self.assertEqual(str(cm.exception),1246 "'RefIn' Op requires that input 'a' be a mutable tensor " +1247 "(e.g.: a tf.Variable)")1248 input_a = self._lib.apply_op("RefOut", T=dtypes.int32, name="t")1249 input_b = self._lib.apply_op("RefOut", T=dtypes.int32, name="u")1250 op = self._lib.apply_op("TwoRefsIn", a=input_a, b=input_b, name="v")1251 # NOTE(mrry): The order of colocation constraints is an implementation1252 # detail.1253 self.assertProtoEquals("""1254 name: 'v' op: 'TwoRefsIn' input: 't' input: 'u'1255 attr { key: 'T' value { type: DT_INT32 } }1256 attr { key: "_class" value { list { s: "loc:@t" s: "loc:@u" } } }1257 """, op.node_def)1258 def testSpecifyDevice(self):1259 with self._g.device("/job:ADevice"):1260 self._lib.apply_op("Simple", a=3)1261 # We look at the whole graph here to make sure the Const op is also given1262 # the specified device.1263 graph_def = self._g.as_graph_def()1264 self.assertEqual(len(graph_def.node), 2)1265 for node in graph_def.node:1266 self.assertDeviceEqual(node.device, "/job:ADevice")1267 def testStructuredOutputSingleList(self):1268 self._add_op("name: 'SimpleStruct' "1269 "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "1270 "attr { name: 'n_a' type: 'int' }")1271 for n_a in [0, 1, 3]:1272 a = self._lib.apply_op("SimpleStruct", n_a=n_a)1273 self.assertTrue(isinstance(a, list))1274 self.assertEqual(n_a, len(a))1275 def testStructuredOutputListAndSingle(self):1276 self._add_op("name: 'MixedStruct' "1277 "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "1278 "output_arg { name: 'b' type: DT_FLOAT } "1279 "attr { name: 'n_a' type: 'int' }")1280 for n_a in [0, 1, 3]:1281 a, b = self._lib.apply_op("MixedStruct", n_a=n_a)1282 self.assertTrue(isinstance(a, list))1283 self.assertEqual(n_a, len(a))1284 self.assertTrue(all(x.dtype == dtypes.int32 for x in a))1285 self.assertTrue(isinstance(b, ops.Tensor))1286 self.assertEqual(dtypes.float32, b.dtype)1287 def testStructuredOutputMultipleLists(self):1288 self._add_op("name: 'ComplexStruct' "1289 "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "1290 "output_arg { name: 'b' type: DT_INT64 number_attr: 'n_b' } "1291 "output_arg { name: 'c' type_list_attr: 't_c' } "1292 "attr { name: 'n_a' type: 'int' } "1293 "attr { name: 'n_b' type: 'int' } "1294 "attr { name: 't_c' type: 'list(type)' }")1295 for n_a in [0, 1, 3]:1296 for n_b in [0, 1, 3]:1297 for t_c in [[],1298 [dtypes.int32],1299 [dtypes.int32, dtypes.float32]]:1300 a, b, c = self._lib.apply_op("ComplexStruct",1301 n_a=n_a,1302 n_b=n_b,1303 t_c=t_c)1304 self.assertEqual(n_a, len(a))1305 self.assertTrue(all(x.dtype == dtypes.int32 for x in a))1306 self.assertEqual(n_b, len(b))1307 self.assertTrue(all(x.dtype == dtypes.int64 for x in b))1308 self.assertEqual(t_c, [x.dtype for x in c])1309class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):1310 def setUp(self):1311 self._lib = OpDefLibrary()1312 self._g = ops.Graph()1313 self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "1314 "output_arg { name: 'out' type: DT_FLOAT }")1315 self._add_op("name: 'Binary' "1316 "input_arg { name: 'a' type_attr: 'T' } "1317 "input_arg { name: 'b' type_attr: 'T' } "1318 "output_arg { name: 'out' type_attr: 'T' } "1319 "attr { name: 'T' type: 'type' }")1320 def _add_op(self, ascii):1321 op_def = op_def_pb2.OpDef()1322 text_format.Merge(ascii, op_def)1323 self._lib.add_op(op_def)1324 def testNoGraph(self):1325 out = self._lib.apply_op("Simple", a=3)1326 self.assertEqual(out.graph, ops.get_default_graph())1327 def testDefaultGraph(self):1328 with self._g.as_default():1329 out = self._lib.apply_op("Simple", a=3)1330 self.assertEqual(out.graph, self._g)1331 def testDifferentGraphFails(self):1332 with self._g.as_default():1333 a = self._lib.apply_op("Simple", a=3)1334 other_g = ops.Graph()1335 with other_g.as_default():1336 b = self._lib.apply_op("Simple", a=4)1337 with self.assertRaises(ValueError) as cm:1338 self._lib.apply_op("Binary", a=a, b=b)1339 self.assertTrue("must be from the same graph" in str(cm.exception))1340if __name__ == "__main__":...
op_def_library.py
Source:op_def_library.py
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Class to hold a library of OpDefs and use it to create Brain operations."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19import six20from tensorflow.core.framework import attr_value_pb221from tensorflow.core.framework import op_def_pb222from tensorflow.core.framework import tensor_pb223from tensorflow.core.framework import tensor_shape_pb224from tensorflow.core.framework import types_pb225from tensorflow.python.framework import dtypes26from tensorflow.python.framework import ops27from tensorflow.python.framework import tensor_shape28from tensorflow.python.platform import tf_logging as logging29from tensorflow.python.util import compat30from tensorflow.python.util import tf_contextlib31def _Attr(op_def, name):32 for attr in op_def.attr:33 if attr.name == name:34 return attr35 raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %36 (op_def.name, name))37def _AttrValue(attr_protos, name):38 if name in attr_protos:39 return attr_protos[name]40 raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %41 (name, attr_protos))42def _SatisfiesTypeConstraint(dtype, attr_def, param_name):43 if attr_def.HasField("allowed_values"):44 allowed_list = attr_def.allowed_values.list.type45 if dtype not in allowed_list:46 raise TypeError(47 "Value passed to parameter '%s' has DataType %s not in list of "48 "allowed values: %s" %49 (param_name, dtypes.as_dtype(dtype).name,50 ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))51def _IsListParameter(arg):52 if arg.number_attr:53 return True54 elif arg.type_list_attr:55 return True56 return False57def _NumTypeFields(arg):58 num = 059 if arg.type != types_pb2.DT_INVALID: num += 160 if arg.type_attr: num += 161 if arg.type_list_attr: num += 162 return num63def _IsListValue(v):64 return isinstance(v, (list, tuple))65def _Flatten(l):66 """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""67 # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]68 l_of_l = [x if _IsListValue(x) else [x] for x in l]69 # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]70 return [item for sublist in l_of_l for item in sublist]71def _Restructure(l, structure):72 """Returns the elements of list l structured according to the given structure.73 A structure is represented by a list whose elements are either74 `None` or a non-negative integer. `None` corresponds to a single75 element in the output list, and an integer N corresponds to a nested76 list of length N.77 The function returns a data structure whose shape is given by78 `structure`, and whose elements are taken from `l`. If `structure`79 is a singleton, the function returns the single data structure80 implied by the 0th element of `structure`. For example:81 _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])82 -> ["foo", ["bar", "baz"], "qux"]83 _Restructure(["foo"], [None]) -> "foo"84 _Restructure(["foo"], [1]) -> ["foo"]85 _Restructure([], [0]) -> []86 Args:87 l: A list.88 structure: A list whose elements are either `None` or a non-negative89 integer.90 Returns:91 The elements of `l`, restructured according to `structure`. If92 `structure` is a list of length 1, this function returns the93 single data structure implied by `structure[0]`.94 """95 result = []96 current_index = 097 for element in structure:98 if element is None:99 result.append(l[current_index])100 current_index += 1101 else:102 result.append(l[current_index:current_index+element])103 current_index += element104 if len(result) == 1:105 return result[0]106 else:107 return tuple(result)108def _MakeFloat(v, arg_name):109 if not isinstance(v, compat.real_types):110 raise TypeError("Expected float for argument '%s' not %s." %111 (arg_name, repr(v)))112 return float(v)113def _MakeInt(v, arg_name):114 if isinstance(v, six.string_types):115 raise TypeError("Expected int for argument '%s' not %s." %116 (arg_name, repr(v)))117 try:118 return int(v)119 except (ValueError, TypeError):120 raise TypeError("Expected int for argument '%s' not %s." %121 (arg_name, repr(v)))122def _MakeStr(v, arg_name):123 if not isinstance(v, compat.bytes_or_text_types):124 raise TypeError("Expected string for argument '%s' not %s." %125 (arg_name, repr(v)))126 return compat.as_bytes(v) # Convert unicode strings to bytes.127def _MakeBool(v, arg_name):128 if not isinstance(v, bool):129 raise TypeError("Expected bool for argument '%s' not %s." %130 (arg_name, repr(v)))131 return v132def _MakeType(v, attr_def):133 try:134 v = dtypes.as_dtype(v).base_dtype135 except TypeError:136 raise TypeError("Expected DataType for argument '%s' not %s." %137 (attr_def.name, repr(v)))138 i = v.as_datatype_enum139 _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name)140 return i141def _MakeShape(v, arg_name):142 """Convert v into a TensorShapeProto."""143 # Args:144 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.145 # arg_name: String, for error messages.146 # Returns:147 # A TensorShapeProto.148 if isinstance(v, tensor_shape_pb2.TensorShapeProto):149 for d in v.dim:150 if d.name:151 logging.warning("Warning: TensorShapeProto with a named dimension: %s",152 str(v))153 break154 return v155 try:156 return tensor_shape.as_shape(v).as_proto()157 except TypeError as e:158 raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))159 except ValueError as e:160 raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))161def _MakeTensor(v, arg_name):162 """Ensure v is a TensorProto."""163 if isinstance(v, tensor_pb2.TensorProto):164 return v165 raise TypeError(166 "Don't know how to convert %s to a TensorProto for argument '%s'" %167 (repr(v), arg_name))168class _OpInfo(object):169 """All per-Op state we would like to precompute/validate."""170 def __init__(self, op_def):171 self.op_def = op_def172 # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it173 # here, instead of these checks.174 for arg in list(op_def.input_arg) + list(op_def.output_arg):175 num_type_fields = _NumTypeFields(arg)176 if num_type_fields != 1:177 raise TypeError("Arg '%s' of '%s' must have one type field not %d" %178 (arg.name, op_def.name, num_type_fields))179 if arg.type_attr:180 attr_type = _Attr(op_def, arg.type_attr).type181 if attr_type != "type":182 raise TypeError("Attr '%s' of '%s' used as a type_attr "183 "but has type %s" %184 (arg.type_attr, op_def.name, attr_type))185 if arg.type_list_attr:186 attr_type = _Attr(op_def, arg.type_list_attr).type187 if attr_type != "list(type)":188 raise TypeError(189 "Attr '%s' of '%s' used as a type_list_attr but has type %s" %190 (arg.type_attr, op_def.name, attr_type))191 if arg.number_attr:192 attr_type = _Attr(op_def, arg.number_attr).type193 if attr_type != "int":194 raise TypeError(195 "Attr '%s' of '%s' used as a number_attr but has type %s" %196 (arg.number_attr, op_def.name, attr_type))197# pylint: disable=g-doc-return-or-yield198@tf_contextlib.contextmanager199def _MaybeColocateWith(inputs):200 """A context manager for (maybe) colocating with a list of input tensors.201 Args:202 inputs: A list of `Tensor` or `Operation` objects.203 Returns:204 A context manager.205 """206 if not inputs:207 yield208 else:209 # NOTE(mrry): The `ops.colocate_with()` function accepts only a single210 # op or tensor, so we create one context manager per element in the list.211 with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]):212 yield213# pylint: enable=g-doc-return-or-yield214class OpDefLibrary(object):215 """Holds a collection of OpDefs, can add the corresponding Ops to a graph."""216 def __init__(self):217 self._ops = {}218 # pylint: disable=invalid-name219 def add_op(self, op_def):220 """Register an OpDef. May call apply_op with the name afterwards."""221 if not isinstance(op_def, op_def_pb2.OpDef):222 raise TypeError("%s is %s, not an op_def_pb2.OpDef" %223 (op_def, type(op_def)))224 if not op_def.name:225 raise ValueError("%s missing name." % op_def)226 if op_def.name in self._ops:227 raise RuntimeError("Op name %s registered twice." % op_def.name)228 self._ops[op_def.name] = _OpInfo(op_def)229 def add_op_list(self, op_list):230 """Register the OpDefs from an OpList."""231 if not isinstance(op_list, op_def_pb2.OpList):232 raise TypeError("%s is %s, not an op_def_pb2.OpList" %233 (op_list, type(op_list)))234 for op_def in op_list.op:235 self.add_op(op_def)236 def apply_op(self, op_type_name, name=None, **keywords):237 # pylint: disable=g-doc-args238 """Add a node invoking a registered Op to a graph.239 Example usage:240 # input1 and input2 can be Tensors or anything ops.convert_to_tensor()241 # will convert to a Tensor.242 op_def_library.apply_op("op", input1=input1, input2=input2)243 # Can specify a node name.244 op_def_library.apply_op("op", input1=input1, name="node_name")245 # Must use keyword arguments, with the names specified in the OpDef.246 op_def_library.apply_op("op", input_name=input, attr_name=attr)247 All attrs must either be inferred from an input or specified.248 (If inferred, the attr must not be specified.) If an attr has a default249 value specified in the Op's OpDef, then you may pass None as the value250 of that attr to get the default.251 Args:252 op_type_name: string. Must match the name field of a registered Op.253 name: string. Optional name of the created op.254 **keywords: input Tensor and attr arguments specified by name,255 and optional parameters to pass when constructing the Operation.256 Returns:257 The Tensor(s) representing the output of the operation, or the Operation258 itself if there are no outputs.259 Raises:260 RuntimeError: On some errors.261 TypeError: On some errors.262 ValueError: On some errors.263 """264 output_structure, is_stateful, op = self._apply_op_helper(265 op_type_name, name, **keywords)266 if output_structure:267 outputs = op.outputs268 res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)269 if isinstance(res, list) and not res and is_stateful:270 return op271 else:272 return res273 else:274 return op275 def _apply_op_helper(self, op_type_name, name=None, **keywords):276 """Implementation of apply_op that returns output_structure, op."""277 op_info = self._ops.get(op_type_name, None)278 if op_info is None:279 raise RuntimeError("Unrecognized Op name " + op_type_name)280 op_def = op_info.op_def281 # Determine the graph context.282 try:283 # Need to flatten all the arguments into a list.284 # pylint: disable=protected-access285 g = ops._get_graph_from_inputs(_Flatten(keywords.values()))286 # pylint: enable=protected-access287 except AssertionError as e:288 raise RuntimeError(289 "Cannot determine graph for Op '%s' due to: %s"290 % (op_type_name, e.message))291 # Default name if not specified.292 if name is None:293 name = op_type_name294 # Check for deprecation295 deprecation_version = op_def.deprecation.version296 if deprecation_version:297 producer = g.graph_def_versions.producer298 if producer >= deprecation_version:299 raise NotImplementedError(300 ("Op %s is not available in GraphDef version %d. "301 "It has been removed in version %d. %s.") %302 (op_type_name, producer, deprecation_version,303 op_def.deprecation.explanation))304 # Fill in the list of default types for all "type" attrs. This305 # will be used to choose a preferred dtype to convert to in the306 # absence of input type information.307 #308 # TODO(b/31302892): Currently the defaults don't work in the right309 # way if you have two inputs, one of whose type resolution depends310 # on the other. Handling this will require restructuring this code311 # significantly.312 default_type_attr_map = {}313 for attr_def in op_def.attr:314 if attr_def.type != "type":315 continue316 key = attr_def.name317 if attr_def.HasField("default_value"):318 default_type_attr_map[key] = dtypes.as_dtype(319 attr_def.default_value.type)320 # Requires that op_def has passed validation (using the C++321 # ValidateOpDef() from ../framework/op_def_util.h).322 attrs = {}323 inputs = []324 input_types = []325 with g.as_default(), ops.name_scope(name) as scope:326 # Perform input type inference327 inferred_from = {}328 for input_arg in op_def.input_arg:329 input_name = input_arg.name330 if input_name in keywords:331 values = keywords.pop(input_name)332 elif input_name + "_" in keywords:333 # Handle the case where the name is a keyword or built-in334 # for Python so we use the name + _ instead.335 input_name += "_"336 values = keywords.pop(input_name)337 else:338 raise TypeError("No argument for input " + input_name)339 # Goals:340 # * Convert values to Tensors if it contains constants.341 # * Verify that values is a list if that matches the input_arg's342 # type.343 # * If the input_arg's type is determined by attrs, either set344 # those attrs and validate those attr values are legal (if345 # they have not yet been set) or validate the input matches346 # the type indicated by the attrs (if they have already been347 # inferred via an earlier input).348 # * If the input_arg has an explicit type, make sure the input349 # conforms.350 if _IsListParameter(input_arg):351 if not _IsListValue(values):352 raise TypeError(353 "Expected list for '%s' argument to '%s' Op, not %s." %354 (input_name, op_type_name, values))355 # In cases where we expect all elements of the list to have the356 # same dtype, try to cast non-Tensor elements to that type.357 dtype = None358 default_dtype = None359 if input_arg.type != types_pb2.DT_INVALID:360 dtype = input_arg.type361 elif input_arg.number_attr:362 if input_arg.type_attr in attrs:363 dtype = attrs[input_arg.type_attr]364 else:365 for t in values:366 if isinstance(t, ops.Tensor):367 dtype = t.dtype368 break369 # dtype still not found, prefer using the default dtype370 # from the attr.371 if dtype is None and input_arg.type_attr in default_type_attr_map:372 default_dtype = default_type_attr_map[input_arg.type_attr]373 try:374 if not input_arg.is_ref and dtype:375 dtype = dtypes.as_dtype(dtype).base_dtype376 values = ops.internal_convert_n_to_tensor(377 values,378 name=input_arg.name,379 dtype=dtype if dtype else None,380 preferred_dtype=default_dtype,381 as_ref=input_arg.is_ref)382 if input_arg.number_attr and len(383 set(v.dtype.base_dtype for v in values)) > 1:384 raise TypeError() # All types should match.385 except (TypeError, ValueError):386 # What types does the conversion function think values have?387 observed_types = []388 for value in values:389 try:390 converted_value = ops.internal_convert_to_tensor(391 value, as_ref=input_arg.is_ref)392 observed_types.append(converted_value.dtype.base_dtype.name)393 except (TypeError, ValueError):394 observed_types.append("<NOT CONVERTIBLE TO TENSOR>")395 observed = ", ".join(observed_types)396 prefix = (397 "Tensors in list passed to '%s' of '%s' Op have types [%s]" %398 (input_name, op_type_name, observed))399 if input_arg.number_attr:400 if input_arg.type != types_pb2.DT_INVALID:401 raise TypeError("%s that do not match expected type %s." %402 (prefix, dtype.name))403 elif input_arg.type_attr in attrs:404 raise TypeError("%s that do not match type %s inferred from "405 "earlier arguments." %406 (prefix, dtype.name))407 else:408 raise TypeError("%s that don't all match." % prefix)409 else:410 raise TypeError("%s that are invalid." % prefix)411 types = [x.dtype for x in values]412 inputs.extend(values)413 else:414 # In cases where we have an expected type, try to convert non-Tensor415 # arguments to that type.416 dtype = None417 default_dtype = None418 if input_arg.type != types_pb2.DT_INVALID:419 dtype = input_arg.type420 elif input_arg.type_attr in attrs:421 dtype = attrs[input_arg.type_attr]422 elif input_arg.type_attr in default_type_attr_map:423 # The dtype could not be inferred solely from the inputs,424 # so we prefer the attr's default, so code that adds a new attr425 # with a default is backwards compatible.426 default_dtype = default_type_attr_map[input_arg.type_attr]427 try:428 values = ops.internal_convert_to_tensor(429 values,430 name=input_arg.name,431 dtype=dtype,432 as_ref=input_arg.is_ref,433 preferred_dtype=default_dtype)434 except TypeError as err:435 if dtype is None:436 raise err437 else:438 raise TypeError(439 "Expected %s passed to parameter '%s' of op '%s', got %s of "440 "type '%s' instead." %441 (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name,442 repr(values), type(values).__name__))443 except ValueError:444 # What type does convert_to_tensor think it has?445 try:446 observed = ops.internal_convert_to_tensor(447 values, as_ref=input_arg.is_ref).dtype.name448 except ValueError as err:449 raise ValueError(450 "Tried to convert '%s' to a tensor and failed. Error: %s" %451 (input_name, err))452 prefix = ("Input '%s' of '%s' Op has type %s that does not match" %453 (input_name, op_type_name, observed))454 if input_arg.type != types_pb2.DT_INVALID:455 raise TypeError("%s expected type of %s." %456 (prefix, dtypes.as_dtype(input_arg.type).name))457 else:458 # Update the maps with the default, if needed.459 k = input_arg.type_attr460 if k in default_type_attr_map:461 if k not in attrs:462 attrs[k] = default_type_attr_map[k]463 if k not in inferred_from:464 inferred_from[k] = "Default in OpDef"465 raise TypeError(466 "%s type %s of argument '%s'." %467 (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,468 inferred_from[input_arg.type_attr]))469 types = [values.dtype]470 inputs.append(values)471 base_types = [x.base_dtype for x in types]472 if input_arg.number_attr:473 # <number-attr> * <type> or <number-attr> * <type-attr>474 if input_arg.number_attr in attrs:475 if len(values) != attrs[input_arg.number_attr]:476 raise ValueError(477 "List argument '%s' to '%s' Op with length %d must match "478 "length %d of argument '%s'." %479 (input_name, op_type_name, len(values),480 attrs[input_arg.number_attr],481 inferred_from[input_arg.number_attr]))482 else:483 attrs[input_arg.number_attr] = len(values)484 inferred_from[input_arg.number_attr] = input_name485 num_attr = _Attr(op_def, input_arg.number_attr)486 if num_attr.has_minimum and len(values) < num_attr.minimum:487 raise ValueError(488 "List argument '%s' to '%s' Op with length %d shorter "489 "than minimum length %d." %490 (input_name, op_type_name, len(values), num_attr.minimum))491 # All tensors must have the same base type.492 if any([bt != base_types[0] for bt in base_types]):493 raise TypeError(494 "All tensors passed to '%s' of '%s' Op "495 "must have the same type." %496 (input_name, op_type_name))497 if input_arg.type != types_pb2.DT_INVALID:498 # <number-attr> * <type> case499 if base_types and base_types[0] != input_arg.type:500 assert False, "Unreachable"501 elif input_arg.type_attr in attrs:502 # <number-attr> * <type-attr> case, where <type-attr> already503 # has an inferred value.504 if base_types and base_types[0] != attrs[input_arg.type_attr]:505 assert False, "Unreachable"506 else:507 # <number-attr> * <type-attr> case, where we are now setting508 # the <type-attr> based on this input509 if not base_types:510 raise TypeError(511 "Don't know how to infer type variable from empty input "512 "list passed to input '%s' of '%s' Op." %513 (input_name, op_type_name))514 attrs[input_arg.type_attr] = base_types[0]515 inferred_from[input_arg.type_attr] = input_name516 type_attr = _Attr(op_def, input_arg.type_attr)517 _SatisfiesTypeConstraint(base_types[0], type_attr,518 param_name=input_name)519 elif input_arg.type_attr:520 # <type-attr>521 attr_value = base_types[0]522 if input_arg.type_attr in attrs:523 if attrs[input_arg.type_attr] != attr_value:524 assert False, "Unreachable"525 else:526 for base_type in base_types:527 _SatisfiesTypeConstraint(base_type,528 _Attr(op_def, input_arg.type_attr),529 param_name=input_name)530 attrs[input_arg.type_attr] = attr_value531 inferred_from[input_arg.type_attr] = input_name532 elif input_arg.type_list_attr:533 # <type-list-attr>534 attr_value = base_types535 if input_arg.type_list_attr in attrs:536 if attrs[input_arg.type_list_attr] != attr_value:537 raise TypeError(538 "Input '%s' of '%s' Op has type list of %s that does not "539 "match type list %s of argument '%s'." %540 (input_name, op_type_name,541 ", ".join(dtypes.as_dtype(x).name for x in attr_value),542 ", ".join(dtypes.as_dtype(x).name543 for x in attrs[input_arg.type_list_attr]),544 inferred_from[input_arg.type_list_attr]))545 else:546 for base_type in base_types:547 _SatisfiesTypeConstraint(base_type,548 _Attr(op_def, input_arg.type_list_attr),549 param_name=input_name)550 attrs[input_arg.type_list_attr] = attr_value551 inferred_from[input_arg.type_list_attr] = input_name552 else:553 # single Tensor with specified type554 if base_types[0] != input_arg.type:555 assert False, "Unreachable"556 if input_arg.is_ref:557 if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access558 raise TypeError(559 ("'%s' Op requires that input '%s' be a mutable tensor "560 "(e.g.: a tf.Variable)") % (op_type_name, input_name))561 input_types.extend(types)562 else:563 input_types.extend(base_types)564 # Process remaining attrs565 for attr in op_def.attr:566 # Skip attrs that have already had their values inferred567 if attr.name in attrs:568 if attr.name in keywords:569 raise TypeError(570 "Should not specify value for inferred attr '%s'." % attr.name)571 continue572 if attr.name in keywords:573 attrs[attr.name] = keywords.pop(attr.name)574 elif attr.name + "_" in keywords:575 # Attrs whose names match Python keywords have an extra '_'576 # appended, so we must check for that as well.577 attrs[attr.name] = keywords.pop(attr.name + "_")578 else:579 raise TypeError("No argument for attr " + attr.name)580 # Convert attr values to AttrValue protos.581 attr_protos = {}582 for attr_def in op_def.attr:583 key = attr_def.name584 value = attrs[key]585 attr_value = attr_value_pb2.AttrValue()586 if attr_def.HasField("default_value") and value is None:587 attr_value.CopyFrom(attr_def.default_value)588 attr_protos[key] = attr_value589 continue590 if attr_def.type.startswith("list("):591 if not _IsListValue(value):592 raise TypeError("Expected list for attr " + key)593 if attr_def.has_minimum:594 if len(value) < attr_def.minimum:595 raise ValueError("Attr '%s' of '%s' Op passed list of length %d "596 "less than minimum %d." %597 (key, op_type_name, len(value),598 attr_def.minimum))599 attr_value.list.SetInParent()600 if attr_def.type == "string":601 attr_value.s = _MakeStr(value, key)602 if attr_def.HasField("allowed_values"):603 if attr_value.s not in attr_def.allowed_values.list.s:604 raise ValueError(605 "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %606 (key, op_type_name, compat.as_text(attr_value.s),607 '", "'.join(map(compat.as_text,608 attr_def.allowed_values.list.s))))609 elif attr_def.type == "list(string)":610 attr_value.list.s.extend([_MakeStr(x, key) for x in value])611 if attr_def.HasField("allowed_values"):612 for x in attr_value.list.s:613 if x not in attr_def.allowed_values.list.s:614 raise ValueError(615 "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %616 (key, op_type_name, compat.as_text(x),617 '", "'.join(map(compat.as_text,618 attr_def.allowed_values.list.s))))619 elif attr_def.type == "int":620 attr_value.i = _MakeInt(value, key)621 if attr_def.has_minimum:622 if attr_value.i < attr_def.minimum:623 raise ValueError(624 "Attr '%s' of '%s' Op passed %d less than minimum %d." %625 (key, op_type_name, attr_value.i, attr_def.minimum))626 elif attr_def.type == "list(int)":627 attr_value.list.i.extend([_MakeInt(x, key) for x in value])628 elif attr_def.type == "float":629 attr_value.f = _MakeFloat(value, key)630 elif attr_def.type == "list(float)":631 attr_value.list.f.extend([_MakeFloat(x, key) for x in value])632 elif attr_def.type == "bool":633 attr_value.b = _MakeBool(value, key)634 elif attr_def.type == "list(bool)":635 attr_value.list.b.extend([_MakeBool(x, key) for x in value])636 elif attr_def.type == "type":637 attr_value.type = _MakeType(value, attr_def)638 elif attr_def.type == "list(type)":639 attr_value.list.type.extend(640 [_MakeType(x, attr_def) for x in value])641 elif attr_def.type == "shape":642 attr_value.shape.CopyFrom(_MakeShape(value, key))643 elif attr_def.type == "list(shape)":644 attr_value.list.shape.extend(645 [_MakeShape(x, key) for x in value])646 elif attr_def.type == "tensor":647 attr_value.tensor.CopyFrom(_MakeTensor(value, key))648 elif attr_def.type == "list(tensor)":649 attr_value.list.tensor.extend(650 [_MakeTensor(x, key) for x in value])651 elif attr_def.type == "func":652 if isinstance(value, attr_value_pb2.NameAttrList):653 attr_value.func.CopyFrom(value)654 elif isinstance(value, compat.bytes_or_text_types):655 attr_value.func.name = value656 else:657 value.add_to_graph(ops.get_default_graph())658 attr_value.func.name = value.name659 else:660 raise TypeError("Unrecognized Attr type " + attr_def.type)661 attr_protos[key] = attr_value662 del attrs # attrs is no longer authoritative, use attr_protos instead663 # Determine output types (possibly using attrs)664 output_types = []665 output_structure = []666 for arg in op_def.output_arg:667 types = []668 if arg.number_attr:669 n = _AttrValue(attr_protos, arg.number_attr).i670 if arg.type_attr:671 types = [_AttrValue(attr_protos, arg.type_attr).type] * n672 else:673 types = [arg.type] * n674 output_structure.append(n)675 elif arg.type_attr:676 t = _AttrValue(attr_protos, arg.type_attr)677 types = [t.type]678 output_structure.append(None)679 elif arg.type_list_attr:680 t = _AttrValue(attr_protos, arg.type_list_attr)681 types = t.list.type682 output_structure.append(len(types))683 else:684 types = [arg.type]685 output_structure.append(None)686 if arg.is_ref:687 types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access688 output_types.extend(types)689 if keywords:690 raise TypeError("apply_op() got unexpected keyword arguments: " +691 ", ".join(sorted(keywords.keys())))692 # NOTE(mrry): We add an explicit colocation constraint between693 # the newly created op and any of its reference-typed inputs.694 must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)695 if arg.is_ref]696 with _MaybeColocateWith(must_colocate_inputs):697 # Add Op to graph698 op = g.create_op(op_type_name, inputs, output_types, name=scope,699 input_types=input_types, attrs=attr_protos,700 op_def=op_def)701 return output_structure, op_def.is_stateful, op...
convert_timeline_to_html.py
Source:convert_timeline_to_html.py
...50 key_diffs = [[style_key for style_key in x if x[style_key] != y[style_key]] for x, y in pairs if x != y]51 if not key_diffs:52 equality = True53 return equality54def have_same_attr(soup_tag, graph_node):55 ignore_style = True56 equality = False57 if 'node_attributes' in graph_node:58 if len(soup_tag.attrs) == 0 and ('node_attributes' not in graph_node or ('node_attributes' in graph_node and len(graph_node['node_attributes']) == 0)):59 equality = True60 else:61 for attr_name, value in soup_tag.attrs.iteritems():62 if type(value) is list:63 attr_value = " ".join(value).strip()64 else:65 attr_value = "".join(value).strip()66 if attr_name.strip() == 'style':67 # create a list of style dicts68 soup_style_dicts = create_style_dict(attr_value) 69 if any(_['attr_name'] == 'style' for _ in graph_node['node_attributes']):70 # get list of style dicts71 graph_style_dicts = [_['attr_value'] for _ in graph_node['node_attributes'] if _['attr_name'] == 'style'][0]72 # iterate through zip of list of style dicts that are sorted by style_name73 if areTwoDictListEqual(soup_style_dicts, graph_style_dicts):74 equality = True75 break76 elif ignore_style:77 equality = True78 break79 else:80 printing(p=print_mode, what="[Attribute Discrepancy] \n --[Soup (%s)] attribute name: [%s], attribute value: [%s] \n --[Graph (%s)] attribute: %s"%(soup_tag.name, attr_name, attr_value, graph_node['node_id'], str(graph_node['node_attributes'])))81 equality = False82 break83 elif any(_['attr_name'] == attr_name and _['attr_value'] == attr_value for _ in graph_node['node_attributes']):84 equality = True85 else:86 printing(p=print_mode, what="[Attribute Discrepancy] \n --[Soup (%s)] attribute name: [%s], attribute value: [%s] \n --[Graph (%s)] attribute: %s"%(soup_tag.name, attr_name, attr_value, graph_node['node_id'], str(graph_node['node_attributes'])))87 equality = False88 break 89 elif (('node_attributes' in graph_node and len(graph_node['node_attributes']) == 0) or 'node_attributes' not in graph_node) and len(soup_tag.attrs) == 0:90 equality = True91 if not equality and not ignore_style:92 printing(p=print_mode, what="[Attribute Discrepancy] \n --[Soup (%s)] attributes: [%s] \n --[Graph (%s)] attribute: %s"%(soup_tag.name, str(soup_tag.attrs), graph_node['node_id'], str(graph_node['node_attributes']) if 'node_attributes' in graph_node else None))93 printing(p=print_mode, what='-'*80)94 return equality95def have_same_children(graph, soup_tag, graph_node, graph_node_id):96 equality = False97 children_soup = [_ for _ in soup_tag.children if isinstance(_, Tag)]98 # ex. graph_node -> G['20']99 children_graph = [_ for _ in graph[graph_node_id]]100 if len(children_soup) == len(children_graph):101 if [child.name.lower() for child in children_soup] == [graph.nodes[child]['tag_name'].lower() for child in children_graph]:102 equal_child_counter = 0103 for pair in itertools.izip_longest(children_soup, children_graph):104 s_child = pair[0]105 g_child = pair[1]106 if s_child.name.lower() == graph.nodes[g_child]['tag_name'].lower() and have_same_attr(s_child, graph.nodes[g_child]):107 equal_child_counter += 1108 else:109 printing(p=print_mode, what="[Children Discrepancy] [Attributes are not the same] \n --[Soup (%s)] \n --[Graph (%s)] %s"%(s_child.name, graph.nodes[g_child]['node_id'], graph.nodes[g_child]['tag_name'].lower()))110 111 if len(children_soup) == len(children_graph) == equal_child_counter:112 equality = True113 else:114 printing(p=print_mode, what="[Children Discrepancy] Length is not equal: [Soup (%s)], [Graph (%s)] %s:"%(soup_tag.name, graph_node['node_id'], graph.nodes[graph_node_id]['tag_name'].lower()))115 if not equality:116 print("Graph %s"%[graph.node[_]['node_id'] for _ in children_graph])117 print("Soup %s"%[_.name for _ in children_soup])118 printing(p=print_mode, what="[Children Discrepancy] [Soup (%s)] [Graph (%s)] %s"%(soup_tag.name, graph_node['node_id'], graph_node['tag_name']))119 for c in itertools.izip_longest([_.name for _ in children_soup],[(graph.node[_]['node_id'], graph.node[_]['tag_name']) for _ in children_graph if 'tag_name' in graph.node[_]]):120 #if c[0].lower() != c[1][1].lower():121 if c[0] and c[1]:122 printing(p=print_mode, what="--[child] [Soup] %s, [Graph (%s)] %s"%(c[0], c[1][0], c[1][1]))123 # elif child is None 124 else:125 printing(p=print_mode, what="--[child] [Soup] %s, [Graph (%s)] [None child]"%(c[0], c[1]))126 printing(p=print_mode, what='-'*80)127 return equality128def have_same_text(soup_tag, graph_node, node_id):129 equality = False130 if 'textContent' in graph_node and " ".join(map(lambda _:_.strip(), soup_tag.findAll(text=True, recursive=False))).strip() != None: #soup_tag.find(text=True, recursive=False) != None:131 #if " ".join(map(lambda _:_.strip(), soup_tag.findAll(text=True, recursive=False))).strip().split() == graph_node['textContent'].strip().encode('latin1').decode('utf8').split():132 if html_parser.unescape(" ".join(map(lambda _:_.strip(), soup_tag.findAll(text=True, recursive=False))).strip()).split() == ftfy.fix_text_encoding(graph_node['textContent'].strip()).split():133 equality = True134 else:135 printing(p=print_mode, what="[TEXT] \n --[Soup (%s)]: text: %s \n --[Graph (%s)]: text: %s"%(soup_tag.name, html_parser.unescape(" ".join(map(lambda _:_.strip(), soup_tag.findAll(text=True, recursive=False))).strip()), graph_node['tag_name'], graph_node['textContent']))136 elif 'textContent' not in graph_node:137 if soup_tag.find(text=True, recursive=False) is None: 138 equality = True139 elif not soup_tag.find(text=True, recursive=False).strip():140 equality = True141 if not equality:142 printing(p=print_mode, what="[TEXT] \n --[Soup (%s)]: text: %s \n --[Graph (%s)]: text: %s"%(soup_tag.name, html_parser.unescape(" ".join(map(lambda _:_.strip(), soup_tag.findAll(text=True, recursive=False))).strip()) if " ".join(map(lambda _:_.strip(), soup_tag.findAll(text=True, recursive=False))).strip() else None, graph_node['node_id'], ftfy.fix_encoding(graph_node['textContent']) if 'textContent' in graph_node else None))143 printing(p=print_mode, what='-'*80)144 return equality145 146def is_equal(graph, soup_tag, graph_node, graph_node_id):147 equality = False148 try:149 if 'tag_name' in graph_node and soup_tag.name:150 if soup_tag.name.lower() == graph_node['tag_name'].lower():151 #if have_same_text(soup_tag, graph_node, graph_node_id):152 if have_same_attr(soup_tag, graph_node): 153 if have_same_children(graph, soup_tag, graph_node, graph_node_id):154 equality = True155 else:156 printing(p=print_mode, what="[Not Equal Children] [Soup]: %s, [Graph (%s)]: %s"%(soup_tag.name, graph_node['node_id'], graph_node['tag_name']))157 #printing(print_mode=print_mode, soup_tag.name, soup_tag.attrs, graph_node['tag_name'], graph_node['node_attributes'] if 'node_attributes' in graph_node else None)158 else:159 printing(p=print_mode, what="[Not Equal Attribute] [Soup]: %s, [Graph (%s)]: %s"%(soup_tag.name, graph_node['node_id'], graph_node['tag_name']))160 #else:161 # printing(p=print_mode, what="[Not Equal TEXT] [Soup]: %s, [Graph (%s)]: %s"%(soup_tag.name, graph_node['node_id'], graph_node['tag_name']))162 else:163 printing(p=print_mode, what="[Not Equal TagName] [Soup]: %s, [Graph (%s)]: %s"%(soup_tag.name, graph_node['node_id'], graph_node['tag_name']))164 except:165 print(soup_tag)166 return equality...
auxiliary.py
Source:auxiliary.py
...438 def get_print(self, session, idmachine):439 """Retornando print"""440 foundmachine = session.query(Compliance_attr)\441 .filter_by(id=idmachine).first()442 machinedict = dict((col, getattr(foundmachine, col))\443 for col in foundmachine.__table__.columns.keys())444 for key, value in custome_dict.iteritems():445 j = PrettyTable(['Nome compliance', 'Valor'])446 j.add_row([key, value])447 print j448 return j449 def delete(self, session, idmachine):450 foundcompliance = session.query(Compliance_attr).\451 filter_by(id=idmachine).delete()452 session.commit()453 session.flush()454 def get(self, session, idmachine):455 """Retornando tabela full"""456 foundcompliance = session.query(Compliance_attr)\457 .filter_by(id=idmachine).first()458 if not foundmachine:459 return False460 if foundmachine.id == idmachine:461 return foundcompliance462 else:463 return False464 def update_attr(self, session, idmachine, attr):465 foundmachine = session.query(Compliance_attr)\466 .filter_by(id=idmachine).first()467 if not foundmachine:468 return False469 if foundmachine.id == idmachine:470 inst = inspect(Compliance_attr)471 attr_names = [c_attr.key for c_attr in inst.mapper.columns_attrs]472 if attr in attr_names:473 if foundmachine.attr == True:474 foundmachine.attr = False475 session.add(foundmachine)476 session.commit()477 session.flush()478 if foundmachine.attr == False:479 foundmachine.attr = True480 session.add(foundmachine)481 session.commit()482 session.flush()483 else:484 return False485 def cria(self, session, idmaquina):486 newcompliance = dbmodel.Compliance_attr(machineid=idmaquina)487 session.add(newcompliance)488 session.commit()489 session.flush()490 return newcompliance491 def update_obs(self, session, idmaquina, newobs):492 compliancefound = session.query(Compliance_attr).\493 filter_by(id=idmaquina).first()494 if not compliancefound:495 return False496 compliancefound.observacoes = newobs497 session.commit()498 session.flush()499 return compliancefound500#handlers criacao de maquinas e compliances...
jobs.py
Source:jobs.py
1table_config = [2 {3 'q': None, # ç¨äºæ°æ®åºæ¥è¯¢çå段ï¼å³Model.Tb.objects.filter(*[])4 'title': "é项", # åæ®µè¡¨æ ¼ä¸æ¾ç¤ºçæ é¢5 'display': 1, # æ¯å¦å¨å段æ¾ç¤ºï¼0表示å¨å端ä¸æ¾ç¤º, 1表示å¨å端éè, 2表示å¨å段æ¾ç¤º6 'text': {'content': "<input type='checkbox' />", 'kwargs': {}}, # ä¸ä¸ª@符å·è¡¨ç¤ºåæ°æ®åºå
çæ°æ®ï¼ä¸¤ä¸ª @ 符å·è¡¨ç¤ºåå
¨å±åéä¸ä¸èªèº«ç¸ççææ¬ä¿¡æ¯7 'attr': {} # èªå®ä¹å±æ§8 }, {9 'q': 'id', # ç¨äºæ°æ®åºæ¥è¯¢çå段ï¼å³Model.Tb.objects.filter(*[])10 'title': "ID", # åæ®µè¡¨æ ¼ä¸æ¾ç¤ºçæ é¢11 'display': 1, # æ¯å¦å¨å段æ¾ç¤ºï¼0表示å¨å端ä¸æ¾ç¤º, 1表示å¨å端éè, 2表示å¨å段æ¾ç¤º12 'text': {'content': "{id}", 'kwargs': {'id': '@id'}}, # ä¸ä¸ª@符å·è¡¨ç¤ºåæ°æ®åºå
çæ°æ®ï¼ä¸¤ä¸ª @ 符å·è¡¨ç¤ºåå
¨å±åéä¸ä¸èªèº«ç¸ççææ¬ä¿¡æ¯13 'attr': {} # èªå®ä¹å±æ§14 },15 {16 'q': 'department_demand__name',17 'title': "éæ±é¨é¨",18 'display': 1,19 'text': {'content': "{n}", 'kwargs': {'n': '@department_demand__name'}},20 'attr': {}21 }, {22 'q': 'department_head__head',23 'title': "é¨é¨è´è´£äºº",24 'display': 1,25 'text': {'content': "{n}", 'kwargs': {'n': '@department_head__head'}},26 'attr': {}27 },28 # {29 # 'q': 'create_at',30 # 'title': "éæ±æ交æ¥æ",31 # 'display': 1,32 # 'text': {'content': "{n}", 'kwargs': {'n': '@create_at'}},33 # 'attr': {}34 # },35 # {36 # 'q': 'apply_information',37 # 'title': "åºèåå ",38 # 'display': 1,39 # 'text': {'content': "{n}", 'kwargs': {'n': '@apply_information'}},40 # 'attr': {}41 # }, {42 # 'q': 'customer_name__name',43 # 'title': "客æ·å称",44 # 'display': 1,45 # 'text': {'content': "{n}", 'kwargs': {'n': '@customer_name__name'}},46 # 'attr': {}47 # }, {48 # 'q': 'customer_id__id',49 # 'title': "客æ·ç¼å·",50 # 'display': 1,51 # 'text': {'content': "{n}", 'kwargs': {'n': '@customer_id__id'}},52 # 'attr': {}53 # }, {54 # 'q': 'projects_name__name',55 # 'title': "项ç®å称",56 # 'display': 1,57 # 'text': {'content': "{n}", 'kwargs': {'n': '@projects_name__name'}},58 # 'attr': {}59 # }, {60 # 'q': 'projects_id__id',61 # 'title': "项ç®ç¼å·",62 # 'display': 1,63 # 'text': {'content': "{n}", 'kwargs': {'n': '@projects_id__id'}},64 # 'attr': {}65 # },66 {67 'q': 'jobs',68 'title': "å²ä½å称",69 'display': 1,70 'text': {'content': "{n}", 'kwargs': {'n': '@jobs'}},71 'attr': {72 'edit-enable': 'true',73 'edit-type': 'input'74 },75 },76 {77 'q': 'state',78 'title': "ç´§æ¥ç¨åº¦",79 'display': 1,80 'text': {'content': "{n}", 'kwargs': {'n': '@@jobState_type_choices'}},81 'attr': {82 'edit-enable': 'true',83 'edit-type': 'select',84 'global-name': "jobState_type_choices",85 'origin': '@state'86 },87 },88 # {89 # 'q': 'personnel_type',90 # 'title': "æ£å¼äººå",91 # 'display': 1,92 # 'text': {'content': "{n}", 'kwargs': {'n': '@personnel_type'}},93 # 'attr': {}94 # }, {95 # 'q': 'personnel_attr',96 # 'title': "人åå±æ§",97 # 'display': 1,98 # 'text': {'content': "{n}", 'kwargs': {'n': '@personnel_attr'}},99 # 'attr': {}100 # }, {101 # 'q': 'customer_level__level',102 # 'title': "客æ·çº§å«",103 # 'display': 1,104 # 'text': {'content': "{n}", 'kwargs': {'n': '@customer_level__level'}},105 # 'attr': {}106 # }, {107 # 'q': 'onsite',108 # 'title': "Onsite",109 # 'display': 1,110 # 'text': {'content': "{n}", 'kwargs': {'n': '@onsite'}},111 # 'attr': {}112 # }, {113 # 'q': 'on_business_trip',114 # 'title': "åºå·®",115 # 'display': 1,116 # 'text': {'content': "{n}", 'kwargs': {'n': '@on_business_trip'}},117 # 'attr': {}118 # }, {119 # 'q': 'degree',120 # 'title': "å¦åè¦æ±",121 # 'display': 1,122 # 'text': {'content': "{n}", 'kwargs': {'n': '@degree'}},123 # 'attr': {}124 # }, {125 # 'q': 'work_time',126 # 'title': "å·¥ä½å¹´é",127 # 'display': 1,128 # 'text': {'content': "{n}", 'kwargs': {'n': '@work_time'}},129 # 'attr': {}130 # }, {131 # 'q': 'equipped_computer',132 # 'title': "é
å¤çµè",133 # 'display': 1,134 # 'text': {'content': "{n}", 'kwargs': {'n': '@equipped_computer'}},135 # 'attr': {}136 # }, {137 # 'q': 'personnel_costs',138 # 'title': "人åææ¬ç±»å",139 # 'display': 1,140 # 'text': {'content': "{n}", 'kwargs': {'n': '@personnel_costs'}},141 # 'attr': {}142 # },143 {144 'q': 'number',145 'title': "éæ±äººæ°",146 'display': 1,147 'text': {'content': "{n}", 'kwargs': {'n': '@number'}},148 'attr': {149 'name': 'number',150 'edit-enable': 'true',151 'edit-type': 'input'152 },153 },154 {155 'q': 'work_place__name',156 'title': "å·¥ä½å°ç¹",157 'display': 1,158 'text': {'content': "{n}", 'kwargs': {'n': '@work_place__name'}},159 'attr': {160 'name': "work_place_name",161 "edit-enable": "true",162 'edit-type': 'select',163 'global-name': "work_place_choices",164 'origin': '@work_place__id'165 },166 },167 {168 'q': 'work_place__id',169 'title': "",170 'display': 2,171 'text': {},172 'attr': {},173 },174 # {175 # 'q': 'salary',176 # 'title': "èªèµèå´",177 # 'display': 1,178 # 'text': {'content': "{n}", 'kwargs': {'n': '@salary'}},179 # 'attr': {}180 # }, {181 # 'q': 'position_requirements',182 # 'title': "å²ä½éæ±ä¿¡æ¯",183 # 'display': 1,184 # 'text': {'content': "{n}", 'kwargs': {'n': '@position_requirements'}},185 # 'attr': {}186 # }, {187 # 'q': 'search_key',188 # 'title': "å²ä½å
³é®å",189 # 'display': 1,190 # 'text': {'content': "{n}", 'kwargs': {'n': '@search_key'}},191 # 'attr': {}192 # }, {193 # 'q': 'joblevel',194 # 'title': "å²ä½çº§å«",195 # 'display': 1,196 # 'text': {'content': "{n}", 'kwargs': {'n': '@joblevel'}},197 # 'attr': {}198 # }, {199 # 'q': 'jobs_highlight',200 # 'title': "å²ä½äº®ç¹",201 # 'display': 1,202 # 'text': {'content': "{n}", 'kwargs': {'n': '@jobs_highlight'}},203 # 'attr': {}204 # }, {205 # 'q': 'project_size',206 # 'title': "å²ä½è§æ¨¡",207 # 'display': 1,208 # 'text': {'content': "{n}", 'kwargs': {'n': '@project_size'}},209 # 'attr': {}210 # }, {211 # 'q': 'referral_bonus',212 # 'title': "æ¨èå¥é",213 # 'display': 1,214 # 'text': {'content': "{n}", 'kwargs': {'n': '@referral_bonus'}},215 # 'attr': {}216 # }, {217 # 'q': 'customers',218 # 'title': "客æ·ä»ç»",219 # 'display': 1,220 # 'text': {'content': "{n}", 'kwargs': {'n': '@customers'}},221 # 'attr': {}222 # }, {223 # 'q': 'describe',224 # 'title': "å¤æ³¨",225 # 'display': 1,226 # 'text': {'content': "{n}", 'kwargs': {'n': '@describe'}},227 # 'attr': {}228 # },229 {230 'q': 'candidate__username',231 'title': "åé人",232 'display': 1,233 'text': {'content': "{n}", 'kwargs': {'n': '@candidate__username'}},234 'attr': {}235 },236 {237 'q': None,238 'title': "é项",239 'display': 1,240 'text': {241 'content': '''<td class="actions">242 <a href="#" class="on-default edit-row" data-toggle="tooltip" data-placement="top" title="" data-original-title="Edit"><i class="fa fa-pencil"></i></a>243 <a href="#" class="on-default remove-row" data-toggle="tooltip" data-placement="top" title="" data-original-title="Delete"><i class="fa fa-trash-o"></i></a>244 <a href="#" class="hidden on-editing save-row" data-toggle="tooltip" data-placement="top" title="" data-original-title="Save"><i class="fa fa-save"></i></a>245 <a href="#" class="hidden on-editing cancel-row" data-toggle="tooltip" data-placement="top" title="" data-original-title="Cancel"><i class="fa fa-times"></i></a>246</td>''',247 'kwargs': {'device_type_id': '@device_type_id', 'nid': '@id'}},248 'attr': {}249 },...
data_pickle.py
Source:data_pickle.py
...67 tclass = global_vars.get(path, None)68 if tclass is not None:69 return tclass70 paths = path.split('.')71 global_vars[path] = lib = getattr(__import__(72 '.'.join(paths[0:-1]), globals(), locals(), fromlist=[paths[-1]], level=0), paths[-1])73 return lib74 @classmethod75 def _get_attrs_by_class(cls, instance_class: str, attr_type=None) -> Optional[Dict[str, Any]]:76 info = cls.class_info.get(instance_class, None)77 if info is None:78 return None79 if attr_type is None:80 output = {}81 output.update(info.get('init_args', {}))82 output.update(info.get('instance_attrs', {}))83 return output84 return info.get(attr_type, None)85 @classmethod86 def _to_data_object(cls, obj):87 instance_class = cls.get_class_path(type(obj))88 result = {89 '__instance__': instance_class90 }91 attribute_action = cls._get_attrs_by_class(instance_class)92 if attribute_action is None:93 attribute_action = {}94 default_action = attribute_action.get('*', DefaultValue)95 for attr_name, attr_value in obj.__dict__.items():96 attr_action = attribute_action.get(attr_name, default_action)97 if not isinstance(attr_action, PicklerAction):98 continue99 result[attr_name] = cls.to_data(attr_value)100 return result101 @classmethod102 def _to_data_dict(cls, obj: dict) -> dict:103 output = {}104 for attr_name, attr_value in obj.items():105 output[attr_name] = cls.to_data(attr_value)106 return output107 @classmethod108 def _to_data_list(cls, obj: list) -> list:109 output = []110 for attr_value in obj:111 output.append(cls.to_data(attr_value))112 return output113 @classmethod114 def to_data(cls, obj):115 if isinstance(obj, (str, type(None), int, float)):116 return obj117 if isinstance(obj, dict):118 return cls._to_data_dict(obj)119 if isinstance(obj, list):120 return cls._to_data_list(obj)121 if isinstance(obj, tuple):122 return tuple(cls._to_data_list(obj))123 if isinstance(obj, type):124 return {'__class__': obj.__name__}125 return cls._to_data_object(obj)126 @classmethod127 def from_data(cls, data):128 if isinstance(data, (str, type(None), int, float)):129 return data130 if isinstance(data, dict):131 return cls._from_data_dict(data)132 if isinstance(data, list):133 return cls._from_data_list(data)134 raise RuntimeError(f"Wrong data type '{type(data)}' ")135 @classmethod136 def _from_data_iterate_attrs(cls, instance_type, obj: dict, args_bucket_name: str, default_action=None):137 action_map = cls._get_attrs_by_class(instance_type, args_bucket_name)138 if not action_map:139 raise RuntimeError("Unknown class, please, add this class to 'class_info'")140 default_action = action_map.get('*', default_action)141 if not isinstance(obj, dict):142 obj = obj.__dict__143 for attr_name, attr_value in obj.items():144 attr_action = action_map.get(attr_name, default_action)145 if attr_action is PicklerAction.PASSTHROUGH:146 yield attr_name, cls.from_data(attr_value)147 continue148 for attr_name, attr_action in action_map.items():149 if attr_name != '*' and not isinstance(attr_action, PicklerAction):150 yield attr_name, attr_action151 @classmethod152 def _from_data_dict(cls, obj: dict) -> dict:153 obj = obj.copy()154 tmp = obj.pop('__class__', None)155 if tmp is not None:156 return globals()[tmp]157 instance_type = obj.pop('__instance__', None)158 if instance_type is None:159 return {attr_name: cls.from_data(attr_value) for attr_name, attr_value in obj.items()}160 init_args = {}161 for attr_name, attr_value in cls._from_data_iterate_attrs(instance_type, obj, 'init_args', None):162 init_args[attr_name] = attr_value163 instance_class = cls.import_and_return(instance_type)164 instance = instance_class(**init_args)165 for attr_name, attr_value in cls._from_data_iterate_attrs(instance_type, obj, 'init_args', PicklerAction.PASSTHROUGH):166 setattr(instance, attr_name, attr_value)167 return instance168 @classmethod169 def _from_data_list(cls, obj: list) -> list:170 output = []171 for attr_value in obj:172 output.append(cls.from_data(attr_value))173 return output174 @classmethod175 def load_from_file(cls, filepath):176 with open(filepath, encoding="utf-8") as file:177 return cls.from_data(json.load(file))178 @classmethod179 def load_data_from_file(cls, filepath):180 with open(filepath, encoding="utf-8") as file:...
const.py
Source:const.py
1"""Predefined values for Moscow PGU integration"""2from typing import Final3DEVICE_CLASS_PGU_INDICATIONS: Final = "pgu_indications"4UNIT_CURRENCY_RUSSIAN_ROUBLES: Final = "RUB"5ATTR_ADDRESS: Final = "address"6ATTR_AMOUNT: Final = "amount"7ATTR_AMOUNT_WITH_INSURANCE: Final = "amount_with_insurance"8ATTR_ARTICLE_TITLE: Final = "article_title"9ATTR_BAILIFF_NAME: Final = "bailiff_name"10ATTR_BAILIFF_PHONE: Final = "bailiff_phone"11ATTR_BALANCE_MESSAGE: Final = "balance_message"12ATTR_BIRTH_DATE: Final = "birth_date"13ATTR_CERTIFICATE_SERIES: Final = "certificate_series"14ATTR_CHARGES_AMOUNT: Final = "charges_amount"15ATTR_CHECKUP_DATE: Final = "checkup_date"16ATTR_CLASS: Final = "class"17ATTR_CODES: Final = "codes"18ATTR_COMMITTED_AT: Final = "committed_at"19ATTR_COUNTER_ID: Final = "counter_id"20ATTR_COUNTER_IDS: Final = "counter_ids"21ATTR_CREATE_DATETIME: Final = "create_datetime"22ATTR_DEBTS: Final = "debts"23ATTR_DEBT_AMOUNT: Final = "debt_amount"24ATTR_DECIMAL_PART_LENGTH: Final = "decimal_part_length"25ATTR_DESCRIPTION: Final = "description"26ATTR_DEVICE: Final = "device"27ATTR_DISCOUNT_DATE: Final = "discount_date"28ATTR_DOCUMENT_SERIES: Final = "document_series"29ATTR_DOCUMENT_TYPE: Final = "document_type"30ATTR_DRIVING_LICENSE_ISSUE_DATE: Final = "driving_license_issue_date"31ATTR_DRIVING_LICENSE_NUMBER: Final = "driving_license_number"32ATTR_DRY_RUN: Final = "dry_run"33ATTR_EMAIL: Final = "email"34ATTR_EMAIL_CONFIRMED: Final = "email_confirmed"35ATTR_END: Final = "end"36ATTR_ENTERPRENEUR_ID: Final = "enterpreneur_id"37ATTR_ENTRANCE_NUMBER: Final = "entrance_number"38ATTR_EPDS: Final = "epds"39ATTR_EPD_ACCOUNT: Final = "epd_account"40ATTR_FIRST_NAME: Final = "first_name"41ATTR_FLAT_ID: Final = "flat_id"42ATTR_FLAT_NUMBER: Final = "flat_number"43ATTR_FLOOR: Final = "floor"44ATTR_FORCE: Final = "force"45ATTR_INDICATION: Final = "indication"46ATTR_INDICATIONS: Final = "indications"47ATTR_INITIATOR: Final = "initiator"48ATTR_INSURANCE_AMOUNT: Final = "insurance_amount"49ATTR_INTERCOM: Final = "intercom"50ATTR_ISSUE_DATE: Final = "issue_date"51ATTR_IS_AT_SCHOOL: Final = "is_at_school"52ATTR_IS_EVACUATED: Final = "is_evacuated"53ATTR_KLADR_MAIN_NAME: Final = "kladr_main_name"54ATTR_KLADR_STREET_NAME: Final = "kladr_street_name"55ATTR_LAST_INDICATION_PERIOD: Final = "last_indication_period"56ATTR_LAST_INDICATION_VALUE: Final = "last_indication_value"57ATTR_LAST_NAME: Final = "last_name"58ATTR_LAST_PAYMENT_DATE: Final = "last_payment_date"59ATTR_LAST_PAYMENT_AMOUNT: Final = "last_payment_amount"60ATTR_LAST_UPDATE_DATE: Final = "last_update_date"61ATTR_LICENSE_PLATE: Final = "license_plate"62ATTR_LOCATION: Final = "location"63ATTR_MIDDLE_NAME: Final = "middle_name"64ATTR_NUMBER: Final = "number"65ATTR_OFFENSES: Final = "offenses"66ATTR_ORIGINAL_INDICATIONS: Final = "original_indications"67ATTR_PAYMENTS: Final = "payments"68ATTR_PAYMENTS_AMOUNT: Final = "payments_amount"69ATTR_PAYMENT_AMOUNT: Final = "payment_amount"70ATTR_PAYMENT_DATE: Final = "payment_date"71ATTR_PAYMENT_STATUS: Final = "payment_status"72ATTR_PAY_LIMIT: Final = "pay_limit"73ATTR_PENALTY: Final = "penalty"74ATTR_PENALTY_AMOUNT: Final = "penalty_amount"75ATTR_PERIOD: Final = "period"76ATTR_PERIODS: Final = "periods"77ATTR_PHONE_NUMBER: Final = "phone_number"78ATTR_PHOTO_URL: Final = "photo_url"79ATTR_POLICE_UNIT_CODE: Final = "police_unit_code"80ATTR_POLICE_UNIT_NAME: Final = "police_unit_name"81ATTR_REASON: Final = "reason"82ATTR_RETURNS_AMOUNT: Final = "returns_amount"83ATTR_RISE_DATE: Final = "rise_date"84ATTR_SCHOOL: Final = "school"85ATTR_SERVICE_TYPE: Final = "service_type"86ATTR_SETTLEMENT_DATE: Final = "settlement_date"87ATTR_START: Final = "start"88ATTR_STATUS: Final = "status"89ATTR_STATUS_RNIP: Final = "status_rnip"90ATTR_STATUS_TEXT: Final = "status_text"91ATTR_SUBMIT_AVAILABLE: Final = "submit_available"92ATTR_SUBMIT_BEGIN_DATE: Final = "submit_begin_date"93ATTR_SUBMIT_END_DATE: Final = "submit_end_date"94ATTR_SUCCESS: Final = "success"95ATTR_TARIFF: Final = "tariff"96ATTR_TOTAL: Final = "total"97ATTR_TRANSFER_AMOUNT: Final = "transfer_amount"98ATTR_TYPE: Final = "type"99ATTR_TYPES: Final = "types"100ATTR_UNLOAD_DATE: Final = "unload_date"101ATTR_UNLOAD_STATUS: Final = "unload_status"102ATTR_UNPAID_AMOUNT: Final = "unpaid_amount"103ATTR_UNPAID_BAILIFF: Final = "unpaid_bailiff"104ATTR_UNPAID_ENTERPRENEUR: Final = "unpaid_enterpreneur"105ATTR_WHOLE_PART_LENGTH: Final = "whole_part_length"106ATTR_ZONES: Final = "zones"107ATTR_ZONE_ID: Final = "zone_id"108ATTR_ZONE_NAME: Final = "zone_name"109ATTR_SERVICE_IS_OFFLINE: Final = "service_is_offline"110TYPE_ELECTRIC: Final = "electric"111TYPE_WATER: Final = "water"112DOMAIN: Final = "moscow_pgu"113DATA_YAML_CONFIG: Final = DOMAIN + "_yaml_config"114DATA_FINAL_CONFIG: Final = DOMAIN + "_final_config"115DATA_ENTITIES: Final = DOMAIN + "_entities"116DATA_SESSION_LOCK: Final = DOMAIN + "_session_lock"117DATA_UPDATERS: Final = DOMAIN + "_updaters"118DATA_UPDATE_LISTENERS: Final = DOMAIN + "_update_listeners"119CONF_APP_VERSION: Final = "app_version"120CONF_BIRTH_DATE: Final = "birth_date"121CONF_DEVICE_AGENT: Final = "device_agent"122CONF_DEVICE_INFO: Final = "device_info"123CONF_DEVICE_OS: Final = "device_os"124CONF_DRIVING_LICENSES: Final = "driving_licenses"125CONF_FILTER: Final = "filter"126CONF_FIRST_NAME: Final = "first_name"127CONF_GUID: Final = "guid"128CONF_ISSUE_DATE: Final = "issue_date"129CONF_LAST_NAME: Final = "last_name"130CONF_MIDDLE_NAME: Final = "middle_name"131CONF_NAME_FORMAT: Final = "name_format"132CONF_NUMBER: Final = "number"133CONF_SERIES: Final = "series"134CONF_TOKEN: Final = "token"135CONF_TRACK_FSSP_PROFILES: Final = "track_fssp_profiles"136CONF_USER_AGENT: Final = "user_agent"137CONF_ROOT_UPDATE_INTERVAL: Final = "root_update_interval"138SUPPORTED_PLATFORMS: Final = ("sensor",) # "binary_sensor") # This will be changed later...
ex40-14.py
Source:ex40-14.py
1# ë©íí´ëì¤ vs ìí¼í´ëì¤2class A(type): attr = 13class B(metaclass=A): pass # Bë ë©íí´ëì¤ì ì¸ì¤í´ì¤. ë©íì ìì±ì ì»ì4I = B() # Ië ë©íê° ìëë¼ í´ëì¤ë¡ë¶í° ìì!5B.attr6# 17I.attr8# AttributeError: 'B' object has no attribute 'attr'9'attr' in B.__dict__, 'attr' in A.__dict__10# (False, True)11class A: attr = 112class B(A): pass # Ië í´ëì¤ì ìí¼í´ëì¤ë¡ë¶í° ììë°ì13I = B()14B.attr15# 116I.attr17# 118'attr' in B.__dict__, 'attr' in A.__dict__19# (False, True)20class M(type): attr = 121class A: attr = 222class B(A, metaclass=M): pass # ìí¼í´ëì¤ë ë©íí´ëì¤ì ì°ì í¨23I = B()24B.attr, I.attr25# (2, 2)26'attr' in B.__dict__, 'attr' in A.__dict__, 'attr' in M.__dict__27# (False, True, True)28class M(type): attr = 129class A: attr = 230class B(A): pass31class C(B, metaclass=M): pass # Superë metaë³´ë¤ 2ë 벨 ì: ì¬ì í ì´ê¹!32I = C()33I.attr, C.attr34# (2, 2)35[x.__name__ for x in C.__mro__] # MROì ëí 모ë ê²ì 32ì¥ ì°¸ì¡°36# ['C', 'B', 'A', 'object']37I.__class__ # ììì ìí´ ë°ë¦: ì¸ì¤í´ì¤ì í´ëì¤38# <class '__main__.C'>39C.__bases__ # ììì ìí´ ë°ë¦: í´ëì¤ì ìí¼í´ëì¤40# (<class '__main__.B'>,)41C.__class__ # ì¸ì¤í´ì¤ íëì ìí´ ë°ë¦: ë©íí´ëì¤42# <class '__main__.M'>43C.__class__.attr # ë©íí´ëì¤ ìì±ì ê°ì ¸ì¤ë ë¤ë¥¸ ë°©ì...
Using AI Code Generation
1var istanbul = require('istanbul');2var collector = new istanbul.Collector();3var reporter = new istanbul.Reporter();4reporter.addAll(['lcov', 'html']);5reporter.write(collector, sync, function () {6console.log('All reports generated');7});
Using AI Code Generation
1var istanbul = require('istanbul');2var instrumenter = new istanbul.Instrumenter();3var fs = require('fs');4var code = fs.readFileSync('test.js', 'utf-8');5var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');6fs.writeFileSync('test-instrumented.js', instrumentedCode);7- [Istanbul](
Using AI Code Generation
1var istanbul = require('istanbul');2var instrumenter = new istanbul.Instrumenter();3var fs = require('fs');4var code = fs.readFileSync('test.js', 'utf-8');5var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');6fs.writeFileSync('instrumentedTest.js', instrumentedCode);7var istanbul = require('istanbul');8var instrumenter = new istanbul.Instrumenter();9var fs = require('fs');10var code = fs.readFileSync('test.js', 'utf-8');11var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');12fs.writeFileSync('instrumentedTest.js', instrumentedCode);
Using AI Code Generation
1var istanbul = require('istanbul');2var instrumenter = new istanbul.Instrumenter();3var code = 'var x = 1;';4var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');5console.log(instrumentedCode);6var istanbul = require('istanbul');7var instrumenter = new istanbul.Instrumenter();8var code = 'var x = 1;';9instrumenter.instrument(code, 'test.js', function(err, instrumentedCode) {10 console.log(instrumentedCode);11});12var istanbul = require('istanbul');13var instrumenter = new istanbul.Instrumenter();14var code = 'var x = 1;';15var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');16console.log(instrumentedCode);17var istanbul = require('istanbul');18var instrumenter = new istanbul.Instrumenter();19var code = 'var x = 1;';20var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');21console.log(instrumentedCode);22var istanbul = require('istanbul');23var instrumenter = new istanbul.Instrumenter();24var code = 'var x = 1;';25var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');26console.log(instrumentedCode);27var istanbul = require('istanbul');28var instrumenter = new istanbul.Instrumenter();29var code = 'var x = 1;';30var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');31console.log(instrumentedCode);32var istanbul = require('istanbul');33var instrumenter = new istanbul.Instrumenter();34var code = 'var x = 1;';35var instrumentedCode = instrumenter.instrumentSync(code, 'test.js');36console.log(instrumentedCode);
Using AI Code Generation
1var istanbul = require('istanbul');2var instrumenter = new istanbul.Instrumenter({ coverageVariable: '__coverage__' });3var fs = require('fs');4var code = fs.readFileSync('./test.js', 'utf8');5var instrumentedCode = instrumenter.instrumentSync(code, './test.js');6fs.writeFileSync('./test-istanbul.js', instrumentedCode);7var jscover = require('jscover');8var instrumentedCode = jscover.instrument(code, './test.js');9fs.writeFileSync('./test-jscover.js', instrumentedCode);10var blanket = require('blanket');11var instrumentedCode = blanket.instrument({12});13fs.writeFileSync('./test-blanket.js', instrumentedCode);14var jscoverage = require('jscoverage');15var instrumentedCode = jscoverage.instrument(code, './test.js');16fs.writeFileSync('./test-jscoverage.js', instrumentedCode);17var jscoverage = require('jscoverage');18var instrumentedCode = jscoverage.instrument(code, './test.js');19fs.writeFileSync('./test-jscoverage.js', instrumentedCode);20var nodeJscoverage = require('node-jscoverage');21var instrumentedCode = nodeJscoverage.instrument(code, './test.js');22fs.writeFileSync('./test-node-jscoverage.js', instrumentedCode);23var nodeJscoverage = require('node-jscoverage');24var instrumentedCode = nodeJscoverage.instrument(code, './test.js');25fs.writeFileSync('./test-node-jscoverage.js', instrumentedCode);26var nodeJscoverage = require('node-jscoverage');27var instrumentedCode = nodeJscoverage.instrument(code, './test.js');28fs.writeFileSync('./test-node-jscoverage.js', instrumentedCode);29var nodeJscoverage = require('node-jscoverage');30var instrumentedCode = nodeJscoverage.instrument(code, './test.js');31fs.writeFileSync('./test-node-jscoverage.js', instrumentedCode);32var nodeJscoverage = require('node-jscoverage');33var instrumentedCode = nodeJscoverage.instrument(code,
Using AI Code Generation
1if (__coverage__) {2 __coverage__['test.js'].s['1'] = 0;3 __coverage__['test.js'].s['2'] = 0;4 __coverage__['test.js'].s['3'] = 0;5 __coverage__['test.js'].s['4'] = 0;6 __coverage__['test.js'].s['5'] = 0;7 __coverage__['test.js'].s['6'] = 0;8 __coverage__['test.js'].s['7'] = 0;9 __coverage__['test.js'].s['8'] = 0;10 __coverage__['test.js'].s['9'] = 0;11 __coverage__['test.js'].s['10'] = 0;12 __coverage__['test.js'].s['11'] = 0;13 __coverage__['test.js'].s['12'] = 0;14 __coverage__['test.js'].s['13'] = 0;15 __coverage__['test.js'].s['14'] = 0;16 __coverage__['test.js'].s['15'] = 0;17 __coverage__['test.js'].s['16'] = 0;18 __coverage__['test.js'].s['17'] = 0;19 __coverage__['test.js'].s['18'] = 0;20 __coverage__['test.js'].s['19'] = 0;21 __coverage__['test.js'].s['20'] = 0;22 __coverage__['test.js'].s['21'] = 0;23 __coverage__['test.js'].s['22'] = 0;24 __coverage__['test.js'].s['23'] = 0;25 __coverage__['test.js'].s['24'] = 0;26 __coverage__['test.js'].s['25'] = 0;27 __coverage__['test.js'].s['26'] = 0;28 __coverage__['test.js'].s['27'] = 0;29 __coverage__['test.js'].s['28'] = 0;30 __coverage__['test.js'].s['29'] = 0;31 __coverage__['test.js'].s['30'] = 0;
Using AI Code Generation
1const istanbul = require('istanbul-api');2const libCoverage = istanbul.libCoverage;3const report = istanbul.reports.create('json');4const map = libCoverage.createCoverageMap({});5map.addFileCoverage({6 statementMap: { 0: { start: { line: 1, column: 0 }, end: { line: 1, column: 1 } } },7 s: { 0: 1 },8 fnMap: {},9 f: {},10 branchMap: {},11 b: {}12});13report.on('done', function () {14 console.log('done');15});16report.writeReport(map, { dir: './' });17const libCoverage = require('istanbul-lib-coverage');18const map = libCoverage.createCoverageMap({});19map.addFileCoverage({20 statementMap: { 0: { start: { line: 1, column: 0 }, end: { line: 1, column: 1 } } },21 s: { 0: 1 },22 fnMap: {},23 f: {},24 branchMap: {},25 b: {}26});27console.log(map.toJSON());28{ foo.js:29 { path: 'foo.js',30 statementMap: { '0': [Object] },31 s: { '0': 1 },32 fnMap: {},33 f: {},34 branchMap: {},35 b: {} } }36{ foo.js:37 { path: 'foo.js',38 statementMap: { '0': [Object] },39 s: { '0': 1 },40 fnMap: {},41 f: {},42 branchMap: {},43 b: {} },44 statementMap: { '0': { start: [Object], end: [Object] } },45 s: { '0': 1 },46 fnMap: {},47 f: {},48 branchMap: {},49 b: {} }
Using AI Code Generation
1var path = require('path');2var fs = require('fs');3var istanbul = require('istanbul');4var instrumenter = new istanbul.Instrumenter();5var coverageVariable = '__coverage__';6var instrumentedCode = instrumenter.instrumentSync(fs.readFileSync(path.resolve('test.js'), 'utf8'), path.resolve('test.js'));7eval(instrumentedCode);8var coverage = global[coverageVariable];9console.log(coverage);10{ '/Users/xyz/Documents/abc/test.js':11 { path: '/Users/xyz/Documents/abc/test.js',12 { '0': { start: [Object], end: [Object] },13 '1': { start: [Object], end: [Object] },14 '2': { start: [Object], end: [Object] },15 '3': { start: [Object], end: [Object] },16 '4': { start: [Object], end: [Object] },17 '5': { start: [Object], end: [Object] },18 '6': { start: [Object], end: [Object] },19 '7': { start: [Object], end: [Object] },20 '8': { start: [Object], end: [Object] },21 '9': { start: [Object], end: [Object] },22 '10': { start: [Object], end: [Object] },23 '11': { start: [Object], end: [Object] },24 '12': { start: [Object], end: [Object] },25 '13': { start: [Object], end: [Object] },26 '14': { start: [Object], end: [Object] },27 '15': { start: [Object], end: [Object] },28 '16': { start: [Object], end: [Object] },29 '17': { start: [Object], end: [Object] },30 '18': { start: [Object], end: [Object] },31 '19': { start: [Object], end: [Object] },32 '20': { start: [Object], end: [Object] },33 '21': { start: [Object], end: [Object] },34 '22': { start: [Object], end: [Object] },
Using AI Code Generation
1var $ = require('jquery');2var $el = $('<div></div>');3$el.attr('data-foo', 'bar');4{5 "s": {6 },7 "b": {},8 "f": {9 },10 "fnMap": {11 "1": {12 "name": "(anonymous_1)",13 "loc": {14 "start": {15 },16 "end": {17 }18 }19 }20 },21 "statementMap": {22 "1": {23 "start": {24 },25 "end": {26 }27 },28 "2": {29 "start": {30 },31 "end": {32 }33 },34 "3": {35 "start": {36 },37 "end": {38 }39 },40 "4": {41 "start": {42 },43 "end": {44 }45 }46 },47 "branchMap": {}48}49var Instrumenter = require('istanbul').Instrumenter
Using AI Code Generation
1function add(num1, num2) {2 return num1 + num2;3}4test('should add two numbers', () => {5 const result = add(3, 4);6 expect(result).toBe(7);7});8test('should add two numbers', () => {9 const result = add(3, 4);10 expect(result).toBe(7);11});12test('should add two numbers', () => {13 const result = add(3, 4);14 expect(result).toBe(7);15});16test('should add two numbers', () => {17 const result = add(3, 4);18 expect(result).toBe(7);19});20test('should add two numbers', () => {21 const result = add(3, 4);22 expect(result).toBe(7);23});24test('should add two numbers', () => {25 const result = add(3, 4);26 expect(result).toBe(7);27});28test('should add two numbers', () => {29 const result = add(3, 4);30 expect(result).toBe(7);31});32test('should add two numbers', () => {33 const result = add(3, 4);34 expect(result).toBe(7);35});36test('should add two numbers', () => {37 const result = add(3, 4);38 expect(result).toBe(7);39});40test('should add two numbers', () => {41 const result = add(3, 4);42 expect(result).toBe(7);43});44test('should add two numbers', () => {45 const result = add(3, 4);46 expect(result).toBe(7);47});48 ✓ should add two numbers (4ms)
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!