Best Python code snippet using playwright-python
head_test.py
Source:head_test.py  
...93    stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z)94    lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.)95    return sum(lpl)/len(lpl)96  def testPoissonWithLogits(self):97    head = head_lib.poisson_regression_head()98    labels = ((0.,), (1.,), (1.,))99    logits = ((0.,), (-1.,), (3.,))100    with ops.Graph().as_default(), session.Session():101      model_fn_ops = head.create_model_fn_ops(102          {},103          labels=labels,104          mode=model_fn.ModeKeys.TRAIN,105          train_op_fn=head_lib.no_op_train_fn,106          logits=logits)107      self._assert_output_alternatives(model_fn_ops)108      _assert_summary_tags(self, ["loss"])109      _assert_no_variables(self)110      loss = self._log_poisson_loss(logits, labels)111      _assert_metrics(self, loss, {"loss": loss}, model_fn_ops)112class RegressionHeadTest(test.TestCase):113  def _assert_output_alternatives(self, model_fn_ops):114    self.assertEquals({115        None: constants.ProblemType.LINEAR_REGRESSION116    }, {117        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)118    })119  # TODO(zakaria): test multilabel regression.120  def testRegressionWithLogits(self):121    head = head_lib.regression_head()122    with ops.Graph().as_default(), session.Session():123      model_fn_ops = head.create_model_fn_ops(124          {},125          labels=((0.,), (1.,), (1.,)),126          mode=model_fn.ModeKeys.TRAIN,127          train_op_fn=head_lib.no_op_train_fn,128          logits=((1.,), (1.,), (3.,)))129      self._assert_output_alternatives(model_fn_ops)130      _assert_summary_tags(self, ["loss"])131      _assert_no_variables(self)132      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)133  def testRegressionWithLogitFn(self):134    head = head_lib.regression_head(link_fn=math_ops.square)135    def _assert_preditions(test_case, expected_predictions, model_fn_ops):136      variables.initialize_local_variables().run()137      test_case.assertAllClose(expected_predictions,138                               model_fn_ops.predictions["scores"].eval())139    with ops.Graph().as_default(), session.Session():140      model_fn_ops = head.create_model_fn_ops(141          {},142          labels=((0.,), (1.,), (1.,)),143          mode=model_fn.ModeKeys.TRAIN,144          train_op_fn=head_lib.no_op_train_fn,145          logits=((1.,), (1.,), (3.,)))146      self._assert_output_alternatives(model_fn_ops)147      _assert_summary_tags(self, ["loss"])148      _assert_no_variables(self)149      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)150      _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops)151  def testRegressionWithInvalidLogits(self):152    head = head_lib.regression_head()153    with ops.Graph().as_default(), session.Session():154      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):155        head.create_model_fn_ops(156            {},157            labels=((0.,), (1.,), (1.,)),158            mode=model_fn.ModeKeys.TRAIN,159            train_op_fn=head_lib.no_op_train_fn,160            logits=((1., 1.), (1., 1.), (3., 1.)))161  def testRegressionWithLogitsInput(self):162    head = head_lib.regression_head()163    with ops.Graph().as_default(), session.Session():164      model_fn_ops = head.create_model_fn_ops(165          {},166          labels=((0.,), (1.,), (1.,)),167          mode=model_fn.ModeKeys.TRAIN,168          train_op_fn=head_lib.no_op_train_fn,169          logits_input=((0., 0.), (0., 0.), (0., 0.)))170      self._assert_output_alternatives(model_fn_ops)171      w = ("regression_head/logits/weights:0",172           "regression_head/logits/biases:0")173      _assert_variables(174          self, expected_global=w, expected_model=w, expected_trainable=w)175      variables.global_variables_initializer().run()176      _assert_summary_tags(self, ["loss"])177      _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)178  def testRegressionWithLogitsAndLogitsInput(self):179    head = head_lib.regression_head()180    with ops.Graph().as_default(), session.Session():181      with self.assertRaisesRegexp(182          ValueError, "Both logits and logits_input supplied"):183        head.create_model_fn_ops(184            {},185            labels=((0.,), (1.,), (1.,)),186            mode=model_fn.ModeKeys.TRAIN,187            train_op_fn=head_lib.no_op_train_fn,188            logits_input=((0., 0.), (0., 0.), (0., 0.)),189            logits=((1.,), (1.,), (3.,)))190  def testRegressionEvalMode(self):191    head = head_lib.regression_head()192    with ops.Graph().as_default(), session.Session():193      model_fn_ops = head.create_model_fn_ops(194          {},195          labels=((1.,), (1.,), (3.,)),196          mode=model_fn.ModeKeys.EVAL,197          train_op_fn=head_lib.no_op_train_fn,198          logits=((0.,), (1.,), (1.,)))199      self._assert_output_alternatives(model_fn_ops)200      self.assertIsNone(model_fn_ops.train_op)201      _assert_no_variables(self)202      _assert_summary_tags(self, ["loss"])203      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)204  def testRegressionWithLabelName(self):205    label_name = "my_label"206    head = head_lib.regression_head(label_name=label_name)207    with ops.Graph().as_default(), session.Session():208      model_fn_ops = head.create_model_fn_ops(209          {},210          labels={label_name: ((0.,), (1.,), (1.,))},211          mode=model_fn.ModeKeys.TRAIN,212          train_op_fn=head_lib.no_op_train_fn,213          logits=((1.,), (1.,), (3.,)))214      self._assert_output_alternatives(model_fn_ops)215      _assert_no_variables(self)216      _assert_summary_tags(self, ["loss"])217      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)218  def testRegressionWithScalarWeights(self):219    head = head_lib.regression_head(weight_column_name="label_weight")220    with ops.Graph().as_default(), session.Session():221      weights = 2.222      labels = ((0.,), (1.,), (1.,))223      model_fn_ops = head.create_model_fn_ops(224          features={"label_weight": weights},225          labels=labels,226          mode=model_fn.ModeKeys.TRAIN,227          train_op_fn=head_lib.no_op_train_fn,228          logits=((1.,), (1.,), (3.,)))229      self._assert_output_alternatives(model_fn_ops)230      _assert_no_variables(self)231      _assert_summary_tags(self, ["loss"])232      _assert_metrics(self, (weights * 5.) / len(labels), {233          "loss": (weights * 5.) / (weights * len(labels))234      }, model_fn_ops)235  def testRegressionWith1DWeights(self):236    head = head_lib.regression_head(weight_column_name="label_weight")237    with ops.Graph().as_default(), session.Session():238      weights = (2., 5., 0.)239      labels = ((0.,), (1.,), (1.,))240      model_fn_ops = head.create_model_fn_ops(241          features={"label_weight": weights},242          labels=labels,243          mode=model_fn.ModeKeys.TRAIN,244          train_op_fn=head_lib.no_op_train_fn,245          logits=((1.,), (1.,), (3.,)))246      self._assert_output_alternatives(model_fn_ops)247      _assert_no_variables(self)248      _assert_summary_tags(self, ["loss"])249      _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},250                      model_fn_ops)251  def testRegressionWith2DWeights(self):252    head = head_lib.regression_head(weight_column_name="label_weight")253    with ops.Graph().as_default(), session.Session():254      weights = ((2.,), (5.,), (0.,))255      labels = ((0.,), (1.,), (1.,))256      model_fn_ops = head.create_model_fn_ops(257          features={"label_weight": weights},258          labels=labels,259          mode=model_fn.ModeKeys.TRAIN,260          train_op_fn=head_lib.no_op_train_fn,261          logits=((1.,), (1.,), (3.,)))262      self._assert_output_alternatives(model_fn_ops)263      _assert_no_variables(self)264      _assert_summary_tags(self, ["loss"])265      _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},266                      model_fn_ops)267  def testRegressionWithCenteredBias(self):268    head = head_lib.regression_head(enable_centered_bias=True)269    with ops.Graph().as_default(), session.Session():270      model_fn_ops = head.create_model_fn_ops(271          {},272          labels=((0.,), (1.,), (1.,)),273          mode=model_fn.ModeKeys.TRAIN,274          train_op_fn=head_lib.no_op_train_fn,275          logits=((1.,), (1.,), (3.,)))276      self._assert_output_alternatives(model_fn_ops)277      _assert_variables(278          self,279          expected_global=(280              "regression_head/centered_bias_weight:0",281              "regression_head/regression_head/centered_bias_weight/Adagrad:0",282          ),283          expected_trainable=("regression_head/centered_bias_weight:0",))284      variables.global_variables_initializer().run()285      _assert_summary_tags(self, [286          "loss",287          "regression_head/centered_bias/bias_0"288      ])289      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)290  def testRegressionErrorInSparseTensorLabels(self):291    head = head_lib.regression_head()292    with ops.Graph().as_default():293      labels = sparse_tensor.SparseTensorValue(294          indices=((0, 0), (1, 0), (2, 0)),295          values=(0., 1., 1.),296          dense_shape=(3, 1))297      with self.assertRaisesRegexp(ValueError,298                                   "SparseTensor is not supported"):299        head.create_model_fn_ops(300            {},301            labels=labels,302            mode=model_fn.ModeKeys.TRAIN,303            train_op_fn=head_lib.no_op_train_fn,304            logits=((1.,), (1.,), (3.,)))305class MultiLabelHeadTest(test.TestCase):306  def _assert_output_alternatives(self, model_fn_ops):307    self.assertEquals({308        None: constants.ProblemType.CLASSIFICATION309    }, {310        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)311    })312  def setUp(self):313    self._logits = ((1., 0., 0.),)314    self._labels = ((0, 0, 1),)315  def _expected_eval_metrics(self, expected_loss):316    return {317        "accuracy": 1. / 3,318        "loss": expected_loss,319        "auc": 1. / 4,320        "auc/class0": 1.,321        "auc/class1": 1.,322        "auc/class2": 0.,323        "auc_precision_recall": 0.166667,324        "auc_precision_recall/class0": 0,325        "auc_precision_recall/class1": 0.,326        "auc_precision_recall/class2": 1.,327        "labels/actual_label_mean/class0": self._labels[0][0],328        "labels/actual_label_mean/class1": self._labels[0][1],329        "labels/actual_label_mean/class2": self._labels[0][2],330        "labels/logits_mean/class0": self._logits[0][0],331        "labels/logits_mean/class1": self._logits[0][1],332        "labels/logits_mean/class2": self._logits[0][2],333        "labels/prediction_mean/class0": self._logits[0][0],334        "labels/prediction_mean/class1": self._logits[0][1],335        "labels/prediction_mean/class2": self._logits[0][2],336        "labels/probability_mean/class0": _sigmoid(self._logits[0][0]),337        "labels/probability_mean/class1": _sigmoid(self._logits[0][1]),338        "labels/probability_mean/class2": _sigmoid(self._logits[0][2]),339    }340  def testMultiLabelWithLogits(self):341    n_classes = 3342    head = head_lib.multi_label_head(343        n_classes=n_classes, metric_class_ids=range(n_classes))344    with ops.Graph().as_default(), session.Session():345      model_fn_ops = head.create_model_fn_ops(346          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,347          logits=self._logits)348      self._assert_output_alternatives(model_fn_ops)349      _assert_no_variables(self)350      _assert_summary_tags(self, ["loss"])351      expected_loss = .89985204352      _assert_metrics(self, expected_loss,353                      self._expected_eval_metrics(expected_loss), model_fn_ops)354  def testMultiLabelTwoClasses(self):355    n_classes = 2356    labels = ((0, 1),)357    logits = ((1., 0.),)358    head = head_lib.multi_label_head(359        n_classes=n_classes, metric_class_ids=range(n_classes))360    with ops.Graph().as_default(), session.Session():361      model_fn_ops = head.create_model_fn_ops(362          {}, model_fn.ModeKeys.TRAIN, labels=labels,363          train_op_fn=head_lib.no_op_train_fn, logits=logits)364      self._assert_output_alternatives(model_fn_ops)365      _assert_no_variables(self)366      _assert_summary_tags(self, ["loss"])367      expected_loss = 1.00320443368      _assert_metrics(self, expected_loss, {369          "accuracy": 0.,370          "auc": 0.,371          "loss": expected_loss,372          "auc/class0": 1.,373          "auc/class1": 0.,374          "labels/actual_label_mean/class0": labels[0][0],375          "labels/actual_label_mean/class1": labels[0][1],376          "labels/logits_mean/class0": logits[0][0],377          "labels/logits_mean/class1": logits[0][1],378          "labels/prediction_mean/class0": logits[0][0],379          "labels/prediction_mean/class1": logits[0][1],380          "labels/probability_mean/class0": _sigmoid(logits[0][0]),381          "labels/probability_mean/class1": _sigmoid(logits[0][1]),382      }, model_fn_ops)383  def testMultiLabelWithInvalidLogits(self):384    head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)385    with ops.Graph().as_default(), session.Session():386      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):387        head.create_model_fn_ops(388            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,389            logits=self._logits)390  def testMultiLabelWithLogitsInput(self):391    n_classes = 3392    head = head_lib.multi_label_head(393        n_classes=n_classes, metric_class_ids=range(n_classes))394    with ops.Graph().as_default(), session.Session():395      model_fn_ops = head.create_model_fn_ops(396          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,397          logits_input=((0., 0.),))398      self._assert_output_alternatives(model_fn_ops)399      w = ("multi_label_head/logits/weights:0",400           "multi_label_head/logits/biases:0")401      _assert_variables(402          self, expected_global=w, expected_model=w, expected_trainable=w)403      variables.global_variables_initializer().run()404      _assert_summary_tags(self, ["loss"])405      expected_loss = .69314718406      _assert_metrics(self, expected_loss, {407          "accuracy": 2. / 3,408          "auc": 2. / 4,409          "loss": expected_loss,410          "auc/class0": 1.,411          "auc/class1": 1.,412          "auc/class2": 0.,413          "labels/actual_label_mean/class0": self._labels[0][0],414          "labels/actual_label_mean/class1": self._labels[0][1],415          "labels/actual_label_mean/class2": self._labels[0][2],416          "labels/logits_mean/class0": 0.,417          "labels/logits_mean/class1": 0.,418          "labels/logits_mean/class2": 0.,419          "labels/prediction_mean/class0": 0.,420          "labels/prediction_mean/class1": 0.,421          "labels/prediction_mean/class2": 0.,422          "labels/probability_mean/class0": .5,423          "labels/probability_mean/class1": .5,424          "labels/probability_mean/class2": .5,425      }, model_fn_ops)426  def testMultiLabelWithLogitsAndLogitsInput(self):427    n_classes = 3428    head = head_lib.multi_label_head(429        n_classes=n_classes, metric_class_ids=range(n_classes))430    with ops.Graph().as_default(), session.Session():431      with self.assertRaisesRegexp(432          ValueError, "Both logits and logits_input supplied"):433        head.create_model_fn_ops(434            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,435            logits_input=((0., 0.),), logits=self._logits)436  def testMultiLabelEval(self):437    n_classes = 3438    head = head_lib.multi_label_head(439        n_classes=n_classes, metric_class_ids=range(n_classes))440    with ops.Graph().as_default(), session.Session():441      model_fn_ops = head.create_model_fn_ops(442          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,443          logits=self._logits)444      self._assert_output_alternatives(model_fn_ops)445      self.assertIsNone(model_fn_ops.train_op)446      _assert_no_variables(self)447      _assert_summary_tags(self, ["loss"])448      expected_loss = .89985204449      _assert_metrics(self, expected_loss,450                      self._expected_eval_metrics(expected_loss), model_fn_ops)451  def testMultiClassEvalWithLargeLogits(self):452    n_classes = 3453    head = head_lib.multi_label_head(454        n_classes=n_classes, metric_class_ids=range(n_classes))455    logits = ((2., 0., -1),)456    with ops.Graph().as_default(), session.Session():457      # logloss: z:label, x:logit458      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))459      model_fn_ops = head.create_model_fn_ops(460          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,461          logits=logits)462      self._assert_output_alternatives(model_fn_ops)463      self.assertIsNone(model_fn_ops.train_op)464      _assert_no_variables(self)465      _assert_summary_tags(self, ["loss"])466      expected_loss = 1.377779467      expected_eval_metrics = {468          "accuracy": 1. / 3,469          "auc": 9.99999e-07,470          "loss": expected_loss,471          "auc/class0": 1.,472          "auc/class1": 1.,473          "auc/class2": 0.,474          "labels/actual_label_mean/class0": 0. / 1,475          "labels/actual_label_mean/class1": 0. / 1,476          "labels/actual_label_mean/class2": 1. / 1,477          "labels/logits_mean/class0": logits[0][0],478          "labels/logits_mean/class1": logits[0][1],479          "labels/logits_mean/class2": logits[0][2],480          "labels/prediction_mean/class0": 1,481          "labels/prediction_mean/class1": 0,482          "labels/prediction_mean/class2": 0,483          "labels/probability_mean/class0": _sigmoid(logits[0][0]),484          "labels/probability_mean/class1": _sigmoid(logits[0][1]),485          "labels/probability_mean/class2": _sigmoid(logits[0][2]),486      }487      _assert_metrics(self, expected_loss,488                      expected_eval_metrics, model_fn_ops)489  def testMultiLabelInfer(self):490    n_classes = 3491    head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")492    with ops.Graph().as_default(), session.Session():493      model_fn_ops = head.create_model_fn_ops(494          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,495          logits=((1., 0., 0.), (0., 0., 1)))496      self.assertIsNone(model_fn_ops.train_op)497      _assert_no_variables(self)498      with session.Session():499        self.assertListEqual(500            [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])501        self.assertItemsEqual(502            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))503        self.assertEqual(504            constants.ProblemType.CLASSIFICATION,505            model_fn_ops.output_alternatives["head_name"][0])506        predictions_for_serving = (507            model_fn_ops.output_alternatives["head_name"][1])508        self.assertIn("classes", six.iterkeys(predictions_for_serving))509        self.assertAllEqual(510            [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],511            predictions_for_serving["classes"].eval())512        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))513        self.assertAllClose(514            [[0.731059, 0.5, 0.5],515             [0.5, 0.5, 0.731059,]],516            predictions_for_serving["probabilities"].eval())517  def testMultiLabelWithLabelName(self):518    n_classes = 3519    label_name = "my_label"520    head = head_lib.multi_label_head(521        n_classes=n_classes,522        label_name=label_name,523        metric_class_ids=range(n_classes))524    with ops.Graph().as_default(), session.Session():525      model_fn_ops = head.create_model_fn_ops(526          {}, model_fn.ModeKeys.TRAIN, {label_name: self._labels},527          head_lib.no_op_train_fn, logits=self._logits)528      self._assert_output_alternatives(model_fn_ops)529      _assert_no_variables(self)530      _assert_summary_tags(self, ["loss"])531      expected_loss = .89985204532      _assert_metrics(self, expected_loss,533                      self._expected_eval_metrics(expected_loss), model_fn_ops)534  def testMultiLabelWithScalarWeight(self):535    n_classes = 3536    head = head_lib.multi_label_head(537        n_classes=n_classes,538        weight_column_name="label_weight",539        metric_class_ids=range(n_classes))540    with ops.Graph().as_default(), session.Session():541      model_fn_ops = head.create_model_fn_ops(542          features={"label_weight": .1},543          labels=self._labels,544          mode=model_fn.ModeKeys.TRAIN,545          train_op_fn=head_lib.no_op_train_fn,546          logits=self._logits)547      self._assert_output_alternatives(model_fn_ops)548      _assert_no_variables(self)549      _assert_summary_tags(self, ["loss"])550      _assert_metrics(self, .089985214,551                      self._expected_eval_metrics(.89985214), model_fn_ops)552  def testMultiLabelWith1DWeight(self):553    n_classes = 3554    head = head_lib.multi_label_head(555        n_classes=n_classes,556        weight_column_name="label_weight",557        metric_class_ids=range(n_classes))558    with ops.Graph().as_default(), session.Session():559      with self.assertRaisesRegexp(560          ValueError, "weights can not be broadcast to values"):561        head.create_model_fn_ops(562            features={"label_weight": (.1, .1, .1)},563            labels=self._labels,564            mode=model_fn.ModeKeys.TRAIN,565            train_op_fn=head_lib.no_op_train_fn,566            logits=self._logits)567  def testMultiLabelWith2DWeight(self):568    n_classes = 3569    head = head_lib.multi_label_head(570        n_classes=n_classes,571        weight_column_name="label_weight",572        metric_class_ids=range(n_classes))573    with ops.Graph().as_default(), session.Session():574      model_fn_ops = head.create_model_fn_ops(575          features={"label_weight": ((.1, .1, .1),)},576          labels=self._labels,577          mode=model_fn.ModeKeys.TRAIN,578          train_op_fn=head_lib.no_op_train_fn,579          logits=self._logits)580      self._assert_output_alternatives(model_fn_ops)581      _assert_no_variables(self)582      _assert_summary_tags(self, ["loss"])583      _assert_metrics(self, .089985214,584                      self._expected_eval_metrics(.89985214), model_fn_ops)585  def testMultiLabelWithCustomLoss(self):586    n_classes = 3587    head = head_lib.multi_label_head(588        n_classes=n_classes,589        weight_column_name="label_weight",590        metric_class_ids=range(n_classes),591        loss_fn=_sigmoid_cross_entropy)592    with ops.Graph().as_default(), session.Session():593      model_fn_ops = head.create_model_fn_ops(594          features={"label_weight": .1},595          labels=self._labels,596          mode=model_fn.ModeKeys.TRAIN,597          train_op_fn=head_lib.no_op_train_fn,598          logits=self._logits)599      self._assert_output_alternatives(model_fn_ops)600      _assert_no_variables(self)601      _assert_summary_tags(self, ["loss"])602      expected_loss = .089985214603      _assert_metrics(self, expected_loss,604                      self._expected_eval_metrics(expected_loss), model_fn_ops)605  def testMultiLabelWithCenteredBias(self):606    n_classes = 3607    head = head_lib.multi_label_head(608        n_classes=n_classes,609        enable_centered_bias=True,610        metric_class_ids=range(n_classes))611    with ops.Graph().as_default(), session.Session():612      model_fn_ops = head.create_model_fn_ops(613          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,614          logits=self._logits)615      self._assert_output_alternatives(model_fn_ops)616      _assert_variables(617          self,618          expected_global=(619              "multi_label_head/centered_bias_weight:0",620              ("multi_label_head/multi_label_head/centered_bias_weight/"621               "Adagrad:0"),),622          expected_trainable=("multi_label_head/centered_bias_weight:0",))623      variables.global_variables_initializer().run()624      _assert_summary_tags(self, (625          "loss",626          "multi_label_head/centered_bias/bias_0",627          "multi_label_head/centered_bias/bias_1",628          "multi_label_head/centered_bias/bias_2"629      ))630      expected_loss = .89985204631      _assert_metrics(self, expected_loss,632                      self._expected_eval_metrics(expected_loss), model_fn_ops)633  def testMultiLabelSparseTensorLabels(self):634    n_classes = 3635    head = head_lib.multi_label_head(636        n_classes=n_classes, metric_class_ids=range(n_classes))637    with ops.Graph().as_default(), session.Session():638      labels = sparse_tensor.SparseTensorValue(639          indices=((0, 0),),640          values=(2,),641          dense_shape=(1, 1))642      model_fn_ops = head.create_model_fn_ops(643          features={},644          mode=model_fn.ModeKeys.TRAIN,645          labels=labels,646          train_op_fn=head_lib.no_op_train_fn,647          logits=self._logits)648      _assert_no_variables(self)649      _assert_summary_tags(self, ["loss"])650      expected_loss = .89985204651      _assert_metrics(self, expected_loss,652                      self._expected_eval_metrics(expected_loss), model_fn_ops)653  def testMultiLabelSparseTensorLabelsTooFewClasses(self):654    n_classes = 3655    head = head_lib.multi_label_head(656        n_classes=n_classes, metric_class_ids=range(n_classes))657    # Set _logits_dimension (n_classes) to a lower value; if it's set to 1658    # upfront, the class throws an error during initialization.659    head._logits_dimension = 1660    with ops.Graph().as_default(), session.Session():661      labels = sparse_tensor.SparseTensorValue(662          indices=((0, 0),),663          values=(2,),664          dense_shape=(1, 1))665      with self.assertRaisesRegexp(ValueError,666                                   "Must set num_classes >= 2 when passing"):667        head.create_model_fn_ops(668            features={},669            labels=labels,670            mode=model_fn.ModeKeys.TRAIN,671            train_op_fn=head_lib.no_op_train_fn,672            logits=[0.])673class BinaryClassificationHeadTest(test.TestCase):674  def _assert_output_alternatives(self, model_fn_ops):675    self.assertEquals({676        None: constants.ProblemType.LOGISTIC_REGRESSION677    }, {678        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)679    })680  def setUp(self):681    self._logits = ((1.,), (1.,))682    self._labels = ((1.,), (0.,))683  def _expected_eval_metrics(self, expected_loss):684    label_mean = np.mean(self._labels)685    return {686        "accuracy": 1. / 2,687        "accuracy/baseline_label_mean": label_mean,688        "accuracy/threshold_0.500000_mean": 1. / 2,689        "auc": 1. / 2,690        "auc_precision_recall": 0.749999,691        "labels/actual_label_mean": label_mean,692        "labels/prediction_mean": .731059,  # softmax693        "loss": expected_loss,694        "precision/positive_threshold_0.500000_mean": 1. / 2,695        "recall/positive_threshold_0.500000_mean": 1. / 1,696    }697  def testBinaryClassificationWithLogits(self):698    n_classes = 2699    head = head_lib.multi_class_head(n_classes=n_classes)700    with ops.Graph().as_default(), session.Session():701      # logloss: z:label, x:logit702      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))703      model_fn_ops = head.create_model_fn_ops(704          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,705          logits=self._logits)706      self._assert_output_alternatives(model_fn_ops)707      _assert_no_variables(self)708      _assert_summary_tags(self, ["loss"])709      expected_loss = .81326175710      _assert_metrics(self, expected_loss,711                      self._expected_eval_metrics(expected_loss), model_fn_ops)712  def testBinaryClassificationWithInvalidLogits(self):713    head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)714    with ops.Graph().as_default(), session.Session():715      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):716        head.create_model_fn_ops(717            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,718            logits=self._logits)719  def testBinaryClassificationWithLogitsInput(self):720    n_classes = 2721    head = head_lib.multi_class_head(n_classes=n_classes)722    with ops.Graph().as_default(), session.Session():723      # logloss: z:label, x:logit724      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))725      model_fn_ops = head.create_model_fn_ops(726          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,727          logits_input=((0., 0.), (0., 0.)))728      self._assert_output_alternatives(model_fn_ops)729      w = ("binary_logistic_head/logits/weights:0",730           "binary_logistic_head/logits/biases:0")731      _assert_variables(732          self, expected_global=w, expected_model=w, expected_trainable=w)733      variables.global_variables_initializer().run()734      _assert_summary_tags(self, ["loss"])735      expected_loss = .69314718736      label_mean = np.mean(self._labels)737      _assert_metrics(self, expected_loss, {738          "accuracy": 1. / 2,739          "accuracy/baseline_label_mean": label_mean,740          "accuracy/threshold_0.500000_mean": 1. / 2,741          "auc": 1. / 2,742          "labels/actual_label_mean": label_mean,743          "labels/prediction_mean": .5,  # softmax744          "loss": expected_loss,745          "precision/positive_threshold_0.500000_mean": 0. / 2,746          "recall/positive_threshold_0.500000_mean": 0. / 1,747      }, model_fn_ops)748  def testBinaryClassificationWithLogitsAndLogitsInput(self):749    head = head_lib.multi_class_head(n_classes=2)750    with ops.Graph().as_default(), session.Session():751      with self.assertRaisesRegexp(752          ValueError, "Both logits and logits_input supplied"):753        head.create_model_fn_ops(754            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,755            logits_input=((0., 0.), (0., 0.)), logits=self._logits)756  def testBinaryClassificationEval(self):757    n_classes = 2758    head = head_lib.multi_class_head(n_classes=n_classes)759    with ops.Graph().as_default(), session.Session():760      # logloss: z:label, x:logit761      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))762      model_fn_ops = head.create_model_fn_ops(763          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,764          logits=self._logits)765      self._assert_output_alternatives(model_fn_ops)766      self.assertIsNone(model_fn_ops.train_op)767      _assert_no_variables(self)768      _assert_summary_tags(self, ["loss"])769      expected_loss = .81326175770      _assert_metrics(self, expected_loss,771                      self._expected_eval_metrics(expected_loss), model_fn_ops)772  def testBinaryClassificationInfer(self):773    n_classes = 2774    head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")775    with ops.Graph().as_default(), session.Session():776      # logloss: z:label, x:logit777      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))778      model_fn_ops = head.create_model_fn_ops(779          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,780          logits=self._logits)781      self.assertIsNone(model_fn_ops.train_op)782      _assert_no_variables(self)783      with session.Session():784        self.assertListEqual(785            [1, 1], list(model_fn_ops.predictions["classes"].eval()))786        self.assertItemsEqual(787            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))788        self.assertEqual(789            constants.ProblemType.LOGISTIC_REGRESSION,790            model_fn_ops.output_alternatives["head_name"][0])791        predictions_for_serving = (792            model_fn_ops.output_alternatives["head_name"][1])793        self.assertIn("classes", six.iterkeys(predictions_for_serving))794        predicted_classes = predictions_for_serving["classes"].eval().tolist()795        self.assertListEqual(796            [b"0", b"1"], predicted_classes[0])797        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))798  def testBinaryClassificationInferMode_withWeightColumn(self):799    n_classes = 2800    head = head_lib.multi_class_head(n_classes=n_classes,801                                     weight_column_name="label_weight")802    with ops.Graph().as_default(), session.Session():803      # logloss: z:label, x:logit804      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))805      model_fn_ops = head.create_model_fn_ops(806          # This is what is being tested, features should not have weight for807          # inference.808          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,809          logits=self._logits)810      self._assert_output_alternatives(model_fn_ops)811      self.assertIsNone(model_fn_ops.train_op)812      _assert_no_variables(self)813  def testErrorInSparseTensorLabels(self):814    n_classes = 2815    head = head_lib.multi_class_head(n_classes=n_classes)816    with ops.Graph().as_default():817      labels = sparse_tensor.SparseTensorValue(818          indices=((0, 0), (1, 0), (2, 0)),819          values=(0, 1, 1),820          dense_shape=(3, 1))821      with self.assertRaisesRegexp(ValueError,822                                   "SparseTensor is not supported"):823        head.create_model_fn_ops(824            {},825            model_fn.ModeKeys.TRAIN,826            labels,827            head_lib.no_op_train_fn,828            logits=((1.,), (1.,), (3.,)))829  def testBinaryClassificationWithLabelName(self):830    label_name = "my_label"831    head = head_lib.multi_class_head(n_classes=2, label_name=label_name)832    with ops.Graph().as_default(), session.Session():833      # logloss: z:label, x:logit834      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))835      model_fn_ops = head.create_model_fn_ops(836          {},837          labels={label_name: self._labels},838          mode=model_fn.ModeKeys.TRAIN,839          train_op_fn=head_lib.no_op_train_fn,840          logits=self._logits)841      self._assert_output_alternatives(model_fn_ops)842      _assert_no_variables(self)843      _assert_summary_tags(self, ["loss"])844      expected_loss = .81326175845      _assert_metrics(self, expected_loss,846                      self._expected_eval_metrics(expected_loss), model_fn_ops)847  def testBinaryClassificationWith1DWeights(self):848    n_classes = 2849    head = head_lib.multi_class_head(850        n_classes=n_classes, weight_column_name="label_weight")851    with ops.Graph().as_default(), session.Session():852      weights = (1., 0.)853      # logloss: z:label, x:logit854      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))855      model_fn_ops = head.create_model_fn_ops(856          features={"label_weight": weights},857          labels=self._labels,858          mode=model_fn.ModeKeys.TRAIN,859          train_op_fn=head_lib.no_op_train_fn,860          logits=self._logits)861      self._assert_output_alternatives(model_fn_ops)862      _assert_no_variables(self)863      _assert_summary_tags(self, ["loss"])864      expected_total_loss = .31326166865      _assert_metrics(866          self,867          expected_total_loss / len(weights),868          {869              "accuracy": 1. / 1,870              "accuracy/baseline_label_mean": 1. / 1,871              "accuracy/threshold_0.500000_mean": 1. / 1,872              "auc": 0. / 1,873              "labels/actual_label_mean": 1. / 1,874              "labels/prediction_mean": .731059,  # softmax875              # eval loss is weighted loss divided by sum of weights.876              "loss": expected_total_loss,877              "precision/positive_threshold_0.500000_mean": 1. / 1,878              "recall/positive_threshold_0.500000_mean": 1. / 1,879          },880          model_fn_ops)881  def testBinaryClassificationWith2DWeights(self):882    n_classes = 2883    head = head_lib.multi_class_head(884        n_classes=n_classes, weight_column_name="label_weight")885    with ops.Graph().as_default(), session.Session():886      weights = ((1.,), (0.,))887      # logloss: z:label, x:logit888      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))889      model_fn_ops = head.create_model_fn_ops(890          features={"label_weight": weights},891          labels=self._labels,892          mode=model_fn.ModeKeys.TRAIN,893          train_op_fn=head_lib.no_op_train_fn,894          logits=self._logits)895      self._assert_output_alternatives(model_fn_ops)896      _assert_no_variables(self)897      _assert_summary_tags(self, ["loss"])898      expected_total_loss = .31326166899      _assert_metrics(900          self,901          expected_total_loss / len(weights),902          {903              "accuracy": 1. / 1,904              "accuracy/baseline_label_mean": 1. / 1,905              "accuracy/threshold_0.500000_mean": 1. / 1,906              "auc": 0. / 1,907              "labels/actual_label_mean": 1. / 1,908              "labels/prediction_mean": .731059,  # softmax909              # eval loss is weighted loss divided by sum of weights.910              "loss": expected_total_loss,911              "precision/positive_threshold_0.500000_mean": 1. / 1,912              "recall/positive_threshold_0.500000_mean": 1. / 1,913          },914          model_fn_ops)915  def testBinaryClassificationWithCustomLoss(self):916    head = head_lib.multi_class_head(917        n_classes=2, weight_column_name="label_weight",918        loss_fn=_sigmoid_cross_entropy)919    with ops.Graph().as_default(), session.Session():920      weights = ((.2,), (0.,))921      model_fn_ops = head.create_model_fn_ops(922          features={"label_weight": weights},923          labels=self._labels,924          mode=model_fn.ModeKeys.TRAIN,925          train_op_fn=head_lib.no_op_train_fn,926          logits=self._logits)927      self._assert_output_alternatives(model_fn_ops)928      _assert_no_variables(self)929      _assert_summary_tags(self, ["loss"])930      # logloss: z:label, x:logit931      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))932      # expected_loss is (total_weighted_loss)/1 since there is 1 nonzero933      # weight.934      expected_loss = 0.062652342935      _assert_metrics(936          self,937          expected_loss,938          {939              "accuracy": 1. / 1,940              "accuracy/baseline_label_mean": 1. / 1,941              "accuracy/threshold_0.500000_mean": 1. / 1,942              "auc": 0. / 1,943              "labels/actual_label_mean": 1. / 1,944              "labels/prediction_mean": .731059,  # softmax945              "loss": expected_loss,946              "precision/positive_threshold_0.500000_mean": 1. / 1,947              "recall/positive_threshold_0.500000_mean": 1. / 1,948          },949          model_fn_ops)950  def testBinaryClassificationWithCenteredBias(self):951    head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)952    with ops.Graph().as_default(), session.Session():953      # logloss: z:label, x:logit954      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))955      model_fn_ops = head.create_model_fn_ops(956          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,957          logits=self._logits)958      self._assert_output_alternatives(model_fn_ops)959      _assert_variables(960          self,961          expected_global=(962              "binary_logistic_head/centered_bias_weight:0",963              ("binary_logistic_head/binary_logistic_head/centered_bias_weight/"964               "Adagrad:0"),),965          expected_trainable=("binary_logistic_head/centered_bias_weight:0",))966      variables.global_variables_initializer().run()967      _assert_summary_tags(self, [968          "loss",969          "binary_logistic_head/centered_bias/bias_0"970      ])971      expected_loss = .81326175972      _assert_metrics(self, expected_loss,973                      self._expected_eval_metrics(expected_loss), model_fn_ops)974class MultiClassHeadTest(test.TestCase):975  def _assert_output_alternatives(self, model_fn_ops):976    self.assertEquals({977        None: constants.ProblemType.CLASSIFICATION978    }, {979        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)980    })981  def setUp(self):982    self._logits = ((1., 0., 0.),)983    self._labels = ((2,),)984  def _expected_eval_metrics(self, expected_loss):985    return {986        "accuracy": 0.,987        "loss": expected_loss,988        "labels/actual_label_mean/class0": 0. / 1,989        "labels/actual_label_mean/class1": 0. / 1,990        "labels/actual_label_mean/class2": 1. / 1,991        "labels/logits_mean/class0": self._logits[0][0],992        "labels/logits_mean/class1": self._logits[0][1],993        "labels/logits_mean/class2": self._logits[0][2],994        "labels/prediction_mean/class0": self._logits[0][0],995        "labels/prediction_mean/class1": self._logits[0][1],996        "labels/prediction_mean/class2": self._logits[0][2],997        "labels/probability_mean/class0": 0.576117,  # softmax998        "labels/probability_mean/class1": 0.211942,  # softmax999        "labels/probability_mean/class2": 0.211942,  # softmax1000    }1001  def testMultiClassWithLogits(self):1002    n_classes = 31003    head = head_lib.multi_class_head(1004        n_classes=n_classes, metric_class_ids=range(n_classes))1005    with ops.Graph().as_default(), session.Session():1006      # logloss: z:label, x:logit1007      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1008      model_fn_ops = head.create_model_fn_ops(1009          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1010          logits=self._logits)1011      self._assert_output_alternatives(model_fn_ops)1012      _assert_no_variables(self)1013      _assert_summary_tags(self, ["loss"])1014      expected_loss = 1.55144471015      _assert_metrics(self, expected_loss,1016                      self._expected_eval_metrics(expected_loss), model_fn_ops)1017  def testMultiClassWithInvalidLogits(self):1018    head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)1019    with ops.Graph().as_default(), session.Session():1020      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1021        head.create_model_fn_ops(1022            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1023            logits=self._logits)1024  def testMultiClassWithNoneTrainOpFnInTrain(self):1025    head = head_lib.multi_class_head(n_classes=3)1026    with ops.Graph().as_default(), session.Session():1027      with self.assertRaisesRegexp(1028          ValueError, "train_op_fn can not be None in TRAIN mode"):1029        head.create_model_fn_ops(1030            {}, model_fn.ModeKeys.TRAIN, self._labels,1031            train_op_fn=None,1032            logits=self._logits)1033  def testMultiClassWithLogitsInput(self):1034    n_classes = 31035    head = head_lib.multi_class_head(1036        n_classes=n_classes, metric_class_ids=range(n_classes))1037    with ops.Graph().as_default(), session.Session():1038      # logloss: z:label, x:logit1039      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1040      model_fn_ops = head.create_model_fn_ops(1041          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1042          logits_input=((0., 0.),))1043      self._assert_output_alternatives(model_fn_ops)1044      w = ("multi_class_head/logits/weights:0",1045           "multi_class_head/logits/biases:0")1046      _assert_variables(1047          self, expected_global=w, expected_model=w, expected_trainable=w)1048      variables.global_variables_initializer().run()1049      _assert_summary_tags(self, ["loss"])1050      expected_loss = 1.09861231051      _assert_metrics(self, expected_loss, {1052          "accuracy": 0.,1053          "loss": expected_loss,1054          "labels/actual_label_mean/class0": 0. / 1,1055          "labels/actual_label_mean/class1": 0. / 1,1056          "labels/actual_label_mean/class2": 1. / 1,1057          "labels/logits_mean/class0": 0.,1058          "labels/logits_mean/class1": 0.,1059          "labels/logits_mean/class2": 0.,1060          "labels/prediction_mean/class0": 1.,1061          "labels/prediction_mean/class1": 0.,1062          "labels/prediction_mean/class2": 0.,1063          "labels/probability_mean/class0": 0.333333,  # softmax1064          "labels/probability_mean/class1": 0.333333,  # softmax1065          "labels/probability_mean/class2": 0.333333,  # softmax1066      }, model_fn_ops)1067  def testMultiClassWithLogitsAndLogitsInput(self):1068    n_classes = 31069    head = head_lib.multi_class_head(1070        n_classes=n_classes, metric_class_ids=range(n_classes))1071    with ops.Graph().as_default(), session.Session():1072      with self.assertRaisesRegexp(1073          ValueError, "Both logits and logits_input supplied"):1074        head.create_model_fn_ops(1075            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1076            logits_input=((0., 0.),), logits=self._logits)1077  def testMultiClassEnableCenteredBias(self):1078    n_classes = 31079    head = head_lib.multi_class_head(1080        n_classes=n_classes, enable_centered_bias=True)1081    with ops.Graph().as_default(), session.Session():1082      # logloss: z:label, x:logit1083      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1084      model_fn_ops = head.create_model_fn_ops(1085          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1086          logits=self._logits)1087      self._assert_output_alternatives(model_fn_ops)1088      _assert_variables(1089          self,1090          expected_global=(1091              "multi_class_head/centered_bias_weight:0",1092              ("multi_class_head/multi_class_head/centered_bias_weight/"1093               "Adagrad:0"),1094          ),1095          expected_trainable=("multi_class_head/centered_bias_weight:0",))1096      variables.global_variables_initializer().run()1097      _assert_summary_tags(self,1098                           ["loss",1099                            "multi_class_head/centered_bias/bias_0",1100                            "multi_class_head/centered_bias/bias_1",1101                            "multi_class_head/centered_bias/bias_2"])1102  def testMultiClassEval(self):1103    n_classes = 31104    head = head_lib.multi_class_head(1105        n_classes=n_classes, metric_class_ids=range(n_classes))1106    with ops.Graph().as_default(), session.Session():1107      # logloss: z:label, x:logit1108      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1109      model_fn_ops = head.create_model_fn_ops(1110          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1111          logits=self._logits)1112      self._assert_output_alternatives(model_fn_ops)1113      self.assertIsNone(model_fn_ops.train_op)1114      _assert_no_variables(self)1115      _assert_summary_tags(self, ["loss"])1116      expected_loss = 1.55144471117      _assert_metrics(self, expected_loss,1118                      self._expected_eval_metrics(expected_loss), model_fn_ops)1119  def testMultiClassEvalModeWithLargeLogits(self):1120    n_classes = 31121    head = head_lib.multi_class_head(1122        n_classes=n_classes, metric_class_ids=range(n_classes))1123    logits = ((2., 0., -1),)1124    with ops.Graph().as_default(), session.Session():1125      # logloss: z:label, x:logit1126      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1127      model_fn_ops = head.create_model_fn_ops(1128          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1129          logits=logits)1130      self._assert_output_alternatives(model_fn_ops)1131      self.assertIsNone(model_fn_ops.train_op)1132      _assert_no_variables(self)1133      _assert_summary_tags(self, ["loss"])1134      expected_loss = 3.16984611135      expected_eval_metrics = {1136          "accuracy": 0.,1137          "loss": expected_loss,1138          "labels/actual_label_mean/class0": 0. / 1,1139          "labels/actual_label_mean/class1": 0. / 1,1140          "labels/actual_label_mean/class2": 1. / 1,1141          "labels/logits_mean/class0": logits[0][0],1142          "labels/logits_mean/class1": logits[0][1],1143          "labels/logits_mean/class2": logits[0][2],1144          "labels/prediction_mean/class0": 1,1145          "labels/prediction_mean/class1": 0,1146          "labels/prediction_mean/class2": 0,1147          "labels/probability_mean/class0": 0.843795,  # softmax1148          "labels/probability_mean/class1": 0.114195,  # softmax1149          "labels/probability_mean/class2": 0.0420101,  # softmax1150      }1151      _assert_metrics(self, expected_loss,1152                      expected_eval_metrics, model_fn_ops)1153  def testMultiClassWithScalarWeight(self):1154    n_classes = 31155    head = head_lib.multi_class_head(1156        n_classes=n_classes,1157        weight_column_name="label_weight",1158        metric_class_ids=range(n_classes))1159    with ops.Graph().as_default(), session.Session():1160      weight = .11161      # logloss: z:label, x:logit1162      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1163      model_fn_ops = head.create_model_fn_ops(1164          features={"label_weight": weight},1165          labels=self._labels,1166          mode=model_fn.ModeKeys.TRAIN,1167          train_op_fn=head_lib.no_op_train_fn,1168          logits=self._logits)1169      self._assert_output_alternatives(model_fn_ops)1170      _assert_no_variables(self)1171      _assert_summary_tags(self, ["loss"])1172      expected_loss = 1.55144471173      _assert_metrics(self, expected_loss * weight,1174                      self._expected_eval_metrics(expected_loss), model_fn_ops)1175  def testMultiClassWith1DWeight(self):1176    n_classes = 31177    head = head_lib.multi_class_head(1178        n_classes=n_classes,1179        weight_column_name="label_weight",1180        metric_class_ids=range(n_classes))1181    with ops.Graph().as_default(), session.Session():1182      weight = .11183      weights = (weight,)1184      # logloss: z:label, x:logit1185      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1186      model_fn_ops = head.create_model_fn_ops(1187          features={"label_weight": weights},1188          labels=self._labels,1189          mode=model_fn.ModeKeys.TRAIN,1190          train_op_fn=head_lib.no_op_train_fn,1191          logits=self._logits)1192      self._assert_output_alternatives(model_fn_ops)1193      _assert_no_variables(self)1194      _assert_summary_tags(self, ["loss"])1195      expected_loss = 1.55144471196      _assert_metrics(self, expected_loss * weight,1197                      self._expected_eval_metrics(expected_loss), model_fn_ops)1198  def testMultiClassWith2DWeight(self):1199    n_classes = 31200    head = head_lib.multi_class_head(1201        n_classes=n_classes,1202        weight_column_name="label_weight",1203        metric_class_ids=range(n_classes))1204    with ops.Graph().as_default(), session.Session():1205      weight = .11206      weights = ((weight,),)1207      # logloss: z:label, x:logit1208      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1209      model_fn_ops = head.create_model_fn_ops(1210          features={"label_weight": weights},1211          labels=self._labels,1212          mode=model_fn.ModeKeys.TRAIN,1213          train_op_fn=head_lib.no_op_train_fn,1214          logits=self._logits)1215      self._assert_output_alternatives(model_fn_ops)1216      _assert_no_variables(self)1217      _assert_summary_tags(self, ["loss"])1218      expected_loss = 1.55144471219      _assert_metrics(self, expected_loss * weight,1220                      self._expected_eval_metrics(expected_loss), model_fn_ops)1221  def testMultiClassWithCustomLoss(self):1222    n_classes = 31223    head = head_lib.multi_class_head(1224        n_classes=n_classes,1225        weight_column_name="label_weight",1226        metric_class_ids=range(n_classes),1227        loss_fn=losses_lib.sparse_softmax_cross_entropy)1228    with ops.Graph().as_default(), session.Session():1229      weight = .11230      # logloss: z:label, x:logit1231      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1232      model_fn_ops = head.create_model_fn_ops(1233          features={"label_weight": weight},1234          labels=self._labels,1235          mode=model_fn.ModeKeys.TRAIN,1236          train_op_fn=head_lib.no_op_train_fn,1237          logits=self._logits)1238      self._assert_output_alternatives(model_fn_ops)1239      _assert_no_variables(self)1240      _assert_summary_tags(self, ["loss"])1241      expected_loss = 1.5514447 * weight1242      _assert_metrics(self, expected_loss,1243                      self._expected_eval_metrics(expected_loss), model_fn_ops)1244  def testMultiClassInfer(self):1245    n_classes = 31246    head = head_lib._multi_class_head(1247        n_classes=n_classes,1248        head_name="head_name")1249    with ops.Graph().as_default():1250      model_fn_ops = head.create_model_fn_ops(1251          features={},1252          mode=model_fn.ModeKeys.INFER,1253          train_op_fn=head_lib.no_op_train_fn,1254          logits=((1., 0., 0.), (0., 0., 1.),))1255      with session.Session():1256        lookup_ops.tables_initializer().run()1257        self.assertAllEqual(1258            [0, 2],1259            model_fn_ops.predictions["classes"].eval())1260        self.assertItemsEqual(1261            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1262        self.assertEqual(1263            constants.ProblemType.CLASSIFICATION,1264            model_fn_ops.output_alternatives["head_name"][0])1265        predictions_for_serving = (1266            model_fn_ops.output_alternatives["head_name"][1])1267        self.assertIn("classes", six.iterkeys(predictions_for_serving))1268        self.assertAllEqual(1269            [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],1270            predictions_for_serving["classes"].eval())1271        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1272        self.assertAllClose(1273            [[0.576117, 0.2119416, 0.2119416],1274             [0.2119416, 0.2119416, 0.576117]],1275            predictions_for_serving["probabilities"].eval())1276  def testInvalidNClasses(self):1277    for n_classes in (None, -1, 0, 1):1278      with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):1279        head_lib.multi_class_head(n_classes=n_classes)1280  def testMultiClassWithLabelKeysInvalidShape(self):1281    with self.assertRaisesRegexp(1282        ValueError, "Length of label_keys must equal n_classes"):1283      head_lib._multi_class_head(1284          n_classes=3, label_keys=("key0", "key1"))1285  def testMultiClassWithLabelKeysTwoClasses(self):1286    with self.assertRaisesRegexp(1287        ValueError, "label_keys is not supported for n_classes=2"):1288      head_lib._multi_class_head(1289          n_classes=2, label_keys=("key0", "key1"))1290  def testMultiClassWithLabelKeysInfer(self):1291    n_classes = 31292    label_keys = ("key0", "key1", "key2")1293    head = head_lib._multi_class_head(1294        n_classes=n_classes, label_keys=label_keys,1295        metric_class_ids=range(n_classes),1296        head_name="head_name")1297    with ops.Graph().as_default():1298      model_fn_ops = head.create_model_fn_ops(1299          features={},1300          mode=model_fn.ModeKeys.INFER,1301          train_op_fn=head_lib.no_op_train_fn,1302          logits=((1., 0., 0.), (0., 0., 1.),))1303      with session.Session():1304        lookup_ops.tables_initializer().run()1305        self.assertAllEqual(1306            [b"key0", b"key2"],1307            model_fn_ops.predictions["classes"].eval())1308        self.assertItemsEqual(1309            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1310        self.assertEqual(1311            constants.ProblemType.CLASSIFICATION,1312            model_fn_ops.output_alternatives["head_name"][0])1313        predictions_for_serving = (1314            model_fn_ops.output_alternatives["head_name"][1])1315        self.assertIn("classes", six.iterkeys(predictions_for_serving))1316        self.assertAllEqual(1317            [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]],1318            predictions_for_serving["classes"].eval())1319        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1320        self.assertAllClose(1321            [[0.576117, 0.2119416, 0.2119416],1322             [0.2119416, 0.2119416, 0.576117]],1323            predictions_for_serving["probabilities"].eval())1324  def testMultiClassWithLabelKeysEvalAccuracy0(self):1325    n_classes = 31326    label_keys = ("key0", "key1", "key2")1327    head = head_lib._multi_class_head(1328        n_classes=n_classes,1329        label_keys=label_keys)1330    with ops.Graph().as_default():1331      model_fn_ops = head.create_model_fn_ops(1332          features={},1333          mode=model_fn.ModeKeys.EVAL,1334          labels=("key2",),1335          train_op_fn=head_lib.no_op_train_fn,1336          logits=((1., 0., 0.),))1337      with session.Session():1338        lookup_ops.tables_initializer().run()1339        self.assertIsNone(model_fn_ops.train_op)1340        _assert_no_variables(self)1341        _assert_summary_tags(self, ["loss"])1342        expected_loss = 1.55144471343        expected_eval_metrics = {1344            "accuracy": 0.,1345            "loss": expected_loss,1346        }1347        _assert_metrics(self, expected_loss,1348                        expected_eval_metrics, model_fn_ops)1349  def testMultiClassWithLabelKeysEvalAccuracy1(self):1350    n_classes = 31351    label_keys = ("key0", "key1", "key2")1352    head = head_lib._multi_class_head(1353        n_classes=n_classes,1354        label_keys=label_keys)1355    with ops.Graph().as_default():1356      model_fn_ops = head.create_model_fn_ops(1357          features={},1358          mode=model_fn.ModeKeys.EVAL,1359          labels=("key2",),1360          train_op_fn=head_lib.no_op_train_fn,1361          logits=((0., 0., 1.),))1362      with session.Session():1363        lookup_ops.tables_initializer().run()1364        self.assertIsNone(model_fn_ops.train_op)1365        _assert_no_variables(self)1366        _assert_summary_tags(self, ["loss"])1367        expected_loss = 0.55144471368        expected_eval_metrics = {1369            "accuracy": 1.,1370            "loss": expected_loss,1371        }1372        _assert_metrics(self, expected_loss,1373                        expected_eval_metrics, model_fn_ops)1374class BinarySvmHeadTest(test.TestCase):1375  def _assert_output_alternatives(self, model_fn_ops):1376    self.assertEquals({1377        None: constants.ProblemType.LOGISTIC_REGRESSION1378    }, {1379        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1380    })1381  def setUp(self):1382    # Prediction for first example is in the right side of the hyperplane1383    # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss1384    # incurred by this example. The 2nd prediction is outside the margin so it1385    # incurs no loss at all.1386    self._predictions = ((-.5,), (1.2,))1387    self._labels = (0, 1)1388    self._expected_losses = (.5, 0.)1389  def testBinarySVMWithLogits(self):1390    head = head_lib.binary_svm_head()1391    with ops.Graph().as_default(), session.Session():1392      model_fn_ops = head.create_model_fn_ops(1393          {},1394          model_fn.ModeKeys.TRAIN,1395          self._labels,1396          head_lib.no_op_train_fn,1397          logits=self._predictions)1398      self._assert_output_alternatives(model_fn_ops)1399      _assert_no_variables(self)1400      _assert_summary_tags(self, ["loss"])1401      expected_loss = np.average(self._expected_losses)1402      _assert_metrics(self, expected_loss, {1403          "accuracy": 1.,1404          "loss": expected_loss,1405      }, model_fn_ops)1406  def testBinarySVMWithInvalidLogits(self):1407    head = head_lib.binary_svm_head()1408    with ops.Graph().as_default(), session.Session():1409      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1410        head.create_model_fn_ops(1411            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1412            logits=np.ones((2, 2)))1413  def testBinarySVMWithLogitsInput(self):1414    head = head_lib.binary_svm_head()1415    with ops.Graph().as_default(), session.Session():1416      model_fn_ops = head.create_model_fn_ops(1417          {},1418          model_fn.ModeKeys.TRAIN,1419          self._labels,1420          head_lib.no_op_train_fn,1421          logits_input=((0., 0.), (0., 0.)))1422      self._assert_output_alternatives(model_fn_ops)1423      w = ("binary_svm_head/logits/weights:0",1424           "binary_svm_head/logits/biases:0")1425      _assert_variables(1426          self, expected_global=w, expected_model=w, expected_trainable=w)1427      variables.global_variables_initializer().run()1428      _assert_summary_tags(self, ["loss"])1429      expected_loss = 1.1430      _assert_metrics(self, expected_loss, {1431          "accuracy": .5,1432          "loss": expected_loss,1433      }, model_fn_ops)1434  def testBinarySVMWithLogitsAndLogitsInput(self):1435    head = head_lib.binary_svm_head()1436    with ops.Graph().as_default(), session.Session():1437      with self.assertRaisesRegexp(1438          ValueError, "Both logits and logits_input supplied"):1439        head.create_model_fn_ops(1440            {},1441            model_fn.ModeKeys.TRAIN,1442            self._labels,1443            head_lib.no_op_train_fn,1444            logits_input=((0., 0.), (0., 0.)),1445            logits=self._predictions)1446  def testBinarySVMEvalMode(self):1447    head = head_lib.binary_svm_head()1448    with ops.Graph().as_default(), session.Session():1449      model_fn_ops = head.create_model_fn_ops(1450          {},1451          model_fn.ModeKeys.EVAL,1452          self._labels,1453          head_lib.no_op_train_fn,1454          logits=self._predictions)1455      self._assert_output_alternatives(model_fn_ops)1456      self.assertIsNone(model_fn_ops.train_op)1457      _assert_no_variables(self)1458      _assert_summary_tags(self, ["loss"])1459      expected_loss = np.average(self._expected_losses)1460      _assert_metrics(self, expected_loss, {1461          "accuracy": 1.,1462          "loss": expected_loss,1463      }, model_fn_ops)1464  def testBinarySVMWithLabelName(self):1465    label_name = "my_label"1466    head = head_lib.binary_svm_head(label_name=label_name)1467    with ops.Graph().as_default(), session.Session():1468      model_fn_ops = head.create_model_fn_ops(1469          {},1470          model_fn.ModeKeys.TRAIN,1471          {label_name: self._labels},1472          head_lib.no_op_train_fn,1473          logits=self._predictions)1474      self._assert_output_alternatives(model_fn_ops)1475      _assert_no_variables(self)1476      _assert_summary_tags(self, ["loss"])1477      expected_loss = np.average(self._expected_losses)1478      _assert_metrics(self, expected_loss, {1479          "accuracy": 1.,1480          "loss": expected_loss,1481      }, model_fn_ops)1482  def testBinarySVMWith1DWeights(self):1483    head = head_lib.binary_svm_head(weight_column_name="weights")1484    with ops.Graph().as_default(), session.Session():1485      weights = (7., 11.)1486      model_fn_ops = head.create_model_fn_ops(1487          # We have to add an extra dim here for weights broadcasting to work.1488          features={"weights": weights},1489          mode=model_fn.ModeKeys.TRAIN,1490          labels=self._labels,1491          train_op_fn=head_lib.no_op_train_fn,1492          logits=self._predictions)1493      self._assert_output_alternatives(model_fn_ops)1494      _assert_no_variables(self)1495      _assert_summary_tags(self, ["loss"])1496      expected_weighted_losses = np.multiply(weights, self._expected_losses)1497      _assert_metrics(self, np.mean(expected_weighted_losses), {1498          "accuracy": 1.,1499          "loss": np.sum(expected_weighted_losses) / np.sum(weights),1500      }, model_fn_ops)1501  def testBinarySVMWith2DWeights(self):1502    head = head_lib.binary_svm_head(weight_column_name="weights")1503    with ops.Graph().as_default(), session.Session():1504      weights = (7., 11.)1505      model_fn_ops = head.create_model_fn_ops(1506          # We have to add an extra dim here for weights broadcasting to work.1507          features={"weights": tuple([(w,) for w in weights])},1508          mode=model_fn.ModeKeys.TRAIN,1509          labels=self._labels,1510          train_op_fn=head_lib.no_op_train_fn,1511          logits=self._predictions)1512      self._assert_output_alternatives(model_fn_ops)1513      _assert_no_variables(self)1514      _assert_summary_tags(self, ["loss"])1515      expected_weighted_losses = np.multiply(weights, self._expected_losses)1516      _assert_metrics(self, np.mean(expected_weighted_losses), {1517          "accuracy": 1.,1518          "loss": np.sum(expected_weighted_losses) / np.sum(weights),1519      }, model_fn_ops)1520  def testBinarySVMWithCenteredBias(self):1521    head = head_lib.binary_svm_head(enable_centered_bias=True)1522    with ops.Graph().as_default(), session.Session():1523      model_fn_ops = head.create_model_fn_ops(1524          {},1525          model_fn.ModeKeys.TRAIN,1526          self._labels,1527          head_lib.no_op_train_fn,1528          logits=self._predictions)1529      self._assert_output_alternatives(model_fn_ops)1530      _assert_variables(1531          self,1532          expected_global=(1533              "binary_svm_head/centered_bias_weight:0",1534              ("binary_svm_head/binary_svm_head/centered_bias_weight/"1535               "Adagrad:0"),1536          ),1537          expected_trainable=("binary_svm_head/centered_bias_weight:0",))1538      variables.global_variables_initializer().run()1539      _assert_summary_tags(self, [1540          "loss",1541          "binary_svm_head/centered_bias/bias_0"1542      ])1543      expected_loss = np.average(self._expected_losses)1544      _assert_metrics(self, expected_loss, {1545          "accuracy": 1.,1546          "loss": expected_loss,1547      }, model_fn_ops)1548class LossOnlyHead(test.TestCase):1549  def testNoPredictionsAndNoMetrics(self):1550    head = head_lib.loss_only_head(lambda: 1, head_name="const")1551    model_fn_ops = head.create_model_fn_ops(1552        features={},1553        mode=model_fn.ModeKeys.TRAIN,1554        train_op_fn=head_lib.no_op_train_fn)1555    self.assertDictEqual(model_fn_ops.predictions, {})1556    self.assertDictEqual(model_fn_ops.eval_metric_ops, {})1557    self.assertIsNotNone(model_fn_ops.loss)1558    with session.Session() as sess:1559      self.assertEqual(1, sess.run(model_fn_ops.loss))1560class MultiHeadTest(test.TestCase):1561  def testInvalidHeads(self):1562    named_head = head_lib.multi_class_head(1563        n_classes=3, label_name="label", head_name="head1")1564    unnamed_head = head_lib.multi_class_head(1565        n_classes=4, label_name="label")1566    with self.assertRaisesRegexp(ValueError, "must have names"):1567      head_lib.multi_head((named_head, unnamed_head))1568  def testTrainWithNoneTrainOpFn(self):1569    head1 = head_lib.multi_class_head(1570        n_classes=3, label_name="label1", head_name="head1")1571    head2 = head_lib.multi_class_head(1572        n_classes=4, label_name="label2", head_name="head2")1573    head = head_lib.multi_head((head1, head2))1574    labels = {1575        "label1": (1,),1576        "label2": (1,)1577    }1578    with self.assertRaisesRegexp(1579        ValueError, "train_op_fn can not be None in TRAIN mode"):1580      head.create_model_fn_ops(1581          features={"weights": (2.0, 10.0)},1582          labels=labels,1583          mode=model_fn.ModeKeys.TRAIN,1584          train_op_fn=None,1585          logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1586  def testTrain_withNoHeadWeights(self):1587    head1 = head_lib.multi_class_head(1588        n_classes=3, label_name="label1", head_name="head1")1589    head2 = head_lib.multi_class_head(1590        n_classes=4, label_name="label2", head_name="head2")1591    head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")1592    head = head_lib.multi_head((head1, head2, head3))1593    labels = {1594        "label1": (1,),1595        "label2": (1,)1596    }1597    model_fn_ops = head.create_model_fn_ops(1598        features={"weights": (2.0, 10.0)},1599        labels=labels,1600        mode=model_fn.ModeKeys.TRAIN,1601        train_op_fn=head_lib.no_op_train_fn,1602        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1603    self.assertIsNone(model_fn_ops.predictions)1604    self.assertIsNotNone(model_fn_ops.loss)1605    self.assertIsNotNone(model_fn_ops.train_op)1606    self.assertTrue(model_fn_ops.eval_metric_ops)1607    self.assertIsNone(model_fn_ops.output_alternatives)1608    with session.Session() as sess:1609      self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)1610  def testTrain_withHeadWeights(self):1611    head1 = head_lib.multi_class_head(1612        n_classes=3, label_name="label1", head_name="head1")1613    head2 = head_lib.multi_class_head(1614        n_classes=4, label_name="label2", head_name="head2")1615    head = head_lib.multi_head((head1, head2), (1, .5))1616    labels = {1617        "label1": (1,),1618        "label2": (1,)1619    }1620    model_fn_ops = head.create_model_fn_ops(1621        features={"weights": (2.0, 10.0)},1622        labels=labels,1623        mode=model_fn.ModeKeys.TRAIN,1624        train_op_fn=head_lib.no_op_train_fn,1625        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1626    self.assertIsNone(model_fn_ops.predictions)1627    self.assertIsNotNone(model_fn_ops.loss)1628    self.assertIsNotNone(model_fn_ops.train_op)1629    self.assertTrue(model_fn_ops.eval_metric_ops)1630    self.assertIsNone(model_fn_ops.output_alternatives)1631    with session.Session() as sess:1632      self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)1633  def testTrain_withDictLogits(self):1634    head1 = head_lib.multi_class_head(1635        n_classes=3, label_name="label1", head_name="head1")1636    head2 = head_lib.multi_class_head(1637        n_classes=4, label_name="label2", head_name="head2")1638    head = head_lib.multi_head((head1, head2))1639    labels = {1640        "label1": (1,),1641        "label2": (1,)1642    }1643    model_fn_ops = head.create_model_fn_ops(1644        features={"weights": (2.0, 10.0)},1645        labels=labels,1646        mode=model_fn.ModeKeys.TRAIN,1647        train_op_fn=head_lib.no_op_train_fn,1648        logits={head1.head_name: ((-0.7, 0.2, .1),),1649                head2.head_name: ((.1, .1, .1, .1),)})1650    self.assertIsNone(model_fn_ops.predictions)1651    self.assertIsNotNone(model_fn_ops.loss)1652    self.assertIsNotNone(model_fn_ops.train_op)1653    self.assertTrue(model_fn_ops.eval_metric_ops)1654    self.assertIsNone(model_fn_ops.output_alternatives)1655    with session.Session() as sess:1656      self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)1657  def testInfer(self):1658    head1 = head_lib.multi_class_head(1659        n_classes=3, label_name="label1", head_name="head1")1660    head2 = head_lib.multi_class_head(1661        n_classes=4, label_name="label2", head_name="head2")1662    head = head_lib.multi_head((head1, head2), (1, .5))1663    labels = {1664        "label1": (1,),1665        "label2": (1,)1666    }1667    model_fn_ops = head.create_model_fn_ops(1668        features={"weights": (2.0, 10.0)},1669        labels=labels,1670        mode=model_fn.ModeKeys.INFER,1671        train_op_fn=head_lib.no_op_train_fn,1672        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1673    self.assertIsNotNone(model_fn_ops.predictions)1674    self.assertIsNone(model_fn_ops.loss)1675    self.assertIsNone(model_fn_ops.train_op)1676    self.assertFalse(model_fn_ops.eval_metric_ops)1677    # Tests predictions keys.1678    self.assertItemsEqual((1679        ("head1", prediction_key.PredictionKey.LOGITS),1680        ("head1", prediction_key.PredictionKey.PROBABILITIES),1681        ("head1", prediction_key.PredictionKey.CLASSES),1682        ("head2", prediction_key.PredictionKey.LOGITS),1683        ("head2", prediction_key.PredictionKey.PROBABILITIES),1684        ("head2", prediction_key.PredictionKey.CLASSES),1685    ), model_fn_ops.predictions.keys())1686    # Tests output alternative.1687    self.assertEquals({1688        "head1": constants.ProblemType.CLASSIFICATION,1689        "head2": constants.ProblemType.CLASSIFICATION,1690    }, {1691        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1692    })1693    self.assertItemsEqual((1694        prediction_key.PredictionKey.PROBABILITIES,1695        prediction_key.PredictionKey.CLASSES,1696    ), model_fn_ops.output_alternatives["head1"][1].keys())1697    self.assertItemsEqual((1698        prediction_key.PredictionKey.PROBABILITIES,1699        prediction_key.PredictionKey.CLASSES,1700    ), model_fn_ops.output_alternatives["head2"][1].keys())1701  def testEval(self):1702    head1 = head_lib.multi_class_head(1703        n_classes=3, label_name="label1", head_name="head1")1704    head2 = head_lib.multi_class_head(1705        n_classes=4, label_name="label2", head_name="head2")1706    head = head_lib.multi_head((head1, head2), (1, .5))1707    labels = {1708        "label1": (1,),1709        "label2": (1,)1710    }1711    model_fn_ops = head.create_model_fn_ops(1712        features={"weights": (2.0, 10.0)},1713        labels=labels,1714        mode=model_fn.ModeKeys.EVAL,1715        train_op_fn=head_lib.no_op_train_fn,1716        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1717    self.assertIsNotNone(model_fn_ops.predictions)1718    self.assertIsNotNone(model_fn_ops.loss)1719    self.assertIsNone(model_fn_ops.train_op)1720    self.assertIsNotNone(model_fn_ops.eval_metric_ops)...multi_head.py
Source:multi_head.py  
...11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Abstractions for the head(s) of a model.16"""17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20import six21from tensorflow.python.estimator import model_fn22from tensorflow.python.estimator.canned import head as head_lib23from tensorflow.python.estimator.canned import metric_keys24from tensorflow.python.estimator.export import export_output as export_output_lib25from tensorflow.python.framework import ops26from tensorflow.python.ops import array_ops27from tensorflow.python.ops import control_flow_ops28from tensorflow.python.ops import math_ops29from tensorflow.python.ops import metrics as metrics_lib30from tensorflow.python.saved_model import signature_constants31from tensorflow.python.summary import summary32from tensorflow.python.training import training_util33_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY34def multi_head(heads, head_weights=None):35  """Creates a `_Head` for multi-objective learning.36  This class merges the output of multiple `_Head` objects.37  Specifically:38  * For training, sums losses of each head, calls `train_op_fn` with this39    final loss.40  * For eval, merges metrics by adding `head.name` suffix to the keys in eval41    metrics, such as `precision/head1`, `precision/head2`.42  * For prediction, merges predictions and updates keys in prediction dict to a43    2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that44    by default the first head is served.45  Usage:46  ```python47  # In `input_fn` specify labels as a dict keyed by head name:48  def input_fn():49    features = ...50    labels1 = ...51    labels2 = ...52    return features, {'head1': labels1, 'head2': labels2}53  # In `model_fn`, specify logits as a dict keyed by head name:54  def model_fn(features, labels, mode):55    # Create simple heads and specify head name.56    head1 = multi_class_head(n_classes=3, name='head1')57    head2 = binary_classification_head(name='head2')58    # Create multi-head from two simple heads.59    head = multi_head([head1, head2])60    # Create logits for each head, and combine them into a dict.61    logits1, logits2 = logit_fn()62    logits = {'head1': logits1, 'head2': logits2}63    # Return the merged EstimatorSpec64    return head.create_estimator_spec(..., logits=logits, ...)65  # Create an estimator with this model_fn.66  estimator = tf.estimator.Estimator(model_fn=model_fn)67  estimator.train(input_fn=input_fn, steps=100)68  ```69  Also supports `logits` as a `Tensor` of shape70  `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the71  last dimension and distribute it appropriately among the heads. E.g.:72  ```python73  def model_fn(features, labels, mode):74    # Create simple heads and specify head name.75    head1 = multi_class_head(n_classes=3, name='head1')76    head2 = binary_classification_head(name='head2')77    # Create multi-head from two simple heads.78    head = multi_head([head1, head2])79    # Create logits for the multihead.80    logits = logit_fn(logits_dimension=head.logits_dimension)81    # Return the merged EstimatorSpec82    return head.create_estimator_spec(..., logits=logits, ...)83  ```84  Args:85    heads: List or tuple of `_Head` instances. All heads must have `name`86      specified. The first head in the list is the default used at serving time.87    head_weights: Optional list of weights, same length as `heads`. Used when88      merging losses to calculate the weighted sum of losses from each head. If89      `None`, all losses are weighted equally.90  Returns:91    A instance of `_Head` that merges multiple heads.92  Raises:...net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py
Source:net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py  
1import torch2import torch.nn as nn3import torch.nn.functional as F4from net.resnet import resnet34, resnet185import copy6from torch.nn import init7from IPython import embed8from auxi.module import FFM_v4, RNNModule9def weights_init_kaiming(m):10    classname = m.__class__.__name__11    if classname.find('Conv') != -1:12        nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')13    elif classname.find('Linear') != -1:14        nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')15        nn.init.constant_(m.bias.data, 0.0)16    elif classname.find('BatchNorm1d') != -1:17        nn.init.normal_(m.weight.data, 1.0, 0.02)18        nn.init.constant_(m.bias.data, 0.0)19class net(nn.Module):20    def __init__(self, config):21        super(net, self).__init__()22        self.net_head = resnet18(pretrained=True, num_classes=26)23        self.head_layer1 = nn.Sequential(*list(self.net_head.children())[:5])24        self.head_layer2 = list(self.net_head.children())[5]25        self.head_layer3 = list(self.net_head.children())[6]26        self.head_layer4 = list(self.net_head.children())[7]27        self.head_cls = list(self.net_head.children())[-1]28        self.net_body = resnet18(pretrained=True, num_classes=26)29        self.body_layer1 = nn.Sequential(*list(self.net_body.children())[:5])30        self.body_layer2 = list(self.net_body.children())[5]31        self.body_layer3 = list(self.net_body.children())[6]32        self.body_layer4 = list(self.net_body.children())[7]33        self.body_cls = list(self.net_body.children())[-1]34        self.net_scene = resnet18(pretrained=True, num_classes=26)35        self.scene_layer1 = nn.Sequential(*list(self.net_scene.children())[:5])36        self.scene_layer2 = list(self.net_scene.children())[5]37        self.scene_layer3 = list(self.net_scene.children())[6]38        self.scene_layer4 = list(self.net_scene.children())[7]39        self.scene_cls = list(self.net_scene.children())[-1]40        self.fc3 = nn.Sequential(nn.Linear(512*3 , 128), nn.Linear(128, 26))41        self.csr_ly1 = RNNModule(input_size=64, hidden_size=64, num_layers=1)42        self.csr_ly2 = RNNModule(input_size=128, hidden_size=128, num_layers=1)43        self.csr_ly3 = RNNModule(input_size=256, hidden_size=256, num_layers=1)44        self.csr_ly4 = RNNModule(input_size=512, hidden_size=512, num_layers=1)45        self.ffm_head_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)46        self.ffm_body_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)47        self.ffm_scene_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)48        self.ffm_head_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)49        self.ffm_body_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)50        self.ffm_scene_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)51    def forward(self, data, mode):52        head_ly1 = self.head_layer1(data['image_head'])53        body_ly1 = self.body_layer1(data['image_body'])54        scene_ly1 = self.scene_layer1(data['image_scene'])55        head_ly1_ = F.adaptive_avg_pool2d(head_ly1, (1,1)).view(head_ly1.size(0), -1)56        body_ly1_ = F.adaptive_avg_pool2d(body_ly1, (1,1)).view(body_ly1.size(0), -1)57        scene_ly1_ = F.adaptive_avg_pool2d(scene_ly1, (1,1)).view(scene_ly1.size(0), -1)58        feature = self.csr_ly1(torch.stack((head_ly1_, body_ly1_, scene_ly1_), dim=1))59        head_ly1 = head_ly1+feature[:,0,:].view(head_ly1.size(0),head_ly1.size(1),1,1).expand_as(head_ly1)60        body_ly1 = body_ly1+feature[:,1,:].view(body_ly1.size(0),body_ly1.size(1),1,1).expand_as(body_ly1)61        scene_ly1 = scene_ly1+feature[:,2,:].view(scene_ly1.size(0),scene_ly1.size(1),1,1).expand_as(scene_ly1)62        head_ly2 = self.head_layer2(head_ly1)63        body_ly2 = self.body_layer2(body_ly1)64        scene_ly2 = self.scene_layer2(scene_ly1)65        head_ly2_ = F.adaptive_avg_pool2d(head_ly2, (1,1)).view(head_ly2.size(0), -1)66        body_ly2_ = F.adaptive_avg_pool2d(body_ly2, (1,1)).view(body_ly2.size(0), -1)67        scene_ly2_ = F.adaptive_avg_pool2d(scene_ly2, (1,1)).view(scene_ly2.size(0), -1)68        feature = self.csr_ly2(torch.stack((head_ly2_, body_ly2_, scene_ly2_), dim=1))69        head_ly2 = head_ly2+feature[:,0,:].view(head_ly2.size(0),head_ly2.size(1),1,1).expand_as(head_ly2)70        body_ly2 = body_ly2+feature[:,1,:].view(body_ly2.size(0),body_ly2.size(1),1,1).expand_as(body_ly2)71        scene_ly2 = scene_ly2+feature[:,2,:].view(scene_ly2.size(0),scene_ly2.size(1),1,1).expand_as(scene_ly2)72        head_ly2 = self.ffm_head_ly2(head_ly2, body_ly2, scene_ly2) + head_ly273        body_ly2 = self.ffm_body_ly2(body_ly2, head_ly2, scene_ly2) + body_ly274        scene_ly2 = self.ffm_scene_ly2(scene_ly2, head_ly2, body_ly2) + scene_ly275        head_ly3 = self.head_layer3(head_ly2)76        body_ly3 = self.body_layer3(body_ly2)77        scene_ly3 = self.scene_layer3(scene_ly2)78        head_ly3_ = F.adaptive_avg_pool2d(head_ly3, (1,1)).view(head_ly3.size(0), -1)79        body_ly3_ = F.adaptive_avg_pool2d(body_ly3, (1,1)).view(body_ly3.size(0), -1)80        scene_ly3_ = F.adaptive_avg_pool2d(scene_ly3, (1,1)).view(scene_ly3.size(0), -1)81        feature = self.csr_ly3(torch.stack((head_ly3_, body_ly3_, scene_ly3_), dim=1))82        head_ly3 = head_ly3+feature[:,0,:].view(head_ly3.size(0),head_ly3.size(1),1,1).expand_as(head_ly3)83        body_ly3 = body_ly3+feature[:,1,:].view(body_ly3.size(0),body_ly3.size(1),1,1).expand_as(body_ly3)84        scene_ly3 = scene_ly3+feature[:,2,:].view(scene_ly3.size(0),scene_ly3.size(1),1,1).expand_as(scene_ly3)85        head_ly4 = self.head_layer4(head_ly3)86        body_ly4 = self.body_layer4(body_ly3)87        scene_ly4 = self.scene_layer4(scene_ly3)88        head_ly4_ = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)89        body_ly4_ = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)90        scene_ly4_ = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)91        feature = self.csr_ly4(torch.stack((head_ly4_, body_ly4_, scene_ly4_), dim=1))92        head_ly4 = head_ly4+feature[:,0,:].view(head_ly4.size(0),head_ly4.size(1),1,1).expand_as(head_ly4)93        body_ly4 = body_ly4+feature[:,1,:].view(body_ly4.size(0),body_ly4.size(1),1,1).expand_as(body_ly4)94        scene_ly4 = scene_ly4+feature[:,2,:].view(scene_ly4.size(0),scene_ly4.size(1),1,1).expand_as(scene_ly4)95        head_ly4 = self.ffm_head_ly4(head_ly4, body_ly4, scene_ly4) + head_ly496        body_ly4 = self.ffm_body_ly4(body_ly4, head_ly4, scene_ly4) + body_ly497        scene_ly4 = self.ffm_scene_ly4(scene_ly4, head_ly4, body_ly4) + scene_ly498        head_ly4 = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)99        body_ly4 = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)100        scene_ly4 = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)101        out_head = self.head_cls(head_ly4)102        out_body = self.body_cls(body_ly4)103        out_scene = self.scene_cls(scene_ly4)104        out = self.fc3(torch.cat((head_ly4, body_ly4, scene_ly4), dim=1))...head_tail.py
Source:head_tail.py  
1"""Utilities to manage head and tail of elements2The scope is to avoid loosing part of the original text in the final tree.3"""4from .tree import Item5class TokenValue:6    def __init__(self, value):7        self.value = value8        self.pos = None9        self.size = None10        self.head = ""11        self.tail = ""12    def __repr__(self):13        return "TokenValue(%s)" % self.value14    def __str__(self):15        return self.value16class HeadTailLexer:17    """Utility to handle head and tail at lexer time.18    """19    LEXER_ATTR = "_luqum_headtail"20    @classmethod21    def handle(cls, token, orig_value):22        """Handling a token.23        .. note::24          PLY does not gives acces to previous tokens,25          although it does not provide any infrastructure for handling specific state.26          So we use the strategy27          of puting a :py:cls:`HeadTailLexer`instance as an attribute of the lexer28          each time we start a new tokenization.29        """30        # get instance31        if token.lexpos == 0:32            # first token make instance33            instance = cls()34            setattr(token.lexer, cls.LEXER_ATTR, instance)35        else:36            instance = getattr(token.lexer, cls.LEXER_ATTR)37        # handle38        instance.handle_token(token, orig_value)39    def __init__(self):40        self.head = None41        """This will track the head of next element, useful only for first element42        """43        self.last_elt = None44        """This will track the last token, so we can use it to add the tail to it.45        """46    def handle_token(self, token, orig_value):47        """Handle head and tail for tokens48        The scope is to avoid loosing part of the original text and keep it in elements.49        """50        # handle headtail51        if token.type == "SEPARATOR":52            if token.lexpos == 0:53                # spaces at expression start, head for next token54                self.head = token.value55            else:56                # tail of last processed token57                if self.last_elt is not None:58                    self.last_elt.value.tail += token.value59        else:60            # if there is a head, apply61            head = self.head62            if head is not None:63                token.value.head = head64                self.head = None65            # keep tracks of token, to apply tail later66            self.last_elt = token67        # also set pos and size68        if isinstance(token.value, (Item, TokenValue)):69            token.value.pos = token.lexpos70            token.value.size = len(orig_value)71token_headtail = HeadTailLexer.handle72class HeadTailManager:73    """Utility to hande head and tail at expression parse time74    """75    def pos(self, p, head_transfer=False, tail_transfer=False):76        """Compute pos and size of element 0 based on it's parts (p[1:])77        :param list p: the parser expression as in PLY78        :param bool head_transfer: True if head of first child will be transfered to p[0]79        :param bool tail_transfer: True if tail of last child wiil be transfered to p[0]80        """81        # pos82        if p[1].pos is not None:83            p[0].pos = p[1].pos84            if not head_transfer:85                # head is'nt transfered, so we are before it86                p[0].pos -= len(p[1].head)87        # size88        p[0].size = sum(89            (elt.size or 0) + len(elt.head or "") + len(elt.tail or "") for elt in p[1:])90        if head_transfer and p[1].head:91            # we account head in size, remove it92            p[0].size -= len(p[1].head)93        last_p = p[len(p) - 1]  # negative indexing not supported by PLY94        if tail_transfer and last_p.tail:95            # we account head in size, remove it96            p[0].size -= len(last_p.tail)97    def binary_operation(self, p, op_tail):98        self.pos(p, head_transfer=False, tail_transfer=False)99        # correct size100        p[0].size -= len(op_tail)101    def simple_term(self, p):102        self.pos(p, head_transfer=True, tail_transfer=True)103        p[0].head = p[1].head104        p[0].tail = p[1].tail105    def unary(self, p):106        """OP expr"""107        self.pos(p, head_transfer=True, tail_transfer=False)108        p[0].head = p[1].head109        p[2].head = p[1].tail + p[2].head110    def post_unary(self, p):111        """expr OP"""112        self.pos(p, head_transfer=False, tail_transfer=True)113        p[1].tail += p[2].head114        p[0].tail = p[2].tail115    def paren(self, p):116        """( expr )"""117        self.pos(p, head_transfer=True, tail_transfer=True)118        # p[0] is global element (Group or FieldGroup)119        # p[2] is content120        # p[1] is left parenthesis121        p[0].head = p[1].head122        p[2].head = p[1].tail + p[2].head123        # p[3] is right parenthesis124        p[2].tail += p[3].head125        p[0].tail = p[3].tail126    def range(self, p):127        """[ expr TO expr ]"""128        self.pos(p, head_transfer=True, tail_transfer=True)129        # p[0] is global element (Range)130        # p[2] is lower bound131        p[0].head = p[1].head132        p[2].head = p[1].tail + p[2].head133        # p[3] is TO134        # p[4] is upper bound135        p[2].tail += p[3].head136        p[4].head = p[3].tail + p[4].head137        # p[5] is upper braket138        p[4].tail += p[5].head139        p[0].tail = p[5].tail140    def search_field(self, p):141        """name: expr"""142        self.pos(p, head_transfer=True, tail_transfer=False)143        # p[0] is global element (SearchField)144        # p[1] is search field name145        # p[2] is COLUMN146        p[0].head = p[1].head147        if p[1].tail or p[2].head:148            pass  # FIXME: add warning, or handle space between point and name in SearchField ?149        # p[3] is the expression150        p[3].head = p[2].tail + p[3].head151head_tail = HeadTailManager()152"""singleton of HeadTailManager...LambdaTest’s Playwright tutorial will give you a broader idea about the Playwright automation framework, its unique features, and use cases with examples to exceed your understanding of Playwright testing. This tutorial will give A to Z guidance, from installing the Playwright framework to some best practices and advanced concepts.
Get 100 minutes of automation test minutes FREE!!
