Best Python code snippet using pytest-play_python
data_io.py
Source:data_io.py  
1import asyncio2import os3from collections import Counter, OrderedDict4from multiprocessing import Pool5import numpy as np6import torch7import torch.nn as nn8import torch.nn.functional as F9from kaldi_io import read_mat, read_vec_flt10from sklearn.preprocessing import KBinsDiscretizer, LabelEncoder, StandardScaler11from torch.utils.data import Dataset12class MissingClassMapError(Exception):13    pass14def load_n_col(file, numpy=False):15    data = []16    with open(file) as fp:17        for line in fp:18            data.append(line.strip().split(" "))19    columns = list(zip(*data))20    if numpy:21        columns = [np.array(list(i)) for i in columns]22    else:23        columns = [list(i) for i in columns]24    return columns25def odict_from_2_col(file, numpy=False):26    col0, col1 = load_n_col(file, numpy=numpy)27    return OrderedDict({c0: c1 for c0, c1 in zip(col0, col1)})28def load_one_tomany(file, numpy=False):29    one = []30    many = []31    with open(file) as fp:32        for line in fp:33            line = line.strip().split(" ", 1)34            one.append(line[0])35            m = line[1].split(" ")36            many.append(np.array(m) if numpy else m)37    if numpy:38        one = np.array(one)39    return one, many40def train_transform(feats, seqlen):41    leeway = feats.shape[0] - seqlen42    startslice = np.random.randint(0, int(leeway)) if leeway > 0 else 043    feats = (44        feats[startslice : startslice + seqlen]45        if leeway > 046        else np.pad(feats, [(0, -leeway), (0, 0)], "constant")47    )48    return torch.FloatTensor(feats)49async def get_item_train(instructions):50    fpath = instructions[0]51    seqlen = instructions[1]52    raw_feats = read_mat(fpath)53    feats = train_transform(raw_feats, seqlen)54    return feats55async def get_item_test(filepath):56    raw_feats = read_mat(filepath)57    return torch.FloatTensor(raw_feats)58def async_map(coroutine_func, iterable):59    loop = asyncio.get_event_loop()60    future = asyncio.gather(*(coroutine_func(param) for param in iterable))61    return loop.run_until_complete(future)62class SpeakerDataset(Dataset):63    def __init__(64        self,65        data_base_path,66        real_speaker_labels=True,67        asynchr=True,68        num_workers=3,69        test_mode=False,70        class_enc_dict={},71        **kwargs72    ):73        self.data_base_path = data_base_path74        self.num_workers = num_workers75        self.test_mode = test_mode76        self.real_speaker_labels = real_speaker_labels77        # self.label_types = label_types78        if self.test_mode:79            self.label_types = []80        else:81            self.label_types = ["speaker"] if self.real_speaker_labels else []82        if os.path.isfile(os.path.join(data_base_path, "spk2nat")):83            self.label_types.append("nationality")84        if os.path.isfile(os.path.join(data_base_path, "spk2gender")):85            self.label_types.append("gender")86        if os.path.isfile(os.path.join(data_base_path, "utt2age")):87            self.label_types.append("age_regression")88            self.label_types.append("age")89        if os.path.isfile(os.path.join(data_base_path, "utt2rec")):90            self.label_types.append("rec")91        if os.path.isfile(os.path.join(data_base_path, "utt2genre")):92            self.label_types.append("genre")93        if self.test_mode and self.label_types:94            assert class_enc_dict, "Class mapping must be passed to test mode dataset"95        self.class_enc_dict = class_enc_dict96        utt2spk_path = os.path.join(data_base_path, "utt2spk")97        spk2utt_path = os.path.join(data_base_path, "spk2utt")98        feats_scp_path = os.path.join(data_base_path, "feats.scp")99        assert os.path.isfile(utt2spk_path)100        assert os.path.isfile(feats_scp_path)101        assert os.path.isfile(spk2utt_path)102        verilist_path = os.path.join(data_base_path, "veri_pairs")103        if self.test_mode:104            if os.path.isfile(verilist_path):105                self.veri_labs, self.veri_0, self.veri_1 = load_n_col(106                    verilist_path, numpy=True107                )108                self.veri_labs = self.veri_labs.astype(int)109                self.veripairs = True110            else:111                self.veripairs = False112        self.utts, self.uspkrs = load_n_col(utt2spk_path)113        self.utt_fpath_dict = odict_from_2_col(feats_scp_path)114        self.label_enc = LabelEncoder()115        self.original_spkrs, self.spkutts = load_one_tomany(spk2utt_path)116        self.spkrs = self.label_enc.fit_transform(self.original_spkrs)117        self.spk_utt_dict = OrderedDict(118            {k: v for k, v in zip(self.spkrs, self.spkutts)}119        )120        self.spk_original_spk_dict = {121            k: v for k, v in zip(self.spkrs, self.original_spkrs)122        }123        self.uspkrs = self.label_enc.transform(self.uspkrs)124        self.utt_spkr_dict = OrderedDict({k: v for k, v in zip(self.utts, self.uspkrs)})125        self.utt_list = list(self.utt_fpath_dict.keys())126        self.first_batch = True127        self.num_classes = (128            {"speaker": len(self.label_enc.classes_)}129            if self.real_speaker_labels130            else {}131        )132        self.asynchr = asynchr133        if "nationality" in self.label_types:134            self.natspkrs, self.nats = load_n_col(135                os.path.join(data_base_path, "spk2nat")136            )137            self.nats = [n.lower().strip() for n in self.nats]138            self.natspkrs = self.label_enc.transform(self.natspkrs)139            self.nat_label_enc = LabelEncoder()140            if not self.test_mode:141                self.nats = self.nat_label_enc.fit_transform(self.nats)142            else:143                self.nat_label_enc = self.class_enc_dict["nationality"]144                self.nats = self.nat_label_enc.transform(self.nats)145            self.spk_nat_dict = OrderedDict(146                {k: v for k, v in zip(self.natspkrs, self.nats)}147            )148            self.num_classes["nationality"] = len(self.nat_label_enc.classes_)149        if "gender" in self.label_types:150            self.genspkrs, self.genders = load_n_col(151                os.path.join(data_base_path, "spk2gender")152            )153            self.genspkrs = self.label_enc.transform(self.genspkrs)154            self.gen_label_enc = LabelEncoder()155            if not self.test_mode:156                self.genders = self.gen_label_enc.fit_transform(self.genders)157            else:158                self.gen_label_enc = self.class_enc_dict["gender"]159                self.genders = self.gen_label_enc.transform(self.genders)160            self.spk_gen_dict = OrderedDict(161                {k: v for k, v in zip(self.genspkrs, self.genders)}162            )163            self.num_classes["gender"] = len(self.gen_label_enc.classes_)164        if "age" in self.label_types:165            # self.genspkrs, self.genders = load_n_col(os.path.join(data_base_path, 'spk2gender'))166            self.num_age_bins = (167                kwargs["num_age_bins"] if "num_age_bins" in kwargs else 10168            )169            self.ageutts, self.ages = load_n_col(170                os.path.join(data_base_path, "utt2age")171            )172            self.ages = np.array(self.ages).astype(np.float)173            self.age_label_enc = KBinsDiscretizer(174                n_bins=self.num_age_bins, encode="ordinal", strategy="uniform"175            )176            if not self.test_mode or "age" not in self.class_enc_dict:177                self.age_classes = self.age_label_enc.fit_transform(178                    np.array(self.ages).reshape(-1, 1)179                ).flatten()180            else:181                self.age_label_enc = self.class_enc_dict["age"]182                self.age_classes = self.age_label_enc.transform(183                    np.array(self.ages).reshape(-1, 1)184                ).flatten()185            self.utt_age_class_dict = OrderedDict(186                {k: v for k, v in zip(self.ageutts, self.age_classes)}187            )188            self.num_classes["age"] = self.num_age_bins189        if "age_regression" in self.label_types:190            # self.genspkrs, self.genders = load_n_col(os.path.join(data_base_path, 'spk2gender'))191            self.ageutts, self.ages = load_n_col(192                os.path.join(data_base_path, "utt2age")193            )194            self.ages = np.array(self.ages).astype(np.float)195            self.age_reg_enc = StandardScaler()196            if not self.test_mode or "age_regression" not in self.class_enc_dict:197                self.ages = self.age_reg_enc.fit_transform(198                    np.array(self.ages).reshape(-1, 1)199                ).flatten()200            else:201                self.age_reg_enc = self.class_enc_dict["age_regression"]202                self.ages = self.age_reg_enc.transform(203                    np.array(self.ages).reshape(-1, 1)204                ).flatten()205            self.utt_age_dict = OrderedDict(206                {k: v for k, v in zip(self.ageutts, self.ages)}207            )208            self.num_classes["age_regression"] = 1209        if "rec" in self.label_types:210            self.recutts, self.recs = load_n_col(211                os.path.join(data_base_path, "utt2rec")212            )213            self.recs = np.array(self.recs)214            self.rec_label_enc = LabelEncoder()215            if not self.test_mode:216                self.recs = self.rec_label_enc.fit_transform(self.recs)217            else:218                self.rec_label_enc = self.class_enc_dict["rec"]219                self.recs = self.rec_label_enc.transform(self.recs)220            self.utt_rec_dict = OrderedDict(221                {k: v for k, v in zip(self.recutts, self.recs)}222            )223            self.num_classes["rec"] = len(self.rec_label_enc.classes_)224        if "genre" in self.label_types:225            self.genreutts, self.genres = load_n_col(226                os.path.join(data_base_path, "utt2genre")227            )228            self.genres = np.array(self.genres)229            self.genre_label_enc = LabelEncoder()230            if not self.test_mode:231                self.genres = self.genre_label_enc.fit_transform(self.genres)232                self.utt_genre_dict = OrderedDict(233                    {k: v for k, v in zip(self.genreutts, self.genres)}234                )235                self.num_classes["genre"] = len(self.genre_label_enc.classes_)236            else:237                # TODO: add this check to other attributes238                if "genre" in self.class_enc_dict:239                    self.genre_label_enc = self.class_enc_dict["genre"]240                    self.genres = self.genre_label_enc.transform(self.genres)241                    self.utt_genre_dict = OrderedDict(242                        {k: v for k, v in zip(self.genreutts, self.genres)}243                    )244                    self.num_classes["genre"] = len(self.genre_label_enc.classes_)245                else:246                    self.label_types.remove("genre")247        self.class_enc_dict = self.get_class_encs()248    def __len__(self):249        return len(self.utt_list)250    def get_class_encs(self):251        class_enc_dict = {}252        if "speaker" in self.label_types:253            class_enc_dict["speaker"] = self.label_enc254        if "age" in self.label_types:255            class_enc_dict["age"] = self.age_label_enc256        if "age_regression" in self.label_types:257            class_enc_dict["age_regression"] = self.age_reg_enc258        if "nationality" in self.label_types:259            class_enc_dict["nationality"] = self.nat_label_enc260        if "gender" in self.label_types:261            class_enc_dict["gender"] = self.gen_label_enc262        if "rec" in self.label_types:263            class_enc_dict["rec"] = self.rec_label_enc264        if "genre" in self.label_types:265            class_enc_dict["genre"] = self.genre_label_enc266        self.class_enc_dict = class_enc_dict267        return class_enc_dict268    @staticmethod269    def get_item(instructions):270        fpath = instructions[0]271        seqlen = instructions[1]272        feats = read_mat(fpath)273        feats = train_transform(feats, seqlen)274        return feats275    def get_item_test(self, idx):276        utt = self.utt_list[idx]277        fpath = self.utt_fpath_dict[utt]278        feats = read_mat(fpath)279        feats = torch.FloatTensor(feats)280        label_dict = {}281        speaker = self.utt_spkr_dict[utt]282        if "speaker" in self.label_types:283            label_dict["speaker"] = torch.LongTensor([speaker])284        if "gender" in self.label_types:285            label_dict["gender"] = torch.LongTensor([self.spk_gen_dict[speaker]])286        if "nationality" in self.label_types:287            label_dict["nationality"] = torch.LongTensor([self.spk_nat_dict[speaker]])288        if "age" in self.label_types:289            label_dict["age"] = torch.LongTensor([self.utt_age_class_dict[utt]])290        if "age_regression" in self.label_types:291            label_dict["age_regression"] = torch.FloatTensor([self.utt_age_dict[utt]])292        if "genre" in self.label_types:293            label_dict["genre"] = torch.LongTensor([self.utt_genre_dict[utt]])294        return feats, label_dict295    def get_test_items(self, num_items=-1, exclude_speakers=None, use_async=True):296        utts = self.utt_list297        if num_items >= 1:298            replace = len(utts) <= num_items299            utts = np.random.choice(utts, size=num_items, replace=replace)300        utts = np.array(utts)301        spkrs = np.array([self.utt_spkr_dict[utt] for utt in utts])302        original_spkrs = np.array([self.spk_original_spk_dict[spkr] for spkr in spkrs])303        if exclude_speakers:304            mask = np.array(305                [False if s in exclude_speakers else True for s in original_spkrs]306            )307            utts = utts[mask]308            spkrs = spkrs[mask]309            original_spkrs = original_spkrs[mask]310        fpaths = [self.utt_fpath_dict[utt] for utt in utts]311        if use_async:312            feats = async_map(get_item_test, fpaths)313        else:314            feats = [torch.FloatTensor(read_mat(f)) for f in fpaths]315        label_dict = {}316        label_dict["speaker"] = np.array(spkrs)317        label_dict["original_speaker"] = np.array(original_spkrs)318        if "nationality" in self.label_types:319            label_dict["nationality"] = np.array([self.spk_nat_dict[s] for s in spkrs])320        if "gender" in self.label_types:321            label_dict["gender"] = np.array([self.spk_gen_dict[s] for s in spkrs])322        if "age" in self.label_types:323            label_dict["age"] = np.array([self.utt_age_class_dict[utt] for utt in utts])324        if "age_regression" in self.label_types:325            label_dict["age_regression"] = np.array(326                [self.utt_age_dict[utt] for utt in utts]327            )328        if "genre" in self.label_types:329            label_dict["genre"] = np.array([self.utt_genre_dict[utt] for utt in utts])330        return feats, label_dict, utts331    def get_batches(self, batch_size=256, max_seq_len=400, sp_tensor=True):332        """333        Main data iterator, specify batch_size and max_seq_len334        sp_tensor determines whether speaker labels are returned as Tensor object or not335        """336        # with Parallel(n_jobs=self.num_workers) as parallel:337        self.idpool = self.spkrs.copy()338        assert batch_size < len(339            self.idpool340        )  # Metric learning assumption large num classes341        lens = [max_seq_len for _ in range(batch_size)]342        while True:343            if len(self.idpool) <= batch_size:344                batch_ids = np.array(self.idpool)345                self.idpool = self.spkrs.copy()346                rem_ids = np.random.choice(347                    self.idpool, size=batch_size - len(batch_ids), replace=False348                )349                batch_ids = np.concatenate([batch_ids, rem_ids])350                self.idpool = list(set(self.idpool) - set(rem_ids))351            else:352                batch_ids = np.random.choice(353                    self.idpool, size=batch_size, replace=False354                )355                self.idpool = list(set(self.idpool) - set(batch_ids))356            batch_fpaths = []357            batch_utts = []358            for i in batch_ids:359                utt = np.random.choice(self.spk_utt_dict[i])360                batch_utts.append(utt)361                batch_fpaths.append(self.utt_fpath_dict[utt])362            if self.asynchr:363                batch_feats = async_map(get_item_train, zip(batch_fpaths, lens))364            else:365                batch_feats = [self.get_item(a) for a in zip(batch_fpaths, lens)]366            # batch_feats = parallel(delayed(self.get_item)(a) for a in zip(batch_fpaths, lens))367            label_dict = {}368            if "speaker" in self.label_types:369                label_dict["speaker"] = (370                    torch.LongTensor(batch_ids) if sp_tensor else batch_ids371                )372            if "nationality" in self.label_types:373                label_dict["nationality"] = torch.LongTensor(374                    [self.spk_nat_dict[s] for s in batch_ids]375                )376            if "gender" in self.label_types:377                label_dict["gender"] = torch.LongTensor(378                    [self.spk_gen_dict[s] for s in batch_ids]379                )380            if "age" in self.label_types:381                label_dict["age"] = torch.LongTensor(382                    [self.utt_age_class_dict[u] for u in batch_utts]383                )384            if "age_regression" in self.label_types:385                label_dict["age_regression"] = torch.FloatTensor(386                    [self.utt_age_dict[u] for u in batch_utts]387                )388            if "rec" in self.label_types:389                label_dict["rec"] = torch.LongTensor(390                    [self.utt_rec_dict[u] for u in batch_utts]391                )392            if "genre" in self.label_types:393                label_dict["genre"] = torch.LongTensor(394                    [self.utt_genre_dict[u] for u in batch_utts]395                )396            yield torch.stack(batch_feats), label_dict397    def get_batches_naive(self, batch_size=256, max_seq_len=400, sp_tensor=True):398        """399        Main data iterator, specify batch_size and max_seq_len400        sp_tensor determines whether speaker labels are returned as Tensor object or not401        """402        self.idpool = self.spkrs.copy()403        # assert batch_size < len(self.idpool) #Metric learning assumption large num classes404        lens = [max_seq_len for _ in range(batch_size)]405        while True:406            batch_ids = np.random.choice(self.idpool, size=batch_size)407            batch_fpaths = []408            batch_utts = []409            for i in batch_ids:410                utt = np.random.choice(self.spk_utt_dict[i])411                batch_utts.append(utt)412                batch_fpaths.append(self.utt_fpath_dict[utt])413            if self.asynchr:414                batch_feats = async_map(get_item_train, zip(batch_fpaths, lens))415            else:416                batch_feats = [self.get_item(a) for a in zip(batch_fpaths, lens)]417            # batch_feats = parallel(delayed(self.get_item)(a) for a in zip(batch_fpaths, lens))418            label_dict = {}419            if "speaker" in self.label_types:420                label_dict["speaker"] = (421                    torch.LongTensor(batch_ids) if sp_tensor else batch_ids422                )423            if "nationality" in self.label_types:424                label_dict["nationality"] = torch.LongTensor(425                    [self.spk_nat_dict[s] for s in batch_ids]426                )427            if "gender" in self.label_types:428                label_dict["gender"] = torch.LongTensor(429                    [self.spk_gen_dict[s] for s in batch_ids]430                )431            if "age" in self.label_types:432                label_dict["age"] = torch.LongTensor(433                    [self.utt_age_class_dict[u] for u in batch_utts]434                )435            if "age_regression" in self.label_types:436                label_dict["age_regression"] = torch.FloatTensor(437                    [self.utt_age_dict[u] for u in batch_utts]438                )439            if "rec" in self.label_types:440                label_dict["rec"] = torch.LongTensor(441                    [self.utt_rec_dict[u] for u in batch_utts]442                )443            if "genre" in self.label_types:444                label_dict["genre"] = torch.LongTensor(445                    [self.utt_genre_dict[u] for u in batch_utts]446                )447            yield torch.stack(batch_feats), label_dict448    def get_batches_balance(449        self, balance_attribute="speaker", batch_size=256, max_seq_len=400450    ):451        """452        Main data iterator, specify batch_size and max_seq_len453        Specify which attribute to balance454        """455        assert balance_attribute in self.label_types456        if balance_attribute == "speaker":457            self.anchorpool = self.spkrs.copy()458            self.get_utt_method = lambda x: np.random.choice(self.spk_utt_dict[x])459        if balance_attribute == "nationality":460            self.anchorpool = sorted(list(set(self.nats)))461            self.nat_utt_dict = OrderedDict({k: [] for k in self.anchorpool})462            for u in self.utt_list:463                spk = self.utt_spkr_dict[u]464                nat = self.spk_nat_dict[spk]465                self.nat_utt_dict[nat].append(u)466            for n in self.nat_utt_dict:467                self.nat_utt_dict[u] = np.array(self.nat_utt_dict[u])468            self.get_utt_method = lambda x: np.random.choice(self.nat_utt_dict[x])469        if balance_attribute == "age":470            self.anchorpool = sorted(list(set(self.age_classes)))471            self.age_utt_dict = OrderedDict({k: [] for k in self.anchorpool})472            for u in self.utt_age_class_dict:473                nat_class = self.utt_age_class_dict[u]474                self.age_utt_dict[nat_class].append(u)475            for a in self.age_utt_dict:476                self.age_utt_dict[a] = np.array(self.age_utt_dict[a])477            self.get_utt_method = lambda x: np.random.choice(self.age_utt_dict[x])478        lens = [max_seq_len for _ in range(batch_size)]479        while True:480            anchors = np.random.choice(self.anchorpool, size=batch_size)481            batch_utts = [self.get_utt_method(a) for a in anchors]482            batch_fpaths = []483            batch_ids = []484            for utt in batch_utts:485                batch_fpaths.append(self.utt_fpath_dict[utt])486                batch_ids.append(self.utt_spkr_dict[utt])487            if self.asynchr:488                batch_feats = async_map(get_item_train, zip(batch_fpaths, lens))489            else:490                batch_feats = [self.get_item(a) for a in zip(batch_fpaths, lens)]491            label_dict = {}492            if "speaker" in self.label_types:493                label_dict["speaker"] = torch.LongTensor(batch_ids)494            if "nationality" in self.label_types:495                label_dict["nationality"] = torch.LongTensor(496                    [self.spk_nat_dict[s] for s in batch_ids]497                )498            if "gender" in self.label_types:499                label_dict["gender"] = torch.LongTensor(500                    [self.spk_gen_dict[s] for s in batch_ids]501                )502            if "age" in self.label_types:503                label_dict["age"] = torch.LongTensor(504                    [self.utt_age_class_dict[u] for u in batch_utts]505                )506            if "age_regression" in self.label_types:507                label_dict["age_regression"] = torch.FloatTensor(508                    [self.utt_age_dict[u] for u in batch_utts]509                )510            if "rec" in self.label_types:511                label_dict["rec"] = torch.LongTensor(512                    [self.utt_rec_dict[u] for u in batch_utts]513                )514            if "genre" in self.label_types:515                label_dict["genre"] = torch.LongTensor(516                    [self.utt_genre_dict[u] for u in batch_utts]517                )518            yield torch.stack(batch_feats), label_dict519    def get_alldata_batches(self, batch_size=256, max_seq_len=400):520        utt_list = self.utt_list521        start_index = 0522        lens = [max_seq_len for _ in range(batch_size)]523        while start_index <= len(utt_list):524            batch_utts = utt_list[start_index : start_index + batch_size]525            batch_fpaths = []526            batch_ids = []527            for utt in batch_utts:528                batch_fpaths.append(self.utt_fpath_dict[utt])529                batch_ids.append(self.utt_spkr_dict[utt])530            if self.asynchr:531                batch_feats = async_map(get_item_train, zip(batch_fpaths, lens))532            else:533                batch_feats = [self.get_item(a) for a in zip(batch_fpaths, lens)]...load_dataset.py
Source:load_dataset.py  
1import csv2import numpy as np3import os4import sys5import pickle6# BASE PATH DEFINITIONS7DATA_BASE_PATH = '.'8OUTPUT_BASE_PATH = '.'9default_FileName_PositiveInstancesDictionnary = os.path.join(DATA_BASE_PATH, "dictionnaries_and_lists/SmallMolMWFilter_UniprotHumanProt_DrugBank_Dictionary.csv")10default_FileName_ListProt = os.path.join(DATA_BASE_PATH, "dictionnaries_and_lists/list_MWFilter_UniprotHumanProt.txt")11default_FileName_ListMol = os.path.join(DATA_BASE_PATH, "dictionnaries_and_lists/list_MWFilter_mol.txt")12default_FileName_MolKernel = os.path.join(DATA_BASE_PATH, "kernels/kernels.data/Tanimoto_d=8_DrugBankSmallMolMWFilterHuman.data")13default_FileName_DicoMolKernel_indice2instance = os.path.join(DATA_BASE_PATH, "kernels/dict/dico_indice2mol_InMolKernel.data")14default_FileName_DicoMolKernel_instance2indice = os.path.join(DATA_BASE_PATH, "kernels/dict/dico_mol2indice_InMolKernel.data")15def load_dataset(FileName_PositiveInstancesDictionnary=default_FileName_PositiveInstancesDictionnary, FileName_ListProt=default_FileName_ListProt, FileName_ListMol=default_FileName_ListMol, FileName_MolKernel=default_FileName_MolKernel, FileName_DicoMolKernel_indice2instance=default_FileName_DicoMolKernel_indice2instance, FileName_DicoMolKernel_instance2indice=default_FileName_DicoMolKernel_instance2indice):16	"""17	Loading the dataset and the molecule kernel18	:param FileName_PositiveInstancesDictionnary: (string) tsv file name: each line corresponds to a molecule; 19									1rst column: gives the DrugBank ID of the molecule20									2nd column: gives the number of targets of the corresponding molecule21									other columns: gives the UniprotIDs of molecule targets (one per column)22	:param FileName_ListProt: (string) txt file name: each line gives the UniprotID of a protein of the dataset23	:param FileName_ListMol: (string) txt file name: each line gives the DrugBankID of a molecule of the dataset24	:param FileName_kernel: (string)  pickle file name: contains the molecule kernel (np.array)25	:param FileName_DicoKernel_indice2instance: (string) pickle file name: contains the dictionnary linking indices of the molecule kernel26									to its corresponding molecule ID27	:param FileName_DicoKernel_instance2indice: (string)  pickle file name: contains the dictionnary linking molecule IDs to indices 28									in the molecule kernel29	30	:return K_mol: (np.array: number of mol^2) molecule kernel31	:return DicoMolKernel_ind2mol: (dictionnary) keys are indices of the molecule kernel (i.e. integers between 0 and number_of_mol)32								    and corresponding values are DrugbankIDS of the molecule corresponding to the index33	:return DicoMolKernel_mol2ind: (dictionnary) keys are DrugbankIDs and values are their corresponding indices of the molecule kernel34	:return interaction_matrix: (np.array: number_of_mol*number_of_prot) array whose values are 1 if the molecule/protein couple35						is interaction or 0 otherwise36	"""37	##loading molecule kernel and its associated dictionnaries38	with open(FileName_MolKernel, 'rb') as fichier:39		pickler = pickle.Unpickler(fichier)40		K_mol = pickler.load().astype(np.float32)41	with open(FileName_DicoMolKernel_indice2instance, 'rb') as fichier:42		pickler = pickle.Unpickler(fichier)43		DicoMolKernel_ind2mol = pickler.load()44	with open(FileName_DicoMolKernel_instance2indice, 'rb') as fichier:45		pickler = pickle.Unpickler(fichier)46		DicoMolKernel_mol2ind = pickler.load()47	48	##charging protein list of dataset49	list_prot_of_dataset = []50	f_in = open(FileName_ListProt, 'r')51	for line in f_in:52		list_prot_of_dataset.append(line.rstrip())53	f_in.close()54	##charging list_mol_of_dataset55	list_mol_of_dataset = []56	f_in = open(FileName_ListMol, 'r')57	for line in f_in:58		list_mol_of_dataset.append(line.rstrip())59	f_in.close()60	##charging list of targets per molecule of the dataset61	#initialization62	dico_targets_per_mol = {}63	for mol in list_mol_of_dataset:64		dico_targets_per_mol[mol] = []65		66	#filling67	f_in = open(FileName_PositiveInstancesDictionnary, 'r')68	reader = csv.reader(f_in, delimiter='\t')69	for row in reader:70		nb_prot = int(row[1])71		for j in range(nb_prot):72			dico_targets_per_mol[row[0]].append(row[2+j])73	del reader74	f_in.close()75	76	##making interaction_matrix77	interaction_matrix = np.zeros((len(list_mol_of_dataset), len(list_prot_of_dataset)), dtype=np.float32)78	for i in range(len(list_mol_of_dataset)):79		list_of_targets = dico_targets_per_mol[list_mol_of_dataset[i]]80		nb=081		for j in range(len(list_prot_of_dataset)):82			if list_prot_of_dataset[j] in list_of_targets:83				interaction_matrix[i,j] = 184				nb+=185		###FOR TESTING86		#if len(list_of_targets)!=nb:87		#	print("alerte")88		#	exit(1)89	90	return K_mol, DicoMolKernel_ind2mol, DicoMolKernel_mol2ind, interaction_matrix91###FOR TESTING	92#K_mol, DicoMolKernel_ind2mol, DicoMolKernel_mol2ind, interaction_matrix = load_dataset(FileName_PositiveInstancesDictionnary, FileName_ListProt, FileName_ListMol, FileName_MolKernel, FileName_DicoMolKernel_indice2instance, FileName_DicoMolKernel_instance2indice)93	94	95	96	97	...config.py
Source:config.py  
1import os23DATA_BASE_PATH = "dataset"45FILE_NAME1 = os.path.sep.join([DATA_BASE_PATH, "conllpp_dev.txt"])6FILE_NAME2 = os.path.sep.join([DATA_BASE_PATH, "conllpp_test.txt"])7FILE_NAME3 = os.path.sep.join([DATA_BASE_PATH, "conllpp_train.txt"])8910NEW_FILE_NAME1 = os.path.sep.join([DATA_BASE_PATH, "conllpp_dev.csv"])11NEW_FILE_NAME2 = os.path.sep.join([DATA_BASE_PATH, "conllpp_test.csv"])12NEW_FILE_NAME3 = os.path.sep.join([DATA_BASE_PATH, "conllpp_train.csv"])131415UP_FILE_NAME1 = os.path.sep.join([DATA_BASE_PATH, "conllpp_up_dev.csv"])16UP_FILE_NAME2 = os.path.sep.join([DATA_BASE_PATH, "conllpp_up_test.csv"])17UP_FILE_NAME3 = os.path.sep.join([DATA_BASE_PATH, "conllpp_up_train.csv"])181920CHECKPOINT_PATH = "checkpoint"2122CHECKPOINT1 = os.path.sep.join([CHECKPOINT_PATH, "checkpoint_saved.pth"])23CHECKPOINT2 = os.path.sep.join([CHECKPOINT_PATH, "checkpoint.pth"])24CHECKPOINT3 = os.path.sep.join([CHECKPOINT_PATH, "model_scripted.pt"])2526CHECKPOINT4 = os.path.sep.join([CHECKPOINT_PATH, "model_last.pt"])
...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
