How to use mktest method in Nose

Best Python code snippet using nose

others.py

Source:others.py Github

copy

Full Screen

1#!/usr/bin/env python 2# -*- coding: utf-8 -*-3# ==============================================================================4# \file others.py5# \author chenghuige 6# \date 2020-07-22 15:11:13.1362067# \Description 8# ==============================================================================9from __future__ import absolute_import10from __future__ import division11from __future__ import print_function12import sys 13import os14import gezi15import melt as mt16import tensorflow as tf17from absl import flags18FLAGS = flags.FLAGS19from tensorflow import keras20from tensorflow.keras import backend as K21import numpy as np22from projects.feed.rank.src.config import *23from projects.feed.rank.src import util 24from projects.feed.rank.src.history import *25from projects.feed.rank.src import history26from projects.feed.rank.src.keywords import *27import gezi 28logging = gezi.logging29class Others(mt.Model):30 def __init__(self):31 super(Others, self).__init__()32 # self.regularizer = keras.regularizers.l1_l2(l2=FLAGS.l2_reg)33 self.regularizer = None34 Embedding = keras.layers.Embedding35 SimpleEmbedding = mt.layers.SimpleEmbedding36 HashEmbedding, HashEmbeddingUD = util.get_hash_embedding_type()37 kwargs = dict(num_buckets=FLAGS.num_feature_buckets, combiner=FLAGS.hash_combiner, 38 embeddings_regularizer=self.regularizer, num_shards=FLAGS.num_shards)39 self.kwargs = kwargs40 self.HashEmbedding = HashEmbedding41 if FLAGS.use_user_emb:42 self.user_emb = HashEmbeddingUD(int(FLAGS.feature_dict_size * FLAGS.user_emb_factor), FLAGS.other_emb_dim, name='user_emb', **kwargs)43 if FLAGS.use_doc_emb:44 self.doc_emb = HashEmbeddingUD(int(FLAGS.feature_dict_size * FLAGS.doc_emb_factor), FLAGS.other_emb_dim, name='doc_emb', **kwargs)45 self.topic_emb = None46 self.kw_emb = None47 self.mktest_kw_emb = SimpleEmbedding(FLAGS.keyword_dict_size, FLAGS.other_emb_dim, name='mktest_kw_emb')48 # mkyuwen kw49 if FLAGS.use_merge_kw_emb:50 if not FLAGS.use_w2v_kw_emb:51 # self.mktest_kw_emb = SimpleEmbedding(FLAGS.keyword_dict_size, FLAGS.other_emb_dim, name='mktest_kw_emb')52 # 这部分统一,下面的好写一些53 self.mktest_user_kw_emb = self.mktest_kw_emb54 self.mktest_doc_kw_emb = self.mktest_kw_emb55 else:56 # mkyuwen 061257 # https://www.cnblogs.com/weiyinfu/p/9873001.html58 ini_w2v_emb_weights, ini_user_w2v_emb_weights, ini_doc_w2v_emb_weights = util.load_pretrained_w2v_emb()59 if not FLAGS.use_split_w2v_kw_emb: # 公用kw-emb60 # self.mktest_kw_emb = keras.layers.Embedding(input_dim=FLAGS.keyword_dict_size + 1, output_dim=FLAGS.other_emb_dim,61 # weights=[ini_w2v_emb_weights.reshape(FLAGS.keyword_dict_size + 1, FLAGS.other_emb_dim)], 62 # name='mktest_0_w2v_kw_emb') # not work63 self.mktest_kw_emb = keras.layers.Embedding(input_dim=FLAGS.keyword_dict_size + 1, output_dim=FLAGS.other_emb_dim,64 embeddings_initializer=keras.initializers.constant(ini_w2v_emb_weights.reshape((FLAGS.keyword_dict_size + 1, FLAGS.other_emb_dim))), 65 # trainable=trainable_,66 name='mktest_uniform_UnD_w2v_kw_emb')67 # 这部分统一,下面的好写一些68 self.mktest_user_kw_emb = self.mktest_kw_emb69 self.mktest_doc_kw_emb = self.mktest_kw_emb70 else: # doc/user分开kw_emb71 self.mktest_user_kw_emb = keras.layers.Embedding(input_dim=FLAGS.keyword_dict_size + 1, output_dim=FLAGS.other_emb_dim,72 embeddings_initializer=keras.initializers.constant(ini_user_w2v_emb_weights.reshape((FLAGS.keyword_dict_size + 1, FLAGS.other_emb_dim))), 73 # trainable=trainable_,74 name='mktest_uniform_User_w2v_kw_emb')75 self.mktest_doc_kw_emb = keras.layers.Embedding(input_dim=FLAGS.keyword_dict_size + 1, output_dim=FLAGS.other_emb_dim,76 embeddings_initializer=keras.initializers.constant(ini_doc_w2v_emb_weights.reshape((FLAGS.keyword_dict_size + 1, FLAGS.other_emb_dim))), 77 # trainable=trainable_,78 name='mktest_uniform_Doc_w2v_kw_emb')79 if FLAGS.use_kw_emb or FLAGS.history_attention:80 # TODO81 # self.kw_emb = SimpleEmbedding(FLAGS.keyword_dict_size, FLAGS.other_emb_dim, name='kw_emb')82 self.kw_emb = self.mktest_kw_emb83 84 if FLAGS.use_topic_emb or FLAGS.history_attention:85 # 1000086 self.topic_emb = SimpleEmbedding(FLAGS.topic_dict_size, FLAGS.other_emb_dim, name='topic_emb')87 # self.topic_emb = self.kw_emb88 if FLAGS.use_time_emb:89 self.time_emb = Embedding(500, FLAGS.other_emb_dim, name='time_emb')90 self.weekday_emb = Embedding(10, FLAGS.other_emb_dim, name='weekday_emb')91 if FLAGS.use_timespan_emb:92 self.timespan_emb = Embedding(300, FLAGS.other_emb_dim, name='timespan_emb')93 if FLAGS.use_deep_position_emb:94 self.pos_emb = Embedding(FLAGS.num_positions, FLAGS.other_emb_dim, name='pos_emb')95 if FLAGS.use_product_emb:96 self.product_emb = Embedding(10, FLAGS.other_emb_dim, name='product_emb')97 if FLAGS.use_cold_emb:98 self.cold_emb = Embedding(10, FLAGS.other_emb_dim, name='cold_emb')99 100 if FLAGS.use_title_emb:101 self.title_emb = SimpleEmbedding(100000, FLAGS.other_emb_dim, name='title_emb') if not FLAGS.title_share_kw_emb else self.kw_emb102 if FLAGS.title_encoder in ['gru', 'lstm']:103 return_sequences = FLAGS.title_pooling is not None104 # self.title_encoder = tf.keras.layers.GRU(FLAGS.other_emb_dim, return_sequences=return_sequences, 105 # dropout=FLAGS.title_drop, recurrent_dropout=FLAGS.title_drop_rec)106 self.title_encoder = mt.layers.CudnnRnn(num_layers=1, 107 num_units=int(FLAGS.hidden_size / 2), 108 keep_prob=1.,109 share_dropout=False,110 recurrent_dropout=False,111 concat_layers=True,112 bw_dropout=False,113 residual_connect=False,114 train_init_state=False,115 cell=FLAGS.title_encoder)116 else:117 self.title_encoder = lambda x, y: x118 if FLAGS.title_pooling:119 self.title_pooling = mt.layers.Pooling(FLAGS.title_pooling)120 121 if FLAGS.use_refresh_emb:122 self.refresh_coldstart_emb = Embedding(1001, FLAGS.other_emb_dim, name='refresh_coldstart_emb')123 self.refresh_today_emb = Embedding(1001, FLAGS.other_emb_dim, name='refresh_today_emb')124 # mkyuwen125 if FLAGS.use_distribution_emb: # 0430126 self.distribution_emb = Embedding(1000, FLAGS.other_emb_dim, name='distribution_id_emb')127 if FLAGS.use_network_emb:128 self.network_emb = Embedding(10, FLAGS.other_emb_dim, name='network_emb')129 if FLAGS.use_activity_emb:130 self.activity_emb = Embedding(10, FLAGS.other_emb_dim, name='activity_emb') 131 if FLAGS.use_type_emb:132 self.type_emb = Embedding(10, FLAGS.other_emb_dim, name='type_emb') 133 self.pooling = mt.layers.Pooling(FLAGS.pooling)134 self.sum_pooling = mt.layers.Pooling('sum')135 if FLAGS.use_history_emb:136 HistoryEncoder = getattr(history, 'History' + FLAGS.history_strategy)137 self.history_encoder = HistoryEncoder(self.doc_emb, self.topic_emb, self.kw_emb)138 def call(self, input):139 add, adds = self.add, self.adds140 self.clear()141 142 if FLAGS.use_user_emb:143 with mt.device(FLAGS.emb_device):144 x_user = self.user_emb(input['uid'])145 self.x_user = x_user146 add(x_user, 'uid') 147 if FLAGS.use_doc_emb:148 with mt.device(FLAGS.emb_device):149 x_doc = self.doc_emb(input['did'])150 x_doc_ = x_doc151 self.x_doc = x_doc152 add(x_doc, 'did')153 # mkyuwen 0624 mv down[1]154 if FLAGS.use_kw_emb or FLAGS.history_attention:155 doc_kw = input["doc_keyword"] 156 doc_kw_emb = self.pooling(self.kw_emb(doc_kw), mt.length(doc_kw))157 self.doc_kw_emb = doc_kw_emb158 159 if FLAGS.use_topic_emb or FLAGS.history_attention:160 doc_topic_emb = self.topic_emb(input['doc_topic'])161 self.doc_topic_emb = doc_topic_emb162 if FLAGS.use_topic_emb:163 add(doc_topic_emb, 'doc_topic')164 # mkyuwen 0624 mv down[1]165 if FLAGS.use_kw_emb:166 add(doc_kw_emb, 'doc_kw')167 # mkyuwen kw168 # ---------- mkyuwen 0504169 if FLAGS.use_merge_kw_emb:170 mktest_tw_history_kw = input['mktest_tw_history_kw_feed']171 mktest_vd_history_kw = input['mktest_vd_history_kw_feed'] # 0521172 mktest_rel_vd_history_kw = input['mktest_rel_vd_history_kw_feed'] 173 mktest_doc_kw = input['mktest_doc_kw_feed']174 mktest_doc_kw_secondary = input['mktest_doc_kw_secondary_feed']175 mktest_tw_long_term_kw = input['mktest_tw_long_term_kw_feed']176 mktest_vd_long_term_kw = input['mktest_vd_long_term_kw_feed'] 177 mktest_new_search_kw = input['mktest_new_search_kw_feed']178 mktest_long_search_kw = input['mktest_long_search_kw_feed']179 mktest_user_kw = input['mktest_user_kw_feed'] 180 if FLAGS.use_w2v_kw_emb: # mkyuwen w2v 参考x = (x + 1) * mask, 做emb的时候已经word_index = hash mod 100w +1181 # ------------ case1:只有非0的index+1,0保持不变。与mt.length逻辑保持一致182 mktest_tw_history_kw_mask0 = tf.cast(mktest_tw_history_kw > 0, tf.int64) # >0的非padding部分,为1183 mktest_tw_history_kw = (mktest_tw_history_kw + mktest_tw_history_kw_mask0) # use_w2v_kw_emb, index+1 手动184 mktest_vd_history_kw = (mktest_vd_history_kw + tf.cast(mktest_vd_history_kw > 0, tf.int64)) # 185 mktest_rel_vd_history_kw = (mktest_rel_vd_history_kw + tf.cast(mktest_rel_vd_history_kw > 0, tf.int64)) # 186 mktest_doc_kw = (mktest_doc_kw + tf.cast(mktest_doc_kw > 0, tf.int64)) # 187 mktest_doc_kw_secondary = (mktest_doc_kw_secondary + tf.cast(mktest_doc_kw_secondary > 0, tf.int64)) # 188 mktest_tw_long_term_kw = (mktest_tw_long_term_kw + tf.cast(mktest_tw_long_term_kw > 0, tf.int64)) # 189 mktest_vd_long_term_kw = (mktest_vd_long_term_kw + tf.cast(mktest_vd_long_term_kw > 0, tf.int64)) # 190 mktest_new_search_kw = (mktest_new_search_kw + tf.cast(mktest_new_search_kw > 0, tf.int64)) # 191 mktest_long_search_kw = (mktest_long_search_kw + tf.cast(mktest_long_search_kw > 0, tf.int64)) # 192 mktest_user_kw = (mktest_user_kw + tf.cast(mktest_user_kw > 0, tf.int64)) # 193 194 # mkyuwen 0612 195 # 分开user/doc, 两者内容是否公用一个在上面判断,对这里透明196 # self.sum_pooling = mt.layers.Pooling('sum')197 self.mktest_kw_pooling = mt.layers.Pooling(FLAGS.merge_kw_emb_pooling)198 if FLAGS.use_tw_history_kw_merge_emb: # user199 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_tw_history_kw), mt.length(mktest_tw_history_kw)), 'mktest_tw_history_kw')200 if FLAGS.use_vd_history_kw_merge_emb: # user201 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_vd_history_kw), mt.length(mktest_vd_history_kw)), 'mktest_vd_history_kw')202 if FLAGS.use_rel_vd_history_kw_merge_emb: # user203 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_rel_vd_history_kw), mt.length(mktest_rel_vd_history_kw)), 'mktest_vd_history_kw')204 if FLAGS.use_doc_kw_merge_emb: # doc205 if FLAGS.use_kw_merge_score:206 mktest_doc_kw_ = self.mktest_doc_kw_emb(mktest_doc_kw)207 mktest_doc_kw_score = input['mktest_doc_kw_score_feed'] 208 mktest_doc_kw_score = K.expand_dims(mktest_doc_kw_score, -1)209 # print ("mktest check mktest_doc_kw_",mktest_doc_kw_.get_shape())210 # print ("mktest check mktest_doc_kw_score",mktest_doc_kw_score.get_shape())211 add(self.mktest_kw_pooling(mktest_doc_kw_ * mktest_doc_kw_score, mt.length(mktest_doc_kw)), 'mktest_doc_kw')212 else:213 add(self.mktest_kw_pooling(self.mktest_doc_kw_emb(mktest_doc_kw), mt.length(mktest_doc_kw)), 'mktest_doc_kw')214 if FLAGS.use_doc_kw_secondary_merge_emb: # doc215 if FLAGS.use_kw_secondary_merge_score:216 mktest_doc_kw_secondary_ = self.mktest_doc_kw_emb(mktest_doc_kw_secondary)217 mktest_doc_kw_secondary_score = input['mktest_doc_kw_secondary_score_feed'] 218 mktest_doc_kw_secondary_score = K.expand_dims(mktest_doc_kw_secondary_score, -1)219 add(self.mktest_kw_pooling(mktest_doc_kw_secondary_ * mktest_doc_kw_secondary_score, mt.length(mktest_doc_kw_secondary)), 'mktest_doc_kw_secondary')220 else:221 add(self.mktest_kw_pooling(self.mktest_doc_kw_emb(mktest_doc_kw_secondary), mt.length(mktest_doc_kw_secondary)), 'mktest_doc_kw_secondary')222 if FLAGS.use_tw_long_term_kw_merge_emb: # user223 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_tw_long_term_kw), mt.length(mktest_tw_long_term_kw)), 'mktest_tw_long_term_kw')224 if FLAGS.use_vd_long_term_kw_merge_emb: # user225 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_vd_long_term_kw), mt.length(mktest_vd_long_term_kw)), 'mktest_vd_long_term_kw')226 if FLAGS.use_new_search_kw_merge_emb: # user227 if FLAGS.use_new_search_kw_merge_score:228 mktest_new_search_kw_ = self.mktest_user_kw_emb(mktest_new_search_kw)229 mktest_new_search_kw_score = input['mktest_new_search_kw_score_feed'] 230 mktest_new_search_kw_score = K.expand_dims(mktest_new_search_kw_score, -1)231 add(self.mktest_kw_pooling(mktest_new_search_kw_ * mktest_new_search_kw_score, mt.length(mktest_new_search_kw)), 'mktest_new_search_kw')232 else:233 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_new_search_kw), mt.length(mktest_new_search_kw)), 'mktest_new_search_kw')234 if FLAGS.use_long_search_kw_merge_emb: # user235 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_long_search_kw), mt.length(mktest_long_search_kw)), 'mktest_long_search_kw')236 if FLAGS.use_user_kw_merge_emb: # user237 if FLAGS.use_user_kw_merge_score:238 mktest_user_kw_ = self.mktest_user_kw_emb(mktest_user_kw)239 mktest_user_kw_score = input['mktest_user_kw_score_feed'] 240 mktest_user_kw_score = K.expand_dims(mktest_user_kw_score, -1)241 add(self.mktest_kw_pooling(mktest_user_kw_ * mktest_user_kw_score, mt.length(mktest_user_kw)), 'mktest_user_kw')242 else:243 add(self.mktest_kw_pooling(self.mktest_user_kw_emb(mktest_user_kw), mt.length(mktest_user_kw)), 'mktest_user_kw')244 245 # other_embs += [(other_embs[-1] + other_embs[-2] + other_embs[-3]) / 3.]246 # ----------# ----------# ----------# ----------# ----------# ----------247 # # mkyuwen 0624 mv [1] here248 # if FLAGS.use_topic_emb or FLAGS.history_attention:249 # doc_kw = input["doc_keyword"] % FLAGS.keyword_dict_size250 # # doc_kw_emb = self.pooling(self.kw_emb(doc_kw), mt.length(doc_kw))251 # # -------- mkyuwen 0624252 # if FLAGS.use_total_samekw_lbwnmktest and FLAGS.use_merge_kw_emb: # 必须同时满足253 # doc_kw_emb = self.mktest_kw_pooling(self.mktest_doc_kw_emb(doc_kw), mt.length(doc_kw))254 # else:255 # doc_kw_emb = self.pooling(self.kw_emb(doc_kw), mt.length(doc_kw))256 257 # # mkyuwen 0624 mv [1] here258 # if FLAGS.use_kw_emb:259 # other_embs += [doc_kw_emb]260 # --------# --------# --------# --------# --------# --------261 if FLAGS.use_history_emb:262 self.history_encoder(input, x_doc, doc_topic_emb, doc_kw_emb)263 self.merge(self.history_encoder.feats)264 if FLAGS.use_time_emb:265 if FLAGS.use_time_so:266 time_module = tf.load_op_library('./ops/time.so')267 get_time_intervals = time_module.time268 else:269 def get_time_intervals(x):270 res = tf.numpy_function(util.get_time_intervals, [x], x.dtype)271 res.set_shape(x.get_shape())272 return res273 time_interval = input['time_interval']274 if FLAGS.time_smoothing:275 x_time = self.time_emb(time_interval)276 num_bins = FLAGS.time_bins_per_hour * 24277 tmask = tf.cast(time_interval > 1, x_time.dtype)278 tbase = time_interval * (1 - tmask)279 time_pre = (time_interval - 2 -1 * FLAGS.time_bins_per_hour) % num_bins + 2 280 time_pre = tbase + time_pre * tmask281 time_pre2 = (time_interval - 2 -2 * FLAGS.time_bins_per_hour) % num_bins + 2282 time_pre2 = tbase + time_pre2 * tmask283 time_after = (time_interval - 2 + 1 * FLAGS.time_bins_per_hour) % num_bins + 2284 time_after = tbase + time_after * tmask285 time_after2 = (time_interval - 2 + 2 * FLAGS.time_bins_per_hour) % num_bins + 2286 time_after2 = tbase + time_after2 * tmask287 x_time_pre = self.time_emb(time_pre)288 x_time_pre2 = self.time_emb(time_pre2)289 x_time_after = self.time_emb(time_after)290 x_time_after2 = self.time_emb(time_after2)291 x_time = (0.4 * x_time + 0.2 * x_time_pre + 0.1 * x_time_pre2 + 0.2 * x_time_after + 0.1 * x_time_after2) / 5.292 # print('x_time2', x_time)293 elif FLAGS.time_bins_per_day:294 num_bins = FLAGS.time_bins_per_hour * 24295 num_large_bins = FLAGS.time_bins_per_day296 intervals_per_large_bin = tf.cast(num_bins / num_large_bins, time_interval.dtype)297 tmask = tf.cast(time_interval > 1, time_interval.dtype)298 tbase = time_interval * (1 - tmask)299 time_interval_large = tf.cast(((time_interval - 2 - FLAGS.time_bin_shift_hours * FLAGS.time_bins_per_hour) % num_bins)/ intervals_per_large_bin, time_interval.dtype) + 2300 time_interval_large = tbase + time_interval_large * tmask301 x_time = self.time_emb(time_interval_large)302 else:303 x_time = self.time_emb(time_interval)304 time_weekday = input['time_weekday'] 305 x_weekday = self.weekday_emb(time_weekday)306 307 adds([308 [x_time, 'time'],309 [x_weekday, 'weekday']310 ])311 # if FLAGS.use_dense_feats:312 # # TODO remove dense feats of time as 23 and 00... 313 # s_time = tf.cast(time_interval, tf.float32) / (24 * FLAGS.time_bins_per_hour + 10.)314 # s_time = tf.zeros_like(time_interval)315 # s_time = mt.youtube_scalar_features(s_time)316 # s_weekday = tf.cast(time_weekday, tf.float32) / 10.317 # s_weekday = mt.youtube_scalar_features(s_weekday)318 # dense_feats = [s_time, s_weekday]319 if FLAGS.use_timespan_emb:320 if FLAGS.use_time_so:321 get_timespan_intervals = time_module.timespan322 else:323 def get_timespan_intervals(x, y): 324 res = tf.numpy_function(util.get_timespan_intervals, [x, y], x.dtype)325 res.set_shape(x.get_shape())326 return res327 timespan_interval = input['timespan_interval']328 x_timespan = self.timespan_emb(timespan_interval)329 add(x_timespan, 'timespan')330 # if FLAGS.use_dense_feats:331 # s_timespan = tf.cast(timespan_interval, tf.float32) / 200. 332 # s_timespan = mt.youtube_scalar_features(s_timespan)333 # # dense_feats += [s_timespan]334 # s_timespan2 = input['impression_time'] - input['article_page_time']335 # max_delta = 3000000336 # s_timespan2 = tf.math.minimum(s_timespan2, max_delta)337 # # s_timespan2 = tf.math.maximum(s_timespan2, -10)338 # s_timespan2 = tf.math.maximum(s_timespan2, 0)339 # s_timespan2 = tf.cast(s_timespan2, tf.float32) / float(max_delta)340 # s_timespan2 = mt.youtube_scalar_features(s_timespan2)341 # dense_feats += [s_timespan2]342 if FLAGS.use_product_emb:343 x_product = self.product_emb(util.get_product_id(input['product']))344 add(x_product, 'product')345 if FLAGS.use_cold_emb:346 cold = input['cold'] if not FLAGS.is_infer else tf.cast(util.is_cb_user(input['rea']), input['index'].dtype)347 x_cold = self.cold_emb(cold) 348 add(x_cold, 'cold')349 if FLAGS.use_title_emb:350 x_title = self.title_emb(input['title'])351 # x_title = self.title_encoder(x_title, mask=tf.sequence_mask(mt.length(input['title'])))352 x_title = self.title_encoder(x_title, mt.length(input['title']))353 x_title = self.title_pooling(x_title, mt.length(input['title']))354 add(x_title, 'title')355 if FLAGS.use_refresh_emb:356 x_refresh1 = self.refresh_coldstart_emb(tf.math.minimum(input['coldstart_refresh_num'], 1000))357 x_refresh2 = self.refresh_today_emb(tf.math.minimum(input['today_refresh_num'], 1000))358 adds([359 [x_refresh1, 'coldstart_refresh'],360 [x_refresh2, 'today_refresh']361 ])362 # mkyuwen 0520(本身input自带feed)363 if FLAGS.use_distribution_emb: # 0430364 input['mktest_distribution_id_feed'] = input['mktest_distribution_id_feed'] % 1000365 x_disid = self.distribution_emb(input['mktest_distribution_id_feed'])366 add(x_disid, 'distribution')367 if FLAGS.use_network_emb:368 x_network = self.network_emb(input['network'])369 add(x_network, 'network')370 if FLAGS.use_activity_emb:371 x_activity = self.activity_emb(input['user_active'] + 1)372 add(x_activity, 'user_active')373 if FLAGS.use_type_emb:374 x_type = self.type_emb(input['type'])375 add(x_type, 'type')376 other_embs = self.embs377 other_embs = [x if len(x.get_shape()) == 2 else tf.squeeze(x, 1) for x in other_embs]378 return other_embs379 def init_predict(self, input, dummy):380 # if FLAGS.use_user_emb:381 input['uid'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, 1], 'uid_feed')382 tf.compat.v1.add_to_collection('uid_feed', input['uid'])383 input['uid'] += dummy384 # if FLAGS.use_doc_emb:385 input['did'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, 1], 'did_feed')386 tf.compat.v1.add_to_collection('did_feed', input['did'])387 input['did'] += dummy388 if FLAGS.use_title_emb:389 title_feed = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'title_feed')390 tf.compat.v1.add_to_collection('title_feed', title_feed)391 input['title'] = title_feed392 input['history'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'doc_idx_feed')393 tf.compat.v1.add_to_collection('doc_idx_feed', input['history'])394 input['history'] += dummy395 input['keyword'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'kw_idx_feed')396 tf.compat.v1.add_to_collection('kw_idx_feed', input['keyword'])397 input['keyword'] += dummy398 input['topic'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'topic_idx_feed')399 tf.compat.v1.add_to_collection('topic_idx_feed', input['topic'])400 input['topic'] += dummy401 input['doc_keyword'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None],'doc_kw_idx_feed')402 tf.compat.v1.add_to_collection('doc_kw_idx_feed', input['doc_keyword'])403 input['doc_keyword'] += dummy404 input['doc_topic'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, 1], 'doc_topic_idx_feed')405 tf.compat.v1.add_to_collection('doc_topic_idx_feed', input['doc_topic'])406 input['doc_topic'] += dummy407 input['impression_time'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, 1], 'time_feed')408 tf.compat.v1.add_to_collection('time_feed', input['impression_time'])409 input['impression_time'] += dummy410 input['impression_time'] = tf.squeeze(input['impression_time'], 1)411 input['article_page_time'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, 1], 'ptime_feed')412 tf.compat.v1.add_to_collection('ptime_feed', input['article_page_time'])413 input['article_page_time'] += dummy414 input['article_page_time'] = tf.squeeze(input['article_page_time'], 1)415 # mkyuwen416 input['mktest_tw_history_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_tw_history_kw_feed')417 tf.compat.v1.add_to_collection('mktest_tw_history_kw_feed', input['mktest_tw_history_kw_feed'])418 input['mktest_tw_history_kw_feed'] += dummy419 input['mktest_vd_history_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_vd_history_kw_feed')420 tf.compat.v1.add_to_collection('mktest_vd_history_kw_feed', input['mktest_vd_history_kw_feed'])421 input['mktest_vd_history_kw_feed'] += dummy422 input['mktest_rel_vd_history_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_rel_vd_history_kw_feed')423 tf.compat.v1.add_to_collection('mktest_rel_vd_history_kw_feed', input['mktest_rel_vd_history_kw_feed'])424 input['mktest_rel_vd_history_kw_feed'] += dummy425 input['mktest_doc_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_doc_kw_feed')426 tf.compat.v1.add_to_collection('mktest_doc_kw_feed', input['mktest_doc_kw_feed'])427 input['mktest_doc_kw_feed'] += dummy428 input['mktest_doc_kw_secondary_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_doc_kw_secondary_feed')429 tf.compat.v1.add_to_collection('mktest_doc_kw_secondary_feed', input['mktest_doc_kw_secondary_feed'])430 input['mktest_doc_kw_secondary_feed'] += dummy431 input['mktest_tw_long_term_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_tw_long_term_kw_feed')432 tf.compat.v1.add_to_collection('mktest_tw_long_term_kw_feed', input['mktest_tw_long_term_kw_feed'])433 input['mktest_tw_long_term_kw_feed'] += dummy434 input['mktest_vd_long_term_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_vd_long_term_kw_feed')435 tf.compat.v1.add_to_collection('mktest_vd_long_term_kw_feed', input['mktest_vd_long_term_kw_feed'])436 input['mktest_vd_long_term_kw_feed'] += dummy437 input['mktest_long_search_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_long_search_kw_feed')438 tf.compat.v1.add_to_collection('mktest_long_search_kw_feed', input['mktest_long_search_kw_feed'])439 input['mktest_long_search_kw_feed'] += dummy440 input['mktest_new_search_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_new_search_kw_feed')441 tf.compat.v1.add_to_collection('mktest_new_search_kw_feed', input['mktest_new_search_kw_feed'])442 input['mktest_new_search_kw_feed'] += dummy443 input['mktest_user_kw_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, None], 'mktest_user_kw_feed')444 tf.compat.v1.add_to_collection('mktest_user_kw_feed', input['mktest_user_kw_feed'])445 input['mktest_user_kw_feed'] += dummy446 # mkyuwen add here447 input['mktest_distribution_id_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.int64), [None, 1], 'mktest_distribution_id_feed')448 tf.compat.v1.add_to_collection('mktest_distribution_id_feed', input['mktest_distribution_id_feed'])449 input['mktest_distribution_id_feed'] += dummy450 input['mktest_distribution_id_feed'] = tf.squeeze(input['mktest_distribution_id_feed'], 1)451 # mkyuwen 0525452 dummy_float = tf.cast(dummy, tf.float32)453 input['mktest_new_search_kw_score_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.float32), [None, None], 'mktest_new_search_kw_score_feed')454 tf.compat.v1.add_to_collection('mktest_new_search_kw_score_feed', input['mktest_new_search_kw_score_feed'])455 input['mktest_new_search_kw_score_feed'] += dummy_float456 input['mktest_user_kw_score_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.float32), [None, None], 'mktest_user_kw_score_feed')457 tf.compat.v1.add_to_collection('mktest_user_kw_score_feed', input['mktest_user_kw_score_feed'])458 input['mktest_user_kw_score_feed'] += dummy_float459 input['mktest_doc_kw_score_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.float32), [None, None], 'mktest_doc_kw_score_feed')460 tf.compat.v1.add_to_collection('mktest_doc_kw_score_feed', input['mktest_doc_kw_score_feed'])461 input['mktest_doc_kw_score_feed'] += dummy_float462 input['mktest_doc_kw_secondary_score_feed'] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=tf.float32), [None, None], 'mktest_doc_kw_secondary_score_feed')463 tf.compat.v1.add_to_collection('mktest_doc_kw_secondary_score_feed', input['mktest_doc_kw_secondary_score_feed'])464 input['mktest_doc_kw_secondary_score_feed'] += dummy_float465 def _add(name, dtype=tf.int64, dims=None):466 if dims is None:467 dims = [None, 1]468 input[name] = tf.compat.v1.placeholder_with_default(tf.constant([[0]], dtype=dtype), dims, '%s_feed' % name)469 tf.compat.v1.add_to_collection('%s_feed' % name, input[name])470 input[name] += dummy471 if dims[-1] == 1:472 input[name] = tf.squeeze(input[name], 1)473 474 def _adds(names, dtype=tf.int64, dims=None):475 for name in names:476 _add(name, dtype, dims)477 names_for_history = ['tw_history', 'tw_history_topic', 'tw_history_rec', 'tw_history_kw', 'vd_history', 'vd_history_topic']...

Full Screen

Full Screen

test_etxrd.py

Source:test_etxrd.py Github

copy

Full Screen

1import unittest2from openid.yadis import services, etxrd, xri3import os.path4def datapath(filename):5 module_directory = os.path.dirname(os.path.abspath(__file__))6 return os.path.join(module_directory, 'data', 'test_etxrd', filename)7XRD_FILE = datapath('valid-populated-xrds.xml')8NOXRDS_FILE = datapath('not-xrds.xml')9NOXRD_FILE = datapath('no-xrd.xml')10# None of the namespaces or service URIs below are official (or even11# sanctioned by the owners of that piece of URL-space)12LID_2_0 = "http://lid.netmesh.org/sso/2.0b5"13TYPEKEY_1_0 = "http://typekey.com/services/1.0"14def simpleOpenIDTransformer(endpoint):15 """Function to extract information from an OpenID service element"""16 if 'http://openid.net/signon/1.0' not in endpoint.type_uris:17 return None18 delegates = list(endpoint.service_element.findall(19 '{http://openid.net/xmlns/1.0}Delegate'))20 assert len(delegates) == 121 delegate = delegates[0].text22 return (endpoint.uri, delegate)23class TestServiceParser(unittest.TestCase):24 def setUp(self):25 self.xmldoc = file(XRD_FILE).read()26 self.yadis_url = 'http://unittest.url/'27 def _getServices(self, flt=None):28 return list(services.applyFilter(self.yadis_url, self.xmldoc, flt))29 def testParse(self):30 """Make sure that parsing succeeds at all"""31 services = self._getServices()32 def testParseOpenID(self):33 """Parse for OpenID services with a transformer function"""34 services = self._getServices(simpleOpenIDTransformer)35 expectedServices = [36 ("http://www.myopenid.com/server", "http://josh.myopenid.com/"),37 ("http://www.schtuff.com/openid", "http://users.schtuff.com/josh"),38 ("http://www.livejournal.com/openid/server.bml",39 "http://www.livejournal.com/users/nedthealpaca/"),40 ]41 it = iter(services)42 for (server_url, delegate) in expectedServices:43 for (actual_url, actual_delegate) in it:44 self.failUnlessEqual(server_url, actual_url)45 self.failUnlessEqual(delegate, actual_delegate)46 break47 else:48 self.fail('Not enough services found')49 def _checkServices(self, expectedServices):50 """Check to make sure that the expected services are found in51 that order in the parsed document."""52 it = iter(self._getServices())53 for (type_uri, uri) in expectedServices:54 for service in it:55 if type_uri in service.type_uris:56 self.failUnlessEqual(service.uri, uri)57 break58 else:59 self.fail('Did not find %r service' % (type_uri,))60 def testGetSeveral(self):61 """Get some services in order"""62 expectedServices = [63 # type, URL64 (TYPEKEY_1_0, None),65 (LID_2_0, "http://mylid.net/josh"),66 ]67 self._checkServices(expectedServices)68 def testGetSeveralForOne(self):69 """Getting services for one Service with several Type elements."""70 types = [ 'http://lid.netmesh.org/sso/2.0b5'71 , 'http://lid.netmesh.org/2.0b5'72 ]73 uri = "http://mylid.net/josh"74 for service in self._getServices():75 if service.uri == uri:76 found_types = service.matchTypes(types)77 if found_types == types:78 break79 else:80 self.fail('Did not find service with expected types and uris')81 def testNoXRDS(self):82 """Make sure that we get an exception when an XRDS element is83 not present"""84 self.xmldoc = file(NOXRDS_FILE).read()85 self.failUnlessRaises(86 etxrd.XRDSError,87 services.applyFilter, self.yadis_url, self.xmldoc, None)88 def testEmpty(self):89 """Make sure that we get an exception when an XRDS element is90 not present"""91 self.xmldoc = ''92 self.failUnlessRaises(93 etxrd.XRDSError,94 services.applyFilter, self.yadis_url, self.xmldoc, None)95 def testNoXRD(self):96 """Make sure that we get an exception when there is no XRD97 element present."""98 self.xmldoc = file(NOXRD_FILE).read()99 self.failUnlessRaises(100 etxrd.XRDSError,101 services.applyFilter, self.yadis_url, self.xmldoc, None)102class TestCanonicalID(unittest.TestCase):103 def mkTest(iname, filename, expectedID):104 """This function builds a method that runs the CanonicalID105 test for the given set of inputs"""106 filename = datapath(filename)107 def test(self):108 xrds = etxrd.parseXRDS(file(filename).read())109 self._getCanonicalID(iname, xrds, expectedID)110 return test111 test_delegated = mkTest(112 "@ootao*test1", "delegated-20060809.xrds",113 "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01")114 test_delegated_r1 = mkTest(115 "@ootao*test1", "delegated-20060809-r1.xrds",116 "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01")117 test_delegated_r2 = mkTest(118 "@ootao*test1", "delegated-20060809-r2.xrds",119 "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01")120 test_sometimesprefix = mkTest(121 "@ootao*test1", "sometimesprefix.xrds",122 "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01")123 test_prefixsometimes = mkTest(124 "@ootao*test1", "prefixsometimes.xrds",125 "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01")126 test_spoof1 = mkTest("=keturn*isDrummond", "spoof1.xrds", etxrd.XRDSFraud)127 test_spoof2 = mkTest("=keturn*isDrummond", "spoof2.xrds", etxrd.XRDSFraud)128 test_spoof3 = mkTest("@keturn*is*drummond", "spoof3.xrds", etxrd.XRDSFraud)129 test_status222 = mkTest("=x", "status222.xrds", None)130 test_multisegment_xri = mkTest('xri://=nishitani*masaki',131 'subsegments.xrds',132 '=!E117.EF2F.454B.C707!0000.0000.3B9A.CA01')133 test_iri_auth_not_allowed = mkTest(134 "phreak.example.com", "delegated-20060809-r2.xrds", etxrd.XRDSFraud)135 test_iri_auth_not_allowed.__doc__ = \136 "Don't let IRI authorities be canonical for the GCS."137 # TODO: Refs138 # test_ref = mkTest("@ootao*test.ref", "ref.xrds", "@!BAE.A650.823B.2475")139 # TODO: Add a IRI authority with an IRI canonicalID.140 # TODO: Add test cases with real examples of multiple CanonicalIDs141 # somewhere in the resolution chain.142 def _getCanonicalID(self, iname, xrds, expectedID):143 if isinstance(expectedID, (str, unicode, type(None))):144 cid = etxrd.getCanonicalID(iname, xrds)145 self.failUnlessEqual(cid, expectedID and xri.XRI(expectedID))146 elif issubclass(expectedID, etxrd.XRDSError):147 self.failUnlessRaises(expectedID, etxrd.getCanonicalID,148 iname, xrds)149 else:150 self.fail("Don't know how to test for expected value %r"151 % (expectedID,))152if __name__ == '__main__':...

Full Screen

Full Screen

test_xri.py

Source:test_xri.py Github

copy

Full Screen

1from unittest import TestCase2from openid.yadis import xri3class XriDiscoveryTestCase(TestCase):4 def test_isXRI(self):5 i = xri.identifierScheme6 self.failUnlessEqual(i('=john.smith'), 'XRI')7 self.failUnlessEqual(i('@smiths/john'), 'XRI')8 self.failUnlessEqual(i('smoker.myopenid.com'), 'URI')9 self.failUnlessEqual(i('xri://=john'), 'XRI')10 self.failUnlessEqual(i(''), 'URI')11class XriEscapingTestCase(TestCase):12 def test_escaping_percents(self):13 self.failUnlessEqual(xri.escapeForIRI('@example/abc%2Fd/ef'),14 '@example/abc%252Fd/ef')15 def test_escaping_xref(self):16 # no escapes17 esc = xri.escapeForIRI18 self.failUnlessEqual('@example/foo/(@bar)', esc('@example/foo/(@bar)'))19 # escape slashes20 self.failUnlessEqual('@example/foo/(@bar%2Fbaz)',21 esc('@example/foo/(@bar/baz)'))22 self.failUnlessEqual('@example/foo/(@bar%2Fbaz)/(+a%2Fb)',23 esc('@example/foo/(@bar/baz)/(+a/b)'))24 # escape query ? and fragment #25 self.failUnlessEqual('@example/foo/(@baz%3Fp=q%23r)?i=j#k',26 esc('@example/foo/(@baz?p=q#r)?i=j#k'))27class XriTransformationTestCase(TestCase):28 def test_to_iri_normal(self):29 self.failUnlessEqual(xri.toIRINormal('@example'), 'xri://@example')30 try:31 unichr(0x10000)32 except ValueError:33 # bleh narrow python build34 def test_iri_to_url(self):35 s = u'l\xa1m'36 expected = 'l%C2%A1m'37 self.failUnlessEqual(xri.iriToURI(s), expected)38 else:39 def test_iri_to_url(self):40 s = u'l\xa1m\U00101010n'41 expected = 'l%C2%A1m%F4%81%80%90n'42 self.failUnlessEqual(xri.iriToURI(s), expected)43class CanonicalIDTest(TestCase):44 def mkTest(providerID, canonicalID, isAuthoritative):45 def test(self):46 result = xri.providerIsAuthoritative(providerID, canonicalID)47 format = "%s providing %s, expected %s"48 message = format % (providerID, canonicalID, isAuthoritative)49 self.failUnlessEqual(isAuthoritative, result, message)50 return test51 test_equals = mkTest('=', '=!698.74D1.A1F2.86C7', True)52 test_atOne = mkTest('@!1234', '@!1234!ABCD', True)53 test_atTwo = mkTest('@!1234!5678', '@!1234!5678!ABCD', True)54 test_atEqualsFails = mkTest('@!1234', '=!1234!ABCD', False)55 test_tooDeepFails = mkTest('@!1234', '@!1234!ABCD!9765', False)56 test_atEqualsAndTooDeepFails = mkTest('@!1234!ABCD', '=!1234', False)57 test_differentBeginningFails = mkTest('=!BABE', '=!D00D', False)58class TestGetRootAuthority(TestCase):59 def mkTest(the_xri, expected_root):60 def test(self):61 actual_root = xri.rootAuthority(the_xri)62 self.failUnlessEqual(actual_root, xri.XRI(expected_root))63 return test64 test_at = mkTest("@foo", "@")65 test_atStar = mkTest("@foo*bar", "@")66 test_atStarStar = mkTest("@*foo*bar", "@")67 test_atWithPath = mkTest("@foo/bar", "@")68 test_bangBang = mkTest("!!990!991", "!")69 test_bang = mkTest("!1001!02", "!")70 test_equalsStar = mkTest("=foo*bar", "=")71 test_xrefPath = mkTest("(example.com)/foo", "(example.com)")72 test_xrefStar = mkTest("(example.com)*bar/foo", "(example.com)")73 test_uriAuth = mkTest("baz.example.com/foo", "baz.example.com")74 test_uriAuthPort = mkTest("baz.example.com:8080/foo",75 "baz.example.com:8080")76 # Looking at the ABNF in XRI Syntax 2.0, I don't think you can77 # have example.com*bar. You can do (example.com)*bar, but that78 # would mean something else.79 ##("example.com*bar/(=baz)", "example.com*bar"),80 ##("baz.example.com!01/foo", "baz.example.com!01"),81if __name__ == '__main__':82 import unittest...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run Nose automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful