Best Python code snippet using responses
head_test.py
Source:head_test.py  
...93    stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z)94    lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.)95    return sum(lpl)/len(lpl)96  def testPoissonWithLogits(self):97    head = head_lib.poisson_regression_head()98    labels = ((0.,), (1.,), (1.,))99    logits = ((0.,), (-1.,), (3.,))100    with ops.Graph().as_default(), session.Session():101      model_fn_ops = head.create_model_fn_ops(102          {},103          labels=labels,104          mode=model_fn.ModeKeys.TRAIN,105          train_op_fn=head_lib.no_op_train_fn,106          logits=logits)107      self._assert_output_alternatives(model_fn_ops)108      _assert_summary_tags(self, ["loss"])109      _assert_no_variables(self)110      loss = self._log_poisson_loss(logits, labels)111      _assert_metrics(self, loss, {"loss": loss}, model_fn_ops)112class RegressionHeadTest(test.TestCase):113  def _assert_output_alternatives(self, model_fn_ops):114    self.assertEquals({115        None: constants.ProblemType.LINEAR_REGRESSION116    }, {117        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)118    })119  # TODO(zakaria): test multilabel regression.120  def testRegressionWithLogits(self):121    head = head_lib.regression_head()122    with ops.Graph().as_default(), session.Session():123      model_fn_ops = head.create_model_fn_ops(124          {},125          labels=((0.,), (1.,), (1.,)),126          mode=model_fn.ModeKeys.TRAIN,127          train_op_fn=head_lib.no_op_train_fn,128          logits=((1.,), (1.,), (3.,)))129      self._assert_output_alternatives(model_fn_ops)130      _assert_summary_tags(self, ["loss"])131      _assert_no_variables(self)132      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)133  def testRegressionWithLogitFn(self):134    head = head_lib.regression_head(link_fn=math_ops.square)135    def _assert_preditions(test_case, expected_predictions, model_fn_ops):136      variables.initialize_local_variables().run()137      test_case.assertAllClose(expected_predictions,138                               model_fn_ops.predictions["scores"].eval())139    with ops.Graph().as_default(), session.Session():140      model_fn_ops = head.create_model_fn_ops(141          {},142          labels=((0.,), (1.,), (1.,)),143          mode=model_fn.ModeKeys.TRAIN,144          train_op_fn=head_lib.no_op_train_fn,145          logits=((1.,), (1.,), (3.,)))146      self._assert_output_alternatives(model_fn_ops)147      _assert_summary_tags(self, ["loss"])148      _assert_no_variables(self)149      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)150      _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops)151  def testRegressionWithInvalidLogits(self):152    head = head_lib.regression_head()153    with ops.Graph().as_default(), session.Session():154      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):155        head.create_model_fn_ops(156            {},157            labels=((0.,), (1.,), (1.,)),158            mode=model_fn.ModeKeys.TRAIN,159            train_op_fn=head_lib.no_op_train_fn,160            logits=((1., 1.), (1., 1.), (3., 1.)))161  def testRegressionWithLogitsInput(self):162    head = head_lib.regression_head()163    with ops.Graph().as_default(), session.Session():164      model_fn_ops = head.create_model_fn_ops(165          {},166          labels=((0.,), (1.,), (1.,)),167          mode=model_fn.ModeKeys.TRAIN,168          train_op_fn=head_lib.no_op_train_fn,169          logits_input=((0., 0.), (0., 0.), (0., 0.)))170      self._assert_output_alternatives(model_fn_ops)171      w = ("regression_head/logits/weights:0",172           "regression_head/logits/biases:0")173      _assert_variables(174          self, expected_global=w, expected_model=w, expected_trainable=w)175      variables.global_variables_initializer().run()176      _assert_summary_tags(self, ["loss"])177      _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)178  def testRegressionWithLogitsAndLogitsInput(self):179    head = head_lib.regression_head()180    with ops.Graph().as_default(), session.Session():181      with self.assertRaisesRegexp(182          ValueError, "Both logits and logits_input supplied"):183        head.create_model_fn_ops(184            {},185            labels=((0.,), (1.,), (1.,)),186            mode=model_fn.ModeKeys.TRAIN,187            train_op_fn=head_lib.no_op_train_fn,188            logits_input=((0., 0.), (0., 0.), (0., 0.)),189            logits=((1.,), (1.,), (3.,)))190  def testRegressionEvalMode(self):191    head = head_lib.regression_head()192    with ops.Graph().as_default(), session.Session():193      model_fn_ops = head.create_model_fn_ops(194          {},195          labels=((1.,), (1.,), (3.,)),196          mode=model_fn.ModeKeys.EVAL,197          train_op_fn=head_lib.no_op_train_fn,198          logits=((0.,), (1.,), (1.,)))199      self._assert_output_alternatives(model_fn_ops)200      self.assertIsNone(model_fn_ops.train_op)201      _assert_no_variables(self)202      _assert_summary_tags(self, ["loss"])203      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)204  def testRegressionWithLabelName(self):205    label_name = "my_label"206    head = head_lib.regression_head(label_name=label_name)207    with ops.Graph().as_default(), session.Session():208      model_fn_ops = head.create_model_fn_ops(209          {},210          labels={label_name: ((0.,), (1.,), (1.,))},211          mode=model_fn.ModeKeys.TRAIN,212          train_op_fn=head_lib.no_op_train_fn,213          logits=((1.,), (1.,), (3.,)))214      self._assert_output_alternatives(model_fn_ops)215      _assert_no_variables(self)216      _assert_summary_tags(self, ["loss"])217      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)218  def testRegressionWithScalarWeights(self):219    head = head_lib.regression_head(weight_column_name="label_weight")220    with ops.Graph().as_default(), session.Session():221      weights = 2.222      labels = ((0.,), (1.,), (1.,))223      model_fn_ops = head.create_model_fn_ops(224          features={"label_weight": weights},225          labels=labels,226          mode=model_fn.ModeKeys.TRAIN,227          train_op_fn=head_lib.no_op_train_fn,228          logits=((1.,), (1.,), (3.,)))229      self._assert_output_alternatives(model_fn_ops)230      _assert_no_variables(self)231      _assert_summary_tags(self, ["loss"])232      _assert_metrics(self, (weights * 5.) / len(labels), {233          "loss": (weights * 5.) / (weights * len(labels))234      }, model_fn_ops)235  def testRegressionWith1DWeights(self):236    head = head_lib.regression_head(weight_column_name="label_weight")237    with ops.Graph().as_default(), session.Session():238      weights = (2., 5., 0.)239      labels = ((0.,), (1.,), (1.,))240      model_fn_ops = head.create_model_fn_ops(241          features={"label_weight": weights},242          labels=labels,243          mode=model_fn.ModeKeys.TRAIN,244          train_op_fn=head_lib.no_op_train_fn,245          logits=((1.,), (1.,), (3.,)))246      self._assert_output_alternatives(model_fn_ops)247      _assert_no_variables(self)248      _assert_summary_tags(self, ["loss"])249      _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},250                      model_fn_ops)251  def testRegressionWith2DWeights(self):252    head = head_lib.regression_head(weight_column_name="label_weight")253    with ops.Graph().as_default(), session.Session():254      weights = ((2.,), (5.,), (0.,))255      labels = ((0.,), (1.,), (1.,))256      model_fn_ops = head.create_model_fn_ops(257          features={"label_weight": weights},258          labels=labels,259          mode=model_fn.ModeKeys.TRAIN,260          train_op_fn=head_lib.no_op_train_fn,261          logits=((1.,), (1.,), (3.,)))262      self._assert_output_alternatives(model_fn_ops)263      _assert_no_variables(self)264      _assert_summary_tags(self, ["loss"])265      _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},266                      model_fn_ops)267  def testRegressionWithCenteredBias(self):268    head = head_lib.regression_head(enable_centered_bias=True)269    with ops.Graph().as_default(), session.Session():270      model_fn_ops = head.create_model_fn_ops(271          {},272          labels=((0.,), (1.,), (1.,)),273          mode=model_fn.ModeKeys.TRAIN,274          train_op_fn=head_lib.no_op_train_fn,275          logits=((1.,), (1.,), (3.,)))276      self._assert_output_alternatives(model_fn_ops)277      _assert_variables(278          self,279          expected_global=(280              "regression_head/centered_bias_weight:0",281              "regression_head/regression_head/centered_bias_weight/Adagrad:0",282          ),283          expected_trainable=("regression_head/centered_bias_weight:0",))284      variables.global_variables_initializer().run()285      _assert_summary_tags(self, [286          "loss",287          "regression_head/centered_bias/bias_0"288      ])289      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)290  def testRegressionErrorInSparseTensorLabels(self):291    head = head_lib.regression_head()292    with ops.Graph().as_default():293      labels = sparse_tensor.SparseTensorValue(294          indices=((0, 0), (1, 0), (2, 0)),295          values=(0., 1., 1.),296          dense_shape=(3, 1))297      with self.assertRaisesRegexp(ValueError,298                                   "SparseTensor is not supported"):299        head.create_model_fn_ops(300            {},301            labels=labels,302            mode=model_fn.ModeKeys.TRAIN,303            train_op_fn=head_lib.no_op_train_fn,304            logits=((1.,), (1.,), (3.,)))305class MultiLabelHeadTest(test.TestCase):306  def _assert_output_alternatives(self, model_fn_ops):307    self.assertEquals({308        None: constants.ProblemType.CLASSIFICATION309    }, {310        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)311    })312  def setUp(self):313    self._logits = ((1., 0., 0.),)314    self._labels = ((0, 0, 1),)315  def _expected_eval_metrics(self, expected_loss):316    return {317        "accuracy": 1. / 3,318        "loss": expected_loss,319        "auc": 1. / 4,320        "auc/class0": 1.,321        "auc/class1": 1.,322        "auc/class2": 0.,323        "auc_precision_recall": 0.166667,324        "auc_precision_recall/class0": 0,325        "auc_precision_recall/class1": 0.,326        "auc_precision_recall/class2": 1.,327        "labels/actual_label_mean/class0": self._labels[0][0],328        "labels/actual_label_mean/class1": self._labels[0][1],329        "labels/actual_label_mean/class2": self._labels[0][2],330        "labels/logits_mean/class0": self._logits[0][0],331        "labels/logits_mean/class1": self._logits[0][1],332        "labels/logits_mean/class2": self._logits[0][2],333        "labels/prediction_mean/class0": self._logits[0][0],334        "labels/prediction_mean/class1": self._logits[0][1],335        "labels/prediction_mean/class2": self._logits[0][2],336        "labels/probability_mean/class0": _sigmoid(self._logits[0][0]),337        "labels/probability_mean/class1": _sigmoid(self._logits[0][1]),338        "labels/probability_mean/class2": _sigmoid(self._logits[0][2]),339    }340  def testMultiLabelWithLogits(self):341    n_classes = 3342    head = head_lib.multi_label_head(343        n_classes=n_classes, metric_class_ids=range(n_classes))344    with ops.Graph().as_default(), session.Session():345      model_fn_ops = head.create_model_fn_ops(346          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,347          logits=self._logits)348      self._assert_output_alternatives(model_fn_ops)349      _assert_no_variables(self)350      _assert_summary_tags(self, ["loss"])351      expected_loss = .89985204352      _assert_metrics(self, expected_loss,353                      self._expected_eval_metrics(expected_loss), model_fn_ops)354  def testMultiLabelTwoClasses(self):355    n_classes = 2356    labels = ((0, 1),)357    logits = ((1., 0.),)358    head = head_lib.multi_label_head(359        n_classes=n_classes, metric_class_ids=range(n_classes))360    with ops.Graph().as_default(), session.Session():361      model_fn_ops = head.create_model_fn_ops(362          {}, model_fn.ModeKeys.TRAIN, labels=labels,363          train_op_fn=head_lib.no_op_train_fn, logits=logits)364      self._assert_output_alternatives(model_fn_ops)365      _assert_no_variables(self)366      _assert_summary_tags(self, ["loss"])367      expected_loss = 1.00320443368      _assert_metrics(self, expected_loss, {369          "accuracy": 0.,370          "auc": 0.,371          "loss": expected_loss,372          "auc/class0": 1.,373          "auc/class1": 0.,374          "labels/actual_label_mean/class0": labels[0][0],375          "labels/actual_label_mean/class1": labels[0][1],376          "labels/logits_mean/class0": logits[0][0],377          "labels/logits_mean/class1": logits[0][1],378          "labels/prediction_mean/class0": logits[0][0],379          "labels/prediction_mean/class1": logits[0][1],380          "labels/probability_mean/class0": _sigmoid(logits[0][0]),381          "labels/probability_mean/class1": _sigmoid(logits[0][1]),382      }, model_fn_ops)383  def testMultiLabelWithInvalidLogits(self):384    head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)385    with ops.Graph().as_default(), session.Session():386      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):387        head.create_model_fn_ops(388            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,389            logits=self._logits)390  def testMultiLabelWithLogitsInput(self):391    n_classes = 3392    head = head_lib.multi_label_head(393        n_classes=n_classes, metric_class_ids=range(n_classes))394    with ops.Graph().as_default(), session.Session():395      model_fn_ops = head.create_model_fn_ops(396          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,397          logits_input=((0., 0.),))398      self._assert_output_alternatives(model_fn_ops)399      w = ("multi_label_head/logits/weights:0",400           "multi_label_head/logits/biases:0")401      _assert_variables(402          self, expected_global=w, expected_model=w, expected_trainable=w)403      variables.global_variables_initializer().run()404      _assert_summary_tags(self, ["loss"])405      expected_loss = .69314718406      _assert_metrics(self, expected_loss, {407          "accuracy": 2. / 3,408          "auc": 2. / 4,409          "loss": expected_loss,410          "auc/class0": 1.,411          "auc/class1": 1.,412          "auc/class2": 0.,413          "labels/actual_label_mean/class0": self._labels[0][0],414          "labels/actual_label_mean/class1": self._labels[0][1],415          "labels/actual_label_mean/class2": self._labels[0][2],416          "labels/logits_mean/class0": 0.,417          "labels/logits_mean/class1": 0.,418          "labels/logits_mean/class2": 0.,419          "labels/prediction_mean/class0": 0.,420          "labels/prediction_mean/class1": 0.,421          "labels/prediction_mean/class2": 0.,422          "labels/probability_mean/class0": .5,423          "labels/probability_mean/class1": .5,424          "labels/probability_mean/class2": .5,425      }, model_fn_ops)426  def testMultiLabelWithLogitsAndLogitsInput(self):427    n_classes = 3428    head = head_lib.multi_label_head(429        n_classes=n_classes, metric_class_ids=range(n_classes))430    with ops.Graph().as_default(), session.Session():431      with self.assertRaisesRegexp(432          ValueError, "Both logits and logits_input supplied"):433        head.create_model_fn_ops(434            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,435            logits_input=((0., 0.),), logits=self._logits)436  def testMultiLabelEval(self):437    n_classes = 3438    head = head_lib.multi_label_head(439        n_classes=n_classes, metric_class_ids=range(n_classes))440    with ops.Graph().as_default(), session.Session():441      model_fn_ops = head.create_model_fn_ops(442          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,443          logits=self._logits)444      self._assert_output_alternatives(model_fn_ops)445      self.assertIsNone(model_fn_ops.train_op)446      _assert_no_variables(self)447      _assert_summary_tags(self, ["loss"])448      expected_loss = .89985204449      _assert_metrics(self, expected_loss,450                      self._expected_eval_metrics(expected_loss), model_fn_ops)451  def testMultiClassEvalWithLargeLogits(self):452    n_classes = 3453    head = head_lib.multi_label_head(454        n_classes=n_classes, metric_class_ids=range(n_classes))455    logits = ((2., 0., -1),)456    with ops.Graph().as_default(), session.Session():457      # logloss: z:label, x:logit458      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))459      model_fn_ops = head.create_model_fn_ops(460          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,461          logits=logits)462      self._assert_output_alternatives(model_fn_ops)463      self.assertIsNone(model_fn_ops.train_op)464      _assert_no_variables(self)465      _assert_summary_tags(self, ["loss"])466      expected_loss = 1.377779467      expected_eval_metrics = {468          "accuracy": 1. / 3,469          "auc": 9.99999e-07,470          "loss": expected_loss,471          "auc/class0": 1.,472          "auc/class1": 1.,473          "auc/class2": 0.,474          "labels/actual_label_mean/class0": 0. / 1,475          "labels/actual_label_mean/class1": 0. / 1,476          "labels/actual_label_mean/class2": 1. / 1,477          "labels/logits_mean/class0": logits[0][0],478          "labels/logits_mean/class1": logits[0][1],479          "labels/logits_mean/class2": logits[0][2],480          "labels/prediction_mean/class0": 1,481          "labels/prediction_mean/class1": 0,482          "labels/prediction_mean/class2": 0,483          "labels/probability_mean/class0": _sigmoid(logits[0][0]),484          "labels/probability_mean/class1": _sigmoid(logits[0][1]),485          "labels/probability_mean/class2": _sigmoid(logits[0][2]),486      }487      _assert_metrics(self, expected_loss,488                      expected_eval_metrics, model_fn_ops)489  def testMultiLabelInfer(self):490    n_classes = 3491    head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")492    with ops.Graph().as_default(), session.Session():493      model_fn_ops = head.create_model_fn_ops(494          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,495          logits=((1., 0., 0.), (0., 0., 1)))496      self.assertIsNone(model_fn_ops.train_op)497      _assert_no_variables(self)498      with session.Session():499        self.assertListEqual(500            [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])501        self.assertItemsEqual(502            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))503        self.assertEqual(504            constants.ProblemType.CLASSIFICATION,505            model_fn_ops.output_alternatives["head_name"][0])506        predictions_for_serving = (507            model_fn_ops.output_alternatives["head_name"][1])508        self.assertIn("classes", six.iterkeys(predictions_for_serving))509        self.assertAllEqual(510            [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],511            predictions_for_serving["classes"].eval())512        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))513        self.assertAllClose(514            [[0.731059, 0.5, 0.5],515             [0.5, 0.5, 0.731059,]],516            predictions_for_serving["probabilities"].eval())517  def testMultiLabelWithLabelName(self):518    n_classes = 3519    label_name = "my_label"520    head = head_lib.multi_label_head(521        n_classes=n_classes,522        label_name=label_name,523        metric_class_ids=range(n_classes))524    with ops.Graph().as_default(), session.Session():525      model_fn_ops = head.create_model_fn_ops(526          {}, model_fn.ModeKeys.TRAIN, {label_name: self._labels},527          head_lib.no_op_train_fn, logits=self._logits)528      self._assert_output_alternatives(model_fn_ops)529      _assert_no_variables(self)530      _assert_summary_tags(self, ["loss"])531      expected_loss = .89985204532      _assert_metrics(self, expected_loss,533                      self._expected_eval_metrics(expected_loss), model_fn_ops)534  def testMultiLabelWithScalarWeight(self):535    n_classes = 3536    head = head_lib.multi_label_head(537        n_classes=n_classes,538        weight_column_name="label_weight",539        metric_class_ids=range(n_classes))540    with ops.Graph().as_default(), session.Session():541      model_fn_ops = head.create_model_fn_ops(542          features={"label_weight": .1},543          labels=self._labels,544          mode=model_fn.ModeKeys.TRAIN,545          train_op_fn=head_lib.no_op_train_fn,546          logits=self._logits)547      self._assert_output_alternatives(model_fn_ops)548      _assert_no_variables(self)549      _assert_summary_tags(self, ["loss"])550      _assert_metrics(self, .089985214,551                      self._expected_eval_metrics(.89985214), model_fn_ops)552  def testMultiLabelWith1DWeight(self):553    n_classes = 3554    head = head_lib.multi_label_head(555        n_classes=n_classes,556        weight_column_name="label_weight",557        metric_class_ids=range(n_classes))558    with ops.Graph().as_default(), session.Session():559      with self.assertRaisesRegexp(560          ValueError, "weights can not be broadcast to values"):561        head.create_model_fn_ops(562            features={"label_weight": (.1, .1, .1)},563            labels=self._labels,564            mode=model_fn.ModeKeys.TRAIN,565            train_op_fn=head_lib.no_op_train_fn,566            logits=self._logits)567  def testMultiLabelWith2DWeight(self):568    n_classes = 3569    head = head_lib.multi_label_head(570        n_classes=n_classes,571        weight_column_name="label_weight",572        metric_class_ids=range(n_classes))573    with ops.Graph().as_default(), session.Session():574      model_fn_ops = head.create_model_fn_ops(575          features={"label_weight": ((.1, .1, .1),)},576          labels=self._labels,577          mode=model_fn.ModeKeys.TRAIN,578          train_op_fn=head_lib.no_op_train_fn,579          logits=self._logits)580      self._assert_output_alternatives(model_fn_ops)581      _assert_no_variables(self)582      _assert_summary_tags(self, ["loss"])583      _assert_metrics(self, .089985214,584                      self._expected_eval_metrics(.89985214), model_fn_ops)585  def testMultiLabelWithCustomLoss(self):586    n_classes = 3587    head = head_lib.multi_label_head(588        n_classes=n_classes,589        weight_column_name="label_weight",590        metric_class_ids=range(n_classes),591        loss_fn=_sigmoid_cross_entropy)592    with ops.Graph().as_default(), session.Session():593      model_fn_ops = head.create_model_fn_ops(594          features={"label_weight": .1},595          labels=self._labels,596          mode=model_fn.ModeKeys.TRAIN,597          train_op_fn=head_lib.no_op_train_fn,598          logits=self._logits)599      self._assert_output_alternatives(model_fn_ops)600      _assert_no_variables(self)601      _assert_summary_tags(self, ["loss"])602      expected_loss = .089985214603      _assert_metrics(self, expected_loss,604                      self._expected_eval_metrics(expected_loss), model_fn_ops)605  def testMultiLabelWithCenteredBias(self):606    n_classes = 3607    head = head_lib.multi_label_head(608        n_classes=n_classes,609        enable_centered_bias=True,610        metric_class_ids=range(n_classes))611    with ops.Graph().as_default(), session.Session():612      model_fn_ops = head.create_model_fn_ops(613          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,614          logits=self._logits)615      self._assert_output_alternatives(model_fn_ops)616      _assert_variables(617          self,618          expected_global=(619              "multi_label_head/centered_bias_weight:0",620              ("multi_label_head/multi_label_head/centered_bias_weight/"621               "Adagrad:0"),),622          expected_trainable=("multi_label_head/centered_bias_weight:0",))623      variables.global_variables_initializer().run()624      _assert_summary_tags(self, (625          "loss",626          "multi_label_head/centered_bias/bias_0",627          "multi_label_head/centered_bias/bias_1",628          "multi_label_head/centered_bias/bias_2"629      ))630      expected_loss = .89985204631      _assert_metrics(self, expected_loss,632                      self._expected_eval_metrics(expected_loss), model_fn_ops)633  def testMultiLabelSparseTensorLabels(self):634    n_classes = 3635    head = head_lib.multi_label_head(636        n_classes=n_classes, metric_class_ids=range(n_classes))637    with ops.Graph().as_default(), session.Session():638      labels = sparse_tensor.SparseTensorValue(639          indices=((0, 0),),640          values=(2,),641          dense_shape=(1, 1))642      model_fn_ops = head.create_model_fn_ops(643          features={},644          mode=model_fn.ModeKeys.TRAIN,645          labels=labels,646          train_op_fn=head_lib.no_op_train_fn,647          logits=self._logits)648      _assert_no_variables(self)649      _assert_summary_tags(self, ["loss"])650      expected_loss = .89985204651      _assert_metrics(self, expected_loss,652                      self._expected_eval_metrics(expected_loss), model_fn_ops)653  def testMultiLabelSparseTensorLabelsTooFewClasses(self):654    n_classes = 3655    head = head_lib.multi_label_head(656        n_classes=n_classes, metric_class_ids=range(n_classes))657    # Set _logits_dimension (n_classes) to a lower value; if it's set to 1658    # upfront, the class throws an error during initialization.659    head._logits_dimension = 1660    with ops.Graph().as_default(), session.Session():661      labels = sparse_tensor.SparseTensorValue(662          indices=((0, 0),),663          values=(2,),664          dense_shape=(1, 1))665      with self.assertRaisesRegexp(ValueError,666                                   "Must set num_classes >= 2 when passing"):667        head.create_model_fn_ops(668            features={},669            labels=labels,670            mode=model_fn.ModeKeys.TRAIN,671            train_op_fn=head_lib.no_op_train_fn,672            logits=[0.])673class BinaryClassificationHeadTest(test.TestCase):674  def _assert_output_alternatives(self, model_fn_ops):675    self.assertEquals({676        None: constants.ProblemType.LOGISTIC_REGRESSION677    }, {678        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)679    })680  def setUp(self):681    self._logits = ((1.,), (1.,))682    self._labels = ((1.,), (0.,))683  def _expected_eval_metrics(self, expected_loss):684    label_mean = np.mean(self._labels)685    return {686        "accuracy": 1. / 2,687        "accuracy/baseline_label_mean": label_mean,688        "accuracy/threshold_0.500000_mean": 1. / 2,689        "auc": 1. / 2,690        "auc_precision_recall": 0.749999,691        "labels/actual_label_mean": label_mean,692        "labels/prediction_mean": .731059,  # softmax693        "loss": expected_loss,694        "precision/positive_threshold_0.500000_mean": 1. / 2,695        "recall/positive_threshold_0.500000_mean": 1. / 1,696    }697  def testBinaryClassificationWithLogits(self):698    n_classes = 2699    head = head_lib.multi_class_head(n_classes=n_classes)700    with ops.Graph().as_default(), session.Session():701      # logloss: z:label, x:logit702      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))703      model_fn_ops = head.create_model_fn_ops(704          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,705          logits=self._logits)706      self._assert_output_alternatives(model_fn_ops)707      _assert_no_variables(self)708      _assert_summary_tags(self, ["loss"])709      expected_loss = .81326175710      _assert_metrics(self, expected_loss,711                      self._expected_eval_metrics(expected_loss), model_fn_ops)712  def testBinaryClassificationWithInvalidLogits(self):713    head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)714    with ops.Graph().as_default(), session.Session():715      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):716        head.create_model_fn_ops(717            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,718            logits=self._logits)719  def testBinaryClassificationWithLogitsInput(self):720    n_classes = 2721    head = head_lib.multi_class_head(n_classes=n_classes)722    with ops.Graph().as_default(), session.Session():723      # logloss: z:label, x:logit724      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))725      model_fn_ops = head.create_model_fn_ops(726          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,727          logits_input=((0., 0.), (0., 0.)))728      self._assert_output_alternatives(model_fn_ops)729      w = ("binary_logistic_head/logits/weights:0",730           "binary_logistic_head/logits/biases:0")731      _assert_variables(732          self, expected_global=w, expected_model=w, expected_trainable=w)733      variables.global_variables_initializer().run()734      _assert_summary_tags(self, ["loss"])735      expected_loss = .69314718736      label_mean = np.mean(self._labels)737      _assert_metrics(self, expected_loss, {738          "accuracy": 1. / 2,739          "accuracy/baseline_label_mean": label_mean,740          "accuracy/threshold_0.500000_mean": 1. / 2,741          "auc": 1. / 2,742          "labels/actual_label_mean": label_mean,743          "labels/prediction_mean": .5,  # softmax744          "loss": expected_loss,745          "precision/positive_threshold_0.500000_mean": 0. / 2,746          "recall/positive_threshold_0.500000_mean": 0. / 1,747      }, model_fn_ops)748  def testBinaryClassificationWithLogitsAndLogitsInput(self):749    head = head_lib.multi_class_head(n_classes=2)750    with ops.Graph().as_default(), session.Session():751      with self.assertRaisesRegexp(752          ValueError, "Both logits and logits_input supplied"):753        head.create_model_fn_ops(754            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,755            logits_input=((0., 0.), (0., 0.)), logits=self._logits)756  def testBinaryClassificationEval(self):757    n_classes = 2758    head = head_lib.multi_class_head(n_classes=n_classes)759    with ops.Graph().as_default(), session.Session():760      # logloss: z:label, x:logit761      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))762      model_fn_ops = head.create_model_fn_ops(763          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,764          logits=self._logits)765      self._assert_output_alternatives(model_fn_ops)766      self.assertIsNone(model_fn_ops.train_op)767      _assert_no_variables(self)768      _assert_summary_tags(self, ["loss"])769      expected_loss = .81326175770      _assert_metrics(self, expected_loss,771                      self._expected_eval_metrics(expected_loss), model_fn_ops)772  def testBinaryClassificationInfer(self):773    n_classes = 2774    head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")775    with ops.Graph().as_default(), session.Session():776      # logloss: z:label, x:logit777      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))778      model_fn_ops = head.create_model_fn_ops(779          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,780          logits=self._logits)781      self.assertIsNone(model_fn_ops.train_op)782      _assert_no_variables(self)783      with session.Session():784        self.assertListEqual(785            [1, 1], list(model_fn_ops.predictions["classes"].eval()))786        self.assertItemsEqual(787            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))788        self.assertEqual(789            constants.ProblemType.LOGISTIC_REGRESSION,790            model_fn_ops.output_alternatives["head_name"][0])791        predictions_for_serving = (792            model_fn_ops.output_alternatives["head_name"][1])793        self.assertIn("classes", six.iterkeys(predictions_for_serving))794        predicted_classes = predictions_for_serving["classes"].eval().tolist()795        self.assertListEqual(796            [b"0", b"1"], predicted_classes[0])797        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))798  def testBinaryClassificationInferMode_withWeightColumn(self):799    n_classes = 2800    head = head_lib.multi_class_head(n_classes=n_classes,801                                     weight_column_name="label_weight")802    with ops.Graph().as_default(), session.Session():803      # logloss: z:label, x:logit804      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))805      model_fn_ops = head.create_model_fn_ops(806          # This is what is being tested, features should not have weight for807          # inference.808          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,809          logits=self._logits)810      self._assert_output_alternatives(model_fn_ops)811      self.assertIsNone(model_fn_ops.train_op)812      _assert_no_variables(self)813  def testErrorInSparseTensorLabels(self):814    n_classes = 2815    head = head_lib.multi_class_head(n_classes=n_classes)816    with ops.Graph().as_default():817      labels = sparse_tensor.SparseTensorValue(818          indices=((0, 0), (1, 0), (2, 0)),819          values=(0, 1, 1),820          dense_shape=(3, 1))821      with self.assertRaisesRegexp(ValueError,822                                   "SparseTensor is not supported"):823        head.create_model_fn_ops(824            {},825            model_fn.ModeKeys.TRAIN,826            labels,827            head_lib.no_op_train_fn,828            logits=((1.,), (1.,), (3.,)))829  def testBinaryClassificationWithLabelName(self):830    label_name = "my_label"831    head = head_lib.multi_class_head(n_classes=2, label_name=label_name)832    with ops.Graph().as_default(), session.Session():833      # logloss: z:label, x:logit834      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))835      model_fn_ops = head.create_model_fn_ops(836          {},837          labels={label_name: self._labels},838          mode=model_fn.ModeKeys.TRAIN,839          train_op_fn=head_lib.no_op_train_fn,840          logits=self._logits)841      self._assert_output_alternatives(model_fn_ops)842      _assert_no_variables(self)843      _assert_summary_tags(self, ["loss"])844      expected_loss = .81326175845      _assert_metrics(self, expected_loss,846                      self._expected_eval_metrics(expected_loss), model_fn_ops)847  def testBinaryClassificationWith1DWeights(self):848    n_classes = 2849    head = head_lib.multi_class_head(850        n_classes=n_classes, weight_column_name="label_weight")851    with ops.Graph().as_default(), session.Session():852      weights = (1., 0.)853      # logloss: z:label, x:logit854      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))855      model_fn_ops = head.create_model_fn_ops(856          features={"label_weight": weights},857          labels=self._labels,858          mode=model_fn.ModeKeys.TRAIN,859          train_op_fn=head_lib.no_op_train_fn,860          logits=self._logits)861      self._assert_output_alternatives(model_fn_ops)862      _assert_no_variables(self)863      _assert_summary_tags(self, ["loss"])864      expected_total_loss = .31326166865      _assert_metrics(866          self,867          expected_total_loss / len(weights),868          {869              "accuracy": 1. / 1,870              "accuracy/baseline_label_mean": 1. / 1,871              "accuracy/threshold_0.500000_mean": 1. / 1,872              "auc": 0. / 1,873              "labels/actual_label_mean": 1. / 1,874              "labels/prediction_mean": .731059,  # softmax875              # eval loss is weighted loss divided by sum of weights.876              "loss": expected_total_loss,877              "precision/positive_threshold_0.500000_mean": 1. / 1,878              "recall/positive_threshold_0.500000_mean": 1. / 1,879          },880          model_fn_ops)881  def testBinaryClassificationWith2DWeights(self):882    n_classes = 2883    head = head_lib.multi_class_head(884        n_classes=n_classes, weight_column_name="label_weight")885    with ops.Graph().as_default(), session.Session():886      weights = ((1.,), (0.,))887      # logloss: z:label, x:logit888      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))889      model_fn_ops = head.create_model_fn_ops(890          features={"label_weight": weights},891          labels=self._labels,892          mode=model_fn.ModeKeys.TRAIN,893          train_op_fn=head_lib.no_op_train_fn,894          logits=self._logits)895      self._assert_output_alternatives(model_fn_ops)896      _assert_no_variables(self)897      _assert_summary_tags(self, ["loss"])898      expected_total_loss = .31326166899      _assert_metrics(900          self,901          expected_total_loss / len(weights),902          {903              "accuracy": 1. / 1,904              "accuracy/baseline_label_mean": 1. / 1,905              "accuracy/threshold_0.500000_mean": 1. / 1,906              "auc": 0. / 1,907              "labels/actual_label_mean": 1. / 1,908              "labels/prediction_mean": .731059,  # softmax909              # eval loss is weighted loss divided by sum of weights.910              "loss": expected_total_loss,911              "precision/positive_threshold_0.500000_mean": 1. / 1,912              "recall/positive_threshold_0.500000_mean": 1. / 1,913          },914          model_fn_ops)915  def testBinaryClassificationWithCustomLoss(self):916    head = head_lib.multi_class_head(917        n_classes=2, weight_column_name="label_weight",918        loss_fn=_sigmoid_cross_entropy)919    with ops.Graph().as_default(), session.Session():920      weights = ((.2,), (0.,))921      model_fn_ops = head.create_model_fn_ops(922          features={"label_weight": weights},923          labels=self._labels,924          mode=model_fn.ModeKeys.TRAIN,925          train_op_fn=head_lib.no_op_train_fn,926          logits=self._logits)927      self._assert_output_alternatives(model_fn_ops)928      _assert_no_variables(self)929      _assert_summary_tags(self, ["loss"])930      # logloss: z:label, x:logit931      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))932      # expected_loss is (total_weighted_loss)/1 since there is 1 nonzero933      # weight.934      expected_loss = 0.062652342935      _assert_metrics(936          self,937          expected_loss,938          {939              "accuracy": 1. / 1,940              "accuracy/baseline_label_mean": 1. / 1,941              "accuracy/threshold_0.500000_mean": 1. / 1,942              "auc": 0. / 1,943              "labels/actual_label_mean": 1. / 1,944              "labels/prediction_mean": .731059,  # softmax945              "loss": expected_loss,946              "precision/positive_threshold_0.500000_mean": 1. / 1,947              "recall/positive_threshold_0.500000_mean": 1. / 1,948          },949          model_fn_ops)950  def testBinaryClassificationWithCenteredBias(self):951    head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)952    with ops.Graph().as_default(), session.Session():953      # logloss: z:label, x:logit954      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))955      model_fn_ops = head.create_model_fn_ops(956          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,957          logits=self._logits)958      self._assert_output_alternatives(model_fn_ops)959      _assert_variables(960          self,961          expected_global=(962              "binary_logistic_head/centered_bias_weight:0",963              ("binary_logistic_head/binary_logistic_head/centered_bias_weight/"964               "Adagrad:0"),),965          expected_trainable=("binary_logistic_head/centered_bias_weight:0",))966      variables.global_variables_initializer().run()967      _assert_summary_tags(self, [968          "loss",969          "binary_logistic_head/centered_bias/bias_0"970      ])971      expected_loss = .81326175972      _assert_metrics(self, expected_loss,973                      self._expected_eval_metrics(expected_loss), model_fn_ops)974class MultiClassHeadTest(test.TestCase):975  def _assert_output_alternatives(self, model_fn_ops):976    self.assertEquals({977        None: constants.ProblemType.CLASSIFICATION978    }, {979        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)980    })981  def setUp(self):982    self._logits = ((1., 0., 0.),)983    self._labels = ((2,),)984  def _expected_eval_metrics(self, expected_loss):985    return {986        "accuracy": 0.,987        "loss": expected_loss,988        "labels/actual_label_mean/class0": 0. / 1,989        "labels/actual_label_mean/class1": 0. / 1,990        "labels/actual_label_mean/class2": 1. / 1,991        "labels/logits_mean/class0": self._logits[0][0],992        "labels/logits_mean/class1": self._logits[0][1],993        "labels/logits_mean/class2": self._logits[0][2],994        "labels/prediction_mean/class0": self._logits[0][0],995        "labels/prediction_mean/class1": self._logits[0][1],996        "labels/prediction_mean/class2": self._logits[0][2],997        "labels/probability_mean/class0": 0.576117,  # softmax998        "labels/probability_mean/class1": 0.211942,  # softmax999        "labels/probability_mean/class2": 0.211942,  # softmax1000    }1001  def testMultiClassWithLogits(self):1002    n_classes = 31003    head = head_lib.multi_class_head(1004        n_classes=n_classes, metric_class_ids=range(n_classes))1005    with ops.Graph().as_default(), session.Session():1006      # logloss: z:label, x:logit1007      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1008      model_fn_ops = head.create_model_fn_ops(1009          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1010          logits=self._logits)1011      self._assert_output_alternatives(model_fn_ops)1012      _assert_no_variables(self)1013      _assert_summary_tags(self, ["loss"])1014      expected_loss = 1.55144471015      _assert_metrics(self, expected_loss,1016                      self._expected_eval_metrics(expected_loss), model_fn_ops)1017  def testMultiClassWithInvalidLogits(self):1018    head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)1019    with ops.Graph().as_default(), session.Session():1020      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1021        head.create_model_fn_ops(1022            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1023            logits=self._logits)1024  def testMultiClassWithNoneTrainOpFnInTrain(self):1025    head = head_lib.multi_class_head(n_classes=3)1026    with ops.Graph().as_default(), session.Session():1027      with self.assertRaisesRegexp(1028          ValueError, "train_op_fn can not be None in TRAIN mode"):1029        head.create_model_fn_ops(1030            {}, model_fn.ModeKeys.TRAIN, self._labels,1031            train_op_fn=None,1032            logits=self._logits)1033  def testMultiClassWithLogitsInput(self):1034    n_classes = 31035    head = head_lib.multi_class_head(1036        n_classes=n_classes, metric_class_ids=range(n_classes))1037    with ops.Graph().as_default(), session.Session():1038      # logloss: z:label, x:logit1039      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1040      model_fn_ops = head.create_model_fn_ops(1041          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1042          logits_input=((0., 0.),))1043      self._assert_output_alternatives(model_fn_ops)1044      w = ("multi_class_head/logits/weights:0",1045           "multi_class_head/logits/biases:0")1046      _assert_variables(1047          self, expected_global=w, expected_model=w, expected_trainable=w)1048      variables.global_variables_initializer().run()1049      _assert_summary_tags(self, ["loss"])1050      expected_loss = 1.09861231051      _assert_metrics(self, expected_loss, {1052          "accuracy": 0.,1053          "loss": expected_loss,1054          "labels/actual_label_mean/class0": 0. / 1,1055          "labels/actual_label_mean/class1": 0. / 1,1056          "labels/actual_label_mean/class2": 1. / 1,1057          "labels/logits_mean/class0": 0.,1058          "labels/logits_mean/class1": 0.,1059          "labels/logits_mean/class2": 0.,1060          "labels/prediction_mean/class0": 1.,1061          "labels/prediction_mean/class1": 0.,1062          "labels/prediction_mean/class2": 0.,1063          "labels/probability_mean/class0": 0.333333,  # softmax1064          "labels/probability_mean/class1": 0.333333,  # softmax1065          "labels/probability_mean/class2": 0.333333,  # softmax1066      }, model_fn_ops)1067  def testMultiClassWithLogitsAndLogitsInput(self):1068    n_classes = 31069    head = head_lib.multi_class_head(1070        n_classes=n_classes, metric_class_ids=range(n_classes))1071    with ops.Graph().as_default(), session.Session():1072      with self.assertRaisesRegexp(1073          ValueError, "Both logits and logits_input supplied"):1074        head.create_model_fn_ops(1075            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1076            logits_input=((0., 0.),), logits=self._logits)1077  def testMultiClassEnableCenteredBias(self):1078    n_classes = 31079    head = head_lib.multi_class_head(1080        n_classes=n_classes, enable_centered_bias=True)1081    with ops.Graph().as_default(), session.Session():1082      # logloss: z:label, x:logit1083      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1084      model_fn_ops = head.create_model_fn_ops(1085          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1086          logits=self._logits)1087      self._assert_output_alternatives(model_fn_ops)1088      _assert_variables(1089          self,1090          expected_global=(1091              "multi_class_head/centered_bias_weight:0",1092              ("multi_class_head/multi_class_head/centered_bias_weight/"1093               "Adagrad:0"),1094          ),1095          expected_trainable=("multi_class_head/centered_bias_weight:0",))1096      variables.global_variables_initializer().run()1097      _assert_summary_tags(self,1098                           ["loss",1099                            "multi_class_head/centered_bias/bias_0",1100                            "multi_class_head/centered_bias/bias_1",1101                            "multi_class_head/centered_bias/bias_2"])1102  def testMultiClassEval(self):1103    n_classes = 31104    head = head_lib.multi_class_head(1105        n_classes=n_classes, metric_class_ids=range(n_classes))1106    with ops.Graph().as_default(), session.Session():1107      # logloss: z:label, x:logit1108      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1109      model_fn_ops = head.create_model_fn_ops(1110          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1111          logits=self._logits)1112      self._assert_output_alternatives(model_fn_ops)1113      self.assertIsNone(model_fn_ops.train_op)1114      _assert_no_variables(self)1115      _assert_summary_tags(self, ["loss"])1116      expected_loss = 1.55144471117      _assert_metrics(self, expected_loss,1118                      self._expected_eval_metrics(expected_loss), model_fn_ops)1119  def testMultiClassEvalModeWithLargeLogits(self):1120    n_classes = 31121    head = head_lib.multi_class_head(1122        n_classes=n_classes, metric_class_ids=range(n_classes))1123    logits = ((2., 0., -1),)1124    with ops.Graph().as_default(), session.Session():1125      # logloss: z:label, x:logit1126      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1127      model_fn_ops = head.create_model_fn_ops(1128          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1129          logits=logits)1130      self._assert_output_alternatives(model_fn_ops)1131      self.assertIsNone(model_fn_ops.train_op)1132      _assert_no_variables(self)1133      _assert_summary_tags(self, ["loss"])1134      expected_loss = 3.16984611135      expected_eval_metrics = {1136          "accuracy": 0.,1137          "loss": expected_loss,1138          "labels/actual_label_mean/class0": 0. / 1,1139          "labels/actual_label_mean/class1": 0. / 1,1140          "labels/actual_label_mean/class2": 1. / 1,1141          "labels/logits_mean/class0": logits[0][0],1142          "labels/logits_mean/class1": logits[0][1],1143          "labels/logits_mean/class2": logits[0][2],1144          "labels/prediction_mean/class0": 1,1145          "labels/prediction_mean/class1": 0,1146          "labels/prediction_mean/class2": 0,1147          "labels/probability_mean/class0": 0.843795,  # softmax1148          "labels/probability_mean/class1": 0.114195,  # softmax1149          "labels/probability_mean/class2": 0.0420101,  # softmax1150      }1151      _assert_metrics(self, expected_loss,1152                      expected_eval_metrics, model_fn_ops)1153  def testMultiClassWithScalarWeight(self):1154    n_classes = 31155    head = head_lib.multi_class_head(1156        n_classes=n_classes,1157        weight_column_name="label_weight",1158        metric_class_ids=range(n_classes))1159    with ops.Graph().as_default(), session.Session():1160      weight = .11161      # logloss: z:label, x:logit1162      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1163      model_fn_ops = head.create_model_fn_ops(1164          features={"label_weight": weight},1165          labels=self._labels,1166          mode=model_fn.ModeKeys.TRAIN,1167          train_op_fn=head_lib.no_op_train_fn,1168          logits=self._logits)1169      self._assert_output_alternatives(model_fn_ops)1170      _assert_no_variables(self)1171      _assert_summary_tags(self, ["loss"])1172      expected_loss = 1.55144471173      _assert_metrics(self, expected_loss * weight,1174                      self._expected_eval_metrics(expected_loss), model_fn_ops)1175  def testMultiClassWith1DWeight(self):1176    n_classes = 31177    head = head_lib.multi_class_head(1178        n_classes=n_classes,1179        weight_column_name="label_weight",1180        metric_class_ids=range(n_classes))1181    with ops.Graph().as_default(), session.Session():1182      weight = .11183      weights = (weight,)1184      # logloss: z:label, x:logit1185      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1186      model_fn_ops = head.create_model_fn_ops(1187          features={"label_weight": weights},1188          labels=self._labels,1189          mode=model_fn.ModeKeys.TRAIN,1190          train_op_fn=head_lib.no_op_train_fn,1191          logits=self._logits)1192      self._assert_output_alternatives(model_fn_ops)1193      _assert_no_variables(self)1194      _assert_summary_tags(self, ["loss"])1195      expected_loss = 1.55144471196      _assert_metrics(self, expected_loss * weight,1197                      self._expected_eval_metrics(expected_loss), model_fn_ops)1198  def testMultiClassWith2DWeight(self):1199    n_classes = 31200    head = head_lib.multi_class_head(1201        n_classes=n_classes,1202        weight_column_name="label_weight",1203        metric_class_ids=range(n_classes))1204    with ops.Graph().as_default(), session.Session():1205      weight = .11206      weights = ((weight,),)1207      # logloss: z:label, x:logit1208      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1209      model_fn_ops = head.create_model_fn_ops(1210          features={"label_weight": weights},1211          labels=self._labels,1212          mode=model_fn.ModeKeys.TRAIN,1213          train_op_fn=head_lib.no_op_train_fn,1214          logits=self._logits)1215      self._assert_output_alternatives(model_fn_ops)1216      _assert_no_variables(self)1217      _assert_summary_tags(self, ["loss"])1218      expected_loss = 1.55144471219      _assert_metrics(self, expected_loss * weight,1220                      self._expected_eval_metrics(expected_loss), model_fn_ops)1221  def testMultiClassWithCustomLoss(self):1222    n_classes = 31223    head = head_lib.multi_class_head(1224        n_classes=n_classes,1225        weight_column_name="label_weight",1226        metric_class_ids=range(n_classes),1227        loss_fn=losses_lib.sparse_softmax_cross_entropy)1228    with ops.Graph().as_default(), session.Session():1229      weight = .11230      # logloss: z:label, x:logit1231      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1232      model_fn_ops = head.create_model_fn_ops(1233          features={"label_weight": weight},1234          labels=self._labels,1235          mode=model_fn.ModeKeys.TRAIN,1236          train_op_fn=head_lib.no_op_train_fn,1237          logits=self._logits)1238      self._assert_output_alternatives(model_fn_ops)1239      _assert_no_variables(self)1240      _assert_summary_tags(self, ["loss"])1241      expected_loss = 1.5514447 * weight1242      _assert_metrics(self, expected_loss,1243                      self._expected_eval_metrics(expected_loss), model_fn_ops)1244  def testMultiClassInfer(self):1245    n_classes = 31246    head = head_lib._multi_class_head(1247        n_classes=n_classes,1248        head_name="head_name")1249    with ops.Graph().as_default():1250      model_fn_ops = head.create_model_fn_ops(1251          features={},1252          mode=model_fn.ModeKeys.INFER,1253          train_op_fn=head_lib.no_op_train_fn,1254          logits=((1., 0., 0.), (0., 0., 1.),))1255      with session.Session():1256        lookup_ops.tables_initializer().run()1257        self.assertAllEqual(1258            [0, 2],1259            model_fn_ops.predictions["classes"].eval())1260        self.assertItemsEqual(1261            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1262        self.assertEqual(1263            constants.ProblemType.CLASSIFICATION,1264            model_fn_ops.output_alternatives["head_name"][0])1265        predictions_for_serving = (1266            model_fn_ops.output_alternatives["head_name"][1])1267        self.assertIn("classes", six.iterkeys(predictions_for_serving))1268        self.assertAllEqual(1269            [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],1270            predictions_for_serving["classes"].eval())1271        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1272        self.assertAllClose(1273            [[0.576117, 0.2119416, 0.2119416],1274             [0.2119416, 0.2119416, 0.576117]],1275            predictions_for_serving["probabilities"].eval())1276  def testInvalidNClasses(self):1277    for n_classes in (None, -1, 0, 1):1278      with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):1279        head_lib.multi_class_head(n_classes=n_classes)1280  def testMultiClassWithLabelKeysInvalidShape(self):1281    with self.assertRaisesRegexp(1282        ValueError, "Length of label_keys must equal n_classes"):1283      head_lib._multi_class_head(1284          n_classes=3, label_keys=("key0", "key1"))1285  def testMultiClassWithLabelKeysTwoClasses(self):1286    with self.assertRaisesRegexp(1287        ValueError, "label_keys is not supported for n_classes=2"):1288      head_lib._multi_class_head(1289          n_classes=2, label_keys=("key0", "key1"))1290  def testMultiClassWithLabelKeysInfer(self):1291    n_classes = 31292    label_keys = ("key0", "key1", "key2")1293    head = head_lib._multi_class_head(1294        n_classes=n_classes, label_keys=label_keys,1295        metric_class_ids=range(n_classes),1296        head_name="head_name")1297    with ops.Graph().as_default():1298      model_fn_ops = head.create_model_fn_ops(1299          features={},1300          mode=model_fn.ModeKeys.INFER,1301          train_op_fn=head_lib.no_op_train_fn,1302          logits=((1., 0., 0.), (0., 0., 1.),))1303      with session.Session():1304        lookup_ops.tables_initializer().run()1305        self.assertAllEqual(1306            [b"key0", b"key2"],1307            model_fn_ops.predictions["classes"].eval())1308        self.assertItemsEqual(1309            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1310        self.assertEqual(1311            constants.ProblemType.CLASSIFICATION,1312            model_fn_ops.output_alternatives["head_name"][0])1313        predictions_for_serving = (1314            model_fn_ops.output_alternatives["head_name"][1])1315        self.assertIn("classes", six.iterkeys(predictions_for_serving))1316        self.assertAllEqual(1317            [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]],1318            predictions_for_serving["classes"].eval())1319        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1320        self.assertAllClose(1321            [[0.576117, 0.2119416, 0.2119416],1322             [0.2119416, 0.2119416, 0.576117]],1323            predictions_for_serving["probabilities"].eval())1324  def testMultiClassWithLabelKeysEvalAccuracy0(self):1325    n_classes = 31326    label_keys = ("key0", "key1", "key2")1327    head = head_lib._multi_class_head(1328        n_classes=n_classes,1329        label_keys=label_keys)1330    with ops.Graph().as_default():1331      model_fn_ops = head.create_model_fn_ops(1332          features={},1333          mode=model_fn.ModeKeys.EVAL,1334          labels=("key2",),1335          train_op_fn=head_lib.no_op_train_fn,1336          logits=((1., 0., 0.),))1337      with session.Session():1338        lookup_ops.tables_initializer().run()1339        self.assertIsNone(model_fn_ops.train_op)1340        _assert_no_variables(self)1341        _assert_summary_tags(self, ["loss"])1342        expected_loss = 1.55144471343        expected_eval_metrics = {1344            "accuracy": 0.,1345            "loss": expected_loss,1346        }1347        _assert_metrics(self, expected_loss,1348                        expected_eval_metrics, model_fn_ops)1349  def testMultiClassWithLabelKeysEvalAccuracy1(self):1350    n_classes = 31351    label_keys = ("key0", "key1", "key2")1352    head = head_lib._multi_class_head(1353        n_classes=n_classes,1354        label_keys=label_keys)1355    with ops.Graph().as_default():1356      model_fn_ops = head.create_model_fn_ops(1357          features={},1358          mode=model_fn.ModeKeys.EVAL,1359          labels=("key2",),1360          train_op_fn=head_lib.no_op_train_fn,1361          logits=((0., 0., 1.),))1362      with session.Session():1363        lookup_ops.tables_initializer().run()1364        self.assertIsNone(model_fn_ops.train_op)1365        _assert_no_variables(self)1366        _assert_summary_tags(self, ["loss"])1367        expected_loss = 0.55144471368        expected_eval_metrics = {1369            "accuracy": 1.,1370            "loss": expected_loss,1371        }1372        _assert_metrics(self, expected_loss,1373                        expected_eval_metrics, model_fn_ops)1374class BinarySvmHeadTest(test.TestCase):1375  def _assert_output_alternatives(self, model_fn_ops):1376    self.assertEquals({1377        None: constants.ProblemType.LOGISTIC_REGRESSION1378    }, {1379        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1380    })1381  def setUp(self):1382    # Prediction for first example is in the right side of the hyperplane1383    # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss1384    # incurred by this example. The 2nd prediction is outside the margin so it1385    # incurs no loss at all.1386    self._predictions = ((-.5,), (1.2,))1387    self._labels = (0, 1)1388    self._expected_losses = (.5, 0.)1389  def testBinarySVMWithLogits(self):1390    head = head_lib.binary_svm_head()1391    with ops.Graph().as_default(), session.Session():1392      model_fn_ops = head.create_model_fn_ops(1393          {},1394          model_fn.ModeKeys.TRAIN,1395          self._labels,1396          head_lib.no_op_train_fn,1397          logits=self._predictions)1398      self._assert_output_alternatives(model_fn_ops)1399      _assert_no_variables(self)1400      _assert_summary_tags(self, ["loss"])1401      expected_loss = np.average(self._expected_losses)1402      _assert_metrics(self, expected_loss, {1403          "accuracy": 1.,1404          "loss": expected_loss,1405      }, model_fn_ops)1406  def testBinarySVMWithInvalidLogits(self):1407    head = head_lib.binary_svm_head()1408    with ops.Graph().as_default(), session.Session():1409      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1410        head.create_model_fn_ops(1411            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1412            logits=np.ones((2, 2)))1413  def testBinarySVMWithLogitsInput(self):1414    head = head_lib.binary_svm_head()1415    with ops.Graph().as_default(), session.Session():1416      model_fn_ops = head.create_model_fn_ops(1417          {},1418          model_fn.ModeKeys.TRAIN,1419          self._labels,1420          head_lib.no_op_train_fn,1421          logits_input=((0., 0.), (0., 0.)))1422      self._assert_output_alternatives(model_fn_ops)1423      w = ("binary_svm_head/logits/weights:0",1424           "binary_svm_head/logits/biases:0")1425      _assert_variables(1426          self, expected_global=w, expected_model=w, expected_trainable=w)1427      variables.global_variables_initializer().run()1428      _assert_summary_tags(self, ["loss"])1429      expected_loss = 1.1430      _assert_metrics(self, expected_loss, {1431          "accuracy": .5,1432          "loss": expected_loss,1433      }, model_fn_ops)1434  def testBinarySVMWithLogitsAndLogitsInput(self):1435    head = head_lib.binary_svm_head()1436    with ops.Graph().as_default(), session.Session():1437      with self.assertRaisesRegexp(1438          ValueError, "Both logits and logits_input supplied"):1439        head.create_model_fn_ops(1440            {},1441            model_fn.ModeKeys.TRAIN,1442            self._labels,1443            head_lib.no_op_train_fn,1444            logits_input=((0., 0.), (0., 0.)),1445            logits=self._predictions)1446  def testBinarySVMEvalMode(self):1447    head = head_lib.binary_svm_head()1448    with ops.Graph().as_default(), session.Session():1449      model_fn_ops = head.create_model_fn_ops(1450          {},1451          model_fn.ModeKeys.EVAL,1452          self._labels,1453          head_lib.no_op_train_fn,1454          logits=self._predictions)1455      self._assert_output_alternatives(model_fn_ops)1456      self.assertIsNone(model_fn_ops.train_op)1457      _assert_no_variables(self)1458      _assert_summary_tags(self, ["loss"])1459      expected_loss = np.average(self._expected_losses)1460      _assert_metrics(self, expected_loss, {1461          "accuracy": 1.,1462          "loss": expected_loss,1463      }, model_fn_ops)1464  def testBinarySVMWithLabelName(self):1465    label_name = "my_label"1466    head = head_lib.binary_svm_head(label_name=label_name)1467    with ops.Graph().as_default(), session.Session():1468      model_fn_ops = head.create_model_fn_ops(1469          {},1470          model_fn.ModeKeys.TRAIN,1471          {label_name: self._labels},1472          head_lib.no_op_train_fn,1473          logits=self._predictions)1474      self._assert_output_alternatives(model_fn_ops)1475      _assert_no_variables(self)1476      _assert_summary_tags(self, ["loss"])1477      expected_loss = np.average(self._expected_losses)1478      _assert_metrics(self, expected_loss, {1479          "accuracy": 1.,1480          "loss": expected_loss,1481      }, model_fn_ops)1482  def testBinarySVMWith1DWeights(self):1483    head = head_lib.binary_svm_head(weight_column_name="weights")1484    with ops.Graph().as_default(), session.Session():1485      weights = (7., 11.)1486      model_fn_ops = head.create_model_fn_ops(1487          # We have to add an extra dim here for weights broadcasting to work.1488          features={"weights": weights},1489          mode=model_fn.ModeKeys.TRAIN,1490          labels=self._labels,1491          train_op_fn=head_lib.no_op_train_fn,1492          logits=self._predictions)1493      self._assert_output_alternatives(model_fn_ops)1494      _assert_no_variables(self)1495      _assert_summary_tags(self, ["loss"])1496      expected_weighted_losses = np.multiply(weights, self._expected_losses)1497      _assert_metrics(self, np.mean(expected_weighted_losses), {1498          "accuracy": 1.,1499          "loss": np.sum(expected_weighted_losses) / np.sum(weights),1500      }, model_fn_ops)1501  def testBinarySVMWith2DWeights(self):1502    head = head_lib.binary_svm_head(weight_column_name="weights")1503    with ops.Graph().as_default(), session.Session():1504      weights = (7., 11.)1505      model_fn_ops = head.create_model_fn_ops(1506          # We have to add an extra dim here for weights broadcasting to work.1507          features={"weights": tuple([(w,) for w in weights])},1508          mode=model_fn.ModeKeys.TRAIN,1509          labels=self._labels,1510          train_op_fn=head_lib.no_op_train_fn,1511          logits=self._predictions)1512      self._assert_output_alternatives(model_fn_ops)1513      _assert_no_variables(self)1514      _assert_summary_tags(self, ["loss"])1515      expected_weighted_losses = np.multiply(weights, self._expected_losses)1516      _assert_metrics(self, np.mean(expected_weighted_losses), {1517          "accuracy": 1.,1518          "loss": np.sum(expected_weighted_losses) / np.sum(weights),1519      }, model_fn_ops)1520  def testBinarySVMWithCenteredBias(self):1521    head = head_lib.binary_svm_head(enable_centered_bias=True)1522    with ops.Graph().as_default(), session.Session():1523      model_fn_ops = head.create_model_fn_ops(1524          {},1525          model_fn.ModeKeys.TRAIN,1526          self._labels,1527          head_lib.no_op_train_fn,1528          logits=self._predictions)1529      self._assert_output_alternatives(model_fn_ops)1530      _assert_variables(1531          self,1532          expected_global=(1533              "binary_svm_head/centered_bias_weight:0",1534              ("binary_svm_head/binary_svm_head/centered_bias_weight/"1535               "Adagrad:0"),1536          ),1537          expected_trainable=("binary_svm_head/centered_bias_weight:0",))1538      variables.global_variables_initializer().run()1539      _assert_summary_tags(self, [1540          "loss",1541          "binary_svm_head/centered_bias/bias_0"1542      ])1543      expected_loss = np.average(self._expected_losses)1544      _assert_metrics(self, expected_loss, {1545          "accuracy": 1.,1546          "loss": expected_loss,1547      }, model_fn_ops)1548class LossOnlyHead(test.TestCase):1549  def testNoPredictionsAndNoMetrics(self):1550    head = head_lib.loss_only_head(lambda: 1, head_name="const")1551    model_fn_ops = head.create_model_fn_ops(1552        features={},1553        mode=model_fn.ModeKeys.TRAIN,1554        train_op_fn=head_lib.no_op_train_fn)1555    self.assertDictEqual(model_fn_ops.predictions, {})1556    self.assertDictEqual(model_fn_ops.eval_metric_ops, {})1557    self.assertIsNotNone(model_fn_ops.loss)1558    with session.Session() as sess:1559      self.assertEqual(1, sess.run(model_fn_ops.loss))1560class MultiHeadTest(test.TestCase):1561  def testInvalidHeads(self):1562    named_head = head_lib.multi_class_head(1563        n_classes=3, label_name="label", head_name="head1")1564    unnamed_head = head_lib.multi_class_head(1565        n_classes=4, label_name="label")1566    with self.assertRaisesRegexp(ValueError, "must have names"):1567      head_lib.multi_head((named_head, unnamed_head))1568  def testTrainWithNoneTrainOpFn(self):1569    head1 = head_lib.multi_class_head(1570        n_classes=3, label_name="label1", head_name="head1")1571    head2 = head_lib.multi_class_head(1572        n_classes=4, label_name="label2", head_name="head2")1573    head = head_lib.multi_head((head1, head2))1574    labels = {1575        "label1": (1,),1576        "label2": (1,)1577    }1578    with self.assertRaisesRegexp(1579        ValueError, "train_op_fn can not be None in TRAIN mode"):1580      head.create_model_fn_ops(1581          features={"weights": (2.0, 10.0)},1582          labels=labels,1583          mode=model_fn.ModeKeys.TRAIN,1584          train_op_fn=None,1585          logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1586  def testTrain_withNoHeadWeights(self):1587    head1 = head_lib.multi_class_head(1588        n_classes=3, label_name="label1", head_name="head1")1589    head2 = head_lib.multi_class_head(1590        n_classes=4, label_name="label2", head_name="head2")1591    head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")1592    head = head_lib.multi_head((head1, head2, head3))1593    labels = {1594        "label1": (1,),1595        "label2": (1,)1596    }1597    model_fn_ops = head.create_model_fn_ops(1598        features={"weights": (2.0, 10.0)},1599        labels=labels,1600        mode=model_fn.ModeKeys.TRAIN,1601        train_op_fn=head_lib.no_op_train_fn,1602        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1603    self.assertIsNone(model_fn_ops.predictions)1604    self.assertIsNotNone(model_fn_ops.loss)1605    self.assertIsNotNone(model_fn_ops.train_op)1606    self.assertTrue(model_fn_ops.eval_metric_ops)1607    self.assertIsNone(model_fn_ops.output_alternatives)1608    with session.Session() as sess:1609      self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)1610  def testTrain_withHeadWeights(self):1611    head1 = head_lib.multi_class_head(1612        n_classes=3, label_name="label1", head_name="head1")1613    head2 = head_lib.multi_class_head(1614        n_classes=4, label_name="label2", head_name="head2")1615    head = head_lib.multi_head((head1, head2), (1, .5))1616    labels = {1617        "label1": (1,),1618        "label2": (1,)1619    }1620    model_fn_ops = head.create_model_fn_ops(1621        features={"weights": (2.0, 10.0)},1622        labels=labels,1623        mode=model_fn.ModeKeys.TRAIN,1624        train_op_fn=head_lib.no_op_train_fn,1625        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1626    self.assertIsNone(model_fn_ops.predictions)1627    self.assertIsNotNone(model_fn_ops.loss)1628    self.assertIsNotNone(model_fn_ops.train_op)1629    self.assertTrue(model_fn_ops.eval_metric_ops)1630    self.assertIsNone(model_fn_ops.output_alternatives)1631    with session.Session() as sess:1632      self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)1633  def testTrain_withDictLogits(self):1634    head1 = head_lib.multi_class_head(1635        n_classes=3, label_name="label1", head_name="head1")1636    head2 = head_lib.multi_class_head(1637        n_classes=4, label_name="label2", head_name="head2")1638    head = head_lib.multi_head((head1, head2))1639    labels = {1640        "label1": (1,),1641        "label2": (1,)1642    }1643    model_fn_ops = head.create_model_fn_ops(1644        features={"weights": (2.0, 10.0)},1645        labels=labels,1646        mode=model_fn.ModeKeys.TRAIN,1647        train_op_fn=head_lib.no_op_train_fn,1648        logits={head1.head_name: ((-0.7, 0.2, .1),),1649                head2.head_name: ((.1, .1, .1, .1),)})1650    self.assertIsNone(model_fn_ops.predictions)1651    self.assertIsNotNone(model_fn_ops.loss)1652    self.assertIsNotNone(model_fn_ops.train_op)1653    self.assertTrue(model_fn_ops.eval_metric_ops)1654    self.assertIsNone(model_fn_ops.output_alternatives)1655    with session.Session() as sess:1656      self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)1657  def testInfer(self):1658    head1 = head_lib.multi_class_head(1659        n_classes=3, label_name="label1", head_name="head1")1660    head2 = head_lib.multi_class_head(1661        n_classes=4, label_name="label2", head_name="head2")1662    head = head_lib.multi_head((head1, head2), (1, .5))1663    labels = {1664        "label1": (1,),1665        "label2": (1,)1666    }1667    model_fn_ops = head.create_model_fn_ops(1668        features={"weights": (2.0, 10.0)},1669        labels=labels,1670        mode=model_fn.ModeKeys.INFER,1671        train_op_fn=head_lib.no_op_train_fn,1672        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1673    self.assertIsNotNone(model_fn_ops.predictions)1674    self.assertIsNone(model_fn_ops.loss)1675    self.assertIsNone(model_fn_ops.train_op)1676    self.assertFalse(model_fn_ops.eval_metric_ops)1677    # Tests predictions keys.1678    self.assertItemsEqual((1679        ("head1", prediction_key.PredictionKey.LOGITS),1680        ("head1", prediction_key.PredictionKey.PROBABILITIES),1681        ("head1", prediction_key.PredictionKey.CLASSES),1682        ("head2", prediction_key.PredictionKey.LOGITS),1683        ("head2", prediction_key.PredictionKey.PROBABILITIES),1684        ("head2", prediction_key.PredictionKey.CLASSES),1685    ), model_fn_ops.predictions.keys())1686    # Tests output alternative.1687    self.assertEquals({1688        "head1": constants.ProblemType.CLASSIFICATION,1689        "head2": constants.ProblemType.CLASSIFICATION,1690    }, {1691        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1692    })1693    self.assertItemsEqual((1694        prediction_key.PredictionKey.PROBABILITIES,1695        prediction_key.PredictionKey.CLASSES,1696    ), model_fn_ops.output_alternatives["head1"][1].keys())1697    self.assertItemsEqual((1698        prediction_key.PredictionKey.PROBABILITIES,1699        prediction_key.PredictionKey.CLASSES,1700    ), model_fn_ops.output_alternatives["head2"][1].keys())1701  def testEval(self):1702    head1 = head_lib.multi_class_head(1703        n_classes=3, label_name="label1", head_name="head1")1704    head2 = head_lib.multi_class_head(1705        n_classes=4, label_name="label2", head_name="head2")1706    head = head_lib.multi_head((head1, head2), (1, .5))1707    labels = {1708        "label1": (1,),1709        "label2": (1,)1710    }1711    model_fn_ops = head.create_model_fn_ops(1712        features={"weights": (2.0, 10.0)},1713        labels=labels,1714        mode=model_fn.ModeKeys.EVAL,1715        train_op_fn=head_lib.no_op_train_fn,1716        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1717    self.assertIsNotNone(model_fn_ops.predictions)1718    self.assertIsNotNone(model_fn_ops.loss)1719    self.assertIsNone(model_fn_ops.train_op)1720    self.assertIsNotNone(model_fn_ops.eval_metric_ops)...multi_head.py
Source:multi_head.py  
...11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Abstractions for the head(s) of a model.16"""17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20import six21from tensorflow.python.estimator import model_fn22from tensorflow.python.estimator.canned import head as head_lib23from tensorflow.python.estimator.canned import metric_keys24from tensorflow.python.estimator.export import export_output as export_output_lib25from tensorflow.python.framework import ops26from tensorflow.python.ops import array_ops27from tensorflow.python.ops import control_flow_ops28from tensorflow.python.ops import math_ops29from tensorflow.python.ops import metrics as metrics_lib30from tensorflow.python.saved_model import signature_constants31from tensorflow.python.summary import summary32from tensorflow.python.training import training_util33_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY34def multi_head(heads, head_weights=None):35  """Creates a `_Head` for multi-objective learning.36  This class merges the output of multiple `_Head` objects.37  Specifically:38  * For training, sums losses of each head, calls `train_op_fn` with this39    final loss.40  * For eval, merges metrics by adding `head.name` suffix to the keys in eval41    metrics, such as `precision/head1`, `precision/head2`.42  * For prediction, merges predictions and updates keys in prediction dict to a43    2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that44    by default the first head is served.45  Usage:46  ```python47  # In `input_fn` specify labels as a dict keyed by head name:48  def input_fn():49    features = ...50    labels1 = ...51    labels2 = ...52    return features, {'head1': labels1, 'head2': labels2}53  # In `model_fn`, specify logits as a dict keyed by head name:54  def model_fn(features, labels, mode):55    # Create simple heads and specify head name.56    head1 = multi_class_head(n_classes=3, name='head1')57    head2 = binary_classification_head(name='head2')58    # Create multi-head from two simple heads.59    head = multi_head([head1, head2])60    # Create logits for each head, and combine them into a dict.61    logits1, logits2 = logit_fn()62    logits = {'head1': logits1, 'head2': logits2}63    # Return the merged EstimatorSpec64    return head.create_estimator_spec(..., logits=logits, ...)65  # Create an estimator with this model_fn.66  estimator = tf.estimator.Estimator(model_fn=model_fn)67  estimator.train(input_fn=input_fn, steps=100)68  ```69  Also supports `logits` as a `Tensor` of shape70  `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the71  last dimension and distribute it appropriately among the heads. E.g.:72  ```python73  def model_fn(features, labels, mode):74    # Create simple heads and specify head name.75    head1 = multi_class_head(n_classes=3, name='head1')76    head2 = binary_classification_head(name='head2')77    # Create multi-head from two simple heads.78    head = multi_head([head1, head2])79    # Create logits for the multihead.80    logits = logit_fn(logits_dimension=head.logits_dimension)81    # Return the merged EstimatorSpec82    return head.create_estimator_spec(..., logits=logits, ...)83  ```84  Args:85    heads: List or tuple of `_Head` instances. All heads must have `name`86      specified. The first head in the list is the default used at serving time.87    head_weights: Optional list of weights, same length as `heads`. Used when88      merging losses to calculate the weighted sum of losses from each head. If89      `None`, all losses are weighted equally.90  Returns:91    A instance of `_Head` that merges multiple heads.92  Raises:...net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py
Source:net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py  
1import torch2import torch.nn as nn3import torch.nn.functional as F4from net.resnet import resnet34, resnet185import copy6from torch.nn import init7from IPython import embed8from auxi.module import FFM_v4, RNNModule9def weights_init_kaiming(m):10    classname = m.__class__.__name__11    if classname.find('Conv') != -1:12        nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')13    elif classname.find('Linear') != -1:14        nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')15        nn.init.constant_(m.bias.data, 0.0)16    elif classname.find('BatchNorm1d') != -1:17        nn.init.normal_(m.weight.data, 1.0, 0.02)18        nn.init.constant_(m.bias.data, 0.0)19class net(nn.Module):20    def __init__(self, config):21        super(net, self).__init__()22        self.net_head = resnet18(pretrained=True, num_classes=26)23        self.head_layer1 = nn.Sequential(*list(self.net_head.children())[:5])24        self.head_layer2 = list(self.net_head.children())[5]25        self.head_layer3 = list(self.net_head.children())[6]26        self.head_layer4 = list(self.net_head.children())[7]27        self.head_cls = list(self.net_head.children())[-1]28        self.net_body = resnet18(pretrained=True, num_classes=26)29        self.body_layer1 = nn.Sequential(*list(self.net_body.children())[:5])30        self.body_layer2 = list(self.net_body.children())[5]31        self.body_layer3 = list(self.net_body.children())[6]32        self.body_layer4 = list(self.net_body.children())[7]33        self.body_cls = list(self.net_body.children())[-1]34        self.net_scene = resnet18(pretrained=True, num_classes=26)35        self.scene_layer1 = nn.Sequential(*list(self.net_scene.children())[:5])36        self.scene_layer2 = list(self.net_scene.children())[5]37        self.scene_layer3 = list(self.net_scene.children())[6]38        self.scene_layer4 = list(self.net_scene.children())[7]39        self.scene_cls = list(self.net_scene.children())[-1]40        self.fc3 = nn.Sequential(nn.Linear(512*3 , 128), nn.Linear(128, 26))41        self.csr_ly1 = RNNModule(input_size=64, hidden_size=64, num_layers=1)42        self.csr_ly2 = RNNModule(input_size=128, hidden_size=128, num_layers=1)43        self.csr_ly3 = RNNModule(input_size=256, hidden_size=256, num_layers=1)44        self.csr_ly4 = RNNModule(input_size=512, hidden_size=512, num_layers=1)45        self.ffm_head_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)46        self.ffm_body_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)47        self.ffm_scene_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)48        self.ffm_head_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)49        self.ffm_body_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)50        self.ffm_scene_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)51    def forward(self, data, mode):52        head_ly1 = self.head_layer1(data['image_head'])53        body_ly1 = self.body_layer1(data['image_body'])54        scene_ly1 = self.scene_layer1(data['image_scene'])55        head_ly1_ = F.adaptive_avg_pool2d(head_ly1, (1,1)).view(head_ly1.size(0), -1)56        body_ly1_ = F.adaptive_avg_pool2d(body_ly1, (1,1)).view(body_ly1.size(0), -1)57        scene_ly1_ = F.adaptive_avg_pool2d(scene_ly1, (1,1)).view(scene_ly1.size(0), -1)58        feature = self.csr_ly1(torch.stack((head_ly1_, body_ly1_, scene_ly1_), dim=1))59        head_ly1 = head_ly1+feature[:,0,:].view(head_ly1.size(0),head_ly1.size(1),1,1).expand_as(head_ly1)60        body_ly1 = body_ly1+feature[:,1,:].view(body_ly1.size(0),body_ly1.size(1),1,1).expand_as(body_ly1)61        scene_ly1 = scene_ly1+feature[:,2,:].view(scene_ly1.size(0),scene_ly1.size(1),1,1).expand_as(scene_ly1)62        head_ly2 = self.head_layer2(head_ly1)63        body_ly2 = self.body_layer2(body_ly1)64        scene_ly2 = self.scene_layer2(scene_ly1)65        head_ly2_ = F.adaptive_avg_pool2d(head_ly2, (1,1)).view(head_ly2.size(0), -1)66        body_ly2_ = F.adaptive_avg_pool2d(body_ly2, (1,1)).view(body_ly2.size(0), -1)67        scene_ly2_ = F.adaptive_avg_pool2d(scene_ly2, (1,1)).view(scene_ly2.size(0), -1)68        feature = self.csr_ly2(torch.stack((head_ly2_, body_ly2_, scene_ly2_), dim=1))69        head_ly2 = head_ly2+feature[:,0,:].view(head_ly2.size(0),head_ly2.size(1),1,1).expand_as(head_ly2)70        body_ly2 = body_ly2+feature[:,1,:].view(body_ly2.size(0),body_ly2.size(1),1,1).expand_as(body_ly2)71        scene_ly2 = scene_ly2+feature[:,2,:].view(scene_ly2.size(0),scene_ly2.size(1),1,1).expand_as(scene_ly2)72        head_ly2 = self.ffm_head_ly2(head_ly2, body_ly2, scene_ly2) + head_ly273        body_ly2 = self.ffm_body_ly2(body_ly2, head_ly2, scene_ly2) + body_ly274        scene_ly2 = self.ffm_scene_ly2(scene_ly2, head_ly2, body_ly2) + scene_ly275        head_ly3 = self.head_layer3(head_ly2)76        body_ly3 = self.body_layer3(body_ly2)77        scene_ly3 = self.scene_layer3(scene_ly2)78        head_ly3_ = F.adaptive_avg_pool2d(head_ly3, (1,1)).view(head_ly3.size(0), -1)79        body_ly3_ = F.adaptive_avg_pool2d(body_ly3, (1,1)).view(body_ly3.size(0), -1)80        scene_ly3_ = F.adaptive_avg_pool2d(scene_ly3, (1,1)).view(scene_ly3.size(0), -1)81        feature = self.csr_ly3(torch.stack((head_ly3_, body_ly3_, scene_ly3_), dim=1))82        head_ly3 = head_ly3+feature[:,0,:].view(head_ly3.size(0),head_ly3.size(1),1,1).expand_as(head_ly3)83        body_ly3 = body_ly3+feature[:,1,:].view(body_ly3.size(0),body_ly3.size(1),1,1).expand_as(body_ly3)84        scene_ly3 = scene_ly3+feature[:,2,:].view(scene_ly3.size(0),scene_ly3.size(1),1,1).expand_as(scene_ly3)85        head_ly4 = self.head_layer4(head_ly3)86        body_ly4 = self.body_layer4(body_ly3)87        scene_ly4 = self.scene_layer4(scene_ly3)88        head_ly4_ = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)89        body_ly4_ = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)90        scene_ly4_ = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)91        feature = self.csr_ly4(torch.stack((head_ly4_, body_ly4_, scene_ly4_), dim=1))92        head_ly4 = head_ly4+feature[:,0,:].view(head_ly4.size(0),head_ly4.size(1),1,1).expand_as(head_ly4)93        body_ly4 = body_ly4+feature[:,1,:].view(body_ly4.size(0),body_ly4.size(1),1,1).expand_as(body_ly4)94        scene_ly4 = scene_ly4+feature[:,2,:].view(scene_ly4.size(0),scene_ly4.size(1),1,1).expand_as(scene_ly4)95        head_ly4 = self.ffm_head_ly4(head_ly4, body_ly4, scene_ly4) + head_ly496        body_ly4 = self.ffm_body_ly4(body_ly4, head_ly4, scene_ly4) + body_ly497        scene_ly4 = self.ffm_scene_ly4(scene_ly4, head_ly4, body_ly4) + scene_ly498        head_ly4 = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)99        body_ly4 = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)100        scene_ly4 = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)101        out_head = self.head_cls(head_ly4)102        out_body = self.body_cls(body_ly4)103        out_scene = self.scene_cls(scene_ly4)104        out = self.fc3(torch.cat((head_ly4, body_ly4, scene_ly4), dim=1))...head_tail.py
Source:head_tail.py  
1"""Utilities to manage head and tail of elements2The scope is to avoid loosing part of the original text in the final tree.3"""4from .tree import Item5class TokenValue:6    def __init__(self, value):7        self.value = value8        self.pos = None9        self.size = None10        self.head = ""11        self.tail = ""12    def __repr__(self):13        return "TokenValue(%s)" % self.value14    def __str__(self):15        return self.value16class HeadTailLexer:17    """Utility to handle head and tail at lexer time.18    """19    LEXER_ATTR = "_luqum_headtail"20    @classmethod21    def handle(cls, token, orig_value):22        """Handling a token.23        .. note::24          PLY does not gives acces to previous tokens,25          although it does not provide any infrastructure for handling specific state.26          So we use the strategy27          of puting a :py:cls:`HeadTailLexer`instance as an attribute of the lexer28          each time we start a new tokenization.29        """30        # get instance31        if token.lexpos == 0:32            # first token make instance33            instance = cls()34            setattr(token.lexer, cls.LEXER_ATTR, instance)35        else:36            instance = getattr(token.lexer, cls.LEXER_ATTR)37        # handle38        instance.handle_token(token, orig_value)39    def __init__(self):40        self.head = None41        """This will track the head of next element, useful only for first element42        """43        self.last_elt = None44        """This will track the last token, so we can use it to add the tail to it.45        """46    def handle_token(self, token, orig_value):47        """Handle head and tail for tokens48        The scope is to avoid loosing part of the original text and keep it in elements.49        """50        # handle headtail51        if token.type == "SEPARATOR":52            if token.lexpos == 0:53                # spaces at expression start, head for next token54                self.head = token.value55            else:56                # tail of last processed token57                if self.last_elt is not None:58                    self.last_elt.value.tail += token.value59        else:60            # if there is a head, apply61            head = self.head62            if head is not None:63                token.value.head = head64                self.head = None65            # keep tracks of token, to apply tail later66            self.last_elt = token67        # also set pos and size68        if isinstance(token.value, (Item, TokenValue)):69            token.value.pos = token.lexpos70            token.value.size = len(orig_value)71token_headtail = HeadTailLexer.handle72class HeadTailManager:73    """Utility to hande head and tail at expression parse time74    """75    def pos(self, p, head_transfer=False, tail_transfer=False):76        """Compute pos and size of element 0 based on it's parts (p[1:])77        :param list p: the parser expression as in PLY78        :param bool head_transfer: True if head of first child will be transfered to p[0]79        :param bool tail_transfer: True if tail of last child wiil be transfered to p[0]80        """81        # pos82        if p[1].pos is not None:83            p[0].pos = p[1].pos84            if not head_transfer:85                # head is'nt transfered, so we are before it86                p[0].pos -= len(p[1].head)87        # size88        p[0].size = sum(89            (elt.size or 0) + len(elt.head or "") + len(elt.tail or "") for elt in p[1:])90        if head_transfer and p[1].head:91            # we account head in size, remove it92            p[0].size -= len(p[1].head)93        last_p = p[len(p) - 1]  # negative indexing not supported by PLY94        if tail_transfer and last_p.tail:95            # we account head in size, remove it96            p[0].size -= len(last_p.tail)97    def binary_operation(self, p, op_tail):98        self.pos(p, head_transfer=False, tail_transfer=False)99        # correct size100        p[0].size -= len(op_tail)101    def simple_term(self, p):102        self.pos(p, head_transfer=True, tail_transfer=True)103        p[0].head = p[1].head104        p[0].tail = p[1].tail105    def unary(self, p):106        """OP expr"""107        self.pos(p, head_transfer=True, tail_transfer=False)108        p[0].head = p[1].head109        p[2].head = p[1].tail + p[2].head110    def post_unary(self, p):111        """expr OP"""112        self.pos(p, head_transfer=False, tail_transfer=True)113        p[1].tail += p[2].head114        p[0].tail = p[2].tail115    def paren(self, p):116        """( expr )"""117        self.pos(p, head_transfer=True, tail_transfer=True)118        # p[0] is global element (Group or FieldGroup)119        # p[2] is content120        # p[1] is left parenthesis121        p[0].head = p[1].head122        p[2].head = p[1].tail + p[2].head123        # p[3] is right parenthesis124        p[2].tail += p[3].head125        p[0].tail = p[3].tail126    def range(self, p):127        """[ expr TO expr ]"""128        self.pos(p, head_transfer=True, tail_transfer=True)129        # p[0] is global element (Range)130        # p[2] is lower bound131        p[0].head = p[1].head132        p[2].head = p[1].tail + p[2].head133        # p[3] is TO134        # p[4] is upper bound135        p[2].tail += p[3].head136        p[4].head = p[3].tail + p[4].head137        # p[5] is upper braket138        p[4].tail += p[5].head139        p[0].tail = p[5].tail140    def search_field(self, p):141        """name: expr"""142        self.pos(p, head_transfer=True, tail_transfer=False)143        # p[0] is global element (SearchField)144        # p[1] is search field name145        # p[2] is COLUMN146        p[0].head = p[1].head147        if p[1].tail or p[2].head:148            pass  # FIXME: add warning, or handle space between point and name in SearchField ?149        # p[3] is the expression150        p[3].head = p[2].tail + p[3].head151head_tail = HeadTailManager()152"""singleton of HeadTailManager...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
