How to use head method in fMBT

Best Python code snippet using fMBT_python

head_test.py

Source:head_test.py Github

copy

Full Screen

...93 stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z)94 lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.)95 return sum(lpl)/len(lpl)96 def testPoissonWithLogits(self):97 head = head_lib.poisson_regression_head()98 labels = ((0.,), (1.,), (1.,))99 logits = ((0.,), (-1.,), (3.,))100 with ops.Graph().as_default(), session.Session():101 model_fn_ops = head.create_model_fn_ops(102 {},103 labels=labels,104 mode=model_fn.ModeKeys.TRAIN,105 train_op_fn=head_lib.no_op_train_fn,106 logits=logits)107 self._assert_output_alternatives(model_fn_ops)108 _assert_summary_tags(self, ["loss"])109 _assert_no_variables(self)110 loss = self._log_poisson_loss(logits, labels)111 _assert_metrics(self, loss, {"loss": loss}, model_fn_ops)112class RegressionHeadTest(test.TestCase):113 def _assert_output_alternatives(self, model_fn_ops):114 self.assertEquals({115 None: constants.ProblemType.LINEAR_REGRESSION116 }, {117 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)118 })119 # TODO(zakaria): test multilabel regression.120 def testRegressionWithLogits(self):121 head = head_lib.regression_head()122 with ops.Graph().as_default(), session.Session():123 model_fn_ops = head.create_model_fn_ops(124 {},125 labels=((0.,), (1.,), (1.,)),126 mode=model_fn.ModeKeys.TRAIN,127 train_op_fn=head_lib.no_op_train_fn,128 logits=((1.,), (1.,), (3.,)))129 self._assert_output_alternatives(model_fn_ops)130 _assert_summary_tags(self, ["loss"])131 _assert_no_variables(self)132 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)133 def testRegressionWithLogitFn(self):134 head = head_lib.regression_head(link_fn=math_ops.square)135 def _assert_preditions(test_case, expected_predictions, model_fn_ops):136 variables.initialize_local_variables().run()137 test_case.assertAllClose(expected_predictions,138 model_fn_ops.predictions["scores"].eval())139 with ops.Graph().as_default(), session.Session():140 model_fn_ops = head.create_model_fn_ops(141 {},142 labels=((0.,), (1.,), (1.,)),143 mode=model_fn.ModeKeys.TRAIN,144 train_op_fn=head_lib.no_op_train_fn,145 logits=((1.,), (1.,), (3.,)))146 self._assert_output_alternatives(model_fn_ops)147 _assert_summary_tags(self, ["loss"])148 _assert_no_variables(self)149 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)150 _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops)151 def testRegressionWithInvalidLogits(self):152 head = head_lib.regression_head()153 with ops.Graph().as_default(), session.Session():154 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):155 head.create_model_fn_ops(156 {},157 labels=((0.,), (1.,), (1.,)),158 mode=model_fn.ModeKeys.TRAIN,159 train_op_fn=head_lib.no_op_train_fn,160 logits=((1., 1.), (1., 1.), (3., 1.)))161 def testRegressionWithLogitsInput(self):162 head = head_lib.regression_head()163 with ops.Graph().as_default(), session.Session():164 model_fn_ops = head.create_model_fn_ops(165 {},166 labels=((0.,), (1.,), (1.,)),167 mode=model_fn.ModeKeys.TRAIN,168 train_op_fn=head_lib.no_op_train_fn,169 logits_input=((0., 0.), (0., 0.), (0., 0.)))170 self._assert_output_alternatives(model_fn_ops)171 w = ("regression_head/logits/weights:0",172 "regression_head/logits/biases:0")173 _assert_variables(174 self, expected_global=w, expected_model=w, expected_trainable=w)175 variables.global_variables_initializer().run()176 _assert_summary_tags(self, ["loss"])177 _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)178 def testRegressionWithLogitsAndLogitsInput(self):179 head = head_lib.regression_head()180 with ops.Graph().as_default(), session.Session():181 with self.assertRaisesRegexp(182 ValueError, "Both logits and logits_input supplied"):183 head.create_model_fn_ops(184 {},185 labels=((0.,), (1.,), (1.,)),186 mode=model_fn.ModeKeys.TRAIN,187 train_op_fn=head_lib.no_op_train_fn,188 logits_input=((0., 0.), (0., 0.), (0., 0.)),189 logits=((1.,), (1.,), (3.,)))190 def testRegressionEvalMode(self):191 head = head_lib.regression_head()192 with ops.Graph().as_default(), session.Session():193 model_fn_ops = head.create_model_fn_ops(194 {},195 labels=((1.,), (1.,), (3.,)),196 mode=model_fn.ModeKeys.EVAL,197 train_op_fn=head_lib.no_op_train_fn,198 logits=((0.,), (1.,), (1.,)))199 self._assert_output_alternatives(model_fn_ops)200 self.assertIsNone(model_fn_ops.train_op)201 _assert_no_variables(self)202 _assert_summary_tags(self, ["loss"])203 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)204 def testRegressionWithLabelName(self):205 label_name = "my_label"206 head = head_lib.regression_head(label_name=label_name)207 with ops.Graph().as_default(), session.Session():208 model_fn_ops = head.create_model_fn_ops(209 {},210 labels={label_name: ((0.,), (1.,), (1.,))},211 mode=model_fn.ModeKeys.TRAIN,212 train_op_fn=head_lib.no_op_train_fn,213 logits=((1.,), (1.,), (3.,)))214 self._assert_output_alternatives(model_fn_ops)215 _assert_no_variables(self)216 _assert_summary_tags(self, ["loss"])217 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)218 def testRegressionWithScalarWeights(self):219 head = head_lib.regression_head(weight_column_name="label_weight")220 with ops.Graph().as_default(), session.Session():221 weights = 2.222 labels = ((0.,), (1.,), (1.,))223 model_fn_ops = head.create_model_fn_ops(224 features={"label_weight": weights},225 labels=labels,226 mode=model_fn.ModeKeys.TRAIN,227 train_op_fn=head_lib.no_op_train_fn,228 logits=((1.,), (1.,), (3.,)))229 self._assert_output_alternatives(model_fn_ops)230 _assert_no_variables(self)231 _assert_summary_tags(self, ["loss"])232 _assert_metrics(self, (weights * 5.) / len(labels), {233 "loss": (weights * 5.) / (weights * len(labels))234 }, model_fn_ops)235 def testRegressionWith1DWeights(self):236 head = head_lib.regression_head(weight_column_name="label_weight")237 with ops.Graph().as_default(), session.Session():238 weights = (2., 5., 0.)239 labels = ((0.,), (1.,), (1.,))240 model_fn_ops = head.create_model_fn_ops(241 features={"label_weight": weights},242 labels=labels,243 mode=model_fn.ModeKeys.TRAIN,244 train_op_fn=head_lib.no_op_train_fn,245 logits=((1.,), (1.,), (3.,)))246 self._assert_output_alternatives(model_fn_ops)247 _assert_no_variables(self)248 _assert_summary_tags(self, ["loss"])249 _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},250 model_fn_ops)251 def testRegressionWith2DWeights(self):252 head = head_lib.regression_head(weight_column_name="label_weight")253 with ops.Graph().as_default(), session.Session():254 weights = ((2.,), (5.,), (0.,))255 labels = ((0.,), (1.,), (1.,))256 model_fn_ops = head.create_model_fn_ops(257 features={"label_weight": weights},258 labels=labels,259 mode=model_fn.ModeKeys.TRAIN,260 train_op_fn=head_lib.no_op_train_fn,261 logits=((1.,), (1.,), (3.,)))262 self._assert_output_alternatives(model_fn_ops)263 _assert_no_variables(self)264 _assert_summary_tags(self, ["loss"])265 _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},266 model_fn_ops)267 def testRegressionWithCenteredBias(self):268 head = head_lib.regression_head(enable_centered_bias=True)269 with ops.Graph().as_default(), session.Session():270 model_fn_ops = head.create_model_fn_ops(271 {},272 labels=((0.,), (1.,), (1.,)),273 mode=model_fn.ModeKeys.TRAIN,274 train_op_fn=head_lib.no_op_train_fn,275 logits=((1.,), (1.,), (3.,)))276 self._assert_output_alternatives(model_fn_ops)277 _assert_variables(278 self,279 expected_global=(280 "regression_head/centered_bias_weight:0",281 "regression_head/regression_head/centered_bias_weight/Adagrad:0",282 ),283 expected_trainable=("regression_head/centered_bias_weight:0",))284 variables.global_variables_initializer().run()285 _assert_summary_tags(self, [286 "loss",287 "regression_head/centered_bias/bias_0"288 ])289 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)290 def testRegressionErrorInSparseTensorLabels(self):291 head = head_lib.regression_head()292 with ops.Graph().as_default():293 labels = sparse_tensor.SparseTensorValue(294 indices=((0, 0), (1, 0), (2, 0)),295 values=(0., 1., 1.),296 dense_shape=(3, 1))297 with self.assertRaisesRegexp(ValueError,298 "SparseTensor is not supported"):299 head.create_model_fn_ops(300 {},301 labels=labels,302 mode=model_fn.ModeKeys.TRAIN,303 train_op_fn=head_lib.no_op_train_fn,304 logits=((1.,), (1.,), (3.,)))305class MultiLabelHeadTest(test.TestCase):306 def _assert_output_alternatives(self, model_fn_ops):307 self.assertEquals({308 None: constants.ProblemType.CLASSIFICATION309 }, {310 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)311 })312 def setUp(self):313 self._logits = ((1., 0., 0.),)314 self._labels = ((0, 0, 1),)315 def _expected_eval_metrics(self, expected_loss):316 return {317 "accuracy": 1. / 3,318 "loss": expected_loss,319 "auc": 1. / 4,320 "auc/class0": 1.,321 "auc/class1": 1.,322 "auc/class2": 0.,323 "auc_precision_recall": 0.166667,324 "auc_precision_recall/class0": 0,325 "auc_precision_recall/class1": 0.,326 "auc_precision_recall/class2": 1.,327 "labels/actual_label_mean/class0": self._labels[0][0],328 "labels/actual_label_mean/class1": self._labels[0][1],329 "labels/actual_label_mean/class2": self._labels[0][2],330 "labels/logits_mean/class0": self._logits[0][0],331 "labels/logits_mean/class1": self._logits[0][1],332 "labels/logits_mean/class2": self._logits[0][2],333 "labels/prediction_mean/class0": self._logits[0][0],334 "labels/prediction_mean/class1": self._logits[0][1],335 "labels/prediction_mean/class2": self._logits[0][2],336 "labels/probability_mean/class0": _sigmoid(self._logits[0][0]),337 "labels/probability_mean/class1": _sigmoid(self._logits[0][1]),338 "labels/probability_mean/class2": _sigmoid(self._logits[0][2]),339 }340 def testMultiLabelWithLogits(self):341 n_classes = 3342 head = head_lib.multi_label_head(343 n_classes=n_classes, metric_class_ids=range(n_classes))344 with ops.Graph().as_default(), session.Session():345 model_fn_ops = head.create_model_fn_ops(346 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,347 logits=self._logits)348 self._assert_output_alternatives(model_fn_ops)349 _assert_no_variables(self)350 _assert_summary_tags(self, ["loss"])351 expected_loss = .89985204352 _assert_metrics(self, expected_loss,353 self._expected_eval_metrics(expected_loss), model_fn_ops)354 def testMultiLabelTwoClasses(self):355 n_classes = 2356 labels = ((0, 1),)357 logits = ((1., 0.),)358 head = head_lib.multi_label_head(359 n_classes=n_classes, metric_class_ids=range(n_classes))360 with ops.Graph().as_default(), session.Session():361 model_fn_ops = head.create_model_fn_ops(362 {}, model_fn.ModeKeys.TRAIN, labels=labels,363 train_op_fn=head_lib.no_op_train_fn, logits=logits)364 self._assert_output_alternatives(model_fn_ops)365 _assert_no_variables(self)366 _assert_summary_tags(self, ["loss"])367 expected_loss = 1.00320443368 _assert_metrics(self, expected_loss, {369 "accuracy": 0.,370 "auc": 0.,371 "loss": expected_loss,372 "auc/class0": 1.,373 "auc/class1": 0.,374 "labels/actual_label_mean/class0": labels[0][0],375 "labels/actual_label_mean/class1": labels[0][1],376 "labels/logits_mean/class0": logits[0][0],377 "labels/logits_mean/class1": logits[0][1],378 "labels/prediction_mean/class0": logits[0][0],379 "labels/prediction_mean/class1": logits[0][1],380 "labels/probability_mean/class0": _sigmoid(logits[0][0]),381 "labels/probability_mean/class1": _sigmoid(logits[0][1]),382 }, model_fn_ops)383 def testMultiLabelWithInvalidLogits(self):384 head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)385 with ops.Graph().as_default(), session.Session():386 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):387 head.create_model_fn_ops(388 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,389 logits=self._logits)390 def testMultiLabelWithLogitsInput(self):391 n_classes = 3392 head = head_lib.multi_label_head(393 n_classes=n_classes, metric_class_ids=range(n_classes))394 with ops.Graph().as_default(), session.Session():395 model_fn_ops = head.create_model_fn_ops(396 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,397 logits_input=((0., 0.),))398 self._assert_output_alternatives(model_fn_ops)399 w = ("multi_label_head/logits/weights:0",400 "multi_label_head/logits/biases:0")401 _assert_variables(402 self, expected_global=w, expected_model=w, expected_trainable=w)403 variables.global_variables_initializer().run()404 _assert_summary_tags(self, ["loss"])405 expected_loss = .69314718406 _assert_metrics(self, expected_loss, {407 "accuracy": 2. / 3,408 "auc": 2. / 4,409 "loss": expected_loss,410 "auc/class0": 1.,411 "auc/class1": 1.,412 "auc/class2": 0.,413 "labels/actual_label_mean/class0": self._labels[0][0],414 "labels/actual_label_mean/class1": self._labels[0][1],415 "labels/actual_label_mean/class2": self._labels[0][2],416 "labels/logits_mean/class0": 0.,417 "labels/logits_mean/class1": 0.,418 "labels/logits_mean/class2": 0.,419 "labels/prediction_mean/class0": 0.,420 "labels/prediction_mean/class1": 0.,421 "labels/prediction_mean/class2": 0.,422 "labels/probability_mean/class0": .5,423 "labels/probability_mean/class1": .5,424 "labels/probability_mean/class2": .5,425 }, model_fn_ops)426 def testMultiLabelWithLogitsAndLogitsInput(self):427 n_classes = 3428 head = head_lib.multi_label_head(429 n_classes=n_classes, metric_class_ids=range(n_classes))430 with ops.Graph().as_default(), session.Session():431 with self.assertRaisesRegexp(432 ValueError, "Both logits and logits_input supplied"):433 head.create_model_fn_ops(434 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,435 logits_input=((0., 0.),), logits=self._logits)436 def testMultiLabelEval(self):437 n_classes = 3438 head = head_lib.multi_label_head(439 n_classes=n_classes, metric_class_ids=range(n_classes))440 with ops.Graph().as_default(), session.Session():441 model_fn_ops = head.create_model_fn_ops(442 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,443 logits=self._logits)444 self._assert_output_alternatives(model_fn_ops)445 self.assertIsNone(model_fn_ops.train_op)446 _assert_no_variables(self)447 _assert_summary_tags(self, ["loss"])448 expected_loss = .89985204449 _assert_metrics(self, expected_loss,450 self._expected_eval_metrics(expected_loss), model_fn_ops)451 def testMultiClassEvalWithLargeLogits(self):452 n_classes = 3453 head = head_lib.multi_label_head(454 n_classes=n_classes, metric_class_ids=range(n_classes))455 logits = ((2., 0., -1),)456 with ops.Graph().as_default(), session.Session():457 # logloss: z:label, x:logit458 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))459 model_fn_ops = head.create_model_fn_ops(460 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,461 logits=logits)462 self._assert_output_alternatives(model_fn_ops)463 self.assertIsNone(model_fn_ops.train_op)464 _assert_no_variables(self)465 _assert_summary_tags(self, ["loss"])466 expected_loss = 1.377779467 expected_eval_metrics = {468 "accuracy": 1. / 3,469 "auc": 9.99999e-07,470 "loss": expected_loss,471 "auc/class0": 1.,472 "auc/class1": 1.,473 "auc/class2": 0.,474 "labels/actual_label_mean/class0": 0. / 1,475 "labels/actual_label_mean/class1": 0. / 1,476 "labels/actual_label_mean/class2": 1. / 1,477 "labels/logits_mean/class0": logits[0][0],478 "labels/logits_mean/class1": logits[0][1],479 "labels/logits_mean/class2": logits[0][2],480 "labels/prediction_mean/class0": 1,481 "labels/prediction_mean/class1": 0,482 "labels/prediction_mean/class2": 0,483 "labels/probability_mean/class0": _sigmoid(logits[0][0]),484 "labels/probability_mean/class1": _sigmoid(logits[0][1]),485 "labels/probability_mean/class2": _sigmoid(logits[0][2]),486 }487 _assert_metrics(self, expected_loss,488 expected_eval_metrics, model_fn_ops)489 def testMultiLabelInfer(self):490 n_classes = 3491 head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")492 with ops.Graph().as_default(), session.Session():493 model_fn_ops = head.create_model_fn_ops(494 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,495 logits=((1., 0., 0.), (0., 0., 1)))496 self.assertIsNone(model_fn_ops.train_op)497 _assert_no_variables(self)498 with session.Session():499 self.assertListEqual(500 [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])501 self.assertItemsEqual(502 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))503 self.assertEqual(504 constants.ProblemType.CLASSIFICATION,505 model_fn_ops.output_alternatives["head_name"][0])506 predictions_for_serving = (507 model_fn_ops.output_alternatives["head_name"][1])508 self.assertIn("classes", six.iterkeys(predictions_for_serving))509 self.assertAllEqual(510 [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],511 predictions_for_serving["classes"].eval())512 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))513 self.assertAllClose(514 [[0.731059, 0.5, 0.5],515 [0.5, 0.5, 0.731059,]],516 predictions_for_serving["probabilities"].eval())517 def testMultiLabelWithLabelName(self):518 n_classes = 3519 label_name = "my_label"520 head = head_lib.multi_label_head(521 n_classes=n_classes,522 label_name=label_name,523 metric_class_ids=range(n_classes))524 with ops.Graph().as_default(), session.Session():525 model_fn_ops = head.create_model_fn_ops(526 {}, model_fn.ModeKeys.TRAIN, {label_name: self._labels},527 head_lib.no_op_train_fn, logits=self._logits)528 self._assert_output_alternatives(model_fn_ops)529 _assert_no_variables(self)530 _assert_summary_tags(self, ["loss"])531 expected_loss = .89985204532 _assert_metrics(self, expected_loss,533 self._expected_eval_metrics(expected_loss), model_fn_ops)534 def testMultiLabelWithScalarWeight(self):535 n_classes = 3536 head = head_lib.multi_label_head(537 n_classes=n_classes,538 weight_column_name="label_weight",539 metric_class_ids=range(n_classes))540 with ops.Graph().as_default(), session.Session():541 model_fn_ops = head.create_model_fn_ops(542 features={"label_weight": .1},543 labels=self._labels,544 mode=model_fn.ModeKeys.TRAIN,545 train_op_fn=head_lib.no_op_train_fn,546 logits=self._logits)547 self._assert_output_alternatives(model_fn_ops)548 _assert_no_variables(self)549 _assert_summary_tags(self, ["loss"])550 _assert_metrics(self, .089985214,551 self._expected_eval_metrics(.89985214), model_fn_ops)552 def testMultiLabelWith1DWeight(self):553 n_classes = 3554 head = head_lib.multi_label_head(555 n_classes=n_classes,556 weight_column_name="label_weight",557 metric_class_ids=range(n_classes))558 with ops.Graph().as_default(), session.Session():559 with self.assertRaisesRegexp(560 ValueError, "weights can not be broadcast to values"):561 head.create_model_fn_ops(562 features={"label_weight": (.1, .1, .1)},563 labels=self._labels,564 mode=model_fn.ModeKeys.TRAIN,565 train_op_fn=head_lib.no_op_train_fn,566 logits=self._logits)567 def testMultiLabelWith2DWeight(self):568 n_classes = 3569 head = head_lib.multi_label_head(570 n_classes=n_classes,571 weight_column_name="label_weight",572 metric_class_ids=range(n_classes))573 with ops.Graph().as_default(), session.Session():574 model_fn_ops = head.create_model_fn_ops(575 features={"label_weight": ((.1, .1, .1),)},576 labels=self._labels,577 mode=model_fn.ModeKeys.TRAIN,578 train_op_fn=head_lib.no_op_train_fn,579 logits=self._logits)580 self._assert_output_alternatives(model_fn_ops)581 _assert_no_variables(self)582 _assert_summary_tags(self, ["loss"])583 _assert_metrics(self, .089985214,584 self._expected_eval_metrics(.89985214), model_fn_ops)585 def testMultiLabelWithCustomLoss(self):586 n_classes = 3587 head = head_lib.multi_label_head(588 n_classes=n_classes,589 weight_column_name="label_weight",590 metric_class_ids=range(n_classes),591 loss_fn=_sigmoid_cross_entropy)592 with ops.Graph().as_default(), session.Session():593 model_fn_ops = head.create_model_fn_ops(594 features={"label_weight": .1},595 labels=self._labels,596 mode=model_fn.ModeKeys.TRAIN,597 train_op_fn=head_lib.no_op_train_fn,598 logits=self._logits)599 self._assert_output_alternatives(model_fn_ops)600 _assert_no_variables(self)601 _assert_summary_tags(self, ["loss"])602 expected_loss = .089985214603 _assert_metrics(self, expected_loss,604 self._expected_eval_metrics(expected_loss), model_fn_ops)605 def testMultiLabelWithCenteredBias(self):606 n_classes = 3607 head = head_lib.multi_label_head(608 n_classes=n_classes,609 enable_centered_bias=True,610 metric_class_ids=range(n_classes))611 with ops.Graph().as_default(), session.Session():612 model_fn_ops = head.create_model_fn_ops(613 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,614 logits=self._logits)615 self._assert_output_alternatives(model_fn_ops)616 _assert_variables(617 self,618 expected_global=(619 "multi_label_head/centered_bias_weight:0",620 ("multi_label_head/multi_label_head/centered_bias_weight/"621 "Adagrad:0"),),622 expected_trainable=("multi_label_head/centered_bias_weight:0",))623 variables.global_variables_initializer().run()624 _assert_summary_tags(self, (625 "loss",626 "multi_label_head/centered_bias/bias_0",627 "multi_label_head/centered_bias/bias_1",628 "multi_label_head/centered_bias/bias_2"629 ))630 expected_loss = .89985204631 _assert_metrics(self, expected_loss,632 self._expected_eval_metrics(expected_loss), model_fn_ops)633 def testMultiLabelSparseTensorLabels(self):634 n_classes = 3635 head = head_lib.multi_label_head(636 n_classes=n_classes, metric_class_ids=range(n_classes))637 with ops.Graph().as_default(), session.Session():638 labels = sparse_tensor.SparseTensorValue(639 indices=((0, 0),),640 values=(2,),641 dense_shape=(1, 1))642 model_fn_ops = head.create_model_fn_ops(643 features={},644 mode=model_fn.ModeKeys.TRAIN,645 labels=labels,646 train_op_fn=head_lib.no_op_train_fn,647 logits=self._logits)648 _assert_no_variables(self)649 _assert_summary_tags(self, ["loss"])650 expected_loss = .89985204651 _assert_metrics(self, expected_loss,652 self._expected_eval_metrics(expected_loss), model_fn_ops)653 def testMultiLabelSparseTensorLabelsTooFewClasses(self):654 n_classes = 3655 head = head_lib.multi_label_head(656 n_classes=n_classes, metric_class_ids=range(n_classes))657 # Set _logits_dimension (n_classes) to a lower value; if it's set to 1658 # upfront, the class throws an error during initialization.659 head._logits_dimension = 1660 with ops.Graph().as_default(), session.Session():661 labels = sparse_tensor.SparseTensorValue(662 indices=((0, 0),),663 values=(2,),664 dense_shape=(1, 1))665 with self.assertRaisesRegexp(ValueError,666 "Must set num_classes >= 2 when passing"):667 head.create_model_fn_ops(668 features={},669 labels=labels,670 mode=model_fn.ModeKeys.TRAIN,671 train_op_fn=head_lib.no_op_train_fn,672 logits=[0.])673class BinaryClassificationHeadTest(test.TestCase):674 def _assert_output_alternatives(self, model_fn_ops):675 self.assertEquals({676 None: constants.ProblemType.LOGISTIC_REGRESSION677 }, {678 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)679 })680 def setUp(self):681 self._logits = ((1.,), (1.,))682 self._labels = ((1.,), (0.,))683 def _expected_eval_metrics(self, expected_loss):684 label_mean = np.mean(self._labels)685 return {686 "accuracy": 1. / 2,687 "accuracy/baseline_label_mean": label_mean,688 "accuracy/threshold_0.500000_mean": 1. / 2,689 "auc": 1. / 2,690 "auc_precision_recall": 0.749999,691 "labels/actual_label_mean": label_mean,692 "labels/prediction_mean": .731059, # softmax693 "loss": expected_loss,694 "precision/positive_threshold_0.500000_mean": 1. / 2,695 "recall/positive_threshold_0.500000_mean": 1. / 1,696 }697 def testBinaryClassificationWithLogits(self):698 n_classes = 2699 head = head_lib.multi_class_head(n_classes=n_classes)700 with ops.Graph().as_default(), session.Session():701 # logloss: z:label, x:logit702 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))703 model_fn_ops = head.create_model_fn_ops(704 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,705 logits=self._logits)706 self._assert_output_alternatives(model_fn_ops)707 _assert_no_variables(self)708 _assert_summary_tags(self, ["loss"])709 expected_loss = .81326175710 _assert_metrics(self, expected_loss,711 self._expected_eval_metrics(expected_loss), model_fn_ops)712 def testBinaryClassificationWithInvalidLogits(self):713 head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)714 with ops.Graph().as_default(), session.Session():715 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):716 head.create_model_fn_ops(717 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,718 logits=self._logits)719 def testBinaryClassificationWithLogitsInput(self):720 n_classes = 2721 head = head_lib.multi_class_head(n_classes=n_classes)722 with ops.Graph().as_default(), session.Session():723 # logloss: z:label, x:logit724 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))725 model_fn_ops = head.create_model_fn_ops(726 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,727 logits_input=((0., 0.), (0., 0.)))728 self._assert_output_alternatives(model_fn_ops)729 w = ("binary_logistic_head/logits/weights:0",730 "binary_logistic_head/logits/biases:0")731 _assert_variables(732 self, expected_global=w, expected_model=w, expected_trainable=w)733 variables.global_variables_initializer().run()734 _assert_summary_tags(self, ["loss"])735 expected_loss = .69314718736 label_mean = np.mean(self._labels)737 _assert_metrics(self, expected_loss, {738 "accuracy": 1. / 2,739 "accuracy/baseline_label_mean": label_mean,740 "accuracy/threshold_0.500000_mean": 1. / 2,741 "auc": 1. / 2,742 "labels/actual_label_mean": label_mean,743 "labels/prediction_mean": .5, # softmax744 "loss": expected_loss,745 "precision/positive_threshold_0.500000_mean": 0. / 2,746 "recall/positive_threshold_0.500000_mean": 0. / 1,747 }, model_fn_ops)748 def testBinaryClassificationWithLogitsAndLogitsInput(self):749 head = head_lib.multi_class_head(n_classes=2)750 with ops.Graph().as_default(), session.Session():751 with self.assertRaisesRegexp(752 ValueError, "Both logits and logits_input supplied"):753 head.create_model_fn_ops(754 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,755 logits_input=((0., 0.), (0., 0.)), logits=self._logits)756 def testBinaryClassificationEval(self):757 n_classes = 2758 head = head_lib.multi_class_head(n_classes=n_classes)759 with ops.Graph().as_default(), session.Session():760 # logloss: z:label, x:logit761 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))762 model_fn_ops = head.create_model_fn_ops(763 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,764 logits=self._logits)765 self._assert_output_alternatives(model_fn_ops)766 self.assertIsNone(model_fn_ops.train_op)767 _assert_no_variables(self)768 _assert_summary_tags(self, ["loss"])769 expected_loss = .81326175770 _assert_metrics(self, expected_loss,771 self._expected_eval_metrics(expected_loss), model_fn_ops)772 def testBinaryClassificationInfer(self):773 n_classes = 2774 head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")775 with ops.Graph().as_default(), session.Session():776 # logloss: z:label, x:logit777 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))778 model_fn_ops = head.create_model_fn_ops(779 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,780 logits=self._logits)781 self.assertIsNone(model_fn_ops.train_op)782 _assert_no_variables(self)783 with session.Session():784 self.assertListEqual(785 [1, 1], list(model_fn_ops.predictions["classes"].eval()))786 self.assertItemsEqual(787 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))788 self.assertEqual(789 constants.ProblemType.LOGISTIC_REGRESSION,790 model_fn_ops.output_alternatives["head_name"][0])791 predictions_for_serving = (792 model_fn_ops.output_alternatives["head_name"][1])793 self.assertIn("classes", six.iterkeys(predictions_for_serving))794 predicted_classes = predictions_for_serving["classes"].eval().tolist()795 self.assertListEqual(796 [b"0", b"1"], predicted_classes[0])797 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))798 def testBinaryClassificationInferMode_withWeightColumn(self):799 n_classes = 2800 head = head_lib.multi_class_head(n_classes=n_classes,801 weight_column_name="label_weight")802 with ops.Graph().as_default(), session.Session():803 # logloss: z:label, x:logit804 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))805 model_fn_ops = head.create_model_fn_ops(806 # This is what is being tested, features should not have weight for807 # inference.808 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,809 logits=self._logits)810 self._assert_output_alternatives(model_fn_ops)811 self.assertIsNone(model_fn_ops.train_op)812 _assert_no_variables(self)813 def testErrorInSparseTensorLabels(self):814 n_classes = 2815 head = head_lib.multi_class_head(n_classes=n_classes)816 with ops.Graph().as_default():817 labels = sparse_tensor.SparseTensorValue(818 indices=((0, 0), (1, 0), (2, 0)),819 values=(0, 1, 1),820 dense_shape=(3, 1))821 with self.assertRaisesRegexp(ValueError,822 "SparseTensor is not supported"):823 head.create_model_fn_ops(824 {},825 model_fn.ModeKeys.TRAIN,826 labels,827 head_lib.no_op_train_fn,828 logits=((1.,), (1.,), (3.,)))829 def testBinaryClassificationWithLabelName(self):830 label_name = "my_label"831 head = head_lib.multi_class_head(n_classes=2, label_name=label_name)832 with ops.Graph().as_default(), session.Session():833 # logloss: z:label, x:logit834 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))835 model_fn_ops = head.create_model_fn_ops(836 {},837 labels={label_name: self._labels},838 mode=model_fn.ModeKeys.TRAIN,839 train_op_fn=head_lib.no_op_train_fn,840 logits=self._logits)841 self._assert_output_alternatives(model_fn_ops)842 _assert_no_variables(self)843 _assert_summary_tags(self, ["loss"])844 expected_loss = .81326175845 _assert_metrics(self, expected_loss,846 self._expected_eval_metrics(expected_loss), model_fn_ops)847 def testBinaryClassificationWith1DWeights(self):848 n_classes = 2849 head = head_lib.multi_class_head(850 n_classes=n_classes, weight_column_name="label_weight")851 with ops.Graph().as_default(), session.Session():852 weights = (1., 0.)853 # logloss: z:label, x:logit854 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))855 model_fn_ops = head.create_model_fn_ops(856 features={"label_weight": weights},857 labels=self._labels,858 mode=model_fn.ModeKeys.TRAIN,859 train_op_fn=head_lib.no_op_train_fn,860 logits=self._logits)861 self._assert_output_alternatives(model_fn_ops)862 _assert_no_variables(self)863 _assert_summary_tags(self, ["loss"])864 expected_total_loss = .31326166865 _assert_metrics(866 self,867 expected_total_loss / len(weights),868 {869 "accuracy": 1. / 1,870 "accuracy/baseline_label_mean": 1. / 1,871 "accuracy/threshold_0.500000_mean": 1. / 1,872 "auc": 0. / 1,873 "labels/actual_label_mean": 1. / 1,874 "labels/prediction_mean": .731059, # softmax875 # eval loss is weighted loss divided by sum of weights.876 "loss": expected_total_loss,877 "precision/positive_threshold_0.500000_mean": 1. / 1,878 "recall/positive_threshold_0.500000_mean": 1. / 1,879 },880 model_fn_ops)881 def testBinaryClassificationWith2DWeights(self):882 n_classes = 2883 head = head_lib.multi_class_head(884 n_classes=n_classes, weight_column_name="label_weight")885 with ops.Graph().as_default(), session.Session():886 weights = ((1.,), (0.,))887 # logloss: z:label, x:logit888 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))889 model_fn_ops = head.create_model_fn_ops(890 features={"label_weight": weights},891 labels=self._labels,892 mode=model_fn.ModeKeys.TRAIN,893 train_op_fn=head_lib.no_op_train_fn,894 logits=self._logits)895 self._assert_output_alternatives(model_fn_ops)896 _assert_no_variables(self)897 _assert_summary_tags(self, ["loss"])898 expected_total_loss = .31326166899 _assert_metrics(900 self,901 expected_total_loss / len(weights),902 {903 "accuracy": 1. / 1,904 "accuracy/baseline_label_mean": 1. / 1,905 "accuracy/threshold_0.500000_mean": 1. / 1,906 "auc": 0. / 1,907 "labels/actual_label_mean": 1. / 1,908 "labels/prediction_mean": .731059, # softmax909 # eval loss is weighted loss divided by sum of weights.910 "loss": expected_total_loss,911 "precision/positive_threshold_0.500000_mean": 1. / 1,912 "recall/positive_threshold_0.500000_mean": 1. / 1,913 },914 model_fn_ops)915 def testBinaryClassificationWithCustomLoss(self):916 head = head_lib.multi_class_head(917 n_classes=2, weight_column_name="label_weight",918 loss_fn=_sigmoid_cross_entropy)919 with ops.Graph().as_default(), session.Session():920 weights = ((.2,), (0.,))921 model_fn_ops = head.create_model_fn_ops(922 features={"label_weight": weights},923 labels=self._labels,924 mode=model_fn.ModeKeys.TRAIN,925 train_op_fn=head_lib.no_op_train_fn,926 logits=self._logits)927 self._assert_output_alternatives(model_fn_ops)928 _assert_no_variables(self)929 _assert_summary_tags(self, ["loss"])930 # logloss: z:label, x:logit931 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))932 # expected_loss is (total_weighted_loss)/1 since there is 1 nonzero933 # weight.934 expected_loss = 0.062652342935 _assert_metrics(936 self,937 expected_loss,938 {939 "accuracy": 1. / 1,940 "accuracy/baseline_label_mean": 1. / 1,941 "accuracy/threshold_0.500000_mean": 1. / 1,942 "auc": 0. / 1,943 "labels/actual_label_mean": 1. / 1,944 "labels/prediction_mean": .731059, # softmax945 "loss": expected_loss,946 "precision/positive_threshold_0.500000_mean": 1. / 1,947 "recall/positive_threshold_0.500000_mean": 1. / 1,948 },949 model_fn_ops)950 def testBinaryClassificationWithCenteredBias(self):951 head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)952 with ops.Graph().as_default(), session.Session():953 # logloss: z:label, x:logit954 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))955 model_fn_ops = head.create_model_fn_ops(956 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,957 logits=self._logits)958 self._assert_output_alternatives(model_fn_ops)959 _assert_variables(960 self,961 expected_global=(962 "binary_logistic_head/centered_bias_weight:0",963 ("binary_logistic_head/binary_logistic_head/centered_bias_weight/"964 "Adagrad:0"),),965 expected_trainable=("binary_logistic_head/centered_bias_weight:0",))966 variables.global_variables_initializer().run()967 _assert_summary_tags(self, [968 "loss",969 "binary_logistic_head/centered_bias/bias_0"970 ])971 expected_loss = .81326175972 _assert_metrics(self, expected_loss,973 self._expected_eval_metrics(expected_loss), model_fn_ops)974class MultiClassHeadTest(test.TestCase):975 def _assert_output_alternatives(self, model_fn_ops):976 self.assertEquals({977 None: constants.ProblemType.CLASSIFICATION978 }, {979 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)980 })981 def setUp(self):982 self._logits = ((1., 0., 0.),)983 self._labels = ((2,),)984 def _expected_eval_metrics(self, expected_loss):985 return {986 "accuracy": 0.,987 "loss": expected_loss,988 "labels/actual_label_mean/class0": 0. / 1,989 "labels/actual_label_mean/class1": 0. / 1,990 "labels/actual_label_mean/class2": 1. / 1,991 "labels/logits_mean/class0": self._logits[0][0],992 "labels/logits_mean/class1": self._logits[0][1],993 "labels/logits_mean/class2": self._logits[0][2],994 "labels/prediction_mean/class0": self._logits[0][0],995 "labels/prediction_mean/class1": self._logits[0][1],996 "labels/prediction_mean/class2": self._logits[0][2],997 "labels/probability_mean/class0": 0.576117, # softmax998 "labels/probability_mean/class1": 0.211942, # softmax999 "labels/probability_mean/class2": 0.211942, # softmax1000 }1001 def testMultiClassWithLogits(self):1002 n_classes = 31003 head = head_lib.multi_class_head(1004 n_classes=n_classes, metric_class_ids=range(n_classes))1005 with ops.Graph().as_default(), session.Session():1006 # logloss: z:label, x:logit1007 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1008 model_fn_ops = head.create_model_fn_ops(1009 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1010 logits=self._logits)1011 self._assert_output_alternatives(model_fn_ops)1012 _assert_no_variables(self)1013 _assert_summary_tags(self, ["loss"])1014 expected_loss = 1.55144471015 _assert_metrics(self, expected_loss,1016 self._expected_eval_metrics(expected_loss), model_fn_ops)1017 def testMultiClassWithInvalidLogits(self):1018 head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)1019 with ops.Graph().as_default(), session.Session():1020 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1021 head.create_model_fn_ops(1022 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1023 logits=self._logits)1024 def testMultiClassWithNoneTrainOpFnInTrain(self):1025 head = head_lib.multi_class_head(n_classes=3)1026 with ops.Graph().as_default(), session.Session():1027 with self.assertRaisesRegexp(1028 ValueError, "train_op_fn can not be None in TRAIN mode"):1029 head.create_model_fn_ops(1030 {}, model_fn.ModeKeys.TRAIN, self._labels,1031 train_op_fn=None,1032 logits=self._logits)1033 def testMultiClassWithLogitsInput(self):1034 n_classes = 31035 head = head_lib.multi_class_head(1036 n_classes=n_classes, metric_class_ids=range(n_classes))1037 with ops.Graph().as_default(), session.Session():1038 # logloss: z:label, x:logit1039 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1040 model_fn_ops = head.create_model_fn_ops(1041 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1042 logits_input=((0., 0.),))1043 self._assert_output_alternatives(model_fn_ops)1044 w = ("multi_class_head/logits/weights:0",1045 "multi_class_head/logits/biases:0")1046 _assert_variables(1047 self, expected_global=w, expected_model=w, expected_trainable=w)1048 variables.global_variables_initializer().run()1049 _assert_summary_tags(self, ["loss"])1050 expected_loss = 1.09861231051 _assert_metrics(self, expected_loss, {1052 "accuracy": 0.,1053 "loss": expected_loss,1054 "labels/actual_label_mean/class0": 0. / 1,1055 "labels/actual_label_mean/class1": 0. / 1,1056 "labels/actual_label_mean/class2": 1. / 1,1057 "labels/logits_mean/class0": 0.,1058 "labels/logits_mean/class1": 0.,1059 "labels/logits_mean/class2": 0.,1060 "labels/prediction_mean/class0": 1.,1061 "labels/prediction_mean/class1": 0.,1062 "labels/prediction_mean/class2": 0.,1063 "labels/probability_mean/class0": 0.333333, # softmax1064 "labels/probability_mean/class1": 0.333333, # softmax1065 "labels/probability_mean/class2": 0.333333, # softmax1066 }, model_fn_ops)1067 def testMultiClassWithLogitsAndLogitsInput(self):1068 n_classes = 31069 head = head_lib.multi_class_head(1070 n_classes=n_classes, metric_class_ids=range(n_classes))1071 with ops.Graph().as_default(), session.Session():1072 with self.assertRaisesRegexp(1073 ValueError, "Both logits and logits_input supplied"):1074 head.create_model_fn_ops(1075 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1076 logits_input=((0., 0.),), logits=self._logits)1077 def testMultiClassEnableCenteredBias(self):1078 n_classes = 31079 head = head_lib.multi_class_head(1080 n_classes=n_classes, enable_centered_bias=True)1081 with ops.Graph().as_default(), session.Session():1082 # logloss: z:label, x:logit1083 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1084 model_fn_ops = head.create_model_fn_ops(1085 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1086 logits=self._logits)1087 self._assert_output_alternatives(model_fn_ops)1088 _assert_variables(1089 self,1090 expected_global=(1091 "multi_class_head/centered_bias_weight:0",1092 ("multi_class_head/multi_class_head/centered_bias_weight/"1093 "Adagrad:0"),1094 ),1095 expected_trainable=("multi_class_head/centered_bias_weight:0",))1096 variables.global_variables_initializer().run()1097 _assert_summary_tags(self,1098 ["loss",1099 "multi_class_head/centered_bias/bias_0",1100 "multi_class_head/centered_bias/bias_1",1101 "multi_class_head/centered_bias/bias_2"])1102 def testMultiClassEval(self):1103 n_classes = 31104 head = head_lib.multi_class_head(1105 n_classes=n_classes, metric_class_ids=range(n_classes))1106 with ops.Graph().as_default(), session.Session():1107 # logloss: z:label, x:logit1108 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1109 model_fn_ops = head.create_model_fn_ops(1110 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1111 logits=self._logits)1112 self._assert_output_alternatives(model_fn_ops)1113 self.assertIsNone(model_fn_ops.train_op)1114 _assert_no_variables(self)1115 _assert_summary_tags(self, ["loss"])1116 expected_loss = 1.55144471117 _assert_metrics(self, expected_loss,1118 self._expected_eval_metrics(expected_loss), model_fn_ops)1119 def testMultiClassEvalModeWithLargeLogits(self):1120 n_classes = 31121 head = head_lib.multi_class_head(1122 n_classes=n_classes, metric_class_ids=range(n_classes))1123 logits = ((2., 0., -1),)1124 with ops.Graph().as_default(), session.Session():1125 # logloss: z:label, x:logit1126 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1127 model_fn_ops = head.create_model_fn_ops(1128 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1129 logits=logits)1130 self._assert_output_alternatives(model_fn_ops)1131 self.assertIsNone(model_fn_ops.train_op)1132 _assert_no_variables(self)1133 _assert_summary_tags(self, ["loss"])1134 expected_loss = 3.16984611135 expected_eval_metrics = {1136 "accuracy": 0.,1137 "loss": expected_loss,1138 "labels/actual_label_mean/class0": 0. / 1,1139 "labels/actual_label_mean/class1": 0. / 1,1140 "labels/actual_label_mean/class2": 1. / 1,1141 "labels/logits_mean/class0": logits[0][0],1142 "labels/logits_mean/class1": logits[0][1],1143 "labels/logits_mean/class2": logits[0][2],1144 "labels/prediction_mean/class0": 1,1145 "labels/prediction_mean/class1": 0,1146 "labels/prediction_mean/class2": 0,1147 "labels/probability_mean/class0": 0.843795, # softmax1148 "labels/probability_mean/class1": 0.114195, # softmax1149 "labels/probability_mean/class2": 0.0420101, # softmax1150 }1151 _assert_metrics(self, expected_loss,1152 expected_eval_metrics, model_fn_ops)1153 def testMultiClassWithScalarWeight(self):1154 n_classes = 31155 head = head_lib.multi_class_head(1156 n_classes=n_classes,1157 weight_column_name="label_weight",1158 metric_class_ids=range(n_classes))1159 with ops.Graph().as_default(), session.Session():1160 weight = .11161 # logloss: z:label, x:logit1162 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1163 model_fn_ops = head.create_model_fn_ops(1164 features={"label_weight": weight},1165 labels=self._labels,1166 mode=model_fn.ModeKeys.TRAIN,1167 train_op_fn=head_lib.no_op_train_fn,1168 logits=self._logits)1169 self._assert_output_alternatives(model_fn_ops)1170 _assert_no_variables(self)1171 _assert_summary_tags(self, ["loss"])1172 expected_loss = 1.55144471173 _assert_metrics(self, expected_loss * weight,1174 self._expected_eval_metrics(expected_loss), model_fn_ops)1175 def testMultiClassWith1DWeight(self):1176 n_classes = 31177 head = head_lib.multi_class_head(1178 n_classes=n_classes,1179 weight_column_name="label_weight",1180 metric_class_ids=range(n_classes))1181 with ops.Graph().as_default(), session.Session():1182 weight = .11183 weights = (weight,)1184 # logloss: z:label, x:logit1185 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1186 model_fn_ops = head.create_model_fn_ops(1187 features={"label_weight": weights},1188 labels=self._labels,1189 mode=model_fn.ModeKeys.TRAIN,1190 train_op_fn=head_lib.no_op_train_fn,1191 logits=self._logits)1192 self._assert_output_alternatives(model_fn_ops)1193 _assert_no_variables(self)1194 _assert_summary_tags(self, ["loss"])1195 expected_loss = 1.55144471196 _assert_metrics(self, expected_loss * weight,1197 self._expected_eval_metrics(expected_loss), model_fn_ops)1198 def testMultiClassWith2DWeight(self):1199 n_classes = 31200 head = head_lib.multi_class_head(1201 n_classes=n_classes,1202 weight_column_name="label_weight",1203 metric_class_ids=range(n_classes))1204 with ops.Graph().as_default(), session.Session():1205 weight = .11206 weights = ((weight,),)1207 # logloss: z:label, x:logit1208 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1209 model_fn_ops = head.create_model_fn_ops(1210 features={"label_weight": weights},1211 labels=self._labels,1212 mode=model_fn.ModeKeys.TRAIN,1213 train_op_fn=head_lib.no_op_train_fn,1214 logits=self._logits)1215 self._assert_output_alternatives(model_fn_ops)1216 _assert_no_variables(self)1217 _assert_summary_tags(self, ["loss"])1218 expected_loss = 1.55144471219 _assert_metrics(self, expected_loss * weight,1220 self._expected_eval_metrics(expected_loss), model_fn_ops)1221 def testMultiClassWithCustomLoss(self):1222 n_classes = 31223 head = head_lib.multi_class_head(1224 n_classes=n_classes,1225 weight_column_name="label_weight",1226 metric_class_ids=range(n_classes),1227 loss_fn=losses_lib.sparse_softmax_cross_entropy)1228 with ops.Graph().as_default(), session.Session():1229 weight = .11230 # logloss: z:label, x:logit1231 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1232 model_fn_ops = head.create_model_fn_ops(1233 features={"label_weight": weight},1234 labels=self._labels,1235 mode=model_fn.ModeKeys.TRAIN,1236 train_op_fn=head_lib.no_op_train_fn,1237 logits=self._logits)1238 self._assert_output_alternatives(model_fn_ops)1239 _assert_no_variables(self)1240 _assert_summary_tags(self, ["loss"])1241 expected_loss = 1.5514447 * weight1242 _assert_metrics(self, expected_loss,1243 self._expected_eval_metrics(expected_loss), model_fn_ops)1244 def testMultiClassInfer(self):1245 n_classes = 31246 head = head_lib._multi_class_head(1247 n_classes=n_classes,1248 head_name="head_name")1249 with ops.Graph().as_default():1250 model_fn_ops = head.create_model_fn_ops(1251 features={},1252 mode=model_fn.ModeKeys.INFER,1253 train_op_fn=head_lib.no_op_train_fn,1254 logits=((1., 0., 0.), (0., 0., 1.),))1255 with session.Session():1256 lookup_ops.tables_initializer().run()1257 self.assertAllEqual(1258 [0, 2],1259 model_fn_ops.predictions["classes"].eval())1260 self.assertItemsEqual(1261 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1262 self.assertEqual(1263 constants.ProblemType.CLASSIFICATION,1264 model_fn_ops.output_alternatives["head_name"][0])1265 predictions_for_serving = (1266 model_fn_ops.output_alternatives["head_name"][1])1267 self.assertIn("classes", six.iterkeys(predictions_for_serving))1268 self.assertAllEqual(1269 [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],1270 predictions_for_serving["classes"].eval())1271 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1272 self.assertAllClose(1273 [[0.576117, 0.2119416, 0.2119416],1274 [0.2119416, 0.2119416, 0.576117]],1275 predictions_for_serving["probabilities"].eval())1276 def testInvalidNClasses(self):1277 for n_classes in (None, -1, 0, 1):1278 with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):1279 head_lib.multi_class_head(n_classes=n_classes)1280 def testMultiClassWithLabelKeysInvalidShape(self):1281 with self.assertRaisesRegexp(1282 ValueError, "Length of label_keys must equal n_classes"):1283 head_lib._multi_class_head(1284 n_classes=3, label_keys=("key0", "key1"))1285 def testMultiClassWithLabelKeysTwoClasses(self):1286 with self.assertRaisesRegexp(1287 ValueError, "label_keys is not supported for n_classes=2"):1288 head_lib._multi_class_head(1289 n_classes=2, label_keys=("key0", "key1"))1290 def testMultiClassWithLabelKeysInfer(self):1291 n_classes = 31292 label_keys = ("key0", "key1", "key2")1293 head = head_lib._multi_class_head(1294 n_classes=n_classes, label_keys=label_keys,1295 metric_class_ids=range(n_classes),1296 head_name="head_name")1297 with ops.Graph().as_default():1298 model_fn_ops = head.create_model_fn_ops(1299 features={},1300 mode=model_fn.ModeKeys.INFER,1301 train_op_fn=head_lib.no_op_train_fn,1302 logits=((1., 0., 0.), (0., 0., 1.),))1303 with session.Session():1304 lookup_ops.tables_initializer().run()1305 self.assertAllEqual(1306 [b"key0", b"key2"],1307 model_fn_ops.predictions["classes"].eval())1308 self.assertItemsEqual(1309 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1310 self.assertEqual(1311 constants.ProblemType.CLASSIFICATION,1312 model_fn_ops.output_alternatives["head_name"][0])1313 predictions_for_serving = (1314 model_fn_ops.output_alternatives["head_name"][1])1315 self.assertIn("classes", six.iterkeys(predictions_for_serving))1316 self.assertAllEqual(1317 [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]],1318 predictions_for_serving["classes"].eval())1319 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1320 self.assertAllClose(1321 [[0.576117, 0.2119416, 0.2119416],1322 [0.2119416, 0.2119416, 0.576117]],1323 predictions_for_serving["probabilities"].eval())1324 def testMultiClassWithLabelKeysEvalAccuracy0(self):1325 n_classes = 31326 label_keys = ("key0", "key1", "key2")1327 head = head_lib._multi_class_head(1328 n_classes=n_classes,1329 label_keys=label_keys)1330 with ops.Graph().as_default():1331 model_fn_ops = head.create_model_fn_ops(1332 features={},1333 mode=model_fn.ModeKeys.EVAL,1334 labels=("key2",),1335 train_op_fn=head_lib.no_op_train_fn,1336 logits=((1., 0., 0.),))1337 with session.Session():1338 lookup_ops.tables_initializer().run()1339 self.assertIsNone(model_fn_ops.train_op)1340 _assert_no_variables(self)1341 _assert_summary_tags(self, ["loss"])1342 expected_loss = 1.55144471343 expected_eval_metrics = {1344 "accuracy": 0.,1345 "loss": expected_loss,1346 }1347 _assert_metrics(self, expected_loss,1348 expected_eval_metrics, model_fn_ops)1349 def testMultiClassWithLabelKeysEvalAccuracy1(self):1350 n_classes = 31351 label_keys = ("key0", "key1", "key2")1352 head = head_lib._multi_class_head(1353 n_classes=n_classes,1354 label_keys=label_keys)1355 with ops.Graph().as_default():1356 model_fn_ops = head.create_model_fn_ops(1357 features={},1358 mode=model_fn.ModeKeys.EVAL,1359 labels=("key2",),1360 train_op_fn=head_lib.no_op_train_fn,1361 logits=((0., 0., 1.),))1362 with session.Session():1363 lookup_ops.tables_initializer().run()1364 self.assertIsNone(model_fn_ops.train_op)1365 _assert_no_variables(self)1366 _assert_summary_tags(self, ["loss"])1367 expected_loss = 0.55144471368 expected_eval_metrics = {1369 "accuracy": 1.,1370 "loss": expected_loss,1371 }1372 _assert_metrics(self, expected_loss,1373 expected_eval_metrics, model_fn_ops)1374class BinarySvmHeadTest(test.TestCase):1375 def _assert_output_alternatives(self, model_fn_ops):1376 self.assertEquals({1377 None: constants.ProblemType.LOGISTIC_REGRESSION1378 }, {1379 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1380 })1381 def setUp(self):1382 # Prediction for first example is in the right side of the hyperplane1383 # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss1384 # incurred by this example. The 2nd prediction is outside the margin so it1385 # incurs no loss at all.1386 self._predictions = ((-.5,), (1.2,))1387 self._labels = (0, 1)1388 self._expected_losses = (.5, 0.)1389 def testBinarySVMWithLogits(self):1390 head = head_lib.binary_svm_head()1391 with ops.Graph().as_default(), session.Session():1392 model_fn_ops = head.create_model_fn_ops(1393 {},1394 model_fn.ModeKeys.TRAIN,1395 self._labels,1396 head_lib.no_op_train_fn,1397 logits=self._predictions)1398 self._assert_output_alternatives(model_fn_ops)1399 _assert_no_variables(self)1400 _assert_summary_tags(self, ["loss"])1401 expected_loss = np.average(self._expected_losses)1402 _assert_metrics(self, expected_loss, {1403 "accuracy": 1.,1404 "loss": expected_loss,1405 }, model_fn_ops)1406 def testBinarySVMWithInvalidLogits(self):1407 head = head_lib.binary_svm_head()1408 with ops.Graph().as_default(), session.Session():1409 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1410 head.create_model_fn_ops(1411 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1412 logits=np.ones((2, 2)))1413 def testBinarySVMWithLogitsInput(self):1414 head = head_lib.binary_svm_head()1415 with ops.Graph().as_default(), session.Session():1416 model_fn_ops = head.create_model_fn_ops(1417 {},1418 model_fn.ModeKeys.TRAIN,1419 self._labels,1420 head_lib.no_op_train_fn,1421 logits_input=((0., 0.), (0., 0.)))1422 self._assert_output_alternatives(model_fn_ops)1423 w = ("binary_svm_head/logits/weights:0",1424 "binary_svm_head/logits/biases:0")1425 _assert_variables(1426 self, expected_global=w, expected_model=w, expected_trainable=w)1427 variables.global_variables_initializer().run()1428 _assert_summary_tags(self, ["loss"])1429 expected_loss = 1.1430 _assert_metrics(self, expected_loss, {1431 "accuracy": .5,1432 "loss": expected_loss,1433 }, model_fn_ops)1434 def testBinarySVMWithLogitsAndLogitsInput(self):1435 head = head_lib.binary_svm_head()1436 with ops.Graph().as_default(), session.Session():1437 with self.assertRaisesRegexp(1438 ValueError, "Both logits and logits_input supplied"):1439 head.create_model_fn_ops(1440 {},1441 model_fn.ModeKeys.TRAIN,1442 self._labels,1443 head_lib.no_op_train_fn,1444 logits_input=((0., 0.), (0., 0.)),1445 logits=self._predictions)1446 def testBinarySVMEvalMode(self):1447 head = head_lib.binary_svm_head()1448 with ops.Graph().as_default(), session.Session():1449 model_fn_ops = head.create_model_fn_ops(1450 {},1451 model_fn.ModeKeys.EVAL,1452 self._labels,1453 head_lib.no_op_train_fn,1454 logits=self._predictions)1455 self._assert_output_alternatives(model_fn_ops)1456 self.assertIsNone(model_fn_ops.train_op)1457 _assert_no_variables(self)1458 _assert_summary_tags(self, ["loss"])1459 expected_loss = np.average(self._expected_losses)1460 _assert_metrics(self, expected_loss, {1461 "accuracy": 1.,1462 "loss": expected_loss,1463 }, model_fn_ops)1464 def testBinarySVMWithLabelName(self):1465 label_name = "my_label"1466 head = head_lib.binary_svm_head(label_name=label_name)1467 with ops.Graph().as_default(), session.Session():1468 model_fn_ops = head.create_model_fn_ops(1469 {},1470 model_fn.ModeKeys.TRAIN,1471 {label_name: self._labels},1472 head_lib.no_op_train_fn,1473 logits=self._predictions)1474 self._assert_output_alternatives(model_fn_ops)1475 _assert_no_variables(self)1476 _assert_summary_tags(self, ["loss"])1477 expected_loss = np.average(self._expected_losses)1478 _assert_metrics(self, expected_loss, {1479 "accuracy": 1.,1480 "loss": expected_loss,1481 }, model_fn_ops)1482 def testBinarySVMWith1DWeights(self):1483 head = head_lib.binary_svm_head(weight_column_name="weights")1484 with ops.Graph().as_default(), session.Session():1485 weights = (7., 11.)1486 model_fn_ops = head.create_model_fn_ops(1487 # We have to add an extra dim here for weights broadcasting to work.1488 features={"weights": weights},1489 mode=model_fn.ModeKeys.TRAIN,1490 labels=self._labels,1491 train_op_fn=head_lib.no_op_train_fn,1492 logits=self._predictions)1493 self._assert_output_alternatives(model_fn_ops)1494 _assert_no_variables(self)1495 _assert_summary_tags(self, ["loss"])1496 expected_weighted_losses = np.multiply(weights, self._expected_losses)1497 _assert_metrics(self, np.mean(expected_weighted_losses), {1498 "accuracy": 1.,1499 "loss": np.sum(expected_weighted_losses) / np.sum(weights),1500 }, model_fn_ops)1501 def testBinarySVMWith2DWeights(self):1502 head = head_lib.binary_svm_head(weight_column_name="weights")1503 with ops.Graph().as_default(), session.Session():1504 weights = (7., 11.)1505 model_fn_ops = head.create_model_fn_ops(1506 # We have to add an extra dim here for weights broadcasting to work.1507 features={"weights": tuple([(w,) for w in weights])},1508 mode=model_fn.ModeKeys.TRAIN,1509 labels=self._labels,1510 train_op_fn=head_lib.no_op_train_fn,1511 logits=self._predictions)1512 self._assert_output_alternatives(model_fn_ops)1513 _assert_no_variables(self)1514 _assert_summary_tags(self, ["loss"])1515 expected_weighted_losses = np.multiply(weights, self._expected_losses)1516 _assert_metrics(self, np.mean(expected_weighted_losses), {1517 "accuracy": 1.,1518 "loss": np.sum(expected_weighted_losses) / np.sum(weights),1519 }, model_fn_ops)1520 def testBinarySVMWithCenteredBias(self):1521 head = head_lib.binary_svm_head(enable_centered_bias=True)1522 with ops.Graph().as_default(), session.Session():1523 model_fn_ops = head.create_model_fn_ops(1524 {},1525 model_fn.ModeKeys.TRAIN,1526 self._labels,1527 head_lib.no_op_train_fn,1528 logits=self._predictions)1529 self._assert_output_alternatives(model_fn_ops)1530 _assert_variables(1531 self,1532 expected_global=(1533 "binary_svm_head/centered_bias_weight:0",1534 ("binary_svm_head/binary_svm_head/centered_bias_weight/"1535 "Adagrad:0"),1536 ),1537 expected_trainable=("binary_svm_head/centered_bias_weight:0",))1538 variables.global_variables_initializer().run()1539 _assert_summary_tags(self, [1540 "loss",1541 "binary_svm_head/centered_bias/bias_0"1542 ])1543 expected_loss = np.average(self._expected_losses)1544 _assert_metrics(self, expected_loss, {1545 "accuracy": 1.,1546 "loss": expected_loss,1547 }, model_fn_ops)1548class LossOnlyHead(test.TestCase):1549 def testNoPredictionsAndNoMetrics(self):1550 head = head_lib.loss_only_head(lambda: 1, head_name="const")1551 model_fn_ops = head.create_model_fn_ops(1552 features={},1553 mode=model_fn.ModeKeys.TRAIN,1554 train_op_fn=head_lib.no_op_train_fn)1555 self.assertDictEqual(model_fn_ops.predictions, {})1556 self.assertDictEqual(model_fn_ops.eval_metric_ops, {})1557 self.assertIsNotNone(model_fn_ops.loss)1558 with session.Session() as sess:1559 self.assertEqual(1, sess.run(model_fn_ops.loss))1560class MultiHeadTest(test.TestCase):1561 def testInvalidHeads(self):1562 named_head = head_lib.multi_class_head(1563 n_classes=3, label_name="label", head_name="head1")1564 unnamed_head = head_lib.multi_class_head(1565 n_classes=4, label_name="label")1566 with self.assertRaisesRegexp(ValueError, "must have names"):1567 head_lib.multi_head((named_head, unnamed_head))1568 def testTrainWithNoneTrainOpFn(self):1569 head1 = head_lib.multi_class_head(1570 n_classes=3, label_name="label1", head_name="head1")1571 head2 = head_lib.multi_class_head(1572 n_classes=4, label_name="label2", head_name="head2")1573 head = head_lib.multi_head((head1, head2))1574 labels = {1575 "label1": (1,),1576 "label2": (1,)1577 }1578 with self.assertRaisesRegexp(1579 ValueError, "train_op_fn can not be None in TRAIN mode"):1580 head.create_model_fn_ops(1581 features={"weights": (2.0, 10.0)},1582 labels=labels,1583 mode=model_fn.ModeKeys.TRAIN,1584 train_op_fn=None,1585 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1586 def testTrain_withNoHeadWeights(self):1587 head1 = head_lib.multi_class_head(1588 n_classes=3, label_name="label1", head_name="head1")1589 head2 = head_lib.multi_class_head(1590 n_classes=4, label_name="label2", head_name="head2")1591 head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")1592 head = head_lib.multi_head((head1, head2, head3))1593 labels = {1594 "label1": (1,),1595 "label2": (1,)1596 }1597 model_fn_ops = head.create_model_fn_ops(1598 features={"weights": (2.0, 10.0)},1599 labels=labels,1600 mode=model_fn.ModeKeys.TRAIN,1601 train_op_fn=head_lib.no_op_train_fn,1602 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1603 self.assertIsNone(model_fn_ops.predictions)1604 self.assertIsNotNone(model_fn_ops.loss)1605 self.assertIsNotNone(model_fn_ops.train_op)1606 self.assertTrue(model_fn_ops.eval_metric_ops)1607 self.assertIsNone(model_fn_ops.output_alternatives)1608 with session.Session() as sess:1609 self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)1610 def testTrain_withHeadWeights(self):1611 head1 = head_lib.multi_class_head(1612 n_classes=3, label_name="label1", head_name="head1")1613 head2 = head_lib.multi_class_head(1614 n_classes=4, label_name="label2", head_name="head2")1615 head = head_lib.multi_head((head1, head2), (1, .5))1616 labels = {1617 "label1": (1,),1618 "label2": (1,)1619 }1620 model_fn_ops = head.create_model_fn_ops(1621 features={"weights": (2.0, 10.0)},1622 labels=labels,1623 mode=model_fn.ModeKeys.TRAIN,1624 train_op_fn=head_lib.no_op_train_fn,1625 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1626 self.assertIsNone(model_fn_ops.predictions)1627 self.assertIsNotNone(model_fn_ops.loss)1628 self.assertIsNotNone(model_fn_ops.train_op)1629 self.assertTrue(model_fn_ops.eval_metric_ops)1630 self.assertIsNone(model_fn_ops.output_alternatives)1631 with session.Session() as sess:1632 self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)1633 def testTrain_withDictLogits(self):1634 head1 = head_lib.multi_class_head(1635 n_classes=3, label_name="label1", head_name="head1")1636 head2 = head_lib.multi_class_head(1637 n_classes=4, label_name="label2", head_name="head2")1638 head = head_lib.multi_head((head1, head2))1639 labels = {1640 "label1": (1,),1641 "label2": (1,)1642 }1643 model_fn_ops = head.create_model_fn_ops(1644 features={"weights": (2.0, 10.0)},1645 labels=labels,1646 mode=model_fn.ModeKeys.TRAIN,1647 train_op_fn=head_lib.no_op_train_fn,1648 logits={head1.head_name: ((-0.7, 0.2, .1),),1649 head2.head_name: ((.1, .1, .1, .1),)})1650 self.assertIsNone(model_fn_ops.predictions)1651 self.assertIsNotNone(model_fn_ops.loss)1652 self.assertIsNotNone(model_fn_ops.train_op)1653 self.assertTrue(model_fn_ops.eval_metric_ops)1654 self.assertIsNone(model_fn_ops.output_alternatives)1655 with session.Session() as sess:1656 self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)1657 def testInfer(self):1658 head1 = head_lib.multi_class_head(1659 n_classes=3, label_name="label1", head_name="head1")1660 head2 = head_lib.multi_class_head(1661 n_classes=4, label_name="label2", head_name="head2")1662 head = head_lib.multi_head((head1, head2), (1, .5))1663 labels = {1664 "label1": (1,),1665 "label2": (1,)1666 }1667 model_fn_ops = head.create_model_fn_ops(1668 features={"weights": (2.0, 10.0)},1669 labels=labels,1670 mode=model_fn.ModeKeys.INFER,1671 train_op_fn=head_lib.no_op_train_fn,1672 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1673 self.assertIsNotNone(model_fn_ops.predictions)1674 self.assertIsNone(model_fn_ops.loss)1675 self.assertIsNone(model_fn_ops.train_op)1676 self.assertFalse(model_fn_ops.eval_metric_ops)1677 # Tests predictions keys.1678 self.assertItemsEqual((1679 ("head1", prediction_key.PredictionKey.LOGITS),1680 ("head1", prediction_key.PredictionKey.PROBABILITIES),1681 ("head1", prediction_key.PredictionKey.CLASSES),1682 ("head2", prediction_key.PredictionKey.LOGITS),1683 ("head2", prediction_key.PredictionKey.PROBABILITIES),1684 ("head2", prediction_key.PredictionKey.CLASSES),1685 ), model_fn_ops.predictions.keys())1686 # Tests output alternative.1687 self.assertEquals({1688 "head1": constants.ProblemType.CLASSIFICATION,1689 "head2": constants.ProblemType.CLASSIFICATION,1690 }, {1691 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1692 })1693 self.assertItemsEqual((1694 prediction_key.PredictionKey.PROBABILITIES,1695 prediction_key.PredictionKey.CLASSES,1696 ), model_fn_ops.output_alternatives["head1"][1].keys())1697 self.assertItemsEqual((1698 prediction_key.PredictionKey.PROBABILITIES,1699 prediction_key.PredictionKey.CLASSES,1700 ), model_fn_ops.output_alternatives["head2"][1].keys())1701 def testEval(self):1702 head1 = head_lib.multi_class_head(1703 n_classes=3, label_name="label1", head_name="head1")1704 head2 = head_lib.multi_class_head(1705 n_classes=4, label_name="label2", head_name="head2")1706 head = head_lib.multi_head((head1, head2), (1, .5))1707 labels = {1708 "label1": (1,),1709 "label2": (1,)1710 }1711 model_fn_ops = head.create_model_fn_ops(1712 features={"weights": (2.0, 10.0)},1713 labels=labels,1714 mode=model_fn.ModeKeys.EVAL,1715 train_op_fn=head_lib.no_op_train_fn,1716 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1717 self.assertIsNotNone(model_fn_ops.predictions)1718 self.assertIsNotNone(model_fn_ops.loss)1719 self.assertIsNone(model_fn_ops.train_op)1720 self.assertIsNotNone(model_fn_ops.eval_metric_ops)...

Full Screen

Full Screen

multi_head.py

Source:multi_head.py Github

copy

Full Screen

...11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Abstractions for the head(s) of a model.16"""17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20import six21from tensorflow.python.estimator import model_fn22from tensorflow.python.estimator.canned import head as head_lib23from tensorflow.python.estimator.canned import metric_keys24from tensorflow.python.estimator.export import export_output as export_output_lib25from tensorflow.python.framework import ops26from tensorflow.python.ops import array_ops27from tensorflow.python.ops import control_flow_ops28from tensorflow.python.ops import math_ops29from tensorflow.python.ops import metrics as metrics_lib30from tensorflow.python.saved_model import signature_constants31from tensorflow.python.summary import summary32from tensorflow.python.training import training_util33_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY34def multi_head(heads, head_weights=None):35 """Creates a `_Head` for multi-objective learning.36 This class merges the output of multiple `_Head` objects.37 Specifically:38 * For training, sums losses of each head, calls `train_op_fn` with this39 final loss.40 * For eval, merges metrics by adding `head.name` suffix to the keys in eval41 metrics, such as `precision/head1`, `precision/head2`.42 * For prediction, merges predictions and updates keys in prediction dict to a43 2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that44 by default the first head is served.45 Usage:46 ```python47 # In `input_fn` specify labels as a dict keyed by head name:48 def input_fn():49 features = ...50 labels1 = ...51 labels2 = ...52 return features, {'head1': labels1, 'head2': labels2}53 # In `model_fn`, specify logits as a dict keyed by head name:54 def model_fn(features, labels, mode):55 # Create simple heads and specify head name.56 head1 = multi_class_head(n_classes=3, name='head1')57 head2 = binary_classification_head(name='head2')58 # Create multi-head from two simple heads.59 head = multi_head([head1, head2])60 # Create logits for each head, and combine them into a dict.61 logits1, logits2 = logit_fn()62 logits = {'head1': logits1, 'head2': logits2}63 # Return the merged EstimatorSpec64 return head.create_estimator_spec(..., logits=logits, ...)65 # Create an estimator with this model_fn.66 estimator = tf.estimator.Estimator(model_fn=model_fn)67 estimator.train(input_fn=input_fn, steps=100)68 ```69 Also supports `logits` as a `Tensor` of shape70 `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the71 last dimension and distribute it appropriately among the heads. E.g.:72 ```python73 def model_fn(features, labels, mode):74 # Create simple heads and specify head name.75 head1 = multi_class_head(n_classes=3, name='head1')76 head2 = binary_classification_head(name='head2')77 # Create multi-head from two simple heads.78 head = multi_head([head1, head2])79 # Create logits for the multihead.80 logits = logit_fn(logits_dimension=head.logits_dimension)81 # Return the merged EstimatorSpec82 return head.create_estimator_spec(..., logits=logits, ...)83 ```84 Args:85 heads: List or tuple of `_Head` instances. All heads must have `name`86 specified. The first head in the list is the default used at serving time.87 head_weights: Optional list of weights, same length as `heads`. Used when88 merging losses to calculate the weighted sum of losses from each head. If89 `None`, all losses are weighted equally.90 Returns:91 A instance of `_Head` that merges multiple heads.92 Raises:...

Full Screen

Full Screen

net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py

Source:net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py Github

copy

Full Screen

1import torch2import torch.nn as nn3import torch.nn.functional as F4from net.resnet import resnet34, resnet185import copy6from torch.nn import init7from IPython import embed8from auxi.module import FFM_v4, RNNModule9def weights_init_kaiming(m):10 classname = m.__class__.__name__11 if classname.find('Conv') != -1:12 nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')13 elif classname.find('Linear') != -1:14 nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')15 nn.init.constant_(m.bias.data, 0.0)16 elif classname.find('BatchNorm1d') != -1:17 nn.init.normal_(m.weight.data, 1.0, 0.02)18 nn.init.constant_(m.bias.data, 0.0)19class net(nn.Module):20 def __init__(self, config):21 super(net, self).__init__()22 self.net_head = resnet18(pretrained=True, num_classes=26)23 self.head_layer1 = nn.Sequential(*list(self.net_head.children())[:5])24 self.head_layer2 = list(self.net_head.children())[5]25 self.head_layer3 = list(self.net_head.children())[6]26 self.head_layer4 = list(self.net_head.children())[7]27 self.head_cls = list(self.net_head.children())[-1]28 self.net_body = resnet18(pretrained=True, num_classes=26)29 self.body_layer1 = nn.Sequential(*list(self.net_body.children())[:5])30 self.body_layer2 = list(self.net_body.children())[5]31 self.body_layer3 = list(self.net_body.children())[6]32 self.body_layer4 = list(self.net_body.children())[7]33 self.body_cls = list(self.net_body.children())[-1]34 self.net_scene = resnet18(pretrained=True, num_classes=26)35 self.scene_layer1 = nn.Sequential(*list(self.net_scene.children())[:5])36 self.scene_layer2 = list(self.net_scene.children())[5]37 self.scene_layer3 = list(self.net_scene.children())[6]38 self.scene_layer4 = list(self.net_scene.children())[7]39 self.scene_cls = list(self.net_scene.children())[-1]40 self.fc3 = nn.Sequential(nn.Linear(512*3 , 128), nn.Linear(128, 26))41 self.csr_ly1 = RNNModule(input_size=64, hidden_size=64, num_layers=1)42 self.csr_ly2 = RNNModule(input_size=128, hidden_size=128, num_layers=1)43 self.csr_ly3 = RNNModule(input_size=256, hidden_size=256, num_layers=1)44 self.csr_ly4 = RNNModule(input_size=512, hidden_size=512, num_layers=1)45 self.ffm_head_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)46 self.ffm_body_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)47 self.ffm_scene_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)48 self.ffm_head_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)49 self.ffm_body_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)50 self.ffm_scene_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)51 def forward(self, data, mode):52 head_ly1 = self.head_layer1(data['image_head'])53 body_ly1 = self.body_layer1(data['image_body'])54 scene_ly1 = self.scene_layer1(data['image_scene'])55 head_ly1_ = F.adaptive_avg_pool2d(head_ly1, (1,1)).view(head_ly1.size(0), -1)56 body_ly1_ = F.adaptive_avg_pool2d(body_ly1, (1,1)).view(body_ly1.size(0), -1)57 scene_ly1_ = F.adaptive_avg_pool2d(scene_ly1, (1,1)).view(scene_ly1.size(0), -1)58 feature = self.csr_ly1(torch.stack((head_ly1_, body_ly1_, scene_ly1_), dim=1))59 head_ly1 = head_ly1+feature[:,0,:].view(head_ly1.size(0),head_ly1.size(1),1,1).expand_as(head_ly1)60 body_ly1 = body_ly1+feature[:,1,:].view(body_ly1.size(0),body_ly1.size(1),1,1).expand_as(body_ly1)61 scene_ly1 = scene_ly1+feature[:,2,:].view(scene_ly1.size(0),scene_ly1.size(1),1,1).expand_as(scene_ly1)62 head_ly2 = self.head_layer2(head_ly1)63 body_ly2 = self.body_layer2(body_ly1)64 scene_ly2 = self.scene_layer2(scene_ly1)65 head_ly2_ = F.adaptive_avg_pool2d(head_ly2, (1,1)).view(head_ly2.size(0), -1)66 body_ly2_ = F.adaptive_avg_pool2d(body_ly2, (1,1)).view(body_ly2.size(0), -1)67 scene_ly2_ = F.adaptive_avg_pool2d(scene_ly2, (1,1)).view(scene_ly2.size(0), -1)68 feature = self.csr_ly2(torch.stack((head_ly2_, body_ly2_, scene_ly2_), dim=1))69 head_ly2 = head_ly2+feature[:,0,:].view(head_ly2.size(0),head_ly2.size(1),1,1).expand_as(head_ly2)70 body_ly2 = body_ly2+feature[:,1,:].view(body_ly2.size(0),body_ly2.size(1),1,1).expand_as(body_ly2)71 scene_ly2 = scene_ly2+feature[:,2,:].view(scene_ly2.size(0),scene_ly2.size(1),1,1).expand_as(scene_ly2)72 head_ly2 = self.ffm_head_ly2(head_ly2, body_ly2, scene_ly2) + head_ly273 body_ly2 = self.ffm_body_ly2(body_ly2, head_ly2, scene_ly2) + body_ly274 scene_ly2 = self.ffm_scene_ly2(scene_ly2, head_ly2, body_ly2) + scene_ly275 head_ly3 = self.head_layer3(head_ly2)76 body_ly3 = self.body_layer3(body_ly2)77 scene_ly3 = self.scene_layer3(scene_ly2)78 head_ly3_ = F.adaptive_avg_pool2d(head_ly3, (1,1)).view(head_ly3.size(0), -1)79 body_ly3_ = F.adaptive_avg_pool2d(body_ly3, (1,1)).view(body_ly3.size(0), -1)80 scene_ly3_ = F.adaptive_avg_pool2d(scene_ly3, (1,1)).view(scene_ly3.size(0), -1)81 feature = self.csr_ly3(torch.stack((head_ly3_, body_ly3_, scene_ly3_), dim=1))82 head_ly3 = head_ly3+feature[:,0,:].view(head_ly3.size(0),head_ly3.size(1),1,1).expand_as(head_ly3)83 body_ly3 = body_ly3+feature[:,1,:].view(body_ly3.size(0),body_ly3.size(1),1,1).expand_as(body_ly3)84 scene_ly3 = scene_ly3+feature[:,2,:].view(scene_ly3.size(0),scene_ly3.size(1),1,1).expand_as(scene_ly3)85 head_ly4 = self.head_layer4(head_ly3)86 body_ly4 = self.body_layer4(body_ly3)87 scene_ly4 = self.scene_layer4(scene_ly3)88 head_ly4_ = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)89 body_ly4_ = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)90 scene_ly4_ = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)91 feature = self.csr_ly4(torch.stack((head_ly4_, body_ly4_, scene_ly4_), dim=1))92 head_ly4 = head_ly4+feature[:,0,:].view(head_ly4.size(0),head_ly4.size(1),1,1).expand_as(head_ly4)93 body_ly4 = body_ly4+feature[:,1,:].view(body_ly4.size(0),body_ly4.size(1),1,1).expand_as(body_ly4)94 scene_ly4 = scene_ly4+feature[:,2,:].view(scene_ly4.size(0),scene_ly4.size(1),1,1).expand_as(scene_ly4)95 head_ly4 = self.ffm_head_ly4(head_ly4, body_ly4, scene_ly4) + head_ly496 body_ly4 = self.ffm_body_ly4(body_ly4, head_ly4, scene_ly4) + body_ly497 scene_ly4 = self.ffm_scene_ly4(scene_ly4, head_ly4, body_ly4) + scene_ly498 head_ly4 = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)99 body_ly4 = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)100 scene_ly4 = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)101 out_head = self.head_cls(head_ly4)102 out_body = self.body_cls(body_ly4)103 out_scene = self.scene_cls(scene_ly4)104 out = self.fc3(torch.cat((head_ly4, body_ly4, scene_ly4), dim=1))...

Full Screen

Full Screen

head_tail.py

Source:head_tail.py Github

copy

Full Screen

1"""Utilities to manage head and tail of elements2The scope is to avoid loosing part of the original text in the final tree.3"""4from .tree import Item5class TokenValue:6 def __init__(self, value):7 self.value = value8 self.pos = None9 self.size = None10 self.head = ""11 self.tail = ""12 def __repr__(self):13 return "TokenValue(%s)" % self.value14 def __str__(self):15 return self.value16class HeadTailLexer:17 """Utility to handle head and tail at lexer time.18 """19 LEXER_ATTR = "_luqum_headtail"20 @classmethod21 def handle(cls, token, orig_value):22 """Handling a token.23 .. note::24 PLY does not gives acces to previous tokens,25 although it does not provide any infrastructure for handling specific state.26 So we use the strategy27 of puting a :py:cls:`HeadTailLexer`instance as an attribute of the lexer28 each time we start a new tokenization.29 """30 # get instance31 if token.lexpos == 0:32 # first token make instance33 instance = cls()34 setattr(token.lexer, cls.LEXER_ATTR, instance)35 else:36 instance = getattr(token.lexer, cls.LEXER_ATTR)37 # handle38 instance.handle_token(token, orig_value)39 def __init__(self):40 self.head = None41 """This will track the head of next element, useful only for first element42 """43 self.last_elt = None44 """This will track the last token, so we can use it to add the tail to it.45 """46 def handle_token(self, token, orig_value):47 """Handle head and tail for tokens48 The scope is to avoid loosing part of the original text and keep it in elements.49 """50 # handle headtail51 if token.type == "SEPARATOR":52 if token.lexpos == 0:53 # spaces at expression start, head for next token54 self.head = token.value55 else:56 # tail of last processed token57 if self.last_elt is not None:58 self.last_elt.value.tail += token.value59 else:60 # if there is a head, apply61 head = self.head62 if head is not None:63 token.value.head = head64 self.head = None65 # keep tracks of token, to apply tail later66 self.last_elt = token67 # also set pos and size68 if isinstance(token.value, (Item, TokenValue)):69 token.value.pos = token.lexpos70 token.value.size = len(orig_value)71token_headtail = HeadTailLexer.handle72class HeadTailManager:73 """Utility to hande head and tail at expression parse time74 """75 def pos(self, p, head_transfer=False, tail_transfer=False):76 """Compute pos and size of element 0 based on it's parts (p[1:])77 :param list p: the parser expression as in PLY78 :param bool head_transfer: True if head of first child will be transfered to p[0]79 :param bool tail_transfer: True if tail of last child wiil be transfered to p[0]80 """81 # pos82 if p[1].pos is not None:83 p[0].pos = p[1].pos84 if not head_transfer:85 # head is'nt transfered, so we are before it86 p[0].pos -= len(p[1].head)87 # size88 p[0].size = sum(89 (elt.size or 0) + len(elt.head or "") + len(elt.tail or "") for elt in p[1:])90 if head_transfer and p[1].head:91 # we account head in size, remove it92 p[0].size -= len(p[1].head)93 last_p = p[len(p) - 1] # negative indexing not supported by PLY94 if tail_transfer and last_p.tail:95 # we account head in size, remove it96 p[0].size -= len(last_p.tail)97 def binary_operation(self, p, op_tail):98 self.pos(p, head_transfer=False, tail_transfer=False)99 # correct size100 p[0].size -= len(op_tail)101 def simple_term(self, p):102 self.pos(p, head_transfer=True, tail_transfer=True)103 p[0].head = p[1].head104 p[0].tail = p[1].tail105 def unary(self, p):106 """OP expr"""107 self.pos(p, head_transfer=True, tail_transfer=False)108 p[0].head = p[1].head109 p[2].head = p[1].tail + p[2].head110 def post_unary(self, p):111 """expr OP"""112 self.pos(p, head_transfer=False, tail_transfer=True)113 p[1].tail += p[2].head114 p[0].tail = p[2].tail115 def paren(self, p):116 """( expr )"""117 self.pos(p, head_transfer=True, tail_transfer=True)118 # p[0] is global element (Group or FieldGroup)119 # p[2] is content120 # p[1] is left parenthesis121 p[0].head = p[1].head122 p[2].head = p[1].tail + p[2].head123 # p[3] is right parenthesis124 p[2].tail += p[3].head125 p[0].tail = p[3].tail126 def range(self, p):127 """[ expr TO expr ]"""128 self.pos(p, head_transfer=True, tail_transfer=True)129 # p[0] is global element (Range)130 # p[2] is lower bound131 p[0].head = p[1].head132 p[2].head = p[1].tail + p[2].head133 # p[3] is TO134 # p[4] is upper bound135 p[2].tail += p[3].head136 p[4].head = p[3].tail + p[4].head137 # p[5] is upper braket138 p[4].tail += p[5].head139 p[0].tail = p[5].tail140 def search_field(self, p):141 """name: expr"""142 self.pos(p, head_transfer=True, tail_transfer=False)143 # p[0] is global element (SearchField)144 # p[1] is search field name145 # p[2] is COLUMN146 p[0].head = p[1].head147 if p[1].tail or p[2].head:148 pass # FIXME: add warning, or handle space between point and name in SearchField ?149 # p[3] is the expression150 p[3].head = p[2].tail + p[3].head151head_tail = HeadTailManager()152"""singleton of HeadTailManager...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run fMBT automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful