Best Python code snippet using yandex-tank
segmenter.py
Source:segmenter.py  
1# -*- coding: utf-8 -*-2"""3@author: Yan Shao, yan.shao@lingfil.uu.se4"""5import reader6import toolbox7from model import Model8from transducer_model import Seq2seq9import sys10import argparse11import os12import codecs13import tensorflow as tf14import cPickle as pickle15from time import time16parser = argparse.ArgumentParser(description='A Universal Tokeniser. Written by Y. Shao, Uppsala University')17parser.add_argument('action', default='tag', choices=['train', 'test', 'tag'], help='train, test or tag')18parser.add_argument('-f', '--format', default='conll', help='Data format of different tasks, conll, mlp1 or mlp2')19parser.add_argument('-p', '--path', default=None, help='Path of the workstation')20parser.add_argument('-t', '--train', default=None, help='File for training')21parser.add_argument('-d', '--dev', default=None, help='File for validation')22parser.add_argument('-e', '--test', default=None, help='File for evaluation')23parser.add_argument('-r', '--raw', default=None, help='Raw file for tagging')24parser.add_argument('-m', '--model', default='trained_model', help='Name of the trained model')25parser.add_argument('-crf', '--crf', default=1, type=int, help='Using CRF interface')26parser.add_argument('-bt', '--bucket_size', default=50, type=int, help='Bucket size')27parser.add_argument('-sl', '--sent_limit', default=300, type=int, help='Long sentences will be chopped')28parser.add_argument('-tg', '--tags', default='BIES', help='Boundary Tagging, default is BIES')29parser.add_argument('-ed', '--emb_dimension', default=50, type=int, help='Dimension of the embeddings')30parser.add_argument('-emb', '--embeddings', default=None, help='Path and name of pre-trained char embeddings')31parser.add_argument('-ng', '--ngram', default=1, type=int, help='Using ngrams')32parser.add_argument('-cell', '--cell', default='gru', help='Use GRU as the recurrent cell', choices=['gru', 'lstm'])33parser.add_argument('-rnn', '--rnn_cell_dimension', default=200, type=int, help='Dimension of the RNN cells')34parser.add_argument('-layer', '--rnn_layer_number', default=1, type=int, help='Numbers of the RNN layers')35parser.add_argument('-dr', '--dropout_rate', default=0.5, type=float, help='Dropout rate')36parser.add_argument('-iter', '--epochs', default=30, type=int, help='Numbers of epochs')37parser.add_argument('-iter_trans', '--epochs_trans', default=50, type=int, help='Epochs for training the transducer')38parser.add_argument('-op', '--optimizer', default='adagrad', help='Optimizer')39parser.add_argument('-lr', '--learning_rate', default=0.2, type=float, help='Initial learning rate')40parser.add_argument('-lr_trans', '--learning_rate_trans', default=0.3, type=float, help='Initial learning rate')41parser.add_argument('-ld', '--decay_rate', default=0.05, type=float, help='Learning rate decay')42parser.add_argument('-mt', '--momentum', default=None, type=float, help='Momentum')43parser.add_argument('-ncp', '--no_clipping', default=False, action='store_true', help='Do not apply gradient clipping')44parser.add_argument("-tb","--train_batch", help="Training batch size", default=10, type=int)45parser.add_argument("-eb","--test_batch", help="Testing batch size", default=500, type=int)46parser.add_argument("-rb","--tag_batch", help="Tagging batch size", default=500, type=int)47parser.add_argument("-g","--gpu", help="the id of gpu, the default is 0", default=0, type=int)48parser.add_argument('-opth', '--output_path', default=None, help='Output path')49parser.add_argument('-sea', '--sea', help='Process languages like Vietamese', default=False, action='store_true')50parser.add_argument('-ss', '--sent_seg', help='Perform sentence seg', default=False, action='store_true')51parser.add_argument('-ens', '--ensemble', default=False, help='Ensemble several weights', action='store_true')52parser.add_argument('-sgl', '--segment_large', default=False, help='Segment (very) large file', action='store_true')53parser.add_argument('-lgs', '--large_size', default=10000, type=int, help='Segment (very) large file')54parser.add_argument('-ot', '--only_tokenised', default=False,55                    help='Only output the tokenised file when segment (very) large file', action='store_true')56parser.add_argument('-ts', '--train_size', default=-1, type=int, help='No. of sentences used for training')57parser.add_argument('-rs', '--reset', default=False, help='Delete and re-initialise the intermediate files',58                    action='store_true')59parser.add_argument('-rst', '--reset_trans', default=False, help='Retrain the transducers', action='store_true')60parser.add_argument('-isp', '--ignore_space', default=False, help='Ignore space delimiters', action='store_true')61parser.add_argument('-imt', '--ignore_mwt', default=False, help='Ignore multi-word tokens to be transcribed',62                    action='store_true')63parser.add_argument('-sb', '--segmentation_bias', default=-1, type=float,64                    help='Add segmentation bias to under(over)-splitting')65parser.add_argument('-tt', '--transduction_type', default='mix', choices=['mix', 'dict', 'trans', 'none'],66                    help='Different ways of transducing the non-segmental MWTs')67args = parser.parse_args()68sys = reload(sys)69sys.setdefaultencoding('utf-8')70print 'Encoding: ', sys.getdefaultencoding()71if args.action == 'train':72    assert args.path is not None73    path = args.path74    train_file = args.train75    dev_file = args.dev76    model_file = args.model77    print 'Reading data......'78    f_names = os.listdir(path)79    if train_file is None or dev_file is None:80        for f_n in f_names:81            if 'ud-train.conllu' in f_n or 'training.segd' in f_n or 'ud-sample.conllu' in f_n:82                train_file = f_n83            elif 'ud-dev.conllu' in f_n or 'development.segd' in f_n:84                dev_file = f_n85    assert train_file is not None86    is_space = True87    if 'Chinese' in path or 'Japanese' in path or args.format == 'mlp2':88        is_space = False89    if args.sea:90        is_space = 'sea'91    if args.reset or not os.path.isfile(path + '/raw_train.txt') or not os.path.isfile(path + '/raw_dev.txt'):92        cat = 'other'93        if 'Chinese' in path or 'Japanese' in path:94            cat = 'zh'95        for line in codecs.open(path + '/' + train_file, 'r', encoding='utf-8'):96            if len(line) < 2:97                break98            if '# sentence' in line or '# text' in line:99                cat = 'gold'100        if dev_file is None:101            reader.get_raw(path, train_file, '/raw_train.txt', cat, is_dev=False, form=args.format, is_space=is_space)102        else:103            reader.get_raw(path, train_file, '/raw_train.txt', cat, form=args.format, is_space=is_space)104            reader.get_raw(path, dev_file, '/raw_dev.txt', cat, form=args.format, is_space=is_space)105    if args.reset or not os.path.isfile(path + '/tag_train.txt') or not os.path.isfile(path + '/tag_dev.txt') or \106            not os.path.isfile(path + '/tag_dev_gold.txt'):107        if dev_file is None:108            raws_train = reader.raw(path + '/raw_train.txt')109            raws_dev = reader.raw(path + '/raw_dev.txt')110            sents_train, sents_dev = reader.gold(path + '/' + train_file, False, form=args.format, is_space=is_space)111        else:112            raws_train = reader.raw(path + '/raw_train.txt')113            sents_train = reader.gold(path + '/' + train_file, form=args.format, is_space=is_space)114            raws_dev = reader.raw(path + '/raw_dev.txt')115            sents_dev = reader.gold(path + '/' + dev_file, form=args.format, is_space=is_space)116        if is_space != 'sea':117            toolbox.raw2tags(raws_train, sents_train, path, 'tag_train.txt', ignore_space=args.ignore_space,118                             reset=args.reset, tag_scheme=args.tags, ignore_mwt=args.ignore_mwt)119            toolbox.raw2tags(raws_dev, sents_dev, path, 'tag_dev.txt', creat_dict=False, gold_path='tag_dev_gold.txt',120                             ignore_space=args.ignore_space, tag_scheme=args.tags, ignore_mwt=args.ignore_mwt)121        else:122            toolbox.raw2tags_sea(raws_train, sents_train, path, 'tag_train.txt', reset=args.reset, tag_scheme=args.tags)123            toolbox.raw2tags_sea(raws_dev, sents_dev, path, 'tag_dev.txt', gold_path='tag_dev_gold.txt',124                                 tag_scheme=args.tags)125    if args.reset or not os.path.isfile(path + '/chars.txt'):126        toolbox.get_chars(path, ['raw_train.txt', 'raw_dev.txt'], sea=is_space)127    char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict = toolbox.get_dicts(path, args.sent_seg, args.tags,128                                                                                        args.crf)129    if args.embeddings is not None:130        print 'Reading embeddings...'131        short_emb = args.embeddings[args.embeddings.index('/') + 1: args.embeddings.index('.')]132        if args.reset or not os.path.isfile(path + '/' + short_emb + '_sub.txt'):133            toolbox.get_sample_embedding(path, args.embeddings, char2idx)134        emb_dim, emb, valid_chars = toolbox.read_sample_embedding(path, short_emb, char2idx)135        for vch in valid_chars:136            if char2idx[vch] in unk_chars_idx:137                unk_chars_idx.remove(char2idx[vch])138    else:139        emb_dim = args.emb_dimension140        emb = None141    train_x, train_y, max_len_train = toolbox.get_input_vec(path, 'tag_train.txt', char2idx, tag2idx,142                                                            limit=args.sent_limit, sent_seg=args.sent_seg,143                                                            is_space=is_space, train_size=args.train_size,144                                                            ignore_space=args.ignore_space)145    dev_x, max_len_dev = toolbox.get_input_vec_raw(path, 'raw_dev.txt', char2idx, limit=args.sent_limit,146                                                   sent_seg=args.sent_seg, is_space=is_space,147                                                   ignore_space=args.ignore_space)148    if args.sent_seg:149        print 'Joint sentence segmentation...'150    else:151        print 'Training set: %d instances; Dev set: %d instances.' % (len(train_x[0]), len(dev_x[0]))152    nums_grams = None153    ng_embs = None154    if args.ngram > 1 and (args.reset or  not os.path.isfile(path + '/' + str(args.ngram) + 'gram.txt')):155        toolbox.get_ngrams(path, args.ngram, is_space)156    ngram = toolbox.read_ngrams(path, args.ngram)157    if args.ngram > 1:158        gram2idx = toolbox.get_ngram_dic(ngram)159        train_gram = toolbox.get_gram_vec(path, 'tag_train.txt', gram2idx, limit=args.sent_limit,sent_seg=args.sent_seg,160                                          is_space=is_space, ignore_space=args.ignore_space)161        dev_gram = toolbox.get_gram_vec(path, 'raw_dev.txt', gram2idx, is_raw=True, limit=args.sent_limit,162                                        sent_seg=args.sent_seg, is_space=is_space, ignore_space=args.ignore_space)163        train_x += train_gram164        dev_x += dev_gram165        nums_grams = []166        for dic in gram2idx:167            nums_grams.append(len(dic.keys()))168    max_len = max(max_len_train, max_len_dev)169    b_train_x, b_train_y = toolbox.buckets(train_x, train_y, size=args.bucket_size)170    b_train_x, b_train_y, b_lens, b_count = toolbox.pad_bucket(b_train_x, b_train_y, max_len)171    b_dev_x = [toolbox.pad_zeros(dev_x_i, max_len) for dev_x_i in dev_x]172    b_dev_y_gold = [line.strip() for line in codecs.open(path + '/tag_dev_gold.txt', 'r', encoding='utf-8')]173    nums_tag = len(tag2idx)174    config = tf.ConfigProto(allow_soft_placement=True)175    gpu_config = "/gpu:" + str(args.gpu)176    transducer = None177    transducer_graph = None178    trans_model = None179    trans_init = None180    if len(trans_dict) > 200 and not args.ignore_mwt:181        transducer = toolbox.get_dict_vec(trans_dict, char2idx)182    t = time()183    initializer = tf.contrib.layers.xavier_initializer()184    if transducer is not None:185        transducer_graph = tf.Graph()186        with transducer_graph.as_default():187            with tf.variable_scope("transducer") as scope:188                trans_model = Seq2seq(path + '/' + model_file + '_transducer')189                print 'Defining transducer...'190                trans_model.define(char_num=len(char2idx), rnn_dim=args.rnn_cell_dimension, emb_dim=args.emb_dimension,191                                   max_x=len(transducer[0][0]), max_y=len(transducer[1][0]))192            trans_init = tf.global_variables_initializer()193        transducer_graph.finalize()194    print 'Initialization....'195    main_graph = tf.Graph()196    with main_graph.as_default():197        with tf.variable_scope("tagger") as scope:198            model = Model(nums_chars=len(char2idx) + 2, nums_tags=nums_tag, buckets_char=b_lens, counts=b_count,199                          crf=args.crf, ngram=nums_grams, batch_size=args.train_batch, sent_seg=args.sent_seg,200                          is_space=is_space, emb_path=args.embeddings, tag_scheme=args.tags)201            model.main_graph(trained_model=path + '/' + model_file + '_model', scope=scope,202                             emb_dim=emb_dim, cell=args.cell, rnn_dim=args.rnn_cell_dimension,203                             rnn_num=args.rnn_layer_number, drop_out=args.dropout_rate, emb=emb)204            t = time()205        model.config(optimizer=args.optimizer, decay=args.decay_rate, lr_v=args.learning_rate,206                     momentum=args.momentum, clipping=not args.no_clipping)207        init = tf.global_variables_initializer()208        print 'Done. Time consumed: %d seconds' % int(time() - t)209    main_graph.finalize()210    main_sess = tf.Session(config=config, graph=main_graph)211    if args.crf > 0:212        decode_graph = tf.Graph()213        with decode_graph.as_default():214            model.decode_graph()215        decode_graph.finalize()216        decode_sess = tf.Session(config=config, graph=decode_graph)217        sess = [main_sess, decode_sess]218    else:219        sess = [main_sess, None]220    with tf.device(gpu_config):221        if transducer is not None:222            print 'Building transducer...'223            t = time()224            trans_sess = tf.Session(config=config, graph=transducer_graph)225            trans_sess.run(trans_init)226            trans_model.train(transducer[0], transducer[1], transducer[2], transducer[3], args.learning_rate_trans,227                              char2idx, trans_sess, args.epochs_trans, batch_size=10, reset=args.reset_trans)228            sess.append(trans_sess)229            print 'Done. Time consumed: %d seconds' % int(time() - t)230            print 'Training the main segmenter..'231        main_sess.run(init)232        print 'Initialisation...'233        print 'Done. Time consumed: %d seconds' % int(time() - t)234        t = time()235        b_dev_raw = [line.strip() for line in codecs.open(path + '/raw_dev.txt', 'r', encoding='utf-8')]236        model.train(b_train_x, b_train_y, b_dev_x, b_dev_raw, b_dev_y_gold, idx2tag, idx2char, unk_chars_idx, trans_dict,237                    sess, args.epochs, path + '/' + model_file + '_weights', transducer=trans_model,238                    lr=args.learning_rate, decay=args.decay_rate, sent_seg=args.sent_seg, outpath=args.output_path)239else:240    assert args.path is not None241    assert args.model is not None242    path = args.path243    assert os.path.isfile(path + '/chars.txt')244    model_file = args.model245    if args.ensemble:246        if not os.path.isfile(path + '/' + model_file + '_1_model') or not os.path.isfile(path + '/' + model_file +247                                                                                          '_1_weights.index'):248            raise Exception('Not any model file or weights file under the name of ' + model_file + '.')249        fin = open(path + '/' + model_file + '_1_model', 'rb')250    else:251        if not os.path.isfile(path + '/' + model_file + '_model') or not os.path.isfile(path + '/' + model_file +252                                                                                        '_weights.index'):253            raise Exception('No model file or weights file under the name of ' + model_file + '.')254        fin = open(path + '/' + model_file + '_model', 'rb')255    weight_path = path + '/' + model_file256    param_dic = pickle.load(fin)257    fin.close()258    nums_chars = param_dic['nums_chars']259    nums_tags = param_dic['nums_tags']260    crf = param_dic['crf']261    emb_dim = param_dic['emb_dim']262    cell = param_dic['cell']263    rnn_dim = param_dic['rnn_dim']264    rnn_num = param_dic['rnn_num']265    drop_out = param_dic['drop_out']266    buckets_char = param_dic['buckets_char']267    nums_ngrams = param_dic['ngram']268    is_space = param_dic['is_space']269    sent_seg = param_dic['sent_seg']270    emb_path = param_dic['emb_path']271    tag_scheme = param_dic['tag_scheme']272    if args.embeddings is not None:273        emb_path = args.embeddings274    ngram = 1275    grams, gram2idx = None, None276    if nums_ngrams is not None:277        ngram = len(nums_ngrams) + 1278    char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict = toolbox.get_dicts(path, sent_seg, tag_scheme, crf)279    trans_char_num = len(char2idx)280    if ngram > 1:281        grams = toolbox.read_ngrams(path, ngram)282    new_chars, new_grams = None, None283    test_x, test_y, raw_x, test_y_gold = None, None, None, None284    sub_dict = None285    max_step = None286    raw_file = None287    if args.action == 'test':288        test_file = args.test289        f_names = os.listdir(path)290        if test_file is None:291            for f_n in f_names:292                if 'ud-test.conllu' in f_n:293                    test_file = f_n294        assert test_file is not None295        cat = 'other'296        if 'Chinese' in path or 'Japanese' in path:297            cat = 'zh'298        for line in codecs.open(path + '/' + test_file, 'r', encoding='utf-8'):299            if len(line) < 2:300                break301            if '# sentence' in line or '# text' in line:302                cat = 'gold'303        reader.get_raw(path, test_file, 'raw_test.txt', cat, form=args.format)304        raws_test = reader.raw(path + '/raw_test.txt')305        test_y_gold = reader.test_gold(path + '/' + test_file, form=args.format, is_space=is_space,306                                       ignore_mwt=args.ignore_mwt)307        new_chars = toolbox.get_new_chars(path + '/raw_test.txt', char2idx, is_space)308        if emb_path is not None:309            valid_chars = toolbox.get_valid_chars(new_chars + char2idx.keys(), emb_path)310        else:311            valid_chars = None312        char2idx, idx2char, unk_chars_idx, sub_dict = toolbox.update_char_dict(char2idx, new_chars, unk_chars_idx, valid_chars)313        test_x, max_len_test = toolbox.get_input_vec_raw(path, 'raw_test.txt', char2idx, limit=args.sent_limit + 100,314                                                         sent_seg=sent_seg, is_space=is_space,315                                                         ignore_space=args.ignore_space)316        max_step = max_len_test317        if sent_seg:318            print 'Joint sentence segmentation...'319        else:320            print 'Test set: %d instances.' % len(test_x[0])321        if ngram > 1:322            gram2idx = toolbox.get_ngram_dic(grams)323            new_grams = toolbox.get_new_grams(path + '/' + test_file, gram2idx, is_space=is_space)324            test_grams = toolbox.get_gram_vec(path, 'raw_test.txt', gram2idx, is_raw=True, limit=args.sent_limit + 100,325                                              sent_seg=sent_seg, is_space=is_space, ignore_space=args.ignore_space)326            test_x += test_grams327        for k in range(len(test_x)):328            test_x[k] = toolbox.pad_zeros(test_x[k], max_step)329    elif args.action == 'tag':330        assert args.raw is not None331        raw_file = args.raw332        new_chars = toolbox.get_new_chars(raw_file, char2idx, is_space)333        if emb_path is not None:334            valid_chars = toolbox.get_valid_chars(new_chars, emb_path)335        else:336            valid_chars = None337        char2idx, idx2char, unk_chars_idx, sub_dict = toolbox.update_char_dict(char2idx, new_chars, unk_chars_idx,338                                                                               valid_chars)339        if not args.segment_large:340            if sent_seg:341                raw_x, raw_len = toolbox.get_input_vec_tag(None, raw_file, char2idx, limit=args.sent_limit + 100,342                                                           is_space=is_space)343            else:344                raw_x, raw_len = toolbox.get_input_vec_raw(None, raw_file, char2idx, limit=args.sent_limit + 100,345                                                           sent_seg=sent_seg, is_space=is_space)346            if sent_seg:347                print 'Joint sentence segmentation...'348            else:349                print 'Raw setences: %d instances.' % len(raw_x[0])350            max_step = raw_len351        else:352            max_step = args.sent_limit353        if ngram > 1:354            gram2idx = toolbox.get_ngram_dic(grams)355            new_grams = toolbox.get_new_grams(raw_file, gram2idx, is_raw=True, is_space=is_space)356            if not args.segment_large:357                if sent_seg:358                    raw_grams = toolbox.get_gram_vec_tag(None, raw_file, gram2idx, limit=args.sent_limit + 100,359                                                         is_space=is_space)360                else:361                    raw_grams = toolbox.get_gram_vec(None, raw_file, gram2idx, is_raw=True, limit=args.sent_limit + 100,362                                                     sent_seg=sent_seg, is_space=is_space)363                raw_x += raw_grams364        if not args.segment_large:365            for k in range(len(raw_x)):366                raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step)367    config = tf.ConfigProto(allow_soft_placement=True)368    gpu_config = "/gpu:" + str(args.gpu)369    transducer = None370    transducer_graph = None371    trans_model = None372    trans_init = None373    if len(trans_dict) > 200:374        transducer = toolbox.get_dict_vec(trans_dict, char2idx)375    t = time()376    initializer = tf.contrib.layers.xavier_initializer()377    if transducer is not None:378        transducer_graph = tf.Graph()379        with transducer_graph.as_default():380            with tf.variable_scope("transducer") as scope:381                trans_model = Seq2seq(path + '/' + model_file + '_transducer')382                trans_fin = open(path + '/' + model_file + '_transducer_model', 'rb')383                trans_param_dic = pickle.load(trans_fin)384                trans_fin.close()385                tr_char_num = trans_param_dic['char_num']386                tr_rnn_dim = trans_param_dic['rnn_dim']387                tr_emb_dim = trans_param_dic['emb_dim']388                tr_max_x = trans_param_dic['max_x']389                tr_max_y = trans_param_dic['max_y']390                print 'Defining transducer...'391                trans_model.define(char_num=tr_char_num, rnn_dim=tr_rnn_dim, emb_dim=tr_emb_dim,392                                   max_x=tr_max_x, max_y=tr_max_y, write_trans_model=False)393            trans_init = tf.global_variables_initializer()394        transducer_graph.finalize()395    print 'Initialization....'396    main_graph = tf.Graph()397    with main_graph.as_default():398        with tf.variable_scope("tagger") as scope:399            model = Model(nums_chars=nums_chars, nums_tags=nums_tags, buckets_char=[max_step], counts=[200],400                          crf=crf, ngram=nums_ngrams, batch_size=args.tag_batch, is_space=is_space)401            model.main_graph(trained_model=None, scope=scope, emb_dim=emb_dim, cell=cell,402                             rnn_dim=rnn_dim, rnn_num=rnn_num, drop_out=drop_out)403        model.define_updates(new_chars=new_chars, emb_path=emb_path, char2idx=char2idx)404        init = tf.global_variables_initializer()405        print 'Done. Time consumed: %d seconds' % int(time() - t)406    main_graph.finalize()407    idx=None408    if args.ensemble:409        idx = 1410        main_sess = []411        while os.path.isfile(path + '/' + model_file + '_' + str(idx) + '_weights.index'):412            main_sess.append(tf.Session(config=config, graph=main_graph))413            idx += 1414    else:415        main_sess = tf.Session(config=config, graph=main_graph)416    if crf:417        decode_graph = tf.Graph()418        with decode_graph.as_default():419            model.decode_graph()420        decode_graph.finalize()421        decode_sess = tf.Session(config=config, graph=decode_graph)422        sess = [main_sess, decode_sess]423    else:424        sess = [main_sess, None]425    with tf.device(gpu_config):426        ens_model = None427        print 'Loading weights....'428        if args.ensemble:429            for i in range(1, idx):430                print 'Ensemble: ' + str(i)431                main_sess[i - 1].run(init)432                model.run_updates(main_sess[i - 1], weight_path + '_' + str(i) + '_weights')433        else:434            main_sess.run(init)435            model.run_updates(main_sess, weight_path + '_weights')436        if transducer is not None:437            print 'Loading transducer...'438            t = time()439            trans_sess = tf.Session(config=config, graph=transducer_graph)440            trans_sess.run(trans_init)441            if os.path.isfile(path + '/' + model_file + '_transducer_weights'):442                trans_weight_path = path + '/' + model_file + '_transducer_weights'443                trans_weight_path = trans_weight_path.replace('//', '/')444                trans_model.saver.restore(trans_sess, trans_weight_path)445            sess.append(trans_sess)446        if args.action == 'test':447            test_y_raw = [line.strip() for line in codecs.open(path + '/raw_test.txt', 'rb', encoding='utf-8')]448            model.test(test_x, test_y_raw, test_y_gold, idx2tag, idx2char, unk_chars_idx, sub_dict, trans_dict, sess,449                       transducer=trans_model, ensemble=args.ensemble, batch_size=args.test_batch, sent_seg=sent_seg,450                       bias=args.segmentation_bias, outpath=args.output_path, trans_type=args.transduction_type)451        if args.action == 'tag':452            if not args.segment_large:453                raw_sents = []454                for line in codecs.open(raw_file, 'rb', encoding='utf-8'):455                    line = line.strip()456                    if len(line) > 0:457                        raw_sents.append(line)458                model.tag(raw_x, raw_sents, idx2tag, idx2char, unk_chars_idx, sub_dict, trans_dict, sess,459                          transducer=trans_model, outpath=args.output_path, ensemble=args.ensemble,460                          batch_size=args.tag_batch, sent_seg=sent_seg, seg_large=args.segment_large, form=args.format)461            else:462                count = 0463                c_line = 0464                l_writer = codecs.open(args.output_path, 'w', encoding='utf-8')465                out = []466                with codecs.open(raw_file, 'r', encoding='utf-8') as l_file:467                    lines = []468                    for line in l_file:469                        line = line.strip()470                        if len(line) > 0:471                            lines.append(line)472                        else:473                            c_line += 1474                        if c_line >= args.large_size:475                            count += len(lines)476                            c_line = 0477                            print count478                            if args.sent_seg:479                                raw_x, _ = toolbox.get_input_vec_tag(None, None, char2idx, lines=lines,480                                                                     limit=args.sent_limit, is_space=is_space)481                            else:482                                raw_x, _ = toolbox.get_input_vec_raw(None, None, char2idx, lines=lines,483                                                                     limit=args.sent_limit, sent_seg=sent_seg,484                                                                     is_space=is_space)485                            if ngram > 1:486                                if sent_seg:487                                    raw_grams = toolbox.get_gram_vec_tag(None, None, gram2idx, lines=lines,488                                                                         limit=args.sent_limit, is_space=is_space)489                                else:490                                    raw_grams = toolbox.get_gram_vec(None, None, gram2idx, lines=lines, is_raw=True,491                                                                     limit=args.sent_limit, sent_seg=sent_seg,492                                                                     is_space=is_space)493                                raw_x += raw_grams494                            for k in range(len(raw_x)):495                                raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step)496                            predition, multi = model.tag(raw_x, lines, idx2tag, idx2char, unk_chars_idx, sub_dict,497                                                         trans_dict, sess, transducer=trans_model,498                                                         outpath=args.output_path, ensemble=args.ensemble,499                                                         batch_size=args.tag_batch, sent_seg=sent_seg,500                                                         seg_large=args.segment_large, form=args.format)501                            if args.only_tokenised:502                                for l_out in predition:503                                    if len(l_out.strip()) > 0:504                                        l_writer.write(l_out + '\n')505                            else:506                                for tagged_t, multi_t in zip(predition, multi):507                                    if len(tagged_t.strip()) > 0:508                                        l_writer.write('#sent_tok: ' + tagged_t + '\n')509                                        idx = 1510                                        tgs = multi_t.split('  ')511                                        pl = ''512                                        for _ in range(8):513                                            pl += '\t' + '_'514                                        for tg in tgs:515                                            if '!#!' in tg:516                                                segs = tg.split('!#!')517                                                l_writer.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' +518                                                               segs[0] + pl + '\n')519                                            else:520                                                l_writer.write(str(idx) + '\t' + tg + pl + '\n')521                                                idx += 1522                                        l_writer.write('\n')523                            lines = []524                    if len(lines) > 0:525                        if args.sent_seg:526                            raw_x, _ = toolbox.get_input_vec_tag(None, None, char2idx, lines=lines,527                                                                      limit=args.sent_limit, is_space=is_space)528                        else:529                            raw_x, _ = toolbox.get_input_vec_raw(None, None, char2idx, lines=lines,530                                                                      limit=args.sent_limit, sent_seg=sent_seg,531                                                                      is_space=is_space)532                        if ngram > 1:533                            if sent_seg:534                                raw_grams = toolbox.get_gram_vec_tag(None, None, gram2idx, lines=lines,535                                                                     limit=args.sent_limit, is_space=is_space)536                            else:537                                raw_grams = toolbox.get_gram_vec(None, None, gram2idx, lines=lines, is_raw=True,538                                                                 limit=args.sent_limit, sent_seg=sent_seg,539                                                                 is_space=is_space)540                            raw_x += raw_grams541                        for k in range(len(raw_x)):542                            raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step)543                        prediction, multi = model.tag(raw_x, lines, idx2tag, idx2char, unk_chars_idx, sub_dict,544                                                      trans_dict, sess, transducer=trans_model,545                                                      outpath=args.output_path, ensemble=args.ensemble,546                                                      batch_size=args.tag_batch, sent_seg=sent_seg,547                                                      seg_large=args.segment_large, form=args.format)548                        if args.only_tokenised:549                            for l_out in prediction:550                                if len(l_out.strip()) > 0:551                                    l_writer.write(l_out + '\n')552                        else:553                            for tagged_t, multi_t in zip(prediction, multi):554                                if len(tagged_t.strip()) > 0:555                                    l_writer.write('#sent_tok: ' + tagged_t + '\n')556                                    idx = 1557                                    tgs = multi_t.split('  ')558                                    pl = ''559                                    for _ in range(8):560                                        pl += '\t' + '_'561                                    for tg in tgs:562                                        if '!#!' in tg:563                                            segs = tg.split('!#!')564                                            l_writer.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' +565                                                           segs[0] + pl + '\n')566                                        else:567                                            l_writer.write(str(idx) + '\t' + tg + pl + '\n')568                                            idx += 1569                                    l_writer.write('\n')570                l_writer.close()...test_space_attachment.py
Source:test_space_attachment.py  
1# coding: utf-82from __future__ import unicode_literals3import pytest4from spacy.tokens.doc import Doc5from ..util import get_doc, apply_transition_sequence6def test_parser_space_attachment(en_tokenizer):7    text = "This is a test.\nTo ensure  spaces are attached well."8    heads = [1, 0, 1, -2, -3, -1, 1, 4, -1, 2, 1, 0, -1, -2]9    tokens = en_tokenizer(text)10    doc = get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads)11    for sent in doc.sents:12        if len(sent) == 1:13            assert not sent[-1].is_space14def test_parser_sentence_space(en_tokenizer):15    # fmt: off16    text = "I look forward to using Thingamajig.  I've been told it will make my life easier..."17    heads = [1, 0, -1, -2, -1, -1, -5, -1, 3, 2, 1, 0, 2, 1, -3, 1, 1, -3, -7]18    deps = ["nsubj", "ROOT", "advmod", "prep", "pcomp", "dobj", "punct", "",19            "nsubjpass", "aux", "auxpass", "ROOT", "nsubj", "aux", "ccomp",20            "poss", "nsubj", "ccomp", "punct"]21    # fmt: on22    tokens = en_tokenizer(text)23    doc = get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps)24    assert len(list(doc.sents)) == 225@pytest.mark.xfail26def test_parser_space_attachment_leading(en_tokenizer, en_parser):27    text = "\t \n This is a sentence ."28    heads = [1, 1, 0, 1, -2, -3]29    tokens = en_tokenizer(text)30    doc = get_doc(tokens.vocab, words=text.split(" "), heads=heads)31    assert doc[0].is_space32    assert doc[1].is_space33    assert doc[2].text == "This"34    with en_parser.step_through(doc) as stepwise:35        pass36    assert doc[0].head.i == 237    assert doc[1].head.i == 238    assert stepwise.stack == set([2])39@pytest.mark.xfail40def test_parser_space_attachment_intermediate_trailing(en_tokenizer, en_parser):41    text = "This is \t a \t\n \n sentence . \n\n \n"42    heads = [1, 0, -1, 2, -1, -4, -5, -1]43    transition = ["L-nsubj", "S", "L-det", "R-attr", "D", "R-punct"]44    tokens = en_tokenizer(text)45    doc = get_doc(tokens.vocab, words=text.split(" "), heads=heads)46    assert doc[2].is_space47    assert doc[4].is_space48    assert doc[5].is_space49    assert doc[8].is_space50    assert doc[9].is_space51    apply_transition_sequence(en_parser, doc, transition)52    for token in doc:53        assert token.dep != 0 or token.is_space54    assert [token.head.i for token in doc] == [1, 1, 1, 6, 3, 3, 1, 1, 7, 7]55@pytest.mark.parametrize("text,length", [(["\n"], 1), (["\n", "\t", "\n\n", "\t"], 4)])56@pytest.mark.xfail57def test_parser_space_attachment_space(en_tokenizer, en_parser, text, length):58    doc = Doc(en_parser.vocab, words=text)59    assert len(doc) == length60    with en_parser.step_through(doc) as _:  # noqa: F84161        pass62    assert doc[0].is_space63    for token in doc:...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
