Best Python code snippet using sure_python
rbm_layer.py
Source:rbm_layer.py  
1# Implementation of RBM Layer2# It can be either fully connected or convolutional3import logging4from operator import mul5import numpy as np6from neon import NervanaObject7from neon.layers.layer import ParameterLayer, Convolution8from neon.transforms import Logistic9logger = logging.getLogger(__name__)10def _calc_optree(optree, be):11    """12    Calculate operation tree and return result as Tensor13    """14    result = be.empty(optree.shape)15    result._assign(optree)16    return result17class RBMLayer(ParameterLayer):18    """19    RBM layer implementation.20    Works with volumetric data.21    Note, that some functions have extra argument "labels" but it is never used.22    It allows more uniform API23    Arguments:24        n_hidden (int): number of hidden units25        init (Optional[Initializer]): Initializer object to use for26            initializing layer weights27        name (Optional[str]): layer name. Defaults to "RBMLayer"28        persistant (bool): whether to use persistant CD29        kPCD (int): number of samples generation during negative phase of CD (CD-k)30        use_fast_weights (bool): whether to use fast weights CD algorithm for learning. Not implemented yet!31        sparse_target (float): desired sparsity32        sparse_damping (float): damping of sparsity parameters33        sparse_cost (float): cost of not matching sparsity target34        collect_zero_signal (bool): whether to use units with 0 signal during learning. Not supported.35    Note: kwargs are used only for multiple inheritance. See ConvolutionalRBMLayer36    """37    def __init__(self, n_hidden, init=None, name="RBMLayer", parallelism="Unknown",38                 persistant=False, kPCD=1, use_fast_weights=False,39                 sparse_target=0, sparse_damping=0, sparse_cost=0,40                 collect_zero_signal=True, **kwargs):41        super(RBMLayer, self).__init__(init, name, parallelism)42        self.persistant = persistant43        self.kPCD = kPCD44        self.use_fast_weights = use_fast_weights45        self.sparse_target = sparse_target46        self.sparse_damping = sparse_damping47        self.sparse_cost = sparse_cost48        self.collect_zero_signal = collect_zero_signal49        self.b_hidden = None50        self.b_visible = None51        self.sigmoid = Logistic()52        self.chain = None53        self.n_hidden = n_hidden54        self.n_visible = None55    def allocate(self, shared_outputs=None):56        super(RBMLayer, self).allocate(shared_outputs)57        # self.W must be already initialized58        self.init_params(None, self.b_vis_shape, self.b_hid_shape)59    def configure(self, in_obj):60        """61        sets shape based parameters of this layer given an input tuple or int62        or input layer63        Arguments:64            in_obj (int, tuple, Layer or Tensor or dataset): object that provides shape65                                                             information for layer66        Returns:67            (tuple): shape of output data68        """69        super(RBMLayer, self).configure(in_obj)70        #TODO: self.in_shape must be an int. Check this71        self.n_visible = self.in_shape72        if isinstance(self.in_shape, tuple):73            self.n_visible = reduce(mul, self.in_shape)74        self.out_shape = (self.n_hidden,)75        self.weight_shape = (self.n_visible, self.n_hidden)76        # bias for visible units77        self.b_vis_shape = (self.n_visible, 1)78        # bias for hidden units79        self.b_hid_shape = (self.n_hidden, 1)80        return self81    def init_params(self, shape=None, b_vis_shape=None, b_hid_shape=None):82        """83        Allocate layer parameter buffers and initialize them with the84            supplied initializer.85        Arguments:86            shape (int, tuple): shape to allocate for layer paremeter87                buffers.88        """89        # initialize self.W90        if not shape is None:91            super(RBMLayer, self).init_params(shape)92        parallel, distributed = self.get_param_attrs()93        if not b_hid_shape is None:94            self.b_hidden = self.be.zeros(b_hid_shape, parallel=parallel, distributed=distributed)95            self.db_hidden = self.be.zeros_like(self.b_hidden)96        if not b_vis_shape is None:97            self.b_visible = self.be.zeros(b_vis_shape, parallel=parallel, distributed=distributed)98            self.db_visible = self.be.zeros_like(self.b_visible)99    def get_params(self):100        """101        Get layer parameters, gradients, and states for optimization102        """103        return (((self.W, self.dW), (self.b_hidden, self.db_hidden),104                (self.b_visible, self.db_visible)), self.states)105    def get_params_serialize(self, keep_states=True):106        """107        Get layer parameters. All parameters are needed for optimization, but108        only Weights are serialized.109        Arguments:110            keep_states (bool): Control whether all parameters are returned111                or just weights for serialization. Defaults to True.112        """113        serial_dict = {'params': {'W': self.W.asnumpyarray(),114                                  'b_hidden': self.b_hidden.asnumpyarray(),115                                  'b_visible': self.b_visible.asnumpyarray(),116                                  'name': self.name}}117        if keep_states:118            serial_dict['states'] = [s.asnumpyarray() for s in self.states]119        return serial_dict120    def set_params(self, pdict):121        """122        Set layer parameters (weights). Allocate space for other parameters but do not initialize123        them.124        Arguments:125            pdict (dict): dictionary or ndarray with layer parameters126        """127        # load pdict, convert self.W to Tensor128        super(RBMLayer, self).set_params(pdict)129        self.b_hidden = self.be.array(self.b_hidden)130        self.db_hidden = self.be.empty_like(self.b_hidden)131        self.b_visible = self.be.array(self.b_visible)132        self.db_visible = self.be.empty_like(self.b_visible)133    def fprop(self, inputs, inference=False, labels=None, weights=None):134        """135        forward propagation. Returns probability of hidden units136        """137        hidden_proba_optree = self.hidden_probability(inputs, weights=weights)138        hidden_proba = self.be.empty(hidden_proba_optree.shape)139        hidden_proba._assign(hidden_proba_optree)140        return hidden_proba141    def bprop(self, hidden_units, alpha=None, beta=None, weights=None):142        """143        CD1 backward pass (negative phase)144        Returns probability of visible units145        """146        visible_proba_optree = self.visible_probability(hidden_units, weights=weights)147        visible_proba = self.be.empty(visible_proba_optree.shape)148        visible_proba._assign(visible_proba_optree)149        return visible_proba150    def _grad_W(self, visible_units, hidden_units):151        """152        Calculate positive or negative gradient of weights153        Inputs:154            visible_units (Tensor): visible units155            hidden_units (Tensor): hidden_units156        Returns:157            OPTree.node158        """159        return self.be.dot(visible_units, hidden_units.T)160    def _grad_b_hidden(self, h_pos, h_neg):161        """162        Gradient of b_hidden163        """164        return self.be.sum(h_pos - h_neg, axis=-1)165    def _grad_b_visible(self, v_pos, v_neg):166        """167        Gradient of b_visible168        """169        return self.be.sum(v_pos - v_neg, axis=-1)170    def update(self, v_pos, labels=None):171        """172        Calculate gradients173        Inputs:174            v_pos (Tensor): input units (typically given input sample X) of size (n_visibles, batch_size)175            labels (Tensor): either OHE labels of shape (n_classes, batch_size) or just labels of shape (1, batch_size).176                In this case it will be converted to OHE labels.177        Returns:178            (update_W, update_b_hidden, update_b_visible) (tuple): gradients of W, b_hidden, b_visible179        """180        # positive phase181        h_pos = self.hidden_probability(v_pos)182        # negative phase183        if self.persistant:184            if self.chain is None:185                self.chain = self.be.zeros(h_pos.shape)186            chain = self.chain187        else:188            chain = h_pos > self.be.array(self.be.rng.uniform(size=h_pos.shape))189        for k in xrange(self.kPCD):190            if self.persistant:191                v_neg = self.sample_visibles(chain)192            else:193                v_neg = self.visible_probability(chain)194            h_neg = self.hidden_probability(v_neg)195            chain = h_neg > self.be.array(self.be.rng.uniform(size=h_neg.shape))196        # calculate chain explicitly197        if self.persistant:198            self.chain._assign(chain)199        if not self.collect_zero_signal:200            zero_signal_mask = self.hidden_preacts.get() == 0201            h_pos[zero_signal_mask] = self.sparse_target202            sparsegrads_b_hidden = self.get_sparse_grads_b_hidden(h_pos)203            h_pos[zero_signal_mask] = 0204            h_neg[zero_signal_mask] = 0205        else:206            sparsegrads_b_hidden = self.get_sparse_grads_b_hidden(h_pos)207        self.dW[:] = -(self._grad_W(v_pos, h_pos) - self._grad_W(v_neg, h_neg))208        self.db_visible[:] = -self._grad_b_visible(v_pos, v_neg)209        self.db_hidden[:] = -(self._grad_b_hidden(h_pos, h_neg) - sparsegrads_b_hidden)210    def hidden_probability(self, inputs, labels=None, weights=None):211        """212        Calculate P(h | v)213        """214        if weights is None:215            weights = self.W216        hidden_preacts = self.be.dot(weights.T, inputs) + self.b_hidden217        hidden_proba = self.sigmoid(hidden_preacts)218        return hidden_proba219    def visible_probability(self, hidden_units, weights=None):220        """221        Calculate P(v|h)222        """223        if weights is None:224            weights = self.W225        visible_preacts = self.be.dot(weights, hidden_units) + self.b_visible226        visible_proba = self.sigmoid(visible_preacts)227        return visible_proba228    def sample_hiddens(self, visible_units, labels=None):229        """230        Sample hidden units.231        """232        h_probability = self.hidden_probability(visible_units)233        return h_probability > self.be.array(self.be.rng.uniform(size=h_probability.shape))234    def sample_visibles(self, hidden_units):235        """236        Sample visible units237        """238        v_probability = self.visible_probability(hidden_units)239        return v_probability > self.be.array(self.be.rng.uniform(size=v_probability.shape))240    def get_sparse_grads_b_hidden(self, h_proba):241        if self.sparse_cost == 0:242            return self.be.zeros_like(self.b_hidden)243        if not hasattr(self, 'hidmeans'):244            self.hidmeans = self.be.empty((self.n_hidden, 1))245            self.hidmeans[:] = self.sparse_target * self.be.ones((self.n_hidden, 1))246        hidden_probability_mean = self.be.mean(h_proba, axis=-1)247        self.hidmeans[:] = self.sparse_damping * self.hidmeans + (1 - self.sparse_damping) * hidden_probability_mean248        sparsegrads_b_hidden = self.sparse_cost * (self.hidmeans - self.sparse_target)249        return sparsegrads_b_hidden250    def free_energy(self, inputs):251        """252        Calculate cost253        """254        Wv_b = self.be.dot(self.W.T, inputs) + self.b_hidden255        energy = self.be.empty((1, self.be.bsz))256        energy[:] = -self.be.dot(self.b_visible.T, inputs) - self.be.sum(self.be.log(1 + self.be.exp(Wv_b)), axis=0)257        return energy258    def get_pseudolikelihood(self, inputs):259        """Stochastic approximation to the pseudo-likelihood"""260        # index of bit i in expression p(x_i | x_{\i})261        if not hasattr(self, 'bit_i_idx'):262            self.bit_i_idx = 0263        else:264            self.bit_i_idx = (self.bit_i_idx + 1) % self.n_visible265        # binarize the input image by rounding to nearest integer266        xi = self.be.array(inputs.get() >= 0.5)267        # calculate free energy for the given bit configuration268        fe_xi = self.free_energy(xi)269        # flip bit x_i of matrix xi and preserve all other bits x_{\i}270        # Equivalent to xi[:,bit_i_idx] = 1-xi[:, bit_i_idx]271        xi_flip = xi.copy(xi)272        xi_flip[self.bit_i_idx] = 1 - xi_flip[self.bit_i_idx]273        # calculate free energy with bit flipped274        fe_xi_flip = self.free_energy(xi_flip)275        # equivalent to e^(-FE(x_i)) / (e^(-FE(x_i)) + e^(-FE(x_{\i})))276        cost = self.be.empty((1, 1))277        cost[:] = self.be.mean(self.n_visible * self.be.log(self.sigmoid(fe_xi_flip - fe_xi)))278        return cost.get()[0, 0]279class RBMLayerWithLabels(RBMLayer):280    """281    Implementation of an RBM layer with combination of multinomial label variables282    and Bernouli feature variables.283    n_hidden (int): number of hidden Bernouli variables284    n_classes (int): number of classes285    n_duplicates (int): number of multinomial label variables.286        Each class is represented by n identical variables (Bernouli).287    """288    def __init__(self, n_hidden, n_classes, n_duplicates=1, init=None, name="RBMLayerWithLabels",289                 persistant=False, kPCD=1, use_fast_weights=False,290                 sparse_target=0, sparse_damping=0, sparse_cost=0,291                 collect_zero_signal=True):292        super(RBMLayerWithLabels, self).__init__(n_hidden, init=init, name=name, persistant=persistant, kPCD=kPCD,293                                                 use_fast_weights=use_fast_weights, sparse_target=sparse_target,294                                                 sparse_cost=sparse_cost, sparse_damping=sparse_damping,295                                                 collect_zero_signal=collect_zero_signal)296        self.n_classes = n_classes297        self.n_duplicates = n_duplicates298        self.fast_W = None299        self.fast_b_hidden = 0300        self.fast_b_visible = 0301        self.fast_states = []302    def allocate(self, shared_outputs=None):303        super(RBMLayerWithLabels, self).allocate(shared_outputs)304        parallel, distributed = self.get_param_attrs()305        if self.use_fast_weights:306            self.fast_W = self.be.zeros(self.weight_shape, parallel=parallel, distributed=distributed)307            self.fast_dW = self.be.zeros_like(self.fast_W)308            self.fast_b_hidden = self.be.zeros(self.b_hid_shape, parallel=parallel, distributed=distributed)309            self.fast_db_hidden = self.be.zeros_like(self.fast_b_hidden)310            self.fast_b_visible = self.be.zeros(self.b_vis_shape, parallel=parallel, distributed=distributed)311            self.fast_db_visible = self.be.zeros_like(self.fast_b_visible)312        else:313            self.fast_W = 0314            self.fast_b_hidden = 0315            self.fast_b_visible = 0316    def configure(self, in_obj):317        super(RBMLayerWithLabels, self).configure(in_obj)318        self.weight_shape = (self.weight_shape[0] + self.n_classes, self.weight_shape[1])319        self.b_vis_shape = (self.b_vis_shape[0] + self.n_classes, self.b_vis_shape[1])320        return self321    def fprop(self, inputs, inference=False, labels=None, weights=None):322        """323        Calculate hidden units324        """325        if labels is None:326            ohe_labels = np.zeros((self.n_classes, self.be.bsz))327        else:328            ohe_labels = label2binary(labels, self.n_classes)329        v_units = self.be.array(np.vstack((ohe_labels, inputs.get())))330        return super(RBMLayerWithLabels, self).fprop(v_units, inference=inference, weights=weights)331    def bprop(self, hidden_units, alpha=None, beta=None, weights=None):332        """333        CD1 backward pass (negative phase)334        Returns probability of visible units335        """336        visible_proba = super(RBMLayerWithLabels, self).bprop(hidden_units, alpha, beta, weights)337        return visible_proba[self.n_classes:]338    def hidden_probability(self, inputs, labels=None, fast_weights=0, fast_b_hidden=0, weights=None):339        """340        Calculate P(h | v)341        """342        if weights is None:343            weights = self.W344        weights = _calc_optree(weights + fast_weights, self.be)345        v_units = inputs346        if inputs.shape[0] != self.W.shape[0]:347            if not labels is None:348                ohe_labels = label2binary(labels, self.n_classes)349                v_units = self.be.array(np.vstack((ohe_labels, inputs.get())))350        if v_units.shape[0] == self.W.shape[0]:351            weights[:self.n_classes] *= self.n_duplicates352        else:353            weights = weights[self.n_classes:]354        hidden_preacts = self.be.dot(weights.T, v_units) + self.b_hidden + fast_b_hidden355        hidden_proba = self.sigmoid(hidden_preacts)356        return hidden_proba357    def visible_probability(self, hidden_units, fast_weights=0, fast_b_visible=0, weights=None):358        """359        Calculate P(v|h)360        """361        if weights is None:362            weights = self.W363        weights = weights + fast_weights364        visible_preacts = _calc_optree(self.be.dot(weights, hidden_units) + self.b_visible + fast_b_visible,365                                       self.be)366        visible_probability = self.be.empty_like(visible_preacts)367        visible_probability[self.n_classes:] = self.sigmoid(visible_preacts[self.n_classes:])368        # TODO: chekc the axis369        temp_exponential = self.be.exp(visible_preacts[:self.n_classes] -370                                       self.be.max(visible_preacts[:self.n_classes], axis=0))371        visible_probability[:self.n_classes] = temp_exponential / self.be.sum(temp_exponential, axis=0)372        return visible_probability373    def sample_hiddens(self, visible_units, labels=None, fast_weights=0, fast_b_hidden=0):374        """375        Sample hidden units.376        """377        h_probability = self.hidden_probability(visible_units, labels, fast_weights, fast_b_hidden)378        return h_probability > self.be.array(self.be.rng.uniform(size=h_probability.shape))379    def sample_visibles(self, hidden_units, fast_weights=0, fast_b_visible=0):380        """381        Sample visible units382        """383        v_units = self.visible_probability(hidden_units, fast_weights, fast_b_visible)384        v_units[self.n_classes:] = (v_units[self.n_classes:] >385                                    self.be.array(self.be.rng.uniform(size=v_units[self.n_classes:].shape)))386        v_units_tensor = _calc_optree(v_units, self.be)387        # multinomial distribution with n = 1 (number of trials)388        random_numbers = self.be.rng.uniform(size=self.be.bsz)389        probabilities = v_units_tensor[:self.n_classes].get().cumsum(axis=0)390        for i in xrange(self.n_classes):391            if i == 0:392                v_units_tensor.get()[i] = random_numbers < probabilities[i]393            else:394                v_units_tensor.get()[i] = (random_numbers >= probabilities[i - 1]) & (random_numbers < probabilities[i])395        return v_units_tensor396    def update(self, v_pos, labels=None):397        """398        Calculate gradients399        Inputs:400            v_pos (Tensor): input units (typically given input sample X) of size (n_visibles, batch_size)401            labels (Tensor): either OHE labels of shape (n_classes, batch_size) or just labels of shape (1, batch_size).402                In this case it will be converted to OHE labels.403        Returns:404            (update_W, update_b_hidden, update_b_visible) (tuple): gradients of W, b_hidden, b_visible405        """406        if labels is None:407            raise Exception('"labels" must be provided!')408        ohe_labels = label2binary(labels, self.n_classes)409        v_pos = _calc_optree(v_pos, self.be)410        v_pos_with_labels = self.be.array(np.vstack((ohe_labels, v_pos.get())))411        h_pos = self.hidden_probability(v_pos_with_labels)412        # sparsity413        sparsegrads_b_hidden = self.get_sparse_grads_b_hidden(h_pos)414        # negative phase415        if self.persistant:416            if self.chain is None:417                self.chain = self.be.zeros(h_pos.shape)418            chain = self.chain419        else:420            chain = h_pos > self.be.array(self.be.rng.uniform(size=h_pos.shape))421        for k in xrange(self.kPCD):422            if self.persistant:423                v_neg = self.sample_visibles(chain, self.fast_W, self.fast_b_visible)424            else:425                v_neg = self.visible_probability(chain, self.fast_W, self.fast_b_visible)426            h_neg = self.hidden_probability(v_neg, fast_weights=self.fast_W, fast_b_hidden=self.fast_b_hidden)427            chain = h_neg > self.be.array(self.be.rng.uniform(size=h_neg.shape))428        if self.persistant:429            self.chain = _calc_optree(chain, self.be)430        self.dW[:] = -(self._grad_W(v_pos_with_labels, h_pos) - self._grad_W(v_neg, h_neg))431        self.db_visible[:] = -(self._grad_b_visible(v_pos_with_labels, v_neg))432        self.db_hidden[:] = -(self._grad_b_hidden(h_pos, h_neg) - sparsegrads_b_hidden)433    def update_for_wake_sleep(self, v_pos, labels=None, persistant=False, kPCD=1):434        """435        Calculate gradients during wake-sleep.436        """437        # positive phase438        ohe_labels = label2binary(labels, self.n_classes)439        v_pos = _calc_optree(v_pos, self.be)440        v_pos_with_labels = self.be.array(np.vstack((ohe_labels, v_pos.get())))441        h_pos = self.hidden_probability(v_pos_with_labels)442        h_pos_sample = h_pos > self.be.array(self.be.rng.uniform(size=h_pos.shape))443        # negative phase444        if persistant:445            if self.chain is None:446                self.chain = self.be.zeros(h_pos.shape)447            chain = self.chain448        else:449            chain = h_pos_sample450        for k in xrange(kPCD):451            v_neg = self.visible_probability(chain)452            v_neg_sample = v_neg.copy(v_neg)453            v_neg_sample[self.n_classes:] = (v_neg[self.n_classes:] >454                                             self.be.array(self.be.rng.uniform(size=v_neg[self.n_classes:].shape)))455            h_neg = self.hidden_probability(v_neg)456            chain = h_neg > self.be.array(self.be.rng.uniform(size=h_neg.shape))457        if persistant:458            self.chain = _calc_optree(chain, self.be)459        update_W = self._grad_W(v_pos_with_labels, h_pos_sample) - self._grad_W(v_neg_sample, h_neg)460        update_b_visible = self.be.mean(v_pos_with_labels - v_neg_sample, axis=-1)461        update_b_hidden = self.be.mean(h_pos_sample - h_neg, axis=-1)462        return update_W / float(self.be.bsz), update_b_hidden, update_b_visible, v_neg, v_neg_sample, h_neg463def label2binary(label, n_classes):464    """465    Convert label to binary vector.466    Labels should be from {0, 1, ..., n_classes} set.467    Input:468        label (Tensor): (1, batch_size) Tensor469        n_classes (int): number of classes470    Returns:471        binary (numpy array): (n_classes, batch_size)-shaped Tensor472    """473    if label.shape[0] == n_classes:474        return label.get()475    if label.shape[0] > 1:476        raise Exception('"label" must 1 x N array!')477    binary = np.zeros((n_classes, label.shape[1]), dtype=np.int32)478    binary[:, label.get()] = 1479    return binary480class ConvolutionalRBMLayer(RBMLayer):481    """482    Convolutional RBM layer implementation.483    Works with volumetric data484    Arguments:485        fshape (tuple(int)): four dimensional shape of convolution window (depth, width, height, n_output_maps)486        strides (Optional[Union[int, dict]]): strides to apply convolution487            window over. An int applies to both dimensions, or a dict with488            str_d, str_h and str_w applies to d, h and w dimensions distinctly.  Defaults489            to str_d = str_w = str_h = None490        padding (Optional[Union[int, dict]]): padding to apply to edges of491            input. An int applies to both dimensions, or a dict with pad_d, pad_h492            and pad_w applies to h and w dimensions distinctly.  Defaults493            to pad_d = pad_w = pad_h = None494        init (Optional[Initializer]): Initializer object to use for495            initializing layer weights496        name (Optional[str]): layer name. Defaults to "ConvolutionLayer"497    """498    def __init__(self, fshape, strides={}, padding={}, init=None, bsum=False, name="ConvolutionalRBMLayer", parallelism="Data",499                 persistant=False, kPCD=1, use_fast_weights=False,500                 sparse_target=0, sparse_damping=0, sparse_cost=0,501                 collect_zero_signal=True):502        super(ConvolutionalRBMLayer, self).__init__(0, init=init, name=name, parallelism=parallelism, persistant=persistant, kPCD=kPCD,503                                                    use_fast_weights=use_fast_weights, sparse_target=sparse_target,504                                                    sparse_cost=sparse_cost, sparse_damping=sparse_damping,505                                                    collect_zero_signal=collect_zero_signal)506        self.nglayer = None507        self.convparams = {'str_h': 1, 'str_w': 1, 'str_d': 1,508                           'pad_h': 0, 'pad_w': 0, 'pad_d': 0,509                           'T': 1, 'D': 1, 'bsum': bsum}  # 3D paramaters510        # keep around args in __dict__ for get_description.511        self.fshape = fshape512        self.strides = strides513        self.padding = padding514        if isinstance(fshape, tuple):515            fkeys = ('R', 'S', 'K') if len(fshape) == 3 else ('T', 'R', 'S', 'K')516            fshape = {k: x for k, x in zip(fkeys, fshape)}517        if isinstance(strides, int):518            strides = {'str_d': strides, 'str_h': strides, 'str_w': strides}519        if isinstance(padding, int):520            padding = {'pad_d': padding, 'pad_h': padding, 'pad_w': padding}521        for d in [fshape, strides, padding]:522            self.convparams.update(d)523        self.b_vis_shape = None524        self.b_hid_shape = None525    def configure(self, in_obj):526        super(ConvolutionalRBMLayer, self).configure(in_obj)527        if self.nglayer is None:528            assert isinstance(self.in_shape, tuple)529            ikeys = ('C', 'H', 'W') if len(self.in_shape) == 3 else ('C', 'D', 'H', 'W')530            shapedict = {k: x for k, x in zip(ikeys, self.in_shape)}531            shapedict['N'] = self.be.bsz532            self.convparams.update(shapedict)533            self.nglayer = self.be.conv_layer(self.be.default_dtype, **self.convparams)534            (K, M, P, Q, N) = self.nglayer.dimO535            self.out_shape = (K, P, Q) if M == 1 else (K, M, P, Q)536            self.n_hidden = self.nglayer.dimO2[0]537            self.weight_shape = self.nglayer.dimF2  # (C * R * S, K)538        if self.convparams['bsum']:539            self.batch_sum_shape = (self.nglayer.K, 1)540        self.b_vis_shape = (reduce(mul, self.in_shape), 1)541        self.b_hid_shape = (self.convparams['K'], 1) # K x 1 x 1 x 1 - number of output feature maps542        return self543    def _grad_W(self, visible_units, hidden_units):544        """545        Calculate positive part of grad_W546        Inputs:547            visible_units (Tensor): visible units548            hidden_units (Tensor): hidden units (or their probabilities)549        """550        result = self.be.empty_like(self.W)551        visible_units_tensor = _calc_optree(visible_units, self.be)552        hidden_units_tensor = _calc_optree(hidden_units, self.be)553        self.be.update_conv(self.nglayer, visible_units_tensor, hidden_units_tensor, result)554        # TODO: check that division by the batch size is needed. It maybe not because of self.batch_sum555        return result556    def _grad_b_hidden(self, h_pos, h_neg):557        """558        Gradient of b_hidden559        """560        update_b_hidden = _calc_optree(self.be.mean(h_pos - h_neg, axis=-1), self.be)561        update_b_hidden = self.be.sum(update_b_hidden.reshape(self.convparams['K'], -1), axis=-1)562        return update_b_hidden563    def _grad_b_visible(self, v_pos, v_neg):564        """565        Gradient of b_visible566        """567        return self.be.sum(v_pos - v_neg, axis=-1)568    # def _complex_mean(self, input_array, mean_axes):569    #     """570    #     calculate mean(mean(mean(..., axis=-1), axis=2), axis=3), axis=4)571    #     """572    #     #TODO: this function maybe not needed. In new version it seems to work like in numpy573    #     shape = [dim for dim in input_array.shape]574    #     array_to_average = input_array575    #     for axis in mean_axes:576    #         shape[axis] = 1577    #         intermediate_mean = self.be.empty(shape)578    #         intermediate_mean[:] = self.be.sum(array_to_average, axis=axis)579    #         array_to_average = intermediate_mean580    #     return self.be.array(np.squeeze(intermediate_mean.get()).reshape(-1, 1) / self.be.bsz)581    def update(self, v_pos, labels=None):582        """583        Calculate gradients584        Inputs:585            v_pos (Tensor): input units (typically given input sample X) of shape (n_visibles, batch_size)586        Returns:587            (update_W, update_b_hidden, update_b_visible, zeros_ratio) (tuple): gradients of W, b_hidden, b_visible, zeros_ratio588                zeros_ratio (float) is n_hidden_units / (n_hidden_units - n_zero)589        """590        # positive phase591        h_pos = self.hidden_probability(v_pos)592        # negative phase593        if self.persistant:594            if self.chain is None:595                self.chain = self.be.zeros(h_pos.shape)596            chain = self.chain597        else:598            chain = h_pos > self.be.array(self.be.rng.uniform(size=h_pos.shape))599        for k in xrange(self.kPCD):600            if self.persistant:601                v_neg = self.sample_visibles(chain)602            else:603                v_neg = self.visible_probability(chain)604            h_neg = self.hidden_probability(v_neg)605            chain = h_neg > self.be.array(self.be.rng.uniform(size=h_neg.shape))606        if self.persistant:607            self.chain._assign(chain)608        if not self.collect_zero_signal:609            zero_signal_mask = self.hidden_preacts.get() == 0610            h_pos[zero_signal_mask] = sparse_target611            sparsegrads_b_hidden = self.get_sparse_grads_b_hidden(h_pos)612            h_pos[zero_signal_mask] = 0613            h_neg[zero_signal_mask] = 0614            zeros_ratio = zero_signal_mask.size / (zero_signal_mask.size - np.sum(zero_signal_mask))615        else:616            sparsegrads_b_hidden = self.get_sparse_grads_b_hidden(h_pos)617            zeros_ratio = 1618        self.dW[:] = -(self._grad_W(v_pos, h_pos) - self._grad_W(v_neg, h_neg))619        self.db_visible[:] = -self._grad_b_visible(v_pos, v_neg)620        #TODO: maybe this should be mean(mean(mean(...))) like in crbm.m?621        self.db_hidden[:] = -(self._grad_b_hidden(h_pos, h_neg) - sparsegrads_b_hidden)622    def hidden_probability(self, inputs, weights=None):623        """624        Calculate P(h | v)625        Inputs:626            inputs (Tensor): visible units of size (n_visible x batch_size)627            weights (Tensor): weights (optional) of size (n_filters * filter_depth * filter_width, filter_height)628        Returns:629            hidden_probability (Tensor): probability of hidden units (n_filters * n_hidden, batch_size)630        """631        #initialization632        if weights is None:633            weights = self.W634        # calculate operation tree for inputs635        inputs_tensor = _calc_optree(inputs, self.be)636        hidden_conv = self.be.empty((self.n_hidden, self.be.bsz))637        self.be.fprop_conv(self.nglayer, inputs_tensor, weights, hidden_conv)638        b_hidden = self.be.ones((self.convparams['K'], self.nglayer.nOut / self.convparams['K']))639        b_hidden[:] = b_hidden * self.b_hidden640        hidden_preacts = hidden_conv + b_hidden.reshape(-1, 1)641        hidden_proba = self.sigmoid(hidden_preacts)642        return hidden_proba643    def visible_probability(self, hidden_units, weights=None):644        """645        Calculate P(v|h)646        """647        if weights is None:648            weights = self.W649        # calculate operation tree for hidden_units650        hidden_units_tensor = _calc_optree(hidden_units, self.be)651        visible_conv = self.be.empty((self.n_visible, self.be.bsz))652        # TODO: maybe we can interchange loop by one convolution? check this.653        self.be.bprop_conv(layer=self.nglayer, F=weights, E=hidden_units_tensor, grad_I=visible_conv, bsum=self.batch_sum)654        visible_preacts = visible_conv + self.b_visible655        visible_proba = self.sigmoid(visible_preacts)656        return visible_proba657    def get_sparse_grads_b_hidden(self, h_probability):658        if self.sparse_cost == 0:659            return self.be.zeros_like(self.b_hidden)660        if not hasattr(self, 'hidmeans'):661            self.hidmeans = self.be.ones(self.nglayer.dimO[:-1]) * self.sparse_target662        h_probability = h_probability.reshape(self.nglayer.dimO)663        hidden_probability_mean = self.be.empty(self.hidmeans.shape + (1,))664        hidden_probability_mean[:] = self.be.mean(h_probability, axis=-1)665        self.hidmeans[:] = self.sparse_damping * self.hidmeans + (1 - self.sparse_damping) * hidden_probability_mean[:, :, :, 0]666        sparsegrads_b_hidden = sparse_cost * np.squeeze(np.mean(np.mean(np.mean(self.hidmeans.get() - self.sparse_target,667                                                                 axis=1), axis=1), axis=1))...test_rhs_analyser.py
Source:test_rhs_analyser.py  
1# RhsAnalyser tests23import sys4sys.path.append("../src")5from common.architecture_support import whosdaddy, whosgranddaddy6from parsing.keywords import pythonbuiltinfunctions7from parsing.parse_rhs_analyser import RhsAnalyser8               9"""1011 TEST - BASE CLASS12 13"""14            15import unittest16import sys171819class MockQuickParse:20    def __init__(self):21        self.quick_found_classes = []22        self.quick_found_module_defs = []23        self.quick_found_module_attrs = []242526class MockOldParseModel:27    def __init__(self):28        self.classlist = {}29        self.modulemethods = []3031            32class MockVisitor:33    def __init__(self, quick_parse):34        self.model = MockOldParseModel()35        self.stack_classes = []36        self.stack_module_functions = [False]37        self.quick_parse = quick_parse38        self.init_lhs_rhs()39        self.imports_encountered = []4041    def init_lhs_rhs(self):42        self.lhs = []43        self.rhs = []44        self.lhs_recording = True45        self.made_rhs_call = False46        self.made_assignment = False47        self.made_append_call = False48        self.made_import = False49        self.pos_rhs_call_pre_first_bracket = None505152class TestCaseBase(unittest.TestCase):5354    def setUp(self):55        self.qp = MockQuickParse()56        self.v = MockVisitor(self.qp)57        self.v.imports_encountered = ['a']58        self.v.model.modulemethods = []59        60        self.modes = ['assign', 'append']6162    def get_visitor(self, rhs, assign_or_append):63        self.v.rhs = rhs64        if assign_or_append == 'assign':65            self.v.made_append_call = False66            self.v.made_assignment = True67        else:68            self.v.made_append_call = True69            self.v.made_assignment = False70        return self.v7172    def do(self, rhs, made_rhs_call, call_pos, quick_classes, quick_defs, result_should_be, rhs_ref_to_class_should_be, imports):73        for mode in self.modes:74            v = self.get_visitor(rhs, mode)75            v.made_rhs_call = made_rhs_call76            v.quick_parse.quick_found_classes = quick_classes77            v.quick_parse.quick_found_module_defs = quick_defs78            v.pos_rhs_call_pre_first_bracket = call_pos79            v.imports_encountered = imports80            81            ra = RhsAnalyser(v)82            self.assertEqual(ra.is_rhs_reference_to_a_class(), result_should_be)83            self.assertEqual(ra.rhs_ref_to_class,              rhs_ref_to_class_should_be)848586"""8788 TESTS89 90"""9192class TestCase_A_Classic(TestCaseBase):93    #self.w = Blah()94    #self.w.append(Blah())9596    """97    if Blah class exists and its NOT in module methods98    Blah() where: class Blah - T Blah99    """100    def test_1(self):101        self.do(rhs=['Blah'], made_rhs_call=True, call_pos=0, quick_classes=['Blah'], quick_defs=[], imports = [],102                result_should_be=True, rhs_ref_to_class_should_be='Blah')103104    """105    if Blah class does NOT exist and it IS in module methods106    Blah() where: def Blah() - F107    """108    def test_2(self):109        self.do(rhs=['Blah'], made_rhs_call=True, call_pos=0, quick_classes=[], quick_defs=['Blah'], imports = [],110                result_should_be=False, rhs_ref_to_class_should_be=None)111        112    """113    if Blah class exists and it IS in module methods - CONTRADICATION, assume Blah is class114    Blah() where: class Blah, def Blah() - T Blah115    """116    def test_3(self):117        self.do(rhs=['Blah'], made_rhs_call=True, call_pos=0, quick_classes=['Blah'], quick_defs=['Blah'], imports = [],118                result_should_be=True, rhs_ref_to_class_should_be='Blah')119        120    """121    if Blah class does NOT exist and its NOT in module methods - GUESS yes, if starts with uppercase letter122    Blah() where: - T Blah123    blah() where: - F124    """125    def test_4(self):126        self.do(rhs=['Blah'], made_rhs_call=True, call_pos=0, quick_classes=[], quick_defs=[], imports = [],127                result_should_be=True, rhs_ref_to_class_should_be='Blah')128129    def test_5(self):130        self.do(rhs=['blah'], made_rhs_call=True, call_pos=0, quick_classes=[], quick_defs=[], imports = [],131                result_should_be=False, rhs_ref_to_class_should_be=None)132133134class TestCase_B_RhsIsInstance(TestCaseBase):135    # self.w = blah136    # self.w.append(blah)137        138    def test_1(self):139        """140        if Blah class exists and blah is NOT in module methods ==> (token transmografied)141        blah where: class Blah - T Blah142        """143        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=['Blah'], quick_defs=[], imports = [],144                result_should_be=True, rhs_ref_to_class_should_be='Blah')145146    def test_2(self):147        """148        if Blah class exists and blah IS in module methods (clear intent that blah is a function ref)149        blah where: class Blah, def blah() - F150        """151        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=['Blah'], quick_defs=['blah'], imports = [],152                result_should_be=False, rhs_ref_to_class_should_be=None)153        154    def test_3(self):155        """156        if Blah class exists and blah is NOT in module methods but Blah IS in module methods157            (if Blah is a function this is unrelated to blah instance)  ==> (token transmografied)158        blah where: class Blah, def Blah() - T Blah159        """160        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=['Blah'], quick_defs=['Blah'], imports = [],161                result_should_be=True, rhs_ref_to_class_should_be='Blah')162163    def test_4(self):164        """165        if Blah class doesn't exist166        blah where: - F167        """168        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=[], quick_defs=[], imports = [],169                result_should_be=False, rhs_ref_to_class_should_be=None)170171        # the state of module methods doesn't matter:172        # blah where: def Blah() - F173        # blah where: def blah() - F174        175        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=[], quick_defs=['Blah'], imports = [],176                result_should_be=False, rhs_ref_to_class_should_be=None)177        178        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=[], quick_defs=['blah'], imports = [],179                result_should_be=False, rhs_ref_to_class_should_be=None)180181182    def test_5(self):183        """184        if blah class exists - syntax is just plain wrong for class creation.185        blah where: class blah - F186        """187        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=['blah'], quick_defs=[], imports = [],188                result_should_be=False, rhs_ref_to_class_should_be=None)189190        # the state of module methods doesn't matter:191        # blah where: class blah, def Blah() - F192193        self.do(rhs=['blah'], made_rhs_call=False, call_pos=0, quick_classes=['blah'], quick_defs=['Blah'], imports = [],194                result_should_be=False, rhs_ref_to_class_should_be=None)195196197class TestCase_C_AttrBeforeClassic(TestCaseBase):198    # self.w = a.Blah()199    # self.w.append(a.Blah())200201    def test_1(self):202        """203        a.Blah() where: class Blah, import a - T a.Blah204        """205        self.do(rhs=['a', 'Blah'], made_rhs_call=True, call_pos=1, quick_classes=['Blah'], quick_defs=[], imports=['a'],206                result_should_be=True, rhs_ref_to_class_should_be='a.Blah')207208    def test_2(self):209        """210        Most common case, e.g. self.popupmenu = wx.Menu() where all you know about is that you imported wx.211        a.Blah() where: import a - T a.Blah212        """213        self.do(rhs=['a', 'Blah'], made_rhs_call=True, call_pos=1, quick_classes=[], quick_defs=[], imports=['a'],214                result_should_be=True, rhs_ref_to_class_should_be='a.Blah')215216    def test_3(self):217        """218        a.Blah() where: class Blah - F219        """220        self.do(rhs=['a', 'Blah'], made_rhs_call=True, call_pos=1, quick_classes=['Blah'], quick_defs=[], imports=[],221                result_should_be=False, rhs_ref_to_class_should_be=None)222223    def test_4(self):224        """225        a.Blah() where: - F226        """227        self.do(rhs=['a', 'Blah'], made_rhs_call=True, call_pos=1, quick_classes=[], quick_defs=[], imports=[],228                result_should_be=False, rhs_ref_to_class_should_be=None)229230231232class TestCase_D_MultipleAttrBeforeClassic(TestCaseBase):233    # self.w = a.b.Blah()234    # self.w.append(a.b.Blah())235236    def test_1(self):237        """238        a.b.Blah() where: class Blah, import a.b - T a.b.Blah239        """240        self.do(rhs=['a', 'b', 'Blah'], made_rhs_call=True, call_pos=2, quick_classes=['Blah'], quick_defs=[], imports=['a.b'],241                result_should_be=True, rhs_ref_to_class_should_be='a.b.Blah')242243    def test_2(self):244        """245        a.b.Blah() where: import a.b - T a.b.Blah()246        """247        self.do(rhs=['a', 'b', 'Blah'], made_rhs_call=True, call_pos=2, quick_classes=[], quick_defs=[], imports=['a.b'],248                result_should_be=True, rhs_ref_to_class_should_be='a.b.Blah')249250    def test_3(self):251        """252        a.b.Blah() where: class Blah, import b - F253        """254        self.do(rhs=['a', 'b', 'Blah'], made_rhs_call=True, call_pos=2, quick_classes=['Blah'], quick_defs=[], imports=['b'],255                result_should_be=False, rhs_ref_to_class_should_be=None)256257    def test_4(self):258        """259        a.b.Blah() where: class Blah, import a - F260        """261        self.do(rhs=['a', 'b', 'Blah'], made_rhs_call=True, call_pos=2, quick_classes=['Blah'], quick_defs=[], imports=['a'],262                result_should_be=False, rhs_ref_to_class_should_be=None)263264265class TestCase_E_DoubleCall(TestCaseBase):266    # self.w = a().Blah()267    # self.w.append(a().Blah())268269    def test_1(self):270        """271        a().Blah() where: class Blah - F272        """273        self.do(rhs=['a', 'Blah'], made_rhs_call=True, call_pos=0, quick_classes=['Blah'], quick_defs=[], imports=[],274                result_should_be=False, rhs_ref_to_class_should_be=None)275276    def test_2(self):277        """278        a().Blah() where: class a - F279        """280        self.do(rhs=['a', 'Blah'], made_rhs_call=True, call_pos=0, quick_classes=['a'], quick_defs=[], imports=[],281                result_should_be=False, rhs_ref_to_class_should_be=None)282283class TestCase_F_CallThenTrailingInstance(TestCaseBase):284    # self.w = a().blah285    # self.w.append(a().blah)286287    def test_1(self):288        """289        a().blah where: class Blah - F290        """291        self.do(rhs=['a', 'blah'], made_rhs_call=True, call_pos=0, quick_classes=['Blah'], quick_defs=[], imports=[],292                result_should_be=False, rhs_ref_to_class_should_be=None)293294295class TestCase_G_AttrBeforeRhsInstance(TestCaseBase):296    # self.w = a.blah297    # self.w.append(a.blah)298299    def test_1(self):300        """301        a.blah where: - F302        """303        self.do(rhs=['a', 'blah'], made_rhs_call=False, call_pos=0, quick_classes=[], quick_defs=[], imports=[],304                result_should_be=False, rhs_ref_to_class_should_be=None)305        306    def test_2(self):307        """308        a.blah where: class a - F309        """310        self.do(rhs=['a', 'blah'], made_rhs_call=False, call_pos=0, quick_classes=['a'], quick_defs=[], imports=[],311                result_should_be=False, rhs_ref_to_class_should_be=None)312313    def test_3(self):314        """315        a.blah where: class blah - F316        """317        self.do(rhs=['a', 'blah'], made_rhs_call=False, call_pos=0, quick_classes=['blah'], quick_defs=[], imports=[],318                result_should_be=False, rhs_ref_to_class_should_be=None)319        320    def test_4(self):321        """322        a.blah where: class Blah - F323        """324        self.do(rhs=['a', 'blah'], made_rhs_call=False, call_pos=0, quick_classes=['Blah'], quick_defs=[], imports=[],325                result_should_be=False, rhs_ref_to_class_should_be=None)326327    def test_5(self):328        """329        a.blah where: class Blah, imports a - F330        """331        self.do(rhs=['a', 'blah'], made_rhs_call=False, call_pos=0, quick_classes=['Blah'], quick_defs=[], imports=['a'],332                result_should_be=False, rhs_ref_to_class_should_be=None)333334class TestCase_AddHoc(TestCaseBase):335336    def test_1(self):337        """338        self.flageditor = FlagEditor(gamestatusstate=self)339        a.blah where: class Blah, imports a - F340        """341        self.do(rhs=['FlagEditor'], made_rhs_call=True, call_pos=0, quick_classes=[], quick_defs=[], imports=[],342                result_should_be=True, rhs_ref_to_class_should_be='FlagEditor')343    344def suite():345    suite1 = unittest.makeSuite(TestCase_A_Classic, 'test')346    suite2 = unittest.makeSuite(TestCase_B_RhsIsInstance, 'test')347    suite3 = unittest.makeSuite(TestCase_C_AttrBeforeClassic, 'test')348    suite4 = unittest.makeSuite(TestCase_D_MultipleAttrBeforeClassic, 'test')349    suite5 = unittest.makeSuite(TestCase_E_DoubleCall, 'test')350    suite6 = unittest.makeSuite(TestCase_F_CallThenTrailingInstance, 'test')351    suite7 = unittest.makeSuite(TestCase_G_AttrBeforeRhsInstance, 'test')352    suite8 = unittest.makeSuite(TestCase_AddHoc, 'test')353    alltests = unittest.TestSuite((suite1, suite2, suite3, suite4, suite5, suite6, suite7))354    #alltests = unittest.TestSuite((suite8, ))355    #alltests = unittest.TestSuite((suite3, suite4))356    return alltests357358def main():359    runner = unittest.TextTestRunner(descriptions = 0, verbosity = 2) # default is descriptions=1, verbosity=1360    runner.run(suite())361362if __name__ == '__main__':363    main()
...routing_forest.py
Source:routing_forest.py  
1"""2Tools to specify functions through trees and forests.3Whaaa?!?4Well, you see, often -- especially when writing transformers -- you have a series of5if/then conditions nested into eachother, in code, where it gets ugly and un-reusable.6This module explores ways to objectivy this: That is, to give us the means to create7such nested conditions in a way that we can define the parts as reusable operable8components.9Think of the relationship between the for loop (code) and the iterator (object), along10with iterator tools (itertools).11This is what we're trying to explore, but for if/then conditions.12I said explore. Some more work is needed here to make it easily usable.13Let's look at an example involving the three main actors of our play.14Each of these are ``Iterable`` and ``Callable`` (``Generator`` to be precise).15- ``CondNode``: implements the if/then (no else) logic16- ``FinalNode``: Final -- yields (both with call and iter) it's single `.val` attribute.17- ``RoutingForest``: An Iterable of ``CondNode``18>>> import inspect19>>>20>>> def could_be_int(obj):21...     if isinstance(obj, int):22...         b = True23...     else:24...         try:25...             int(obj)26...             b = True27...         except ValueError:28...             b = False29...     if b:30...         print(f'{inspect.currentframe().f_code.co_name}')31...     return b32...33>>> def could_be_float(obj):34...     if isinstance(obj, float):35...         b = True36...     else:37...         try:38...             float(obj)39...             b = True40...         except ValueError:41...             b = False42...     if b:43...         print(f'{inspect.currentframe().f_code.co_name}')44...     return b45...46>>> print(47...     could_be_int(30),48...     could_be_int(30.3),49...     could_be_int('30.2'),50...     could_be_int('nope'),51... )52could_be_int53could_be_int54True True False False55>>> print(56...     could_be_float(30),57...     could_be_float(30.3),58...     could_be_float('30.2'),59...     could_be_float('nope'),60... )61could_be_float62could_be_float63could_be_float64True True True False65>>> assert could_be_int('30.2') is False66>>> assert could_be_float('30.2') is True67could_be_float68>>>69>>> st = RoutingForest(70...     [71...         CondNode(72...             cond=could_be_int,73...             then=RoutingForest(74...                 [75...                     CondNode(76...                         cond=lambda x: int(x) >= 10,77...                         then=FinalNode('More than a digit'),78...                     ),79...                     CondNode(80...                         cond=lambda x: (int(x) % 2) == 1,81...                         then=FinalNode("That's odd!"),82...                     ),83...                 ]84...             ),85...         ),86...         CondNode(cond=could_be_float, then=FinalNode('could be seen as a float')),87...     ]88... )89>>> assert list(st('nothing I can do with that')) == []90>>> assert list(st(8)) == ['could be seen as a float']91could_be_int92could_be_float93>>> assert list(st(9)) == ["That's odd!", 'could be seen as a float']94could_be_int95could_be_float96>>> assert list(st(10)) == ['More than a digit', 'could be seen as a float']97could_be_int98could_be_float99>>> assert list(st(11)) == [100...     'More than a digit',101...     "That's odd!",102...     'could be seen as a float',103... ]104could_be_int105could_be_float106>>>107>>> print(108...     '### RoutingForest ########################################################################################'109... )110### RoutingForest ########################################################################################111>>> rf = RoutingForest(112...     [113...         SwitchCaseNode(114...             switch=lambda x: x % 5,115...             cases={0: FinalNode('zero_mod_5'), 1: FinalNode('one_mod_5')},116...             default=FinalNode('default_mod_5'),117...         ),118...         SwitchCaseNode(119...             switch=lambda x: x % 2,120...             cases={0: FinalNode('even'), 1: FinalNode('odd')},121...             default=FinalNode('that is not an int'),122...         ),123...     ]124... )125>>>126>>> assert list(rf(5)) == ['zero_mod_5', 'odd']127>>> assert list(rf(6)) == ['one_mod_5', 'even']128>>> assert list(rf(7)) == ['default_mod_5', 'odd']129>>> assert list(rf(8)) == ['default_mod_5', 'even']130>>> assert list(rf(10)) == ['zero_mod_5', 'even']131>>>132"""133from itertools import chain134from dataclasses import dataclass135from typing import Any, Iterable, Callable, Mapping, Tuple136class RoutingNode:137    """A RoutingNode instance needs to be callable on a single object, yielding an iterable or a final value"""138    def __call__(self, obj):139        raise NotImplementedError('You should implement this.')140@dataclass141class FinalNode(RoutingNode):142    """A RoutingNode that is final.143    It yields (both with call and iter) it's single `.val` attribute."""144    val: Any145    def __call__(self, obj=None):146        yield self.val147    def __iter__(self):148        yield self.val149    # def __getstate__(self):150    #     return {'val': self.val}151@dataclass152class CondNode(RoutingNode):153    """A RoutingNode that implements the if/then (no else) logic"""154    cond: Callable[[Any], bool]155    then: Any156    def __call__(self, obj):157        if self.cond(obj):158            yield from self.then(obj)159    def __iter__(self):160        yield from self.then161@dataclass162class RoutingForest(RoutingNode):163    """164    >>> rf = RoutingForest([165    ...     CondNode(cond=lambda x: isinstance(x, int),166    ...              then=RoutingForest([167    ...                  CondNode(cond=lambda x: int(x) >= 10, then=FinalNode('More than a digit')),168    ...                  CondNode(cond=lambda x: (int(x) % 2) == 1, then=FinalNode("That's odd!"))])169    ...             ),170    ...     CondNode(cond=lambda x: isinstance(x, (int, float)),171    ...              then=FinalNode('could be seen as a float')),172    ... ])173    >>> assert list(rf('nothing I can do with that')) == []174    >>> assert list(rf(8)) == ['could be seen as a float']175    >>> assert list(rf(9)) == ["That's odd!", 'could be seen as a float']176    >>> assert list(rf(10)) == ['More than a digit', 'could be seen as a float']177    >>> assert list(rf(11)) == ['More than a digit', "That's odd!", 'could be seen as a float']178    """179    cond_nodes: Iterable180    def __call__(self, obj):181        yield from chain(*(cond_node(obj) for cond_node in self.cond_nodes))182        # for cond_node in self.cond_nodes:183        #     yield from cond_node(obj)184    def __iter__(self):185        yield from chain(*self.cond_nodes)186FeatCondThens = Iterable[Tuple[Callable, Callable]]187@dataclass188class FeatCondNode(RoutingNode):189    """A RoutingNode that yields multiple routes, one for each of several conditions met,190    where the condition is computed implements computes a feature of the obj and according to a"""191    feat: Callable192    feat_cond_thens: FeatCondThens193    def __call__(self, obj):194        feature = self.feat(obj)195        for cond, then in self.feat_cond_thens:196            if cond(feature):197                yield from then(obj)198    def __iter__(self):199        yield from chain(*self.feat_cond_thens.values())200NoDefault = type('NoDefault', (object,), {})201NO_DFLT = NoDefault()202@dataclass203class SwitchCaseNode(RoutingNode):204    """A RoutingNode that implements the switch/case/else logic.205    It's just a specialization (enhanced with a "default" option) of the FeatCondNode class to a situation206    where the cond function of feat_cond_thens is equality, therefore the routing can be207    implemented with a {value_to_compare_to_feature: then_node} map.208    :param switch: A function returning the feature of an object we want to switch on209    :param cases: The mapping from feature to RoutingNode that should be yield for that feature.210        Often is a dict, but only requirement is that it implements the cases.get(val, default) method.211    :param default: Default RoutingNode to yield if no212    >>> rf = RoutingForest([213    ...     SwitchCaseNode(switch=lambda x: x % 5,214    ...                    cases={0: FinalNode('zero_mod_5'), 1: FinalNode('one_mod_5')},215    ...                    default=FinalNode('default_mod_5')),216    ...     SwitchCaseNode(switch=lambda x: x % 2,217    ...                    cases={0: FinalNode('even'), 1: FinalNode('odd')},218    ...                    default=FinalNode('that is not an int')),219    ... ])220    >>>221    >>> assert(list(rf(5)) == ['zero_mod_5', 'odd'])222    >>> assert(list(rf(6)) == ['one_mod_5', 'even'])223    >>> assert(list(rf(7)) == ['default_mod_5', 'odd'])224    >>> assert(list(rf(8)) == ['default_mod_5', 'even'])225    >>> assert(list(rf(10)) == ['zero_mod_5', 'even'])226    """227    switch: Callable228    cases: Mapping229    default: Any = NO_DFLT230    def __call__(self, obj):231        feature = self.switch(obj)232        if self.default is NO_DFLT:233            yield from self.cases.get(feature)(obj)234        else:235            yield from self.cases.get(feature, self.default)(obj)236    def __iter__(self):237        yield from chain(*self.cases.values())238        if self.default:239            yield self.default240def wrap_leafs_with_final_node(x):241    for xx in x:242        if isinstance(xx, RoutingNode):243            yield xx244        else:245            yield FinalNode(xx)246if __name__ == '__main__':247    print(248        '##########################################################################################################'249    )250    import inspect251    def could_be_int(obj):252        if isinstance(obj, int):253            b = True254        else:255            try:256                int(obj)257                b = True258            except ValueError:259                b = False260        if b:261            print(f'{inspect.currentframe().f_code.co_name}')262        return b263    def could_be_float(obj):264        if isinstance(obj, float):265            b = True266        else:267            try:268                float(obj)269                b = True270            except ValueError:271                b = False272        if b:273            print(f'{inspect.currentframe().f_code.co_name}')274        return b275    print(276        could_be_int(30),277        could_be_int(30.3),278        could_be_int('30.2'),279        could_be_int('nope'),280    )281    print(282        could_be_float(30),283        could_be_float(30.3),284        could_be_float('30.2'),285        could_be_float('nope'),286    )287    assert could_be_int('30.2') is False288    assert could_be_float('30.2') is True289    st = RoutingForest(290        [291            CondNode(292                cond=could_be_int,293                then=RoutingForest(294                    [295                        CondNode(296                            cond=lambda x: int(x) >= 10,297                            then=FinalNode('More than a digit'),298                        ),299                        CondNode(300                            cond=lambda x: (int(x) % 2) == 1,301                            then=FinalNode("That's odd!"),302                        ),303                    ]304                ),305            ),306            CondNode(cond=could_be_float, then=FinalNode('could be seen as a float')),307        ]308    )309    assert list(st('nothing I can do with that')) == []310    assert list(st(8)) == ['could be seen as a float']311    assert list(st(9)) == ["That's odd!", 'could be seen as a float']312    assert list(st(10)) == ['More than a digit', 'could be seen as a float']313    assert list(st(11)) == [314        'More than a digit',315        "That's odd!",316        'could be seen as a float',317    ]318    print(319        '### RoutingForest ########################################################################################'320    )321    rf = RoutingForest(322        [323            SwitchCaseNode(324                switch=lambda x: x % 5,325                cases={0: FinalNode('zero_mod_5'), 1: FinalNode('one_mod_5')},326                default=FinalNode('default_mod_5'),327            ),328            SwitchCaseNode(329                switch=lambda x: x % 2,330                cases={0: FinalNode('even'), 1: FinalNode('odd')},331                default=FinalNode('that is not an int'),332            ),333        ]334    )335    assert list(rf(5)) == ['zero_mod_5', 'odd']336    assert list(rf(6)) == ['one_mod_5', 'even']337    assert list(rf(7)) == ['default_mod_5', 'odd']338    assert list(rf(8)) == ['default_mod_5', 'even']...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
