How to use d_h method in fMBT

Best Python code snippet using fMBT_python

modules.py

Source:modules.py Github

copy

Full Screen

1# -*- coding: utf-8 -*-2# !/usr/bin/python34"""5# @Time : 2020/8/16# @Author : Yongrui Chen7# @File : modules.py8# @Software: PyCharm9"""10import sys11import torch12import torch.nn as nn13sys.path.append("..")14from src.models.nn_utils import encode_question, encode_header, build_mask151617class LSTM(nn.Module):1819 def __init__(self, d_input, d_h, n_layers=1, batch_first=True, birnn=True, dropout=0.3):20 super(LSTM, self).__init__()2122 n_dir = 2 if birnn else 123 self.init_h = nn.Parameter(torch.Tensor(n_layers * n_dir, d_h))24 self.init_c = nn.Parameter(torch.Tensor(n_layers * n_dir, d_h))2526 INI = 1e-227 torch.nn.init.uniform_(self.init_h, -INI, INI)28 torch.nn.init.uniform_(self.init_c, -INI, INI)2930 self.lstm = nn.LSTM(31 input_size=d_input,32 hidden_size=d_h,33 num_layers=n_layers,34 bidirectional=birnn,35 batch_first=not batch_first36 )37 self.dropout = nn.Dropout(dropout)3839 def forward(self, seqs, seq_lens=None, init_states=None):4041 bs = seqs.size(0)42 bf = self.lstm.batch_first4344 if not bf:45 seqs = seqs.transpose(0, 1)4647 seqs = self.dropout(seqs)484950 size = (self.init_h.size(0), bs, self.init_h.size(1))51 if init_states is None:52 init_states = (self.init_h.unsqueeze(1).expand(*size).contiguous(),53 self.init_c.unsqueeze(1).expand(*size).contiguous())5455 if seq_lens is not None:56 assert bs == len(seq_lens)57 sort_ind = sorted(range(len(seq_lens)), key=lambda i: seq_lens[i], reverse=True)58 seq_lens = [seq_lens[i] for i in sort_ind]59 seqs = self.reorder_sequence(seqs, sort_ind, bf)60 init_states = self.reorder_init_states(init_states, sort_ind)6162 packed_seq = nn.utils.rnn.pack_padded_sequence(seqs, seq_lens)63 packed_out, final_states = self.lstm(packed_seq, init_states)64 lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out)6566 back_map = {ind: i for i, ind in enumerate(sort_ind)}67 reorder_ind = [back_map[i] for i in range(len(seq_lens))]68 lstm_out = self.reorder_sequence(lstm_out, reorder_ind, bf)69 final_states = self.reorder_init_states(final_states, reorder_ind)70 else:71 lstm_out, final_states = self.lstm(seqs)72 return lstm_out.transpose(0, 1), final_states7374 def reorder_sequence(self, seqs, order, batch_first=False):75 """76 seqs: [T, B, D] if not batch_first77 order: list of sequence length78 """79 batch_dim = 0 if batch_first else 180 assert len(order) == seqs.size()[batch_dim]81 order = torch.LongTensor(order).to(seqs.device)82 sorted_seqs = seqs.index_select(index=order, dim=batch_dim)83 return sorted_seqs8485 def reorder_init_states(self, states, order):86 """87 lstm_states: (H, C) of tensor [layer, batch, hidden]88 order: list of sequence length89 """90 assert isinstance(states, tuple)91 assert len(states) == 292 assert states[0].size() == states[1].size()93 assert len(order) == states[0].size()[1]9495 order = torch.LongTensor(order).to(states[0].device)96 sorted_states = (states[0].index_select(index=order, dim=1),97 states[1].index_select(index=order, dim=1))98 return sorted_states99100class SelectNumber(nn.Module):101 def __init__(self, d_in, d_h, n_layers, dropout_prob, pooling_type, max_select_num):102 super(SelectNumber, self).__init__()103104 self.d_in = d_in105 self.d_h = d_h106 self.n_layers = n_layers107 self.dropout_prob = dropout_prob108 self.pooling_type = pooling_type109 self.max_select_num = max_select_num110111 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,112 dropout=dropout_prob, birnn=True)113 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,114 dropout=dropout_prob, birnn=True)115116 self.W_att_q = nn.Linear(d_h, 1)117 self.W_att_h = nn.Linear(d_h, 1)118 self.W_hidden = nn.Linear(d_h, d_h * n_layers)119 self.W_cell = nn.Linear(d_h, d_h * n_layers)120 self.W_out = nn.Sequential(121 nn.Linear(d_h, d_h),122 nn.Tanh(),123 nn.Linear(d_h, self.max_select_num + 1)124 )125126 self.softmax = nn.Softmax(dim=-1)127128 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums):129130 # [bs, max_h_num, d_h]131 h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)132133 bs = len(q_lens)134135 # self-attention for header136 # [bs, max_h_num]137 att_weights_h = self.W_att_h(h_pooling).squeeze(2)138 att_mask_h = build_mask(att_weights_h, q_lens, dim=-2)139 att_weights_h = self.softmax(att_weights_h.masked_fill(att_mask_h == 0, -float("inf")))140141 # [bs, d_h]142 h_context = torch.mul(h_pooling, att_weights_h.unsqueeze(2)).sum(1)143144 # [bs, d_h] -> [bs, 2 * d_h]145 # enlarge because there are two layers.146 hidden = self.W_hidden(h_context)147 hidden = hidden.view(bs, self.n_layers * 2, int(self.d_h / 2))148 hidden = hidden.transpose(0, 1).contiguous()149150 cell = self.W_cell(h_context)151 cell = cell.view(bs, self.n_layers * 2, int(self.d_h / 2))152 cell = cell.transpose(0, 1).contiguous()153154 # [bs, max_q_len, d_h]155 q_enc = encode_question(self.q_encoder, q_emb, q_lens, init_states=(hidden, cell))156157 # self-attention for question158 # [bs, max_q_len]159 att_weights_q = self.W_att_q(q_enc).squeeze(2)160 att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)161 att_weights_q = self.softmax(att_weights_q.masked_fill(att_mask_q == 0, -float("inf")))162163 q_context = torch.mul(q_enc, att_weights_q.unsqueeze(2).expand_as(q_enc)).sum(dim=1)164165 # [bs, max_select_num + 1]166 score_sel_num = self.W_out(q_context)167 return score_sel_num168169170class SelectColumn(nn.Module):171 def __init__(self, d_in, d_h, n_layers, dropout_prob, pooling_type):172 super(SelectColumn, self).__init__()173174 self.d_in = d_in175 self.d_h = d_h176 self.n_layers = n_layers177 self.dropout_prob = dropout_prob178 self.pooling_type = pooling_type179180 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,181 dropout=dropout_prob, birnn=True)182 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,183 dropout=dropout_prob, birnn=True)184185 self.W_att = nn.Linear(d_h, d_h)186 self.W_q = nn.Linear(d_h, d_h)187 self.W_h = nn.Linear(d_h, d_h)188 self.W_out = nn.Sequential(nn.Tanh(), nn.Linear(2 * d_h, 1))189190 self.softmax = nn.Softmax(dim=-1)191192 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums):193 # [bs, max_q_len, d_h]194 q_enc = encode_question(self.q_encoder, q_emb, q_lens)195196 # [bs, max_h_num, d_h]197 h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)198199 # [bs, max_h_num, max_q_len]200 att_weights = torch.bmm(h_pooling, self.W_att(q_enc).transpose(1, 2))201 att_mask = build_mask(att_weights, q_lens, dim=-1)202 att_weights = self.softmax(att_weights.masked_fill(att_mask==0, -float("inf")))203204 # att_weights: -> [bs, max_h_num, max_q_len, 1]205 # q_enc: -> [bs, 1, max_q_len, d_h]206 # [bs, max_h_num, d_h]207 q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)208209 # [bs, max_h_num, d_h * 2]210 comb_context = torch.cat([self.W_q(q_context), self.W_h(h_pooling)], dim=-1)211212 # [bs, max_h_num]213 score_sel_col = self.W_out(comb_context).squeeze(2)214215 # mask216 for b, h_num in enumerate(h_nums):217 score_sel_col[b, h_num:] = -float("inf")218 return score_sel_col219220221class SelectAggregation(nn.Module):222 def __init__(self, d_in, d_h, n_layers, dropout_prob, n_agg, pooling_type, max_sel_num):223 super(SelectAggregation, self).__init__()224225 self.d_in = d_in226 self.d_h = d_h227 self.n_layers = n_layers228 self.dropout_prob = dropout_prob229 self.n_agg = n_agg230 self.pooling_type = pooling_type231 self.max_sel_num = max_sel_num232233 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,234 dropout=dropout_prob, birnn=True)235 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,236 dropout=dropout_prob, birnn=True)237238 self.W_att = nn.Linear(d_h, d_h)239 self.W_out = nn.Sequential(240 nn.Linear(d_h, d_h),241 nn.Tanh(),242 nn.Linear(d_h, self.n_agg)243 )244245 self.softmax = nn.Softmax(dim=-1)246247 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, sel_col):248 # [bs, max_q_len, d_h]249 q_enc = encode_question(self.q_encoder, q_emb, q_lens)250251 # [bs, max_h_num, d_h]252 h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)253254 bs = len(q_emb)255 h_pooling_sel = h_pooling[list(range(bs)), sel_col]256257 att_weights = torch.bmm(self.W_att(q_enc), h_pooling_sel.unsqueeze(2)).squeeze(2)258 att_mask = build_mask(att_weights, q_lens, dim=-2)259 att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))260261 # att_weights: [bs, max_sel_num, max_q_len] -> [bs, max_sel_num, max_q_len, 1]262 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]263 q_context = torch.mul(q_enc, att_weights.unsqueeze(2).expand_as(q_enc)).sum(dim=1)264265 # [bs, max_sel_num, n_agg]266 score_sel_agg = self.W_out(q_context)267 return score_sel_agg268269270class SelectMultipleAggregation(nn.Module):271 def __init__(self, d_in, d_h, n_layers, dropout_prob, n_agg, pooling_type, max_sel_num):272 super(SelectMultipleAggregation, self).__init__()273274 self.d_in = d_in275 self.d_h = d_h276 self.n_layers = n_layers277 self.dropout_prob = dropout_prob278 self.n_agg = n_agg279 self.pooling_type = pooling_type280 self.max_sel_num = max_sel_num281282 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,283 dropout=dropout_prob, birnn=True)284 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,285 dropout=dropout_prob, birnn=True)286287 self.W_att = nn.Linear(d_h, d_h)288 self.W_q = nn.Linear(d_h, d_h)289 self.W_h = nn.Linear(d_h, d_h)290 self.W_out = nn.Sequential(291 nn.Linear(2 * d_h, d_h),292 nn.Tanh(),293 nn.Linear(d_h, self.n_agg)294 )295296 self.softmax = nn.Softmax(dim=-1)297298 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, sel_cols):299 # [bs, max_q_len, d_h]300 q_enc = encode_question(self.q_encoder, q_emb, q_lens)301302 # [bs, max_h_num, d_h]303 h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)304305 padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)306307 h_pooling_sel = []308 for b, cols in enumerate(sel_cols):309 if len(cols) > 0:310 h_tmp = [h_pooling[b][cols, :]]311 else:312 h_tmp = []313 h_tmp += [padding_t] * (self.max_sel_num - len(cols))314 h_tmp = torch.cat(h_tmp, dim=0)315 h_pooling_sel.append(h_tmp)316 # [bs, max_sel_num, d_h]317 h_pooling_sel = torch.stack(h_pooling_sel)318319 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]320 # h_pooling_sel: [bs, max_sel_num, d_h] -> [bs, max_sel_num, d_h, 1]321 # [bs, max_sel_num, max_q_len]322 att_weights = torch.matmul(323 self.W_att(q_enc).unsqueeze(1),324 h_pooling_sel.unsqueeze(3)325 ).squeeze(3)326 att_mask = build_mask(att_weights, q_lens, dim=-1)327 att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))328329 # att_weights: [bs, max_sel_num, max_q_len] -> [bs, max_sel_num, max_q_len, 1]330 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]331 q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)332333 # [bs, max_sel_num, n_agg]334 score_sel_agg = self.W_out(torch.cat([self.W_q(q_context), self.W_h(h_pooling_sel)], dim=2))335 return score_sel_agg336337338class WhereJoiner(nn.Module):339 def __init__(self, d_in, d_h, n_layers, dropout_prob):340 super(WhereJoiner, self).__init__()341 self.d_in = d_in342 self.d_h = d_h343 self.n_layers = n_layers344 self.dropout_prob = dropout_prob345346 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,347 dropout=dropout_prob, birnn=True)348 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,349 dropout=dropout_prob, birnn=True)350351 self.W_att = nn.Linear(d_h, 1)352 self.W_out = nn.Sequential(nn.Linear(d_h, d_h),353 nn.Tanh(),354 nn.Linear(d_h, 3))355356 self.softmax = nn.Softmax(dim=1)357358 def forward(self, q_emb, q_lens):359360 q_enc = encode_question(self.q_encoder, q_emb, q_lens)361362 # self-atttention for question363 # [bs, max_q_len]364 att_weights_q = self.W_att(q_enc).squeeze(2)365 att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)366 att_weights_q = att_weights_q.masked_fill(att_mask_q == 0, -float('inf'))367 att_weights_q = self.softmax(att_weights_q)368369 # [bs, d_h]370 q_context = torch.mul(371 q_enc,372 att_weights_q.unsqueeze(2).expand_as(q_enc)373 ).sum(dim=1)374 where_op_logits = self.W_out(q_context)375 return where_op_logits376377378class WhereNumber(nn.Module):379 def __init__(self, d_in, d_in_ch, d_h, n_layers, dropout_prob, pooling_type, max_where_num):380 super(WhereNumber, self).__init__()381382 self.d_in = d_in383 self.d_h = d_h384 self.n_layers = n_layers385 self.dropout_prob = dropout_prob386 self.pooling_type = pooling_type387 self.max_where_num = max_where_num388389 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,390 dropout=dropout_prob, birnn=True)391 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,392 dropout=dropout_prob, birnn=True)393394 self.q_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,395 dropout=dropout_prob, birnn=True)396 self.h_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,397 dropout=dropout_prob, birnn=True)398399 self.W_att_q = nn.Linear(d_h, 1)400 self.W_att_h = nn.Linear(d_h, 1)401 self.W_hidden = nn.Linear(d_h, d_h * n_layers)402 self.W_cell = nn.Linear(d_h, d_h * n_layers)403404 self.W_att_q_ch = nn.Linear(d_h, 1)405 self.W_att_h_ch = nn.Linear(d_h, 1)406 self.W_hidden_ch = nn.Linear(d_h, d_h * n_layers)407 self.W_cell_ch = nn.Linear(d_h, d_h * n_layers)408409 self.W_out = nn.Sequential(410 nn.Linear(d_h * 2, d_h),411 nn.Tanh(),412 nn.Linear(d_h, self.max_where_num + 1)413 )414415 self.softmax = nn.Softmax(dim=-1)416417 def get_context(self, q_emb, q_lens, h_emb, h_lens, h_nums,418 q_encoder, h_encoder, W_att_q, W_att_h, W_hidden, W_cell):419 # [bs, max_h_num, d_h]420 h_pooling = encode_header(h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)421422 bs = len(q_lens)423424 # self-attention for header425 # [bs, max_h_num]426 att_weights_h = W_att_h(h_pooling).squeeze(2)427 att_mask_h = build_mask(att_weights_h, q_lens, dim=-2)428 att_weights_h = self.softmax(att_weights_h.masked_fill(att_mask_h == 0, -float("inf")))429430 # [bs, d_h]431 h_context = torch.mul(h_pooling, att_weights_h.unsqueeze(2)).sum(1)432433 # [bs, d_h] -> [bs, 2 * d_h]434 # enlarge because there are two layers.435 hidden = W_hidden(h_context)436 hidden = hidden.view(bs, self.n_layers * 2, int(self.d_h / 2))437 hidden = hidden.transpose(0, 1).contiguous()438439 cell = W_cell(h_context)440 cell = cell.view(bs, self.n_layers * 2, int(self.d_h / 2))441 cell = cell.transpose(0, 1).contiguous()442443 # [bs, max_q_len, d_h]444 q_enc = encode_question(q_encoder, q_emb, q_lens, init_states=(hidden, cell))445446 # self-attention for question447 # [bs, max_q_len]448 att_weights_q = W_att_q(q_enc).squeeze(2)449 att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)450 att_weights_q = self.softmax(att_weights_q.masked_fill(att_mask_q == 0, -float("inf")))451452 q_context = torch.mul(q_enc, att_weights_q.unsqueeze(2).expand_as(q_enc)).sum(dim=1)453 return q_context454455 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums,456 q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch):457458 q_context = self.get_context(q_emb, q_lens, h_emb, h_lens, h_nums,459 self.q_encoder, self.h_encoder, self.W_att_q,460 self.W_att_h, self.W_hidden, self.W_cell)461 q_context_ch = self.get_context(q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch, h_nums,462 self.q_encoder_ch, self.h_encoder_ch, self.W_att_q_ch,463 self.W_att_h_ch, self.W_hidden_ch, self.W_cell_ch)464 # [bs, max_where_num + 1]465 score_sel_num = self.W_out(torch.cat([q_context, q_context_ch], dim=-1))466 return score_sel_num467468469class WhereColumn(nn.Module):470 def __init__(self, d_in, d_in_ch, d_h, n_layers, dropout_prob, pooling_type):471 super(WhereColumn, self).__init__()472473 self.d_in = d_in474 self.d_h = d_h475 self.n_layers = n_layers476 self.dropout_prob = dropout_prob477 self.pooling_type = pooling_type478479 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,480 dropout=dropout_prob, birnn=True)481 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,482 dropout=dropout_prob, birnn=True)483484 self.W_att = nn.Linear(d_h, d_h)485 self.W_q = nn.Linear(d_h, d_h)486 self.W_h = nn.Linear(d_h, d_h)487488 self.q_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,489 dropout=dropout_prob, birnn=True)490 self.h_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,491 dropout=dropout_prob, birnn=True)492493 self.W_att_ch = nn.Linear(d_h, d_h)494 self.W_q_ch = nn.Linear(d_h, d_h)495 self.W_h_ch = nn.Linear(d_h, d_h)496497 self.W_out = nn.Sequential(nn.Tanh(), nn.Linear(4 * d_h, 1))498499 self.softmax = nn.Softmax(dim=-1)500501 def get_context(self, q_emb, q_lens, h_emb, h_lens, h_nums,502 q_encoder, h_encoder, W_att, W_q, W_h):503 # [bs, max_q_len, d_h]504 q_enc = encode_question(q_encoder, q_emb, q_lens)505506 # [bs, max_h_num, d_h]507 h_pooling = encode_header(h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)508509 # [bs, max_h_num, max_q_len]510 att_weights = torch.bmm(h_pooling, W_att(q_enc).transpose(1, 2))511 att_mask = build_mask(att_weights, q_lens, dim=-1)512 att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))513514 # att_weights: -> [bs, max_h_num, max_q_len, 1]515 # q_enc: -> [bs, 1, max_q_len, d_h]516 # [bs, max_h_num, d_h]517 q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)518519 return W_q(q_context), W_h(h_pooling)520521 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums,522 q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch):523 q_context, h_pooling = self.get_context(q_emb, q_lens, h_emb, h_lens, h_nums,524 self.q_encoder, self.h_encoder, self.W_att,525 self.W_q, self.W_h)526527 q_context_ch, h_pooling_ch = self.get_context(q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch, h_nums,528 self.q_encoder_ch, self.h_encoder_ch, self.W_att_ch,529 self.W_q_ch, self.W_h_ch)530531 # [bs, max_h_num, d_h * 2]532 comb_context = torch.cat([q_context, q_context_ch, h_pooling, h_pooling_ch], dim=-1)533534 # [bs, max_h_num]535 score_where_col = self.W_out(comb_context).squeeze(2)536 for b, h_num in enumerate(h_nums):537 score_where_col[b, h_num:] = -float("inf")538 return score_where_col539540541class WhereOperator(nn.Module):542 def __init__(self, d_in, d_h, n_layers, dropout_prob, n_op, pooling_type, max_where_num):543 super(WhereOperator, self).__init__()544545 self.d_in = d_in546 self.d_h = d_h547 self.n_layers = n_layers548 self.dropout_prob = dropout_prob549 self.n_op = n_op550 self.pooling_type = pooling_type551 self.max_where_num = max_where_num552553 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,554 dropout=dropout_prob, birnn=True)555 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,556 dropout=dropout_prob, birnn=True)557558 self.W_att = nn.Linear(d_h, d_h)559 self.W_q = nn.Linear(d_h, d_h)560 self.W_h = nn.Linear(d_h, d_h)561 self.W_out = nn.Sequential(562 nn.Linear(2 * d_h, d_h),563 nn.Tanh(),564 nn.Linear(d_h, self.n_op)565 )566567 self.softmax = nn.Softmax(dim=-1)568569 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, where_cols):570 # [bs, max_q_len, d_h]571 q_enc = encode_question(self.q_encoder, q_emb, q_lens)572573 # [bs, max_h_num, d_h]574 h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)575576 padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)577578 h_pooling_where = []579 for b, cols in enumerate(where_cols):580 if len(cols) > 0:581 h_tmp = [h_pooling[b][cols, :]]582 else:583 h_tmp = []584 h_tmp += [padding_t] * (self.max_where_num - len(cols))585 h_tmp = torch.cat(h_tmp, dim=0)586 h_pooling_where.append(h_tmp)587 # [bs, max_where_num, d_h]588 h_pooling_where = torch.stack(h_pooling_where)589590 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]591 # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]592 # [bs, max_where_num, max_q_len]593 att_weights = torch.matmul(594 self.W_att(q_enc).unsqueeze(1),595 h_pooling_where.unsqueeze(3)596 ).squeeze(3)597 att_mask = build_mask(att_weights, q_lens, dim=-1)598 att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))599600 # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]601 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]602 q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)603604 # [bs, max_where_num, n_agg]605 score_where_op = self.W_out(torch.cat([self.W_q(q_context), self.W_h(h_pooling_where)], dim=2))606 return score_where_op607608609class WhereAggregation(nn.Module):610 def __init__(self, d_in, d_h, n_layers, dropout_prob, n_agg, pooling_type, max_where_num):611 super(WhereAggregation, self).__init__()612613 self.d_in = d_in614 self.d_h = d_h615 self.n_layers = n_layers616 self.dropout_prob = dropout_prob617 self.n_agg = n_agg618 self.pooling_type = pooling_type619 self.max_where_num = max_where_num620621 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,622 dropout=dropout_prob, birnn=True)623 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,624 dropout=dropout_prob, birnn=True)625626 self.W_att = nn.Linear(d_h, d_h)627 self.W_q = nn.Linear(d_h, d_h)628 self.W_h = nn.Linear(d_h, d_h)629 self.W_out = nn.Sequential(630 nn.Linear(2 * d_h, d_h),631 nn.Tanh(),632 nn.Linear(d_h, self.n_agg)633 )634635 self.softmax = nn.Softmax(dim=-1)636637 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, where_cols):638 # [bs, max_q_len, d_h]639 q_enc = encode_question(self.q_encoder, q_emb, q_lens)640641 # [bs, max_h_num, d_h]642 h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)643644 padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)645646 h_pooling_where = []647 for b, cols in enumerate(where_cols):648 if len(cols) > 0:649 h_tmp = [h_pooling[b][cols, :]]650 else:651 h_tmp = []652 h_tmp += [padding_t] * (self.max_where_num - len(cols))653 h_tmp = torch.cat(h_tmp, dim=0)654 h_pooling_where.append(h_tmp)655 # [bs, max_where_num, d_h]656 h_pooling_where = torch.stack(h_pooling_where)657658 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]659 # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]660 # [bs, max_where_num, max_q_len]661 att_weights = torch.matmul(662 self.W_att(q_enc).unsqueeze(1),663 h_pooling_where.unsqueeze(3)664 ).squeeze(3)665 att_mask = build_mask(att_weights, q_lens, dim=-1)666 att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))667668 # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]669 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]670 q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)671672 # [bs, max_where_num, n_agg]673 score_where_agg = self.W_out(torch.cat([self.W_q(q_context), self.W_h(h_pooling_where)], dim=2))674 return score_where_agg675676677class WhereValue(nn.Module):678 def __init__(self, d_in, d_in_ch, d_h, d_f, n_layers, dropout_prob, n_op, pooling_type, max_where_num):679 super(WhereValue, self).__init__()680681 self.d_in = d_in682 self.d_h = d_h683 self.d_in_ch = d_in_ch684 self.n_layers = n_layers685 self.dropout_prob = dropout_prob686 self.n_op = n_op687 self.pooling_type = pooling_type688 self.max_where_num = max_where_num689690 self.d_f = d_f691 self.q_feature_embed = nn.Embedding(2, self.d_f)692693 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,694 dropout=dropout_prob, birnn=True)695 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,696 dropout=dropout_prob, birnn=True)697698 self.W_att = nn.Linear(d_h + self.d_f, d_h)699 self.W_q = nn.Linear(d_h + self.d_f, d_h)700 self.W_h = nn.Linear(d_h, d_h)701702 self.q_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,703 dropout=dropout_prob, birnn=True)704 self.h_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,705 dropout=dropout_prob, birnn=True)706707 self.W_att_ch = nn.Linear(d_h, d_h)708 self.W_q_ch = nn.Linear(d_h, d_h)709 self.W_h_ch = nn.Linear(d_h, d_h)710711 self.W_op = nn.Linear(n_op, d_h)712713 self.W_out = nn.Sequential(714 nn.Linear(6 * d_h + self.d_f, d_h),715 nn.Tanh(),716 nn.Linear(d_h, 2)717 )718719 self.softmax = nn.Softmax(dim=-1)720721 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums,722 q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch, where_cols, where_ops,723 q_feature):724 bs = len(q_emb)725 max_q_len = max(q_lens)726727 # [bs, max_q_len, d_h]728 q_enc = encode_question(self.q_encoder, q_emb, q_lens)729 for b, f in enumerate(q_feature):730 while len(f) < max_q_len:731 q_feature[b].append(0)732 q_feature = torch.tensor(q_feature)733 if q_enc.is_cuda:734 q_feature = q_feature.to(q_enc.device)735736 q_feature_enc = self.q_feature_embed(q_feature)737738 q_enc = torch.cat([q_enc, q_feature_enc], -1)739740 # [bs, max_h_num, d_h]741 h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)742743 padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)744 h_pooling_where = []745 for b, cols in enumerate(where_cols):746 if len(cols) > 0:747 h_tmp = [h_pooling[b][cols, :]]748 else:749 h_tmp = []750 h_tmp += [padding_t] * (self.max_where_num - len(cols))751 h_tmp = torch.cat(h_tmp, dim=0)752 h_pooling_where.append(h_tmp)753 # [bs, max_where_num, d_h]754 h_pooling_where = torch.stack(h_pooling_where)755756 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]757 # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]758 # [bs, max_where_num, max_q_len]759 att_weights = torch.matmul(760 self.W_att(q_enc).unsqueeze(1),761 h_pooling_where.unsqueeze(3)762 ).squeeze(3)763 att_mask = build_mask(att_weights, q_lens, dim=-1)764 att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))765766 # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]767 # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]768 # [bs, max_where_num, d_h]769 q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)770771 q_enc_ch = encode_question(self.q_encoder_ch, q_emb_ch, q_lens_ch)772773 # [bs, max_h_num, d_h]774 h_pooling_ch = encode_header(self.h_encoder_ch, h_emb_ch, h_lens_ch,775 h_nums, pooling_type=self.pooling_type)776777 padding_t_ch = torch.zeros_like(h_pooling_ch[0][0]).unsqueeze(0)778779 h_pooling_where_ch = []780 for b, cols in enumerate(where_cols):781 if len(cols) > 0:782 h_tmp = [h_pooling_ch[b][cols, :]]783 else:784 h_tmp = []785 h_tmp += [padding_t_ch] * (self.max_where_num - len(cols))786 h_tmp = torch.cat(h_tmp, dim=0)787 h_pooling_where_ch.append(h_tmp)788 h_pooling_where_ch = torch.stack(h_pooling_where_ch)789790 att_weights_ch = torch.matmul(791 self.W_att_ch(q_enc_ch).unsqueeze(1),792 h_pooling_where_ch.unsqueeze(3)793 ).squeeze(3)794 att_mask_ch = build_mask(att_weights_ch, q_lens_ch, dim=-1)795 att_weights_ch = self.softmax(att_weights_ch.masked_fill(att_mask_ch == 0, -float("inf")))796 q_context_ch = torch.mul(att_weights_ch.unsqueeze(3), q_enc_ch.unsqueeze(1)).sum(dim=2)797798 op_enc = []799 for b in range(bs):800 op_enc_tmp = torch.zeros(self.max_where_num, self.n_op)801 op = where_ops[b]802 idx_scatter = []803 op_len = len(op)804 for i in range(self.max_where_num):805 if i < op_len:806 idx_scatter.append([op[i]])807 else:808 idx_scatter.append([0])809 op_enc_tmp = op_enc_tmp.scatter(1, torch.tensor(idx_scatter), 1)810 op_enc.append(op_enc_tmp)811 op_enc = torch.stack(op_enc)812 if q_context.is_cuda:813 op_enc = op_enc.to(q_context.device)814815 comb_context = torch.cat(816 [self.W_q(q_context),817 self.W_h(h_pooling_where),818 self.W_q_ch(q_context_ch),819 self.W_h_ch(h_pooling_where_ch),820 self.W_op(op_enc)],821 dim=2822 )823 comb_context = comb_context.unsqueeze(2).expand(-1, -1, q_enc.size(1), -1)824 q_enc = q_enc.unsqueeze(1).expand(-1, comb_context.size(1), -1, -1)825826 # [bs, max_where_num, max_q_num, 2]827 score_where_val = self.W_out(torch.cat([comb_context, q_enc], dim=3))828829 for b, l in enumerate(q_lens):830 if l < max_q_len:831 score_where_val[b, :, l:, :] = -float("inf")832 return score_where_val833834835class OrderByColumn(nn.Module):836 def __init__(self, d_in, d_h, n_layers, dropout_prob, pooling_type):837 super(OrderByColumn, self).__init__()838 self.d_in = d_in839 self.d_h = d_h840 self.n_layers = n_layers841 self.dropout_prob = dropout_prob842 self.pooling_type = pooling_type843844 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2),845 n_layers=n_layers, batch_first=True846 , dropout=dropout_prob, birnn=True)847 self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2),848 n_layers=n_layers, batch_first=True,849 dropout=dropout_prob, birnn=True)850851 self.W_att = nn.Linear(d_h, d_h)852 self.W_question = nn.Linear(d_h, d_h)853 self.W_header = nn.Linear(d_h, d_h)854 self.W_out = nn.Sequential(nn.Tanh(), nn.Linear(2 * d_h, 1))855 self.softmax = nn.Softmax(dim=2)856857 def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, mask=True):858859 # [bs, max_q_len, d_h]860 q_enc = encode_question(self.q_encoder, q_emb, q_lens)861 # [bs, max_h_num, d_h]862 h_pooling = encode_header(self.h_encoder,863 h_emb, h_lens, h_nums, pooling_type=self.pooling_type)864865 # [bs, max_h_num, max_q_len]866 # torch.bmm: bs * ([max_header_len, d_h], [d_h, max_q_len])867 att_weights = torch.bmm(h_pooling, self.W_att(q_enc).transpose(1, 2))868 att_mask = build_mask(att_weights, h_nums, dim=-1)869 att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))870 att_weights = self.softmax(att_weights)871872 # attention_weights: -> [bs, max_h_num, max_q_len, 1]873 # q_enc: -> [bs, 1, max_q_len, d_h]874 # [bs, max_h_num, d_h]875 q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)876 comb_context = torch.cat([self.W_question(q_context), self.W_header(h_pooling)], dim=-1)877878 score_ord_col = self.W_out(comb_context).squeeze(2)879880 if mask:881 for b, h_num in enumerate(h_nums):882 score_ord_col[b, h_num:] = -float('inf')883 return score_ord_col884885886class OrderByOrder(nn.Module):887 def __init__(self, d_in=300, d_h=100, n_layers=2, dropout_prob=0.3):888 super(OrderByOrder, self).__init__()889 self.d_in = d_in890 self.d_h = d_h891 self.n_layers = n_layers892 self.dropout_prob = dropout_prob893894 self.max_order = 3895896 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers,897 batch_first=True, dropout=dropout_prob, birnn=True)898899 self.W_att_question = nn.Linear(d_h, d_h)900 self.W_att = nn.Linear(d_h, 1)901902 self.W_out = nn.Sequential(nn.Linear(d_h, d_h), nn.Tanh(),903 nn.Linear(d_h, self.max_order))904 self.softmax = nn.Softmax(dim=1)905906 def forward(self, q_emb, q_lens):907 # [bs, max_q_len, d_h]908 q_enc = encode_question(self.q_encoder, q_emb, q_lens)909910 # [bs, max_q_len, 1]911 att_weights = self.W_att(q_enc)912 # print(att_weights.shape)913 att_mask = build_mask(att_weights, q_lens, dim=-2)914 att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))915 # [bs, max_q_len, 1]916 att_weights = self.softmax(att_weights)917 # print(att_weights.shape)918919 # [bs, d_h]920 q_context = torch.bmm(q_enc.transpose(1, 2), att_weights).squeeze(2)921922 score_order = self.W_out(q_context)923 return score_order924925926class OrderByLimit(nn.Module):927 def __init__(self, d_in, d_h, n_layers, dropout_prob, n_limit):928 super(OrderByLimit, self).__init__()929 self.d_in = d_in930 self.d_h = d_h931 self.n_layers = n_layers932 self.dropout_prob = dropout_prob933934 # limit = 0 -> no limit limit <- [1, 9]935 self.n_limit = n_limit936937 self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers,938 batch_first=True, dropout=dropout_prob, birnn=True)939940 self.W_att_question = nn.Linear(d_h, d_h)941 self.W_att = nn.Linear(d_h, 1)942943 self.W_out = nn.Sequential(nn.Linear(d_h, d_h), nn.Tanh(),944 nn.Linear(d_h, self.n_limit))945 self.softmax = nn.Softmax(dim=1)946947 def forward(self, q_emb, q_lens):948 # [bs, max_q_len, d_h]949 q_enc = encode_question(self.q_encoder, q_emb, q_lens)950951 # [bs, max_q_len, 1]952 att_weights = self.W_att(q_enc)953 att_mask = build_mask(att_weights, q_lens, dim=-2)954 att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))955 # [bs, max_q_len, 1]956 att_weights = self.softmax(att_weights)957958 # [bs, d_h]959 q_context = torch.bmm(q_enc.transpose(1, 2), att_weights).squeeze(2)960 score_limit = self.W_out(q_context) ...

Full Screen

Full Screen

GRUsim.py

Source:GRUsim.py Github

copy

Full Screen

1import torch as th2import torch.nn as nn3import numpy as np4npa = np.asarray5def tonumpy(ls):6 return [t.data.numpy() for t in ls]7class Erf(nn.Module):8 def forward(self, x):9 return x.erf()10class ErfSigmoid(nn.Module):11 def forward(self, x):12 return (1 + x.erf())/213def simgru(inpseq, d_h, time=None, varUz=1, varUr=1, varUh=1,14 varWz=1, varWr=1, varWh=1,15 varbz=1, varbr=1, varbh=1,16 mubz=0, mubr=0, mubh=0,17 bias=True, nonlin=None, sigmoid=None, wt_tie=True,18 # get_h=False, get_update=False, get_reset=False, get_htilde=False,19 h_init=0):20 r'''Simulate a GRU on a sequence and obtain data21 A GRU evolves according to the equations22 23 \tilde z^t = W_z x^t + U_z h^{t-1} + b_z24 z^t = sigmoid(\tilde z^t)25 \tilde r^t = W_r x^t + U_r h^{t-1} + b_r26 r^t = sigmoid(\tilde r^t)27 \tilde h^t = W_h x^t + U_h(h^{t-1} \odot r^t) + b_h28 h^t = (1 - z^t) \odot h^{t-1} + z^t \odot nonlin(\tilde h^t)29 30 where31 32 h^t is state at time t33 x^t is input at time t34 z^t is ``update gate``: 1 means do update/forget previous h^{t-1}35 0 means h^t = h^{t-1}36 r^t is ``reset gate'':37 the smaller, the easier to make proposed update not depend on h^{t-1}38 W_z, W_r, W_h are weights converting input to hidden states39 U_z, U_r, U_h are weights converting state to state40 b_z, b_r, b_h are biases41 This function simulates a randomly initialized GRU42 on the sequence `inpseq` and returns a record of43 several data.44 Inputs:45 inpseq: a matrix `seqlen x inputdim` where each row46 is a token47 d_h: dimension of state h^t48 time: run simulation for t up to `time`.49 If is None, then set of `seqlen`.50 varUz: each element of U_z has variance `varUz`/d51 varUr: each element of U_r has variance `varUr`/d52 varUh: each element of U_h has variance `varUh`/d53 varWz: each element of W_z has variance `varWz`/d54 varWr: each element of W_r has variance `varWr`/d55 varWh: each element of W_h has variance `varWh`/d56 varbz: each element of b_z has variance `varbz`57 varbr: each element of b_r has variance `varbr`58 varbh: each element of b_h has variance `varbh`59 mubz: each element of b_z has mean `mubz`60 mubr: each element of b_r has mean `mubr`61 mubh: each element of b_h has mean `mubh`62 bias: whether to turn on bias63 (if False, then `varb_`s and `mub_`s have no effect)64 nonlin: nonlinearity; has to be pytorch module or65 has the same ducktype66 sigmoid: sigmoid function; has to pytorch module67 or has the same ducktype68 wt_tie: whether to tie the weights69 h_init: the magnitude of the initial state (which70 is randomly initialized)71 Outputs:72 a dictionary with the following keys73 h: `time+1 x d_h` matrix containing all h^t74 th: `time x d_h` matrix containing all \tilde h^t75 tz: `time x d_h` matrix containing all \tilde z^t76 tr: `time x d_h` matrix containing all \tilde r^t77 hcov: `time+1 x time+1` matrix of 2nd moments of h^t78 thcov: `time x time` matrix of 2nd moments of \tilde h^t79 tzcov: `time x time` matrix of 2nd moments of \tilde z^t80 trcov: `time x time` matrix of 2nd moments of \tilde r^t81 '''82 d_i = inpseq.shape[1]83 if time is None:84 time = inpseq.shape[0]85 updates = []86 resets = []87 htildes = []88 def makelayer(nonlin=nonlin, sigmoid=sigmoid):89 W_update = nn.Linear(d_i, d_h, bias=False)90 update = nn.Linear(d_h, d_h, bias=bias)91 W_reset = nn.Linear(d_i, d_h, bias=False)92 reset = nn.Linear(d_h, d_h, bias=bias)93 W_htilde = nn.Linear(d_i, d_h, bias=False)94 htilde = nn.Linear(d_h, d_h, bias=bias)95 W_update.weight.data.normal_(0, np.sqrt(varWz/d_i))96 update.weight.data.normal_(0, np.sqrt(varUz/d_h))97 W_reset.weight.data.normal_(0, np.sqrt(varWr/d_i))98 reset.weight.data.normal_(0, np.sqrt(varUr/d_h))99 W_htilde.weight.data.normal_(0, np.sqrt(varWh/d_i))100 htilde.weight.data.normal_(0, np.sqrt(varUh/d_h))101 if bias:102 update.bias.data.normal_(mubz, np.sqrt(varbz))103 reset.bias.data.normal_(mubr, np.sqrt(varbr))104 htilde.bias.data.normal_(mubh, np.sqrt(varbh))105 if nonlin is None:106 nonlin = lambda: lambda x: x107 if sigmoid is None:108 sigmoid = nn.Sigmoid109 def r(h, inp):110 _u = update(h) + W_update(inp)111 # if get_update:112 updates.append(_u)113 u = sigmoid()(_u)114 _r = reset(h) + W_reset(inp)115 # if get_reset:116 resets.append(_r)117 r = sigmoid()(_r)118 _h = htilde(h * r) + W_htilde(inp)119 # if get_htilde:120 htildes.append(_h)121 ht = nonlin()(_h)122 return (1 - u) * h + u * ht123 return r124 if wt_tie:125 mylayer = makelayer()126# if h_init is None:127# x = 0128# else:129 x = th.randn(1, d_h) * h_init130 xs = [x]131 for i in range(time):132 if not wt_tie:133 mylayer = makelayer()134 xx = mylayer(xs[-1], inpseq[i:i+1])135 xs.append(xx)136 ret = {}137 # if get_h:138 ret['h'] = npa(tonumpy(xs)).squeeze()139 # if get_update:140 ret['tz'] = npa(tonumpy(updates)).squeeze()141 # if get_reset:142 ret['tr'] = npa(tonumpy(resets)).squeeze()143 # if get_htilde:144 ret['th'] = npa(tonumpy(htildes)).squeeze()145# print(ret['h'].shape)146 ret['tzcov'] = ret['tz'] @ ret['tz'].T / d_h147 ret['trcov'] = ret['tr'] @ ret['tr'].T / d_h148 ret['thcov'] = ret['th'] @ ret['th'].T / d_h149 ret['hcov'] = ret['h'] @ ret['h'].T / d_h150 # ret['hnorms'] = [th.mean(u**2).data.item() for u in xs]151 return ret152def simgru2(inpseq1, inpseq2, d_h, varUz=1, varUr=1, varUh=1,153 varWz=1, varWr=1, varWh=1,154 varbz=1, varbr=1, varbh=1,155 mubz=0, mubr=0, mubh=0,156 bias=True, nonlin=None, sigmoid=None, wt_tie=True,157 # get_h=False, get_update=False, get_reset=False, get_htilde=False,158 h_init=0):159 d_i = inpseq1.shape[1]160 updates = []161 resets = []162 htildes = []163 def makelayer(nonlin=nonlin, sigmoid=sigmoid):164 W_update = nn.Linear(d_i, d_h, bias=False)165 update = nn.Linear(d_h, d_h, bias=bias)166 W_reset = nn.Linear(d_i, d_h, bias=False)167 reset = nn.Linear(d_h, d_h, bias=bias)168 W_htilde = nn.Linear(d_i, d_h, bias=False)169 htilde = nn.Linear(d_h, d_h, bias=bias)170 W_update.weight.data.normal_(0, np.sqrt(varWz/d_i))171 update.weight.data.normal_(0, np.sqrt(varUz/d_h))172 W_reset.weight.data.normal_(0, np.sqrt(varWr/d_i))173 reset.weight.data.normal_(0, np.sqrt(varUr/d_h))174 W_htilde.weight.data.normal_(0, np.sqrt(varWh/d_i))175 htilde.weight.data.normal_(0, np.sqrt(varUh/d_h))176 if bias:177 update.bias.data.normal_(mubz, np.sqrt(varbz))178 reset.bias.data.normal_(mubr, np.sqrt(varbr))179 htilde.bias.data.normal_(mubh, np.sqrt(varbh))180 if nonlin is None:181 nonlin = lambda: lambda x: x182 if sigmoid is None:183 sigmoid = nn.Sigmoid184 def r(h, inp):185 _u = update(h) + W_update(inp)186 # if get_update:187 updates.append(_u)188 u = sigmoid()(_u)189 _r = reset(h) + W_reset(inp)190 # if get_reset:191 resets.append(_r)192 r = sigmoid()(_r)193 _h = htilde(h * r) + W_htilde(inp)194 # if get_htilde:195 htildes.append(_h)196 ht = nonlin()(_h)197 return (1 - u) * h + u * ht198 return r199 if wt_tie:200 mylayer = makelayer()201# if h_init is None:202# x = 0203# else:204 rets = {1: {}, 2: {}}205 for i, seq in enumerate([inpseq1, inpseq2]):206 x = th.randn(1, d_h) * h_init207 xs = [x]208 for tok in seq:209 if not wt_tie:210 mylayer = makelayer()211 xx = mylayer(xs[-1], tok)212 xs.append(xx)213 ret = rets[i+1]214 # if get_h:215 ret['h'] = npa(tonumpy(xs)).squeeze()216 # if get_update:217 ret['tz'] = npa(tonumpy(updates)).squeeze()218 # if get_reset:219 ret['tr'] = npa(tonumpy(resets)).squeeze()220 # if get_htilde:221 ret['th'] = npa(tonumpy(htildes)).squeeze()222 # print(ret['h'].shape)223 ret['tzcov'] = ret['tz'] @ ret['tz'].T / d_h224 ret['trcov'] = ret['tr'] @ ret['tr'].T / d_h225 ret['thcov'] = ret['th'] @ ret['th'].T / d_h226 ret['hcov'] = ret['h'] @ ret['h'].T / d_h227 ret = rets['x'] = {}228 ret['tzcov'] = rets[1]['tz'] @ rets[2]['tz'].T / d_h229 ret['trcov'] = rets[1]['tr'] @ rets[2]['tr'].T / d_h230 ret['thcov'] = rets[1]['th'] @ rets[2]['th'].T / d_h231 ret['hcov'] = rets[1]['h'] @ rets[2]['h'].T / d_h232 # ret['hnorms'] = [th.mean(u**2).data.item() for u in xs]233 234 rets['hcov'] = np.block(235 [[rets[1]['hcov'][1:, 1:], rets['x']['hcov'][1:, 1:]],236 [rets['x']['hcov'][1:, 1:].T, rets[2]['hcov'][1:, 1:]]]237 )...

Full Screen

Full Screen

UFlow.py

Source:UFlow.py Github

copy

Full Screen

1import torch2from ..Step.NormalizingFlow import FCNormalizingFlow3class UFlow(FCNormalizingFlow):4 def __init__(self, enc_steps, dec_steps, z_log_density, dropping_factors):5 super(UFlow, self).__init__(enc_steps + dec_steps, z_log_density)6 self.enc_dropping_factors = dropping_factors7 self.dec_gathering_factors = dropping_factors[::-1]8 self.enc_steps = enc_steps9 self.dec_steps = dec_steps10 def forward(self, x, context=None):11 b_size = x.shape[0]12 jac_tot = 0.13 z_all = []14 for step, drop_factors in zip(self.enc_steps, self.enc_dropping_factors):15 z, jac = step(x.contiguous().view(b_size, -1), context)16 d_c, d_h, d_w = drop_factors17 C, H, W = step.img_sizes18 c, h, w = int(C/d_c), int(H/d_h), int(W/d_w)19 z_reshaped = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \20 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)21 z_all += [z_reshaped[:, :, :, :, 1:]]22 x = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \23 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]24 jac_tot += jac25 for step, gath_factors, z in zip(self.dec_steps[1:], self.dec_gathering_factors[1:], z_all[::-1][1:]):26 d_c, d_h, d_w = gath_factors27 C, H, W = step.img_sizes28 c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)29 z = torch.cat((x.unsqueeze(4), z), 4)30 z = z.view(b_size, c, h, w, d_c, d_h, d_w)31 z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)32 z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)33 x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)34 x, jac = step(x.contiguous().view(b_size, -1), context)35 x = x.contiguous().view(b_size, C, H, W)36 jac_tot += jac37 z = x.view(b_size, -1)38 return z, jac_tot39 def invert(self, z, context=None):40 z_prev = z41 b_size = z.shape[0]42 z_all = []43 for step, drop_factors in zip(self.dec_steps[::-1][:-1], self.enc_dropping_factors[:-1]):44 x = step.invert(z.contiguous().view(b_size, -1), context)45 d_c, d_h, d_w = drop_factors46 C, H, W = step.img_sizes47 c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)48 z_reshaped = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \49 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)50 z_all += [z_reshaped[:, :, :, :, 1:]]51 x = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \52 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]53 z = x54 x = self.enc_steps[-1].invert(z.view(b_size, -1), context).view(z.shape)55 for step, gath_factors, z in zip(self.enc_steps[::-1][1:], self.dec_gathering_factors[1:], z_all[::-1]):56 d_c, d_h, d_w = gath_factors57 C, H, W = step.img_sizes58 c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)59 z = torch.cat((x.unsqueeze(4), z), 4)60 z = z.view(b_size, c, h, w, d_c, d_h, d_w)61 z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)62 z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)63 x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)64 x = step.invert(x.contiguous().view(b_size, -1), context).contiguous().view(b_size, C, H, W)65 return x.view(b_size, -1)66class ImprovedUFlow(FCNormalizingFlow):67 def __init__(self, enc_steps, dec_steps, z_log_density, dropping_factors, conditioner_nets):68 super(ImprovedUFlow, self).__init__(enc_steps + dec_steps, z_log_density)69 self.enc_dropping_factors = dropping_factors70 self.dec_gathering_factors = dropping_factors[::-1]71 self.enc_steps = enc_steps72 self.dec_steps = dec_steps73 self.conditioner_nets = conditioner_nets74 def forward(self, x, context=None):75 b_size = x.shape[0]76 jac_tot = 0.77 z_all = []78 full_context = context79 for i, (step, drop_factors) in enumerate(zip(self.enc_steps, self.enc_dropping_factors)):80 z, jac = step(x.contiguous().view(b_size, -1), full_context)81 d_c, d_h, d_w = drop_factors82 C, H, W = step.img_sizes83 c, h, w = int(C/d_c), int(H/d_h), int(W/d_w)84 z_reshaped = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \85 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)86 z_all += [z_reshaped[:, :, :, :, 1:]]87 x = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \88 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]89 jac_tot += jac90 if i < len(self.enc_steps) - 1:91 z_for_context = torch.cat((torch.zeros_like(z_reshaped[:, :, :, :, [0]]), z_reshaped[:, :, :, :, 1:]), 4)92 z_for_context = z_for_context.view(b_size, c, h, w, d_c, d_h, d_w)93 z_for_context = z_for_context.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)94 z_for_context = z_for_context.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)95 z_for_context = z_for_context.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)96 local_context = self.conditioner_nets[i](z_for_context)97 full_context = torch.cat((local_context, context), 1) if context is not None else local_context98 for step, gath_factors, z in zip(self.dec_steps[1:], self.dec_gathering_factors[1:], z_all[::-1][1:]):99 d_c, d_h, d_w = gath_factors100 C, H, W = step.img_sizes101 c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)102 z = torch.cat((x.unsqueeze(4), z), 4)103 z = z.view(b_size, c, h, w, d_c, d_h, d_w)104 z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)105 z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)106 x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)107 x, jac = step(x.contiguous().view(b_size, -1), context)108 x = x.contiguous().view(b_size, C, H, W)109 jac_tot += jac110 z = x.view(b_size, -1)111 return z, jac_tot112 def invert(self, z, context=None):113 b_size = z.shape[0]114 z_all = []115 full_contexts = [context]116 for i, (step, drop_factors) in enumerate(zip(self.dec_steps[::-1][:-1], self.enc_dropping_factors[:-1])):117 x = step.invert(z.contiguous().view(b_size, -1), context)118 d_c, d_h, d_w = drop_factors119 C, H, W = step.img_sizes120 c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)121 z_reshaped = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \122 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)123 z_all += [z_reshaped[:, :, :, :, 1:]]124 x = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \125 .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]126 z = x127 z_for_context = torch.cat((torch.zeros_like(z_reshaped[:, :, :, :, [0]]), z_reshaped[:, :, :, :, 1:]), 4)128 z_for_context = z_for_context.view(b_size, c, h, w, d_c, d_h, d_w)129 z_for_context = z_for_context.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)130 z_for_context = z_for_context.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)131 z_for_context = z_for_context.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)132 local_context = self.conditioner_nets[i](z_for_context)133 full_contexts += [torch.cat((local_context, context), 1) if context is not None else local_context]134 x = self.enc_steps[-1].invert(z.view(b_size, -1), full_contexts[-1]).view(z.shape)135 for step, gath_factors, z, full_context in zip(self.enc_steps[::-1][1:], self.dec_gathering_factors[1:], z_all[::-1], full_contexts[::-1][1:]):136 d_c, d_h, d_w = gath_factors137 C, H, W = step.img_sizes138 c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)139 z = torch.cat((x.unsqueeze(4), z), 4)140 z = z.view(b_size, c, h, w, d_c, d_h, d_w)141 z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)142 z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)143 x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)144 x = step.invert(x.contiguous().view(b_size, -1), full_context).contiguous().view(b_size, C, H, W)...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run fMBT automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful