Best Python code snippet using fMBT_python
modules.py
Source:modules.py  
1# -*- coding: utf-8 -*-2# !/usr/bin/python34"""5# @Time    : 2020/8/16# @Author  : Yongrui Chen7# @File    : modules.py8# @Software: PyCharm9"""10import sys11import torch12import torch.nn as nn13sys.path.append("..")14from src.models.nn_utils import encode_question, encode_header, build_mask151617class LSTM(nn.Module):1819    def __init__(self, d_input, d_h, n_layers=1, batch_first=True, birnn=True, dropout=0.3):20        super(LSTM, self).__init__()2122        n_dir = 2 if birnn else 123        self.init_h = nn.Parameter(torch.Tensor(n_layers * n_dir, d_h))24        self.init_c = nn.Parameter(torch.Tensor(n_layers * n_dir, d_h))2526        INI = 1e-227        torch.nn.init.uniform_(self.init_h, -INI, INI)28        torch.nn.init.uniform_(self.init_c, -INI, INI)2930        self.lstm = nn.LSTM(31            input_size=d_input,32            hidden_size=d_h,33            num_layers=n_layers,34            bidirectional=birnn,35            batch_first=not batch_first36        )37        self.dropout = nn.Dropout(dropout)3839    def forward(self, seqs, seq_lens=None, init_states=None):4041        bs = seqs.size(0)42        bf = self.lstm.batch_first4344        if not bf:45            seqs = seqs.transpose(0, 1)4647        seqs = self.dropout(seqs)484950        size = (self.init_h.size(0), bs, self.init_h.size(1))51        if init_states is None:52            init_states = (self.init_h.unsqueeze(1).expand(*size).contiguous(),53                           self.init_c.unsqueeze(1).expand(*size).contiguous())5455        if seq_lens is not None:56            assert bs == len(seq_lens)57            sort_ind = sorted(range(len(seq_lens)), key=lambda i: seq_lens[i], reverse=True)58            seq_lens = [seq_lens[i] for i in sort_ind]59            seqs = self.reorder_sequence(seqs, sort_ind, bf)60            init_states = self.reorder_init_states(init_states, sort_ind)6162            packed_seq = nn.utils.rnn.pack_padded_sequence(seqs, seq_lens)63            packed_out, final_states = self.lstm(packed_seq, init_states)64            lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out)6566            back_map = {ind: i for i, ind in enumerate(sort_ind)}67            reorder_ind = [back_map[i] for i in range(len(seq_lens))]68            lstm_out = self.reorder_sequence(lstm_out, reorder_ind, bf)69            final_states = self.reorder_init_states(final_states, reorder_ind)70        else:71            lstm_out, final_states = self.lstm(seqs)72        return lstm_out.transpose(0, 1), final_states7374    def reorder_sequence(self, seqs, order, batch_first=False):75        """76        seqs: [T, B, D] if not batch_first77        order: list of sequence length78        """79        batch_dim = 0 if batch_first else 180        assert len(order) == seqs.size()[batch_dim]81        order = torch.LongTensor(order).to(seqs.device)82        sorted_seqs = seqs.index_select(index=order, dim=batch_dim)83        return sorted_seqs8485    def reorder_init_states(self, states, order):86        """87        lstm_states: (H, C) of tensor [layer, batch, hidden]88        order: list of sequence length89        """90        assert isinstance(states, tuple)91        assert len(states) == 292        assert states[0].size() == states[1].size()93        assert len(order) == states[0].size()[1]9495        order = torch.LongTensor(order).to(states[0].device)96        sorted_states = (states[0].index_select(index=order, dim=1),97                         states[1].index_select(index=order, dim=1))98        return sorted_states99100class SelectNumber(nn.Module):101    def __init__(self, d_in, d_h, n_layers, dropout_prob, pooling_type, max_select_num):102        super(SelectNumber, self).__init__()103104        self.d_in = d_in105        self.d_h = d_h106        self.n_layers = n_layers107        self.dropout_prob = dropout_prob108        self.pooling_type = pooling_type109        self.max_select_num = max_select_num110111        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,112                              dropout=dropout_prob, birnn=True)113        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,114                              dropout=dropout_prob, birnn=True)115116        self.W_att_q = nn.Linear(d_h, 1)117        self.W_att_h = nn.Linear(d_h, 1)118        self.W_hidden = nn.Linear(d_h, d_h * n_layers)119        self.W_cell = nn.Linear(d_h, d_h * n_layers)120        self.W_out = nn.Sequential(121            nn.Linear(d_h, d_h),122            nn.Tanh(),123            nn.Linear(d_h, self.max_select_num + 1)124        )125126        self.softmax = nn.Softmax(dim=-1)127128    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums):129130        # [bs, max_h_num, d_h]131        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)132133        bs = len(q_lens)134135        # self-attention for header136        # [bs, max_h_num]137        att_weights_h = self.W_att_h(h_pooling).squeeze(2)138        att_mask_h = build_mask(att_weights_h, q_lens, dim=-2)139        att_weights_h = self.softmax(att_weights_h.masked_fill(att_mask_h == 0, -float("inf")))140141        # [bs, d_h]142        h_context = torch.mul(h_pooling, att_weights_h.unsqueeze(2)).sum(1)143144        # [bs, d_h] -> [bs, 2 * d_h]145        # enlarge because there are two layers.146        hidden = self.W_hidden(h_context)147        hidden = hidden.view(bs, self.n_layers * 2, int(self.d_h / 2))148        hidden = hidden.transpose(0, 1).contiguous()149150        cell = self.W_cell(h_context)151        cell = cell.view(bs, self.n_layers * 2, int(self.d_h / 2))152        cell = cell.transpose(0, 1).contiguous()153154        # [bs, max_q_len, d_h]155        q_enc = encode_question(self.q_encoder, q_emb, q_lens, init_states=(hidden, cell))156157        # self-attention for question158        # [bs, max_q_len]159        att_weights_q = self.W_att_q(q_enc).squeeze(2)160        att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)161        att_weights_q = self.softmax(att_weights_q.masked_fill(att_mask_q == 0, -float("inf")))162163        q_context = torch.mul(q_enc, att_weights_q.unsqueeze(2).expand_as(q_enc)).sum(dim=1)164165        # [bs, max_select_num + 1]166        score_sel_num = self.W_out(q_context)167        return score_sel_num168169170class SelectColumn(nn.Module):171    def __init__(self, d_in, d_h, n_layers, dropout_prob, pooling_type):172        super(SelectColumn, self).__init__()173174        self.d_in = d_in175        self.d_h = d_h176        self.n_layers = n_layers177        self.dropout_prob = dropout_prob178        self.pooling_type = pooling_type179180        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,181                              dropout=dropout_prob, birnn=True)182        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,183                              dropout=dropout_prob, birnn=True)184185        self.W_att = nn.Linear(d_h, d_h)186        self.W_q = nn.Linear(d_h, d_h)187        self.W_h = nn.Linear(d_h, d_h)188        self.W_out = nn.Sequential(nn.Tanh(), nn.Linear(2 * d_h, 1))189190        self.softmax = nn.Softmax(dim=-1)191192    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums):193        # [bs, max_q_len, d_h]194        q_enc = encode_question(self.q_encoder, q_emb, q_lens)195196        # [bs, max_h_num, d_h]197        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)198199        # [bs, max_h_num, max_q_len]200        att_weights = torch.bmm(h_pooling, self.W_att(q_enc).transpose(1, 2))201        att_mask = build_mask(att_weights, q_lens, dim=-1)202        att_weights = self.softmax(att_weights.masked_fill(att_mask==0, -float("inf")))203204        # att_weights: -> [bs, max_h_num, max_q_len, 1]205        # q_enc: -> [bs, 1, max_q_len, d_h]206        # [bs, max_h_num, d_h]207        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)208209        # [bs, max_h_num, d_h * 2]210        comb_context = torch.cat([self.W_q(q_context), self.W_h(h_pooling)], dim=-1)211212        # [bs, max_h_num]213        score_sel_col = self.W_out(comb_context).squeeze(2)214215        # mask216        for b, h_num in enumerate(h_nums):217            score_sel_col[b, h_num:] = -float("inf")218        return score_sel_col219220221class SelectAggregation(nn.Module):222    def __init__(self, d_in, d_h, n_layers, dropout_prob, n_agg, pooling_type, max_sel_num):223        super(SelectAggregation, self).__init__()224225        self.d_in = d_in226        self.d_h = d_h227        self.n_layers = n_layers228        self.dropout_prob = dropout_prob229        self.n_agg = n_agg230        self.pooling_type = pooling_type231        self.max_sel_num = max_sel_num232233        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,234                              dropout=dropout_prob, birnn=True)235        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,236                              dropout=dropout_prob, birnn=True)237238        self.W_att = nn.Linear(d_h, d_h)239        self.W_out = nn.Sequential(240            nn.Linear(d_h, d_h),241            nn.Tanh(),242            nn.Linear(d_h, self.n_agg)243        )244245        self.softmax = nn.Softmax(dim=-1)246247    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, sel_col):248        # [bs, max_q_len, d_h]249        q_enc = encode_question(self.q_encoder, q_emb, q_lens)250251        # [bs, max_h_num, d_h]252        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)253254        bs = len(q_emb)255        h_pooling_sel = h_pooling[list(range(bs)), sel_col]256257        att_weights = torch.bmm(self.W_att(q_enc), h_pooling_sel.unsqueeze(2)).squeeze(2)258        att_mask = build_mask(att_weights, q_lens, dim=-2)259        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))260261        # att_weights: [bs, max_sel_num, max_q_len] -> [bs, max_sel_num, max_q_len, 1]262        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]263        q_context = torch.mul(q_enc, att_weights.unsqueeze(2).expand_as(q_enc)).sum(dim=1)264265        # [bs, max_sel_num, n_agg]266        score_sel_agg = self.W_out(q_context)267        return score_sel_agg268269270class SelectMultipleAggregation(nn.Module):271    def __init__(self, d_in, d_h, n_layers, dropout_prob, n_agg, pooling_type, max_sel_num):272        super(SelectMultipleAggregation, self).__init__()273274        self.d_in = d_in275        self.d_h = d_h276        self.n_layers = n_layers277        self.dropout_prob = dropout_prob278        self.n_agg = n_agg279        self.pooling_type = pooling_type280        self.max_sel_num = max_sel_num281282        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,283                              dropout=dropout_prob, birnn=True)284        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,285                              dropout=dropout_prob, birnn=True)286287        self.W_att = nn.Linear(d_h, d_h)288        self.W_q = nn.Linear(d_h, d_h)289        self.W_h = nn.Linear(d_h, d_h)290        self.W_out = nn.Sequential(291            nn.Linear(2 * d_h, d_h),292            nn.Tanh(),293            nn.Linear(d_h, self.n_agg)294        )295296        self.softmax = nn.Softmax(dim=-1)297298    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, sel_cols):299        # [bs, max_q_len, d_h]300        q_enc = encode_question(self.q_encoder, q_emb, q_lens)301302        # [bs, max_h_num, d_h]303        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)304305        padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)306307        h_pooling_sel = []308        for b, cols in enumerate(sel_cols):309            if len(cols) > 0:310                h_tmp = [h_pooling[b][cols, :]]311            else:312                h_tmp = []313            h_tmp += [padding_t] * (self.max_sel_num - len(cols))314            h_tmp = torch.cat(h_tmp, dim=0)315            h_pooling_sel.append(h_tmp)316        # [bs, max_sel_num, d_h]317        h_pooling_sel = torch.stack(h_pooling_sel)318319        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]320        # h_pooling_sel: [bs, max_sel_num, d_h] -> [bs, max_sel_num, d_h, 1]321        # [bs, max_sel_num, max_q_len]322        att_weights = torch.matmul(323            self.W_att(q_enc).unsqueeze(1),324            h_pooling_sel.unsqueeze(3)325        ).squeeze(3)326        att_mask = build_mask(att_weights, q_lens, dim=-1)327        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))328329        # att_weights: [bs, max_sel_num, max_q_len] -> [bs, max_sel_num, max_q_len, 1]330        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]331        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)332333        # [bs, max_sel_num, n_agg]334        score_sel_agg = self.W_out(torch.cat([self.W_q(q_context), self.W_h(h_pooling_sel)], dim=2))335        return score_sel_agg336337338class WhereJoiner(nn.Module):339    def __init__(self, d_in, d_h, n_layers, dropout_prob):340        super(WhereJoiner, self).__init__()341        self.d_in = d_in342        self.d_h = d_h343        self.n_layers = n_layers344        self.dropout_prob = dropout_prob345346        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,347                              dropout=dropout_prob, birnn=True)348        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,349                              dropout=dropout_prob, birnn=True)350351        self.W_att = nn.Linear(d_h, 1)352        self.W_out = nn.Sequential(nn.Linear(d_h, d_h),353                                   nn.Tanh(),354                                   nn.Linear(d_h, 3))355356        self.softmax = nn.Softmax(dim=1)357358    def forward(self, q_emb, q_lens):359360        q_enc = encode_question(self.q_encoder, q_emb, q_lens)361362        #  self-atttention for question363        #  [bs, max_q_len]364        att_weights_q = self.W_att(q_enc).squeeze(2)365        att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)366        att_weights_q = att_weights_q.masked_fill(att_mask_q == 0, -float('inf'))367        att_weights_q = self.softmax(att_weights_q)368369        #  [bs, d_h]370        q_context = torch.mul(371            q_enc,372            att_weights_q.unsqueeze(2).expand_as(q_enc)373        ).sum(dim=1)374        where_op_logits = self.W_out(q_context)375        return where_op_logits376377378class WhereNumber(nn.Module):379    def __init__(self, d_in, d_in_ch, d_h, n_layers, dropout_prob, pooling_type, max_where_num):380        super(WhereNumber, self).__init__()381382        self.d_in = d_in383        self.d_h = d_h384        self.n_layers = n_layers385        self.dropout_prob = dropout_prob386        self.pooling_type = pooling_type387        self.max_where_num = max_where_num388389        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,390                              dropout=dropout_prob, birnn=True)391        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,392                              dropout=dropout_prob, birnn=True)393394        self.q_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,395                              dropout=dropout_prob, birnn=True)396        self.h_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,397                              dropout=dropout_prob, birnn=True)398399        self.W_att_q = nn.Linear(d_h, 1)400        self.W_att_h = nn.Linear(d_h, 1)401        self.W_hidden = nn.Linear(d_h, d_h * n_layers)402        self.W_cell = nn.Linear(d_h, d_h * n_layers)403404        self.W_att_q_ch = nn.Linear(d_h, 1)405        self.W_att_h_ch = nn.Linear(d_h, 1)406        self.W_hidden_ch = nn.Linear(d_h, d_h * n_layers)407        self.W_cell_ch = nn.Linear(d_h, d_h * n_layers)408409        self.W_out = nn.Sequential(410            nn.Linear(d_h * 2, d_h),411            nn.Tanh(),412            nn.Linear(d_h, self.max_where_num + 1)413        )414415        self.softmax = nn.Softmax(dim=-1)416417    def get_context(self, q_emb, q_lens, h_emb, h_lens, h_nums,418                    q_encoder, h_encoder, W_att_q, W_att_h, W_hidden, W_cell):419        # [bs, max_h_num, d_h]420        h_pooling = encode_header(h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)421422        bs = len(q_lens)423424        # self-attention for header425        # [bs, max_h_num]426        att_weights_h = W_att_h(h_pooling).squeeze(2)427        att_mask_h = build_mask(att_weights_h, q_lens, dim=-2)428        att_weights_h = self.softmax(att_weights_h.masked_fill(att_mask_h == 0, -float("inf")))429430        # [bs, d_h]431        h_context = torch.mul(h_pooling, att_weights_h.unsqueeze(2)).sum(1)432433        # [bs, d_h] -> [bs, 2 * d_h]434        # enlarge because there are two layers.435        hidden = W_hidden(h_context)436        hidden = hidden.view(bs, self.n_layers * 2, int(self.d_h / 2))437        hidden = hidden.transpose(0, 1).contiguous()438439        cell = W_cell(h_context)440        cell = cell.view(bs, self.n_layers * 2, int(self.d_h / 2))441        cell = cell.transpose(0, 1).contiguous()442443        # [bs, max_q_len, d_h]444        q_enc = encode_question(q_encoder, q_emb, q_lens, init_states=(hidden, cell))445446        # self-attention for question447        # [bs, max_q_len]448        att_weights_q = W_att_q(q_enc).squeeze(2)449        att_mask_q = build_mask(att_weights_q, q_lens, dim=-2)450        att_weights_q = self.softmax(att_weights_q.masked_fill(att_mask_q == 0, -float("inf")))451452        q_context = torch.mul(q_enc, att_weights_q.unsqueeze(2).expand_as(q_enc)).sum(dim=1)453        return q_context454455    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums,456                    q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch):457458        q_context = self.get_context(q_emb, q_lens, h_emb, h_lens, h_nums,459                                     self.q_encoder, self.h_encoder, self.W_att_q,460                                     self.W_att_h, self.W_hidden, self.W_cell)461        q_context_ch = self.get_context(q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch, h_nums,462                                     self.q_encoder_ch, self.h_encoder_ch, self.W_att_q_ch,463                                     self.W_att_h_ch, self.W_hidden_ch, self.W_cell_ch)464        # [bs, max_where_num + 1]465        score_sel_num = self.W_out(torch.cat([q_context, q_context_ch], dim=-1))466        return score_sel_num467468469class WhereColumn(nn.Module):470    def __init__(self, d_in, d_in_ch, d_h, n_layers, dropout_prob, pooling_type):471        super(WhereColumn, self).__init__()472473        self.d_in = d_in474        self.d_h = d_h475        self.n_layers = n_layers476        self.dropout_prob = dropout_prob477        self.pooling_type = pooling_type478479        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,480                              dropout=dropout_prob, birnn=True)481        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,482                              dropout=dropout_prob, birnn=True)483484        self.W_att = nn.Linear(d_h, d_h)485        self.W_q = nn.Linear(d_h, d_h)486        self.W_h = nn.Linear(d_h, d_h)487488        self.q_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,489                              dropout=dropout_prob, birnn=True)490        self.h_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,491                              dropout=dropout_prob, birnn=True)492493        self.W_att_ch = nn.Linear(d_h, d_h)494        self.W_q_ch = nn.Linear(d_h, d_h)495        self.W_h_ch = nn.Linear(d_h, d_h)496497        self.W_out = nn.Sequential(nn.Tanh(), nn.Linear(4 * d_h, 1))498499        self.softmax = nn.Softmax(dim=-1)500501    def get_context(self, q_emb, q_lens, h_emb, h_lens, h_nums,502                    q_encoder, h_encoder, W_att, W_q, W_h):503        # [bs, max_q_len, d_h]504        q_enc = encode_question(q_encoder, q_emb, q_lens)505506        # [bs, max_h_num, d_h]507        h_pooling = encode_header(h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)508509        # [bs, max_h_num, max_q_len]510        att_weights = torch.bmm(h_pooling, W_att(q_enc).transpose(1, 2))511        att_mask = build_mask(att_weights, q_lens, dim=-1)512        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))513514        # att_weights: -> [bs, max_h_num, max_q_len, 1]515        # q_enc: -> [bs, 1, max_q_len, d_h]516        # [bs, max_h_num, d_h]517        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)518519        return W_q(q_context), W_h(h_pooling)520521    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums,522                q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch):523        q_context, h_pooling = self.get_context(q_emb, q_lens, h_emb, h_lens, h_nums,524                                                self.q_encoder, self.h_encoder, self.W_att,525                                                self.W_q, self.W_h)526527        q_context_ch, h_pooling_ch = self.get_context(q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch, h_nums,528                                                    self.q_encoder_ch, self.h_encoder_ch, self.W_att_ch,529                                                    self.W_q_ch, self.W_h_ch)530531        # [bs, max_h_num, d_h * 2]532        comb_context = torch.cat([q_context, q_context_ch, h_pooling, h_pooling_ch], dim=-1)533534        # [bs, max_h_num]535        score_where_col = self.W_out(comb_context).squeeze(2)536        for b, h_num in enumerate(h_nums):537            score_where_col[b, h_num:] = -float("inf")538        return score_where_col539540541class WhereOperator(nn.Module):542    def __init__(self, d_in, d_h, n_layers, dropout_prob, n_op, pooling_type, max_where_num):543        super(WhereOperator, self).__init__()544545        self.d_in = d_in546        self.d_h = d_h547        self.n_layers = n_layers548        self.dropout_prob = dropout_prob549        self.n_op = n_op550        self.pooling_type = pooling_type551        self.max_where_num = max_where_num552553        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,554                              dropout=dropout_prob, birnn=True)555        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,556                              dropout=dropout_prob, birnn=True)557558        self.W_att = nn.Linear(d_h, d_h)559        self.W_q = nn.Linear(d_h, d_h)560        self.W_h = nn.Linear(d_h, d_h)561        self.W_out = nn.Sequential(562            nn.Linear(2 * d_h, d_h),563            nn.Tanh(),564            nn.Linear(d_h, self.n_op)565        )566567        self.softmax = nn.Softmax(dim=-1)568569    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, where_cols):570        # [bs, max_q_len, d_h]571        q_enc = encode_question(self.q_encoder, q_emb, q_lens)572573        # [bs, max_h_num, d_h]574        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)575576        padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)577578        h_pooling_where = []579        for b, cols in enumerate(where_cols):580            if len(cols) > 0:581                h_tmp = [h_pooling[b][cols, :]]582            else:583                h_tmp = []584            h_tmp += [padding_t] * (self.max_where_num - len(cols))585            h_tmp = torch.cat(h_tmp, dim=0)586            h_pooling_where.append(h_tmp)587        # [bs, max_where_num, d_h]588        h_pooling_where = torch.stack(h_pooling_where)589590        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]591        # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]592        # [bs, max_where_num, max_q_len]593        att_weights = torch.matmul(594            self.W_att(q_enc).unsqueeze(1),595            h_pooling_where.unsqueeze(3)596        ).squeeze(3)597        att_mask = build_mask(att_weights, q_lens, dim=-1)598        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))599600        # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]601        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]602        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)603604        # [bs, max_where_num, n_agg]605        score_where_op = self.W_out(torch.cat([self.W_q(q_context), self.W_h(h_pooling_where)], dim=2))606        return score_where_op607608609class WhereAggregation(nn.Module):610    def __init__(self, d_in, d_h, n_layers, dropout_prob, n_agg, pooling_type, max_where_num):611        super(WhereAggregation, self).__init__()612613        self.d_in = d_in614        self.d_h = d_h615        self.n_layers = n_layers616        self.dropout_prob = dropout_prob617        self.n_agg = n_agg618        self.pooling_type = pooling_type619        self.max_where_num = max_where_num620621        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h/2), n_layers=n_layers, batch_first=True,622                              dropout=dropout_prob, birnn=True)623        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,624                              dropout=dropout_prob, birnn=True)625626        self.W_att = nn.Linear(d_h, d_h)627        self.W_q = nn.Linear(d_h, d_h)628        self.W_h = nn.Linear(d_h, d_h)629        self.W_out = nn.Sequential(630            nn.Linear(2 * d_h, d_h),631            nn.Tanh(),632            nn.Linear(d_h, self.n_agg)633        )634635        self.softmax = nn.Softmax(dim=-1)636637    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, where_cols):638        # [bs, max_q_len, d_h]639        q_enc = encode_question(self.q_encoder, q_emb, q_lens)640641        # [bs, max_h_num, d_h]642        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)643644        padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)645646        h_pooling_where = []647        for b, cols in enumerate(where_cols):648            if len(cols) > 0:649                h_tmp = [h_pooling[b][cols, :]]650            else:651                h_tmp = []652            h_tmp += [padding_t] * (self.max_where_num - len(cols))653            h_tmp = torch.cat(h_tmp, dim=0)654            h_pooling_where.append(h_tmp)655        # [bs, max_where_num, d_h]656        h_pooling_where = torch.stack(h_pooling_where)657658        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]659        # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]660        # [bs, max_where_num, max_q_len]661        att_weights = torch.matmul(662            self.W_att(q_enc).unsqueeze(1),663            h_pooling_where.unsqueeze(3)664        ).squeeze(3)665        att_mask = build_mask(att_weights, q_lens, dim=-1)666        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))667668        # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]669        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]670        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)671672        # [bs, max_where_num, n_agg]673        score_where_agg = self.W_out(torch.cat([self.W_q(q_context), self.W_h(h_pooling_where)], dim=2))674        return score_where_agg675676677class WhereValue(nn.Module):678    def __init__(self, d_in, d_in_ch, d_h, d_f, n_layers, dropout_prob, n_op, pooling_type, max_where_num):679        super(WhereValue, self).__init__()680681        self.d_in = d_in682        self.d_h = d_h683        self.d_in_ch = d_in_ch684        self.n_layers = n_layers685        self.dropout_prob = dropout_prob686        self.n_op = n_op687        self.pooling_type = pooling_type688        self.max_where_num = max_where_num689690        self.d_f = d_f691        self.q_feature_embed = nn.Embedding(2, self.d_f)692693        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,694                              dropout=dropout_prob, birnn=True)695        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,696                              dropout=dropout_prob, birnn=True)697698        self.W_att = nn.Linear(d_h + self.d_f, d_h)699        self.W_q = nn.Linear(d_h + self.d_f, d_h)700        self.W_h = nn.Linear(d_h, d_h)701702        self.q_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,703                              dropout=dropout_prob, birnn=True)704        self.h_encoder_ch = LSTM(d_input=d_in_ch, d_h=int(d_h / 2), n_layers=n_layers, batch_first=True,705                              dropout=dropout_prob, birnn=True)706707        self.W_att_ch = nn.Linear(d_h, d_h)708        self.W_q_ch = nn.Linear(d_h, d_h)709        self.W_h_ch = nn.Linear(d_h, d_h)710711        self.W_op = nn.Linear(n_op, d_h)712713        self.W_out = nn.Sequential(714            nn.Linear(6 * d_h + self.d_f, d_h),715            nn.Tanh(),716            nn.Linear(d_h, 2)717        )718719        self.softmax = nn.Softmax(dim=-1)720721    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums,722                q_emb_ch, q_lens_ch, h_emb_ch, h_lens_ch, where_cols, where_ops,723                q_feature):724        bs = len(q_emb)725        max_q_len = max(q_lens)726727        # [bs, max_q_len, d_h]728        q_enc = encode_question(self.q_encoder, q_emb, q_lens)729        for b, f in enumerate(q_feature):730            while len(f) < max_q_len:731                q_feature[b].append(0)732        q_feature = torch.tensor(q_feature)733        if q_enc.is_cuda:734            q_feature = q_feature.to(q_enc.device)735736        q_feature_enc = self.q_feature_embed(q_feature)737738        q_enc = torch.cat([q_enc, q_feature_enc], -1)739740        # [bs, max_h_num, d_h]741        h_pooling = encode_header(self.h_encoder, h_emb, h_lens, h_nums, pooling_type=self.pooling_type)742743        padding_t = torch.zeros_like(h_pooling[0][0]).unsqueeze(0)744        h_pooling_where = []745        for b, cols in enumerate(where_cols):746            if len(cols) > 0:747                h_tmp = [h_pooling[b][cols, :]]748            else:749                h_tmp = []750            h_tmp += [padding_t] * (self.max_where_num - len(cols))751            h_tmp = torch.cat(h_tmp, dim=0)752            h_pooling_where.append(h_tmp)753        # [bs, max_where_num, d_h]754        h_pooling_where = torch.stack(h_pooling_where)755756        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]757        # h_pooling_where: [bs, max_where_num, d_h] -> [bs, max_where_num, d_h, 1]758        # [bs, max_where_num, max_q_len]759        att_weights = torch.matmul(760            self.W_att(q_enc).unsqueeze(1),761            h_pooling_where.unsqueeze(3)762        ).squeeze(3)763        att_mask = build_mask(att_weights, q_lens, dim=-1)764        att_weights = self.softmax(att_weights.masked_fill(att_mask == 0, -float("inf")))765766        # att_weights: [bs, max_where_num, max_q_len] -> [bs, max_where_num, max_q_len, 1]767        # q_enc: [bs, max_q_len, d_h] -> [bs, 1, max_q_len, d_h]768        # [bs, max_where_num, d_h]769        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)770771        q_enc_ch = encode_question(self.q_encoder_ch, q_emb_ch, q_lens_ch)772773        # [bs, max_h_num, d_h]774        h_pooling_ch = encode_header(self.h_encoder_ch, h_emb_ch, h_lens_ch,775                                     h_nums, pooling_type=self.pooling_type)776777        padding_t_ch = torch.zeros_like(h_pooling_ch[0][0]).unsqueeze(0)778779        h_pooling_where_ch = []780        for b, cols in enumerate(where_cols):781            if len(cols) > 0:782                h_tmp = [h_pooling_ch[b][cols, :]]783            else:784                h_tmp = []785            h_tmp += [padding_t_ch] * (self.max_where_num - len(cols))786            h_tmp = torch.cat(h_tmp, dim=0)787            h_pooling_where_ch.append(h_tmp)788        h_pooling_where_ch = torch.stack(h_pooling_where_ch)789790        att_weights_ch = torch.matmul(791            self.W_att_ch(q_enc_ch).unsqueeze(1),792            h_pooling_where_ch.unsqueeze(3)793        ).squeeze(3)794        att_mask_ch = build_mask(att_weights_ch, q_lens_ch, dim=-1)795        att_weights_ch = self.softmax(att_weights_ch.masked_fill(att_mask_ch == 0, -float("inf")))796        q_context_ch = torch.mul(att_weights_ch.unsqueeze(3), q_enc_ch.unsqueeze(1)).sum(dim=2)797798        op_enc = []799        for b in range(bs):800            op_enc_tmp = torch.zeros(self.max_where_num, self.n_op)801            op = where_ops[b]802            idx_scatter = []803            op_len = len(op)804            for i in range(self.max_where_num):805                if i < op_len:806                    idx_scatter.append([op[i]])807                else:808                    idx_scatter.append([0])809            op_enc_tmp = op_enc_tmp.scatter(1, torch.tensor(idx_scatter), 1)810            op_enc.append(op_enc_tmp)811        op_enc = torch.stack(op_enc)812        if q_context.is_cuda:813            op_enc = op_enc.to(q_context.device)814815        comb_context = torch.cat(816            [self.W_q(q_context),817             self.W_h(h_pooling_where),818             self.W_q_ch(q_context_ch),819             self.W_h_ch(h_pooling_where_ch),820             self.W_op(op_enc)],821            dim=2822        )823        comb_context = comb_context.unsqueeze(2).expand(-1, -1, q_enc.size(1), -1)824        q_enc = q_enc.unsqueeze(1).expand(-1, comb_context.size(1), -1, -1)825826        # [bs, max_where_num, max_q_num, 2]827        score_where_val = self.W_out(torch.cat([comb_context, q_enc], dim=3))828829        for b, l in enumerate(q_lens):830            if l < max_q_len:831                score_where_val[b, :, l:, :] = -float("inf")832        return score_where_val833834835class OrderByColumn(nn.Module):836    def __init__(self, d_in, d_h, n_layers, dropout_prob, pooling_type):837        super(OrderByColumn, self).__init__()838        self.d_in = d_in839        self.d_h = d_h840        self.n_layers = n_layers841        self.dropout_prob = dropout_prob842        self.pooling_type = pooling_type843844        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2),845                              n_layers=n_layers, batch_first=True846                              , dropout=dropout_prob, birnn=True)847        self.h_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2),848                              n_layers=n_layers, batch_first=True,849                              dropout=dropout_prob, birnn=True)850851        self.W_att = nn.Linear(d_h, d_h)852        self.W_question = nn.Linear(d_h, d_h)853        self.W_header = nn.Linear(d_h, d_h)854        self.W_out = nn.Sequential(nn.Tanh(), nn.Linear(2 * d_h, 1))855        self.softmax = nn.Softmax(dim=2)856857    def forward(self, q_emb, q_lens, h_emb, h_lens, h_nums, mask=True):858859        # [bs, max_q_len, d_h]860        q_enc = encode_question(self.q_encoder, q_emb, q_lens)861        # [bs, max_h_num, d_h]862        h_pooling = encode_header(self.h_encoder,863                                  h_emb, h_lens, h_nums, pooling_type=self.pooling_type)864865        # [bs, max_h_num, max_q_len]866        # torch.bmm: bs * ([max_header_len, d_h], [d_h, max_q_len])867        att_weights = torch.bmm(h_pooling, self.W_att(q_enc).transpose(1, 2))868        att_mask = build_mask(att_weights, h_nums, dim=-1)869        att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))870        att_weights = self.softmax(att_weights)871872        # attention_weights: -> [bs, max_h_num, max_q_len, 1]873        # q_enc: -> [bs, 1, max_q_len, d_h]874        # [bs, max_h_num, d_h]875        q_context = torch.mul(att_weights.unsqueeze(3), q_enc.unsqueeze(1)).sum(dim=2)876        comb_context = torch.cat([self.W_question(q_context), self.W_header(h_pooling)], dim=-1)877878        score_ord_col = self.W_out(comb_context).squeeze(2)879880        if mask:881            for b, h_num in enumerate(h_nums):882                score_ord_col[b, h_num:] = -float('inf')883        return score_ord_col884885886class OrderByOrder(nn.Module):887    def __init__(self, d_in=300, d_h=100, n_layers=2, dropout_prob=0.3):888        super(OrderByOrder, self).__init__()889        self.d_in = d_in890        self.d_h = d_h891        self.n_layers = n_layers892        self.dropout_prob = dropout_prob893894        self.max_order = 3895896        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers,897                              batch_first=True, dropout=dropout_prob, birnn=True)898899        self.W_att_question = nn.Linear(d_h, d_h)900        self.W_att = nn.Linear(d_h, 1)901902        self.W_out = nn.Sequential(nn.Linear(d_h, d_h), nn.Tanh(),903                                   nn.Linear(d_h, self.max_order))904        self.softmax = nn.Softmax(dim=1)905906    def forward(self, q_emb, q_lens):907        # [bs, max_q_len, d_h]908        q_enc = encode_question(self.q_encoder, q_emb, q_lens)909910        # [bs, max_q_len, 1]911        att_weights = self.W_att(q_enc)912        # print(att_weights.shape)913        att_mask = build_mask(att_weights, q_lens, dim=-2)914        att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))915        # [bs, max_q_len, 1]916        att_weights = self.softmax(att_weights)917        # print(att_weights.shape)918919        # [bs, d_h]920        q_context = torch.bmm(q_enc.transpose(1, 2), att_weights).squeeze(2)921922        score_order = self.W_out(q_context)923        return score_order924925926class OrderByLimit(nn.Module):927    def __init__(self, d_in, d_h, n_layers, dropout_prob, n_limit):928        super(OrderByLimit, self).__init__()929        self.d_in = d_in930        self.d_h = d_h931        self.n_layers = n_layers932        self.dropout_prob = dropout_prob933934        # limit = 0 -> no limit        limit <- [1, 9]935        self.n_limit = n_limit936937        self.q_encoder = LSTM(d_input=d_in, d_h=int(d_h / 2), n_layers=n_layers,938                              batch_first=True, dropout=dropout_prob, birnn=True)939940        self.W_att_question = nn.Linear(d_h, d_h)941        self.W_att = nn.Linear(d_h, 1)942943        self.W_out = nn.Sequential(nn.Linear(d_h, d_h), nn.Tanh(),944                                   nn.Linear(d_h, self.n_limit))945        self.softmax = nn.Softmax(dim=1)946947    def forward(self, q_emb, q_lens):948        # [bs, max_q_len, d_h]949        q_enc = encode_question(self.q_encoder, q_emb, q_lens)950951        # [bs, max_q_len, 1]952        att_weights = self.W_att(q_enc)953        att_mask = build_mask(att_weights, q_lens, dim=-2)954        att_weights = att_weights.masked_fill(att_mask == 0, -float('inf'))955        # [bs, max_q_len, 1]956        att_weights = self.softmax(att_weights)957958        # [bs, d_h]959        q_context = torch.bmm(q_enc.transpose(1, 2), att_weights).squeeze(2)960        score_limit = self.W_out(q_context)
...GRUsim.py
Source:GRUsim.py  
1import torch as th2import torch.nn as nn3import numpy as np4npa = np.asarray5def tonumpy(ls):6    return [t.data.numpy() for t in ls]7class Erf(nn.Module):8    def forward(self, x):9        return x.erf()10class ErfSigmoid(nn.Module):11    def forward(self, x):12        return (1 + x.erf())/213def simgru(inpseq, d_h, time=None, varUz=1, varUr=1, varUh=1,14           varWz=1, varWr=1, varWh=1,15           varbz=1, varbr=1, varbh=1,16           mubz=0, mubr=0, mubh=0,17           bias=True, nonlin=None, sigmoid=None, wt_tie=True,18           # get_h=False, get_update=False, get_reset=False, get_htilde=False,19           h_init=0):20    r'''Simulate a GRU on a sequence and obtain data21    A GRU evolves according to the equations22    23        \tilde z^t = W_z x^t + U_z h^{t-1} + b_z24        z^t = sigmoid(\tilde z^t)25        \tilde r^t = W_r x^t + U_r h^{t-1} + b_r26        r^t = sigmoid(\tilde r^t)27        \tilde h^t = W_h x^t + U_h(h^{t-1} \odot r^t) + b_h28        h^t = (1 - z^t) \odot h^{t-1} + z^t \odot nonlin(\tilde h^t)29        30    where31    32        h^t is state at time t33        x^t is input at time t34        z^t is ``update gate``: 1 means do update/forget previous h^{t-1}35                                0 means h^t = h^{t-1}36        r^t is ``reset gate'':37            the smaller, the easier to make proposed update not depend on h^{t-1}38        W_z, W_r, W_h are weights converting input to hidden states39        U_z, U_r, U_h are weights converting state to state40        b_z, b_r, b_h are biases41    This function simulates a randomly initialized GRU42    on the sequence `inpseq` and returns a record of43    several data.44    Inputs:45        inpseq: a matrix `seqlen x inputdim` where each row46            is a token47        d_h: dimension of state h^t48        time: run simulation for t up to `time`.49            If is None, then set of `seqlen`.50        varUz: each element of U_z has variance `varUz`/d51        varUr: each element of U_r has variance `varUr`/d52        varUh: each element of U_h has variance `varUh`/d53        varWz: each element of W_z has variance `varWz`/d54        varWr: each element of W_r has variance `varWr`/d55        varWh: each element of W_h has variance `varWh`/d56        varbz: each element of b_z has variance `varbz`57        varbr: each element of b_r has variance `varbr`58        varbh: each element of b_h has variance `varbh`59        mubz: each element of b_z has mean `mubz`60        mubr: each element of b_r has mean `mubr`61        mubh: each element of b_h has mean `mubh`62        bias: whether to turn on bias63            (if False, then `varb_`s and `mub_`s have no effect)64        nonlin: nonlinearity; has to be pytorch module or65            has the same ducktype66        sigmoid: sigmoid function; has to pytorch module67            or has the same ducktype68        wt_tie: whether to tie the weights69        h_init: the magnitude of the initial state (which70            is randomly initialized)71    Outputs:72        a dictionary with the following keys73        h: `time+1 x d_h` matrix containing all h^t74        th: `time x d_h` matrix containing all \tilde h^t75        tz: `time x d_h` matrix containing all \tilde z^t76        tr: `time x d_h` matrix containing all \tilde r^t77        hcov: `time+1 x time+1` matrix of 2nd moments of h^t78        thcov: `time x time` matrix of 2nd moments of \tilde h^t79        tzcov: `time x time` matrix of 2nd moments of \tilde z^t80        trcov: `time x time` matrix of 2nd moments of \tilde r^t81    '''82    d_i = inpseq.shape[1]83    if time is None:84        time = inpseq.shape[0]85    updates = []86    resets = []87    htildes = []88    def makelayer(nonlin=nonlin, sigmoid=sigmoid):89        W_update = nn.Linear(d_i, d_h, bias=False)90        update = nn.Linear(d_h, d_h, bias=bias)91        W_reset = nn.Linear(d_i, d_h, bias=False)92        reset = nn.Linear(d_h, d_h, bias=bias)93        W_htilde = nn.Linear(d_i, d_h, bias=False)94        htilde = nn.Linear(d_h, d_h, bias=bias)95        W_update.weight.data.normal_(0, np.sqrt(varWz/d_i))96        update.weight.data.normal_(0, np.sqrt(varUz/d_h))97        W_reset.weight.data.normal_(0, np.sqrt(varWr/d_i))98        reset.weight.data.normal_(0, np.sqrt(varUr/d_h))99        W_htilde.weight.data.normal_(0, np.sqrt(varWh/d_i))100        htilde.weight.data.normal_(0, np.sqrt(varUh/d_h))101        if bias:102            update.bias.data.normal_(mubz, np.sqrt(varbz))103            reset.bias.data.normal_(mubr, np.sqrt(varbr))104            htilde.bias.data.normal_(mubh, np.sqrt(varbh))105        if nonlin is None:106            nonlin = lambda: lambda x: x107        if sigmoid is None:108            sigmoid = nn.Sigmoid109        def r(h, inp):110            _u = update(h) + W_update(inp)111            # if get_update:112            updates.append(_u)113            u = sigmoid()(_u)114            _r = reset(h) + W_reset(inp)115            # if get_reset:116            resets.append(_r)117            r = sigmoid()(_r)118            _h = htilde(h * r) + W_htilde(inp)119            # if get_htilde:120            htildes.append(_h)121            ht = nonlin()(_h)122            return (1 - u) * h + u * ht123        return r124    if wt_tie:125        mylayer = makelayer()126#     if h_init is None:127#         x = 0128#     else:129    x = th.randn(1, d_h) * h_init130    xs = [x]131    for i in range(time):132        if not wt_tie:133            mylayer = makelayer()134        xx = mylayer(xs[-1], inpseq[i:i+1])135        xs.append(xx)136    ret = {}137    # if get_h:138    ret['h'] = npa(tonumpy(xs)).squeeze()139    # if get_update:140    ret['tz'] = npa(tonumpy(updates)).squeeze()141    # if get_reset:142    ret['tr'] = npa(tonumpy(resets)).squeeze()143    # if get_htilde:144    ret['th'] = npa(tonumpy(htildes)).squeeze()145#     print(ret['h'].shape)146    ret['tzcov'] = ret['tz'] @ ret['tz'].T / d_h147    ret['trcov'] = ret['tr'] @ ret['tr'].T / d_h148    ret['thcov'] = ret['th'] @ ret['th'].T / d_h149    ret['hcov'] = ret['h'] @ ret['h'].T / d_h150    # ret['hnorms'] = [th.mean(u**2).data.item() for u in xs]151    return ret152def simgru2(inpseq1, inpseq2, d_h, varUz=1, varUr=1, varUh=1,153           varWz=1, varWr=1, varWh=1,154           varbz=1, varbr=1, varbh=1,155           mubz=0, mubr=0, mubh=0,156           bias=True, nonlin=None, sigmoid=None, wt_tie=True,157           # get_h=False, get_update=False, get_reset=False, get_htilde=False,158           h_init=0):159    d_i = inpseq1.shape[1]160    updates = []161    resets = []162    htildes = []163    def makelayer(nonlin=nonlin, sigmoid=sigmoid):164        W_update = nn.Linear(d_i, d_h, bias=False)165        update = nn.Linear(d_h, d_h, bias=bias)166        W_reset = nn.Linear(d_i, d_h, bias=False)167        reset = nn.Linear(d_h, d_h, bias=bias)168        W_htilde = nn.Linear(d_i, d_h, bias=False)169        htilde = nn.Linear(d_h, d_h, bias=bias)170        W_update.weight.data.normal_(0, np.sqrt(varWz/d_i))171        update.weight.data.normal_(0, np.sqrt(varUz/d_h))172        W_reset.weight.data.normal_(0, np.sqrt(varWr/d_i))173        reset.weight.data.normal_(0, np.sqrt(varUr/d_h))174        W_htilde.weight.data.normal_(0, np.sqrt(varWh/d_i))175        htilde.weight.data.normal_(0, np.sqrt(varUh/d_h))176        if bias:177            update.bias.data.normal_(mubz, np.sqrt(varbz))178            reset.bias.data.normal_(mubr, np.sqrt(varbr))179            htilde.bias.data.normal_(mubh, np.sqrt(varbh))180        if nonlin is None:181            nonlin = lambda: lambda x: x182        if sigmoid is None:183            sigmoid = nn.Sigmoid184        def r(h, inp):185            _u = update(h) + W_update(inp)186            # if get_update:187            updates.append(_u)188            u = sigmoid()(_u)189            _r = reset(h) + W_reset(inp)190            # if get_reset:191            resets.append(_r)192            r = sigmoid()(_r)193            _h = htilde(h * r) + W_htilde(inp)194            # if get_htilde:195            htildes.append(_h)196            ht = nonlin()(_h)197            return (1 - u) * h + u * ht198        return r199    if wt_tie:200        mylayer = makelayer()201#     if h_init is None:202#         x = 0203#     else:204    rets = {1: {}, 2: {}}205    for i, seq in enumerate([inpseq1, inpseq2]):206        x = th.randn(1, d_h) * h_init207        xs = [x]208        for tok in seq:209            if not wt_tie:210                mylayer = makelayer()211            xx = mylayer(xs[-1], tok)212            xs.append(xx)213        ret = rets[i+1]214        # if get_h:215        ret['h'] = npa(tonumpy(xs)).squeeze()216        # if get_update:217        ret['tz'] = npa(tonumpy(updates)).squeeze()218        # if get_reset:219        ret['tr'] = npa(tonumpy(resets)).squeeze()220        # if get_htilde:221        ret['th'] = npa(tonumpy(htildes)).squeeze()222    #     print(ret['h'].shape)223        ret['tzcov'] = ret['tz'] @ ret['tz'].T / d_h224        ret['trcov'] = ret['tr'] @ ret['tr'].T / d_h225        ret['thcov'] = ret['th'] @ ret['th'].T / d_h226        ret['hcov'] = ret['h'] @ ret['h'].T / d_h227    ret = rets['x'] = {}228    ret['tzcov'] = rets[1]['tz'] @ rets[2]['tz'].T / d_h229    ret['trcov'] = rets[1]['tr'] @ rets[2]['tr'].T / d_h230    ret['thcov'] = rets[1]['th'] @ rets[2]['th'].T / d_h231    ret['hcov'] = rets[1]['h'] @ rets[2]['h'].T / d_h232    # ret['hnorms'] = [th.mean(u**2).data.item() for u in xs]233    234    rets['hcov'] =  np.block(235        [[rets[1]['hcov'][1:, 1:], rets['x']['hcov'][1:, 1:]],236         [rets['x']['hcov'][1:, 1:].T, rets[2]['hcov'][1:, 1:]]]237    )...UFlow.py
Source:UFlow.py  
1import torch2from ..Step.NormalizingFlow import FCNormalizingFlow3class UFlow(FCNormalizingFlow):4    def __init__(self, enc_steps, dec_steps, z_log_density, dropping_factors):5        super(UFlow, self).__init__(enc_steps + dec_steps, z_log_density)6        self.enc_dropping_factors = dropping_factors7        self.dec_gathering_factors = dropping_factors[::-1]8        self.enc_steps = enc_steps9        self.dec_steps = dec_steps10    def forward(self, x, context=None):11        b_size = x.shape[0]12        jac_tot = 0.13        z_all = []14        for step, drop_factors in zip(self.enc_steps, self.enc_dropping_factors):15            z, jac = step(x.contiguous().view(b_size, -1), context)16            d_c, d_h, d_w = drop_factors17            C, H, W = step.img_sizes18            c, h, w = int(C/d_c), int(H/d_h), int(W/d_w)19            z_reshaped = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \20                    .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)21            z_all += [z_reshaped[:, :, :, :, 1:]]22            x = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \23                    .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]24            jac_tot += jac25        for step, gath_factors, z in zip(self.dec_steps[1:], self.dec_gathering_factors[1:], z_all[::-1][1:]):26            d_c, d_h, d_w = gath_factors27            C, H, W = step.img_sizes28            c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)29            z = torch.cat((x.unsqueeze(4), z), 4)30            z = z.view(b_size, c, h, w, d_c, d_h, d_w)31            z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)32            z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)33            x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)34            x, jac = step(x.contiguous().view(b_size, -1), context)35            x = x.contiguous().view(b_size, C, H, W)36            jac_tot += jac37        z = x.view(b_size, -1)38        return z, jac_tot39    def invert(self, z, context=None):40        z_prev = z41        b_size = z.shape[0]42        z_all = []43        for step, drop_factors in zip(self.dec_steps[::-1][:-1], self.enc_dropping_factors[:-1]):44            x = step.invert(z.contiguous().view(b_size, -1), context)45            d_c, d_h, d_w = drop_factors46            C, H, W = step.img_sizes47            c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)48            z_reshaped = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \49                .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)50            z_all += [z_reshaped[:, :, :, :, 1:]]51            x = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \52                    .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]53            z = x54        x = self.enc_steps[-1].invert(z.view(b_size, -1), context).view(z.shape)55        for step, gath_factors, z in zip(self.enc_steps[::-1][1:], self.dec_gathering_factors[1:], z_all[::-1]):56            d_c, d_h, d_w = gath_factors57            C, H, W = step.img_sizes58            c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)59            z = torch.cat((x.unsqueeze(4), z), 4)60            z = z.view(b_size, c, h, w, d_c, d_h, d_w)61            z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)62            z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)63            x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)64            x = step.invert(x.contiguous().view(b_size, -1), context).contiguous().view(b_size, C, H, W)65        return x.view(b_size, -1)66class ImprovedUFlow(FCNormalizingFlow):67    def __init__(self, enc_steps, dec_steps, z_log_density, dropping_factors, conditioner_nets):68        super(ImprovedUFlow, self).__init__(enc_steps + dec_steps, z_log_density)69        self.enc_dropping_factors = dropping_factors70        self.dec_gathering_factors = dropping_factors[::-1]71        self.enc_steps = enc_steps72        self.dec_steps = dec_steps73        self.conditioner_nets = conditioner_nets74    def forward(self, x, context=None):75        b_size = x.shape[0]76        jac_tot = 0.77        z_all = []78        full_context = context79        for i, (step, drop_factors) in enumerate(zip(self.enc_steps, self.enc_dropping_factors)):80            z, jac = step(x.contiguous().view(b_size, -1), full_context)81            d_c, d_h, d_w = drop_factors82            C, H, W = step.img_sizes83            c, h, w = int(C/d_c), int(H/d_h), int(W/d_w)84            z_reshaped = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \85                    .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)86            z_all += [z_reshaped[:, :, :, :, 1:]]87            x = z.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \88                    .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]89            jac_tot += jac90            if i < len(self.enc_steps) - 1:91                z_for_context = torch.cat((torch.zeros_like(z_reshaped[:, :, :, :, [0]]), z_reshaped[:, :, :, :, 1:]), 4)92                z_for_context = z_for_context.view(b_size, c, h, w, d_c, d_h, d_w)93                z_for_context = z_for_context.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)94                z_for_context = z_for_context.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)95                z_for_context = z_for_context.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)96                local_context = self.conditioner_nets[i](z_for_context)97                full_context = torch.cat((local_context, context), 1) if context is not None else local_context98        for step, gath_factors, z in zip(self.dec_steps[1:], self.dec_gathering_factors[1:], z_all[::-1][1:]):99            d_c, d_h, d_w = gath_factors100            C, H, W = step.img_sizes101            c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)102            z = torch.cat((x.unsqueeze(4), z), 4)103            z = z.view(b_size, c, h, w, d_c, d_h, d_w)104            z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)105            z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)106            x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)107            x, jac = step(x.contiguous().view(b_size, -1), context)108            x = x.contiguous().view(b_size, C, H, W)109            jac_tot += jac110        z = x.view(b_size, -1)111        return z, jac_tot112    def invert(self, z, context=None):113        b_size = z.shape[0]114        z_all = []115        full_contexts = [context]116        for i, (step, drop_factors) in enumerate(zip(self.dec_steps[::-1][:-1], self.enc_dropping_factors[:-1])):117            x = step.invert(z.contiguous().view(b_size, -1), context)118            d_c, d_h, d_w = drop_factors119            C, H, W = step.img_sizes120            c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)121            z_reshaped = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \122                .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)123            z_all += [z_reshaped[:, :, :, :, 1:]]124            x = x.view(-1, C, H, W).unfold(1, d_c, d_c).unfold(2, d_h, d_h) \125                    .unfold(3, d_w, d_w).contiguous().view(b_size, c, h, w, -1)[:, :, :, :, 0]126            z = x127            z_for_context = torch.cat((torch.zeros_like(z_reshaped[:, :, :, :, [0]]), z_reshaped[:, :, :, :, 1:]), 4)128            z_for_context = z_for_context.view(b_size, c, h, w, d_c, d_h, d_w)129            z_for_context = z_for_context.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)130            z_for_context = z_for_context.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)131            z_for_context = z_for_context.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)132            local_context = self.conditioner_nets[i](z_for_context)133            full_contexts += [torch.cat((local_context, context), 1) if context is not None else local_context]134        x = self.enc_steps[-1].invert(z.view(b_size, -1), full_contexts[-1]).view(z.shape)135        for step, gath_factors, z, full_context in zip(self.enc_steps[::-1][1:], self.dec_gathering_factors[1:], z_all[::-1], full_contexts[::-1][1:]):136            d_c, d_h, d_w = gath_factors137            C, H, W = step.img_sizes138            c, h, w = int(C / d_c), int(H / d_h), int(W / d_w)139            z = torch.cat((x.unsqueeze(4), z), 4)140            z = z.view(b_size, c, h, w, d_c, d_h, d_w)141            z = z.permute(0, 1, 2, 3, 6, 4, 5).contiguous().view(b_size, c, h, W, d_c, d_h)142            z = z.permute(0, 1, 2, 5, 3, 4).contiguous().view(b_size, c, H, W, d_c)143            x = z.permute(0, 1, 4, 2, 3).contiguous().view(b_size, C, H, W)144            x = step.invert(x.contiguous().view(b_size, -1), full_context).contiguous().view(b_size, C, H, W)...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
