How to use command_eval method in play_selenium

Best Python code snippet using play_selenium_python

eval_tools.py

Source:eval_tools.py Github

copy

Full Screen

1"""2Evaluate system performance.3"""4from __future__ import absolute_import5from __future__ import division6from __future__ import print_function7import collections8import csv9import nltk10import numpy as np11import os12import sys13import random14if sys.version_info > (3, 0):15 from six.moves import xrange16from ..bashlint import data_tools17from ..encoder_decoder import data_utils, graph_utils18from ..eval import token_based, tree_dist19from ..nlp_tools import constants, tokenizer20def manual_eval(prediction_path, dataset, FLAGS, top_k, num_examples=-1, interactive=True, verbose=True):21 """22 Conduct dev/test set evaluation.23 Evaluation metrics:24 1) full command accuracy;25 2) command template accuracy. 26 :param interactive:27 - If set, prompt the user to enter judgement if a prediction does not28 match any of the groundtruths and the correctness of the prediction29 has not been pre-determined;30 Otherwise, all predictions that does not match any of the groundtruths are counted as wrong.31 """32 # Group dataset33 grouped_dataset = data_utils.group_parallel_data(dataset)34 # Load model prediction35 prediction_list = load_predictions(prediction_path, top_k)36 metrics = get_manual_evaluation_metrics(37 grouped_dataset, prediction_list, FLAGS, num_examples=num_examples, interactive=interactive, verbose=verbose)38 return metrics39def gen_manual_evaluation_table(dataset, FLAGS, num_examples=-1, interactive=True):40 """41 Conduct dev/test set evaluation. The results of multiple pre-specified models are tabulated in the same table.42 Evaluation metrics:43 1) full command accuracy;44 2) command template accuracy.45 :param interactive:46 - If set, prompt the user to enter judgement if a prediction does not47 match any of the groundtruths and the correctness of the prediction48 has not been pre-determined;49 Otherwise, all predictions that does not match any of the groundtruths are counted as wrong.50 """51 # Group dataset52 grouped_dataset = data_utils.group_parallel_data(dataset)53 # Load all model predictions54 model_names, model_predictions = load_all_model_predictions(55 grouped_dataset, FLAGS, top_k=3)56 manual_eval_metrics = {}57 for model_id, model_name in enumerate(model_names):58 prediction_list = model_predictions[model_names]59 M = get_manual_evaluation_metrics(60 grouped_dataset, prediction_list, FLAGS, num_examples=num_examples, interactive=interactive, verbose=False)61 manual_eval_metrics[model_name] = [M['acc_f'][0],62 M['acc_f'[1]], M['acc_t'][0], M['acc_t'][1]]63 metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3']64 print_eval_table(model_names, metrics_names, manual_eval_metrics)65def get_manual_evaluation_metrics(grouped_dataset, prediction_list, FLAGS, num_examples=-1, interactive=True,66 verbose=True):67 if len(grouped_dataset) != len(prediction_list):68 raise ValueError("ground truth and predictions length must be equal: "69 "{} vs. {}".format(len(grouped_dataset), len(prediction_list)))70 # Get dev set samples (fixed)71 random.seed(100)72 example_ids = list(range(len(grouped_dataset)))73 random.shuffle(example_ids)74 if num_examples > 0:75 sample_ids = example_ids[:num_examples]76 else:77 sample_ids = example_ids78 # Load cached evaluation results79 structure_eval_cache, command_eval_cache = \80 load_cached_evaluations(81 os.path.join(FLAGS.data_dir, 'manual_judgements'), verbose=True)82 eval_bash = FLAGS.dataset.startswith("bash")83 cmd_parser = data_tools.bash_parser if eval_bash \84 else data_tools.paren_parser85 # Interactive manual evaluation86 num_t_top_1_correct = 0.087 num_f_top_1_correct = 0.088 num_t_top_3_correct = 0.089 num_f_top_3_correct = 0.090 for exam_id, example_id in enumerate(sample_ids):91 data_group = grouped_dataset[example_id][1]92 sc_txt = data_group[0].sc_txt.strip()93 sc_key = get_example_nl_key(sc_txt)94 command_gts = [dp.tg_txt for dp in data_group]95 command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]96 predictions = prediction_list[example_id]97 top_3_s_correct_marked = False98 top_3_f_correct_marked = False99 for i in xrange(min(3, len(predictions))):100 pred_cmd = predictions[i]101 pred_ast = cmd_parser(pred_cmd)102 pred_temp = data_tools.ast2template(103 pred_ast, loose_constraints=True)104 temp_match = tree_dist.one_match(105 command_gt_asts, pred_ast, ignore_arg_value=True)106 str_match = tree_dist.one_match(107 command_gt_asts, pred_ast, ignore_arg_value=False)108 # Match ground truths & exisitng judgements109 command_example_key = '{}<NL_PREDICTION>{}'.format(110 sc_key, pred_cmd)111 structure_example_key = '{}<NL_PREDICTION>{}'.format(112 sc_key, pred_temp)113 command_eval, structure_eval = '', ''114 if str_match:115 command_eval = 'y'116 structure_eval = 'y'117 elif temp_match:118 structure_eval = 'y'119 if command_eval_cache and command_example_key in command_eval_cache:120 command_eval = command_eval_cache[command_example_key]121 if structure_eval_cache and structure_example_key in structure_eval_cache:122 structure_eval = structure_eval_cache[structure_example_key]123 # Prompt for new judgements124 if command_eval != 'y':125 if structure_eval == 'y':126 if not command_eval and interactive:127 print('#{}. {}'.format(exam_id, sc_txt))128 for j, gt in enumerate(command_gts):129 print('- GT{}: {}'.format(j, gt))130 print('> {}'.format(pred_cmd))131 command_eval = input(132 'CORRECT COMMAND? [y/reason] ')133 add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,134 structure_eval, command_eval)135 print()136 else:137 if not structure_eval and interactive:138 print('#{}. {}'.format(exam_id, sc_txt))139 for j, gt in enumerate(command_gts):140 print('- GT{}: {}'.format(j, gt))141 print('> {}'.format(pred_cmd))142 structure_eval = input(143 'CORRECT STRUCTURE? [y/reason] ')144 if structure_eval == 'y':145 command_eval = input(146 'CORRECT COMMAND? [y/reason] ')147 add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,148 structure_eval, command_eval)149 print()150 structure_eval_cache[structure_example_key] = structure_eval151 command_eval_cache[command_example_key] = command_eval152 if structure_eval == 'y':153 if i == 0:154 num_t_top_1_correct += 1155 if not top_3_s_correct_marked:156 num_t_top_3_correct += 1157 top_3_s_correct_marked = True158 if command_eval == 'y':159 if i == 0:160 num_f_top_1_correct += 1161 if not top_3_f_correct_marked:162 num_f_top_3_correct += 1163 top_3_f_correct_marked = True164 metrics = {}165 acc_f_1 = num_f_top_1_correct / len(sample_ids)166 acc_f_3 = num_f_top_3_correct / len(sample_ids)167 acc_t_1 = num_t_top_1_correct / len(sample_ids)168 acc_t_3 = num_t_top_3_correct / len(sample_ids)169 metrics['acc_f'] = [acc_f_1, acc_f_3]170 metrics['acc_t'] = [acc_t_1, acc_t_3]171 if verbose:172 print('{} examples evaluated'.format(len(sample_ids)))173 print('Top 1 Command Acc = {:.3f}'.format(acc_f_1))174 print('Top 3 Command Acc = {:.3f}'.format(acc_f_3))175 print('Top 1 Template Acc = {:.3f}'.format(acc_t_1))176 print('Top 3 Template Acc = {:.3f}'.format(acc_t_3))177 return metrics178def add_judgement(data_dir, nl, command, correct_template='', correct_command=''):179 """180 Append a new judgement181 """182 data_dir = os.path.join(data_dir, 'manual_judgements')183 manual_judgement_path = os.path.join(184 data_dir, 'manual.evaluations.author')185 if not os.path.exists(manual_judgement_path):186 with open(manual_judgement_path, 'w') as o_f:187 o_f.write(188 'description,prediction,template,correct template,correct command\n')189 with open(manual_judgement_path, 'a') as o_f:190 temp = data_tools.cmd2template(command, loose_constraints=True)191 if not correct_template:192 correct_template = 'n'193 if not correct_command:194 correct_command = 'n'195 o_f.write('"{}","{}","{}","{}","{}"\n'.format(196 nl.replace('"', '""'), command.replace('"', '""'),197 temp.replace('"', '""'), correct_template.replace('"', '""'),198 correct_command.replace('"', '""')))199 print('new judgement added to {}'.format(manual_judgement_path))200def automatic_eval(prediction_path, dataset, FLAGS, top_k, num_samples=-1, verbose=False):201 """202 Generate automatic evaluation metrics on dev/test set.203 The following metrics are computed:204 Top 1,3,5,10205 1. Structure accuracy206 2. Full command accuracy207 3. Command keyword overlap208 4. BLEU209 """210 grouped_dataset = data_utils.group_parallel_data(dataset)211 try:212 vocabs = data_utils.load_vocabulary(FLAGS)213 except ValueError:214 vocabs = None215 # Load predictions216 prediction_list = load_predictions(prediction_path, top_k)217 if len(grouped_dataset) != len(prediction_list):218 raise ValueError("ground truth and predictions length must be equal: "219 "{} vs. {}".format(len(grouped_dataset), len(prediction_list)))220 metrics = get_automatic_evaluation_metrics(grouped_dataset, prediction_list, vocabs, FLAGS,221 top_k, num_samples, verbose)222 return metrics223def gen_automatic_evaluation_table(dataset, FLAGS):224 # Group dataset225 grouped_dataset = data_utils.group_parallel_data(dataset)226 vocabs = data_utils.load_vocabulary(FLAGS)227 model_names, model_predictions = load_all_model_predictions(228 grouped_dataset, FLAGS, top_k=3)229 auto_eval_metrics = {}230 for model_id, model_name in enumerate(model_names):231 prediction_list = model_predictions[model_id]232 if prediction_list is not None:233 M = get_automatic_evaluation_metrics(234 grouped_dataset, prediction_list, vocabs, FLAGS, top_k=3)235 auto_eval_metrics[model_name] = [M['bleu'][0],236 M['bleu'][1], M['cms'][0], M['cms'][1]]237 else:238 print('Model {} skipped in evaluation'.format(model_name))239 metrics_names = ['BLEU1', 'BLEU3', 'TM1', 'TM3']240 print_eval_table(model_names, metrics_names, auto_eval_metrics)241def get_automatic_evaluation_metrics(grouped_dataset, prediction_list, vocabs, FLAGS, top_k,242 num_samples=-1, verbose=False):243 cmd_parser = data_tools.bash_parser244 rev_sc_vocab = vocabs.rev_sc_vocab if vocabs is not None else None245 # Load cached evaluation results246 structure_eval_cache, command_eval_cache = \247 load_cached_evaluations(248 os.path.join(FLAGS.data_dir, 'manual_judgements'))249 # Compute manual evaluation scores on a subset of examples250 if num_samples > 0:251 # Get FIXED dev set samples252 random.seed(100)253 example_ids = list(range(len(grouped_dataset)))254 random.shuffle(example_ids)255 sample_ids = example_ids[:100]256 grouped_dataset = [grouped_dataset[i] for i in sample_ids]257 prediction_list = [prediction_list[i] for i in sample_ids]258 num_eval = 0259 top_k_temp_correct = np.zeros([len(grouped_dataset), top_k])260 top_k_str_correct = np.zeros([len(grouped_dataset), top_k])261 top_k_cms = np.zeros([len(grouped_dataset), top_k])262 top_k_bleu = np.zeros([len(grouped_dataset), top_k])263 command_gt_asts_list, pred_ast_list = [], []264 for data_id in xrange(len(grouped_dataset)):265 _, data_group = grouped_dataset[data_id]266 sc_str = data_group[0].sc_txt.strip()267 sc_key = get_example_nl_key(sc_str)268 if vocabs is not None:269 sc_tokens = [rev_sc_vocab[i] for i in data_group[0].sc_ids]270 if FLAGS.channel == 'char':271 sc_features = ''.join(sc_tokens)272 sc_features = sc_features.replace(constants._SPACE, ' ')273 else:274 sc_features = ' '.join(sc_tokens)275 command_gts = [dp.tg_txt.strip() for dp in data_group]276 command_gt_asts = [cmd_parser(cmd) for cmd in command_gts]277 command_gt_asts_list.append(command_gt_asts)278 template_gts = [data_tools.cmd2template(279 cmd, loose_constraints=True) for cmd in command_gts]280 template_gt_asts = [cmd_parser(temp) for temp in template_gts]281 if verbose:282 print("Example {}".format(data_id))283 print("Original Source: {}".format(sc_str.encode('utf-8')))284 if vocabs is not None:285 print("Source: {}".format(286 [x.encode('utf-8') for x in sc_features]))287 for j, command_gt in enumerate(command_gts):288 print("GT Target {}: {}".format(289 j + 1, command_gt.strip().encode('utf-8')))290 num_eval += 1291 predictions = prediction_list[data_id]292 for i in xrange(len(predictions)):293 pred_cmd = predictions[i]294 pred_ast = cmd_parser(pred_cmd)295 if i == 0:296 pred_ast_list.append(pred_ast)297 pred_temp = data_tools.cmd2template(298 pred_cmd, loose_constraints=True)299 # A) Exact match with ground truths & exisitng judgements300 command_example_key = '{}<NL_PREDICTION>{}'.format(301 sc_key, pred_cmd)302 structure_example_key = '{}<NL_PREDICTION>{}'.format(303 sc_key, pred_temp)304 # B) Match ignoring flag orders305 temp_match = tree_dist.one_match(306 template_gt_asts, pred_ast, ignore_arg_value=True)307 str_match = tree_dist.one_match(308 command_gt_asts, pred_ast, ignore_arg_value=False)309 if command_eval_cache and command_example_key in command_eval_cache:310 str_match = normalize_judgement(311 command_eval_cache[command_example_key]) == 'y'312 if structure_eval_cache and structure_example_key in structure_eval_cache:313 temp_match = normalize_judgement(314 structure_eval_cache[structure_example_key]) == 'y'315 if temp_match:316 top_k_temp_correct[data_id, i] = 1317 if str_match:318 top_k_str_correct[data_id, i] = 1319 cms = token_based.command_match_score(command_gt_asts, pred_ast)320 # if pred_cmd.strip():321 # bleu = token_based.sentence_bleu_score(command_gt_asts, pred_ast)322 # else:323 # bleu = 0324 bleu = nltk.translate.bleu_score.sentence_bleu(325 command_gts, pred_cmd)326 top_k_cms[data_id, i] = cms327 top_k_bleu[data_id, i] = bleu328 if verbose:329 print("Prediction {}: {} ({}, {})".format(330 i + 1, pred_cmd, cms, bleu))331 if verbose:332 print()333 bleu = token_based.corpus_bleu_score(command_gt_asts_list, pred_ast_list)334 top_temp_acc = [-1 for _ in [1, 3, 5, 10]]335 top_cmd_acc = [-1 for _ in [1, 3, 5, 10]]336 top_cms = [-1 for _ in [1, 3, 5, 10]]337 top_bleu = [-1 for _ in [1, 3, 5, 10]]338 top_temp_acc[0] = top_k_temp_correct[:, 0].mean()339 top_cmd_acc[0] = top_k_str_correct[:, 0].mean()340 top_cms[0] = top_k_cms[:, 0].mean()341 top_bleu[0] = top_k_bleu[:, 0].mean()342 print("{} examples evaluated".format(num_eval))343 print("Top 1 Template Acc = %.3f" % top_temp_acc[0])344 print("Top 1 Command Acc = %.3f" % top_cmd_acc[0])345 print("Average top 1 Template Match Score = %.3f" % top_cms[0])346 print("Average top 1 BLEU Score = %.3f" % top_bleu[0])347 if len(predictions) > 1:348 top_temp_acc[1] = np.max(top_k_temp_correct[:, :3], 1).mean()349 top_cmd_acc[1] = np.max(top_k_str_correct[:, :3], 1).mean()350 top_cms[1] = np.max(top_k_cms[:, :3], 1).mean()351 top_bleu[1] = np.max(top_k_bleu[:, :3], 1).mean()352 print("Top 3 Template Acc = %.3f" % top_temp_acc[1])353 print("Top 3 Command Acc = %.3f" % top_cmd_acc[1])354 print("Average top 3 Template Match Score = %.3f" % top_cms[1])355 print("Average top 3 BLEU Score = %.3f" % top_bleu[1])356 if len(predictions) > 3:357 top_temp_acc[2] = np.max(top_k_temp_correct[:, :5], 1).mean()358 top_cmd_acc[2] = np.max(top_k_str_correct[:, :5], 1).mean()359 top_cms[2] = np.max(top_k_cms[:, :5], 1).mean()360 top_bleu[2] = np.max(top_k_bleu[:, :5], 1).mean()361 print("Top 5 Template Acc = %.3f" % top_temp_acc[2])362 print("Top 5 Command Acc = %.3f" % top_cmd_acc[2])363 print("Average top 5 Template Match Score = %.3f" % top_cms[2])364 print("Average top 5 BLEU Score = %.3f" % top_bleu[2])365 if len(predictions) > 5:366 top_temp_acc[3] = np.max(top_k_temp_correct[:, :10], 1).mean()367 top_cmd_acc[3] = np.max(top_k_str_correct[:, :10], 1).mean()368 top_cms[3] = np.max(top_k_cms[:, :10], 1).mean()369 top_bleu[3] = np.max(top_k_bleu[:, :10], 1).mean()370 print("Top 10 Template Acc = %.3f" % top_temp_acc[3])371 print("Top 10 Command Acc = %.3f" % top_cmd_acc[3])372 print("Average top 10 Template Match Score = %.3f" % top_cms[3])373 print("Average top 10 BLEU Score = %.3f" % top_bleu[3])374 print('Corpus BLEU = %.3f' % bleu)375 print()376 metrics = {}377 metrics['acc_f'] = top_cmd_acc378 metrics['acc_t'] = top_temp_acc379 metrics['cms'] = top_cms380 metrics['bleu'] = top_bleu381 return metrics382def print_eval_table(model_names, metrics_names, model_metrics):383 def pad_spaces(s, max_len):384 return s + ' ' * (max_len - len(s))385 # print evaluation table386 # pad model names with spaces to create alignment387 max_len = len(max(model_names, key=len))388 max_len_with_offset = (int(max_len / 4) + 1) * 4389 first_row = pad_spaces('Model', max_len_with_offset)390 for metrics_name in metrics_names:391 first_row += '{} '.format(metrics_name)392 print(first_row.strip())393 print('-' * len(first_row))394 for i, model_name in enumerate(model_names):395 row = pad_spaces(model_name, max_len_with_offset)396 if model_name in model_metrics:397 for metrics in model_metrics[model_name]:398 row += '{:.2f} '.format(metrics)399 print(row.strip())400 print('-' * len(first_row))401def load_all_model_predictions(grouped_dataset, FLAGS, top_k=1, model_names=('token_seq2seq',402 'tellina',403 'token_copynet',404 'partial_token_seq2seq',405 'partial_token_copynet',406 'char_seq2seq',407 'char_copynet')):408 """409 Load predictions of multiple models (specified with "model_names").410 :return model_predictions: List of model predictions.411 """412 def load_model_predictions():413 model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS)414 model_dir = os.path.join(FLAGS.model_root_dir, model_subdir)415 prediction_path = os.path.join(416 model_dir, 'predictions.{}.latest'.format(decode_sig))417 prediction_list = load_predictions(prediction_path, top_k)418 if prediction_list is not None and len(grouped_dataset) != len(prediction_list):419 raise ValueError("ground truth list and prediction list length must "420 "be equal: {} vs. {}".format(len(grouped_dataset),421 len(prediction_list)))422 return prediction_list423 # Load model predictions424 model_predictions = []425 # -- Token426 FLAGS.channel = 'token'427 FLAGS.normalized = False428 FLAGS.fill_argument_slots = False429 FLAGS.use_copy = False430 # --- Seq2Seq431 if 'token_seq2seq' in model_names:432 model_predictions.append(load_model_predictions())433 # --- Tellina434 if 'tellina' in model_names:435 FLAGS.normalized = True436 FLAGS.fill_argument_slots = True437 model_predictions.append(load_model_predictions())438 FLAGS.normalized = False439 FLAGS.fill_argument_slots = False440 # --- CopyNet441 if 'token_copynet' in model_names:442 FLAGS.use_copy = True443 FLAGS.copy_fun = 'copynet'444 model_predictions.append(load_model_predictions())445 FLAGS.use_copy = False446 # -- Parital token447 FLAGS.channel = 'partial.token'448 # --- Seq2Seq449 if 'partial_token_seq2seq' in model_names:450 model_predictions.append(load_model_predictions())451 # --- CopyNet452 if 'partial_token_copynet' in model_names:453 FLAGS.use_copy = True454 FLAGS.copy_fun = 'copynet'455 model_predictions.append(load_model_predictions())456 FLAGS.use_copy = False457 # -- Character458 FLAGS.channel = 'char'459 FLAGS.batch_size = 32460 FLAGS.min_vocab_frequency = 1461 # --- Seq2Seq462 if 'char_seq2seq' in model_names:463 model_predictions.append(load_model_predictions())464 # --= CopyNet465 if 'char_copynet' in model_names:466 FLAGS.use_copy = True467 FLAGS.copy_fun = 'copynet'468 model_predictions.append(load_model_predictions())469 FLAGS.use_copy = False470 return model_names, model_predictions471def load_predictions(prediction_path, top_k, verbose=True):472 """473 Load model predictions (top_k per example) from disk.474 :param prediction_path: path to the decoding output475 We assume the file is of the format:476 1. The i-th line of the file contains predictions for example i in the dataset'477 2. Each line contains top-k predictions separated by "|||".478 :param top_k: Maximum number of predictions to read per example.479 :return: List of top k predictions.480 """481 if os.path.exists(prediction_path):482 with open(prediction_path) as f:483 prediction_list = []484 for line in f:485 predictions = line.split('|||')486 prediction_list.append(predictions[:top_k])487 else:488 if verbose:489 print('Warning: file not found: {}'.format(prediction_path))490 return None491 if verbose:492 print('{} predictions loaded from {}'.format(493 len(prediction_list), prediction_path))494 return prediction_list495def load_cached_correct_translations(data_dir, treat_empty_as_correct=False, verbose=False):496 """497 Load cached correct translations from disk.498 :return: nl -> template translation map, nl -> command translation map499 """500 command_translations = collections.defaultdict(set)501 template_translations = collections.defaultdict(set)502 eval_files = []503 for file_name in os.listdir(data_dir):504 if 'evaluations' in file_name and not file_name.endswith('base'):505 eval_files.append(file_name)506 for file_name in sorted(eval_files):507 manual_judgement_path = os.path.join(data_dir, file_name)508 with open(manual_judgement_path) as f:509 if verbose:510 print('reading cached evaluations from {}'.format(511 manual_judgement_path))512 reader = csv.DictReader(f)513 current_nl_key = ''514 for row in reader:515 if row['description']:516 current_nl_key = get_example_nl_key(row['description'])517 pred_cmd = row['prediction']518 if 'template' in row:519 pred_temp = row['template']520 else:521 pred_temp = data_tools.cmd2template(522 pred_cmd, loose_constraints=True)523 structure_eval = row['correct template']524 if treat_empty_as_correct:525 structure_eval = normalize_judgement(structure_eval)526 command_eval = row['correct command']527 if treat_empty_as_correct:528 command_eval = normalize_judgement(command_eval)529 if structure_eval == 'y':530 template_translations[current_nl_key].add(pred_temp)531 if command_eval == 'y':532 command_translations[current_nl_key].add(pred_cmd)533 print('{} template translations loaded'.format(len(template_translations)))534 print('{} command translations loaded'.format(len(command_translations)))535 return template_translations, command_translations536def load_cached_evaluations(model_dir, verbose=True):537 """538 Load cached evaluation results from disk.539 :param model_dir: Directory where the evaluation result file is stored.540 :param decode_sig: The decoding signature of the model being evaluated.541 :return: dictionaries storing the evaluation results.542 """543 structure_eval_results = {}544 command_eval_results = {}545 eval_files = []546 for file_name in os.listdir(model_dir):547 if 'evaluations' in file_name and not file_name.endswith('base'):548 eval_files.append(file_name)549 for file_name in sorted(eval_files):550 manual_judgement_path = os.path.join(model_dir, file_name)551 ser, cer = load_cached_evaluations_from_file(552 manual_judgement_path, verbose=verbose)553 for key in ser:554 structure_eval_results[key] = ser[key]555 for key in cer:556 command_eval_results[key] = cer[key]557 if verbose:558 print('{} structure evaluation results loaded'.format(559 len(structure_eval_results)))560 print('{} command evaluation results loaded'.format(561 len(command_eval_results)))562 return structure_eval_results, command_eval_results563def load_cached_evaluations_from_file(input_file, treat_empty_as_correct=False, verbose=True):564 structure_eval_results = {}565 command_eval_results = {}566 with open(input_file, encoding='utf-8') as f:567 if verbose:568 print('reading cached evaluations from {}'.format(input_file))569 reader = csv.DictReader(f)570 current_nl_key = ''571 for row in reader:572 if row['description']:573 current_nl_key = get_example_nl_key(row['description'])574 pred_cmd = row['prediction']575 if 'template' in row:576 pred_temp = row['template']577 else:578 pred_temp = data_tools.cmd2template(579 pred_cmd, loose_constraints=True)580 command_eval = row['correct command']581 if treat_empty_as_correct:582 command_eval = normalize_judgement(command_eval)583 command_example_key = '{}<NL_PREDICTION>{}'.format(584 current_nl_key, pred_cmd)585 if command_eval:586 command_eval_results[command_example_key] = command_eval587 structure_eval = row['correct template']588 if treat_empty_as_correct:589 structure_eval = normalize_judgement(structure_eval)590 structure_example_key = '{}<NL_PREDICTION>{}'.format(591 current_nl_key, pred_temp)592 if structure_eval:593 structure_eval_results[structure_example_key] = structure_eval594 return structure_eval_results, command_eval_results595def get_example_nl_key(nl):596 """597 Get the natural language description in an example with nuances removed.598 """599 tokens, _ = tokenizer.basic_tokenizer(nl)600 return ' '.join(tokens)601def get_example_cm_key(cm):602 """603 TODO: implement command normalization604 1. flag order normalization605 2. flag format normalization (long flag vs. short flag)606 3. remove flags whose effect does not matter607 """608 return cm609def normalize_judgement(x):610 if not x or x.lower() == 'y':611 return 'y'612 else:...

Full Screen

Full Screen

inter_annotator_agreement.py

Source:inter_annotator_agreement.py Github

copy

Full Screen

1"""2Compute the inter-annotator agreement.3"""4from __future__ import absolute_import5from __future__ import division6from __future__ import print_function7import collections8import csv9import os10import sys11from bashlint import data_tools12from eval.eval_tools import load_cached_evaluations_from_file13from eval.eval_tools import get_example_nl_key, get_example_cm_key14from eval.eval_tools import normalize_judgement15def iaa(a1, a2):16 assert(len(a1) == len(a2))17 num_agree = 018 for i in range(len(a1)):19 if a1[i].lower() == a2[i].lower():20 num_agree += 121 return float(num_agree) / len(a1)22def read_annotations(input_file):23 command_judgements, template_judgements = [], []24 with open(input_file) as f:25 reader = csv.DictReader(f)26 for row in reader:27 command_eval = normalize_judgement(row['correct command'].strip())28 template_eval = normalize_judgement(row['correct template'].strip())29 command_judgements.append(command_eval)30 template_judgements.append(template_eval)31 return command_judgements, template_judgements32def inter_annotator_agreement(input_files1, input_files2):33 command_judgements1, template_judgements1 = [], []34 command_judgements2, template_judgements2 = [], []35 for input_file in input_files1:36 cj, tj = read_annotations(input_file)37 command_judgements1.extend(cj)38 template_judgements1.extend(tj)39 for input_file in input_files2:40 cj, tj = read_annotations(input_file)41 command_judgements2.extend(cj)42 template_judgements2.extend(tj)43 print('IAA-F: {}'.format(iaa(command_judgements1, command_judgements2)))44 print('IAA-T: {}'.format(iaa(template_judgements1, template_judgements2)))45def combine_annotations_multi_files():46 """47 Combine multiple annotations files and discard the annotations that has a conflict.48 """49 input_dir = sys.argv[1]50 template_evals = {}51 command_evals = {}52 discarded_keys = set({})53 for in_csv in os.listdir(input_dir):54 in_csv_path = os.path.join(input_dir, in_csv)55 with open(in_csv_path) as f:56 reader = csv.DictReader(f)57 current_description = ''58 for row in reader:59 template_eval = normalize_judgement(row['correct template'])60 command_eval = normalize_judgement(row['correct command'])61 description = get_example_nl_key(row['description'])62 if description.strip():63 current_description = description64 else:65 description = current_description66 prediction = row['prediction']67 example_key = '{}<NL_PREDICTION>{}'.format(description, prediction)68 if example_key in template_evals and template_evals[example_key] != template_eval:69 discarded_keys.add(example_key)70 continue71 if example_key in command_evals and command_evals[example_key] != command_eval:72 discarded_keys.add(example_key)73 continue74 template_evals[example_key] = template_eval75 command_evals[example_key] = command_eval76 print('{} read ({} manually annotated examples, {} discarded)'.format(in_csv_path, len(template_evals), len(discarded_keys)))77 # Write to new file78 assert(len(template_evals) == len(command_evals))79 with open('manual_annotations.additional', 'w') as o_f:80 o_f.write('description,prediction,template,correct template,correct comand\n')81 for key in sorted(template_evals.keys()):82 if key in discarded_keys:83 continue84 description, prediction = key.split('<NL_PREDICTION>')85 template_eval = template_evals[example_key]86 command_eval = command_evals[example_key]87 pred_tree = data_tools.bash_parser(prediction)88 pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True)89 o_f.write('"{}","{}","{}",{},{}\n'.format(90 description.replace('"', '""'),91 prediction.replace('"', '""'),92 pred_temp.replace('"', '""'),93 template_eval,94 command_eval95 ))96 97def combine_annotations_multi_annotators():98 """99 Combine the annotations input by three annotators.100 :param input_file1: main annotation file 1.101 :param input_file2: main annotation file 2 (should contain the same number of102 lines as input_file1).103 :param input_file3: supplementary annotation file which contains annotations104 of lines in input_file1 and input_file2 that contain a disagreement.105 :param output_file: file that contains the combined annotations.106 """107 input_file1 = sys.argv[1]108 input_file2 = sys.argv[2]109 input_file3 = sys.argv[3]110 output_file = sys.argv[4]111 o_f = open(output_file, 'w')112 o_f.write('description,prediction,template,correct template,correct command,'113 'correct template A,correct command A,'114 'correct template B,correct command B,'115 'correct template C,correct command C\n')116 sup_structure_eval, sup_command_eval = load_cached_evaluations_from_file(117 input_file3, treat_empty_as_correct=True)118 with open(input_file1) as f1:119 with open(input_file2) as f2:120 reader1 = csv.DictReader(f1)121 reader2 = csv.DictReader(f2)122 current_desp = ''123 for row1, row2 in zip(reader1, reader2):124 row1_template_eval = normalize_judgement(row1['correct template'].strip())125 row1_command_eval = normalize_judgement(row1['correct command'].strip())126 row2_template_eval = normalize_judgement(row2['correct template'].strip())127 row2_command_eval = normalize_judgement(row2['correct command'].strip())128 if row1['description']:129 current_desp = row1['description'].strip()130 sc_key = get_example_nl_key(current_desp)131 pred_cmd = row1['prediction'].strip()132 if not pred_cmd:133 row1_template_eval, row1_command_eval = 'n', 'n'134 row2_template_eval, row2_command_eval = 'n', 'n'135 pred_temp = data_tools.cmd2template(pred_cmd, loose_constraints=True)136 structure_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_temp)137 command_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_cmd)138 row3_template_eval, row3_command_eval = None, None139 if structure_example_key in sup_structure_eval:140 row3_template_eval = sup_structure_eval[structure_example_key]141 if command_example_key in sup_command_eval:142 row3_command_eval = sup_command_eval[command_example_key]143 if row1_template_eval != row2_template_eval or row1_command_eval != row2_command_eval:144 if row1_template_eval != row2_template_eval:145 if row3_template_eval is None:146 print(structure_example_key)147 assert(row3_template_eval is not None)148 template_eval = row3_template_eval149 else:150 template_eval = row1_template_eval151 if row1_command_eval != row2_command_eval:152 # if row3_command_eval is None:153 # print(command_example_key)154 assert(row3_command_eval is not None)155 command_eval = row3_command_eval156 else:157 command_eval = row1_command_eval158 else:159 template_eval = row1_template_eval160 command_eval = row1_command_eval161 if row3_template_eval is None:162 row3_template_eval = ''163 if row3_command_eval is None:164 row3_command_eval = ''165 o_f.write('"{}","{}","{}",{},{},{},{},{},{},{},{}\n'.format(166 current_desp.replace('"', '""'), pred_cmd.replace('"', '""'), pred_temp.replace('"', '""'),167 template_eval, command_eval,168 row1_template_eval, row1_command_eval,169 row2_template_eval, row2_command_eval,170 row3_template_eval, row3_command_eval))171 o_f.close()172def print_error_analysis_sheet():173 input_file1 = sys.argv[1]174 input_file2 = sys.argv[2]175 input_file3 = sys.argv[3]176 output_file = sys.argv[4]177 o_f = open(output_file, 'w')178 o_f.write('description,model,prediction,correct template,correct command,'179 'correct template A,correct command A,'180 'correct template B,correct command B,'181 'correct template C,correct command C\n')182 sup_structure_eval, sup_command_eval = load_cached_evaluations_from_file(183 input_file3, treat_empty_as_correct=True)184 # for key in sup_structure_eval:185 # print(key)186 # print('------------------')187 with open(input_file1) as f1:188 with open(input_file2) as f2:189 reader1 = csv.DictReader(f1)190 reader2 = csv.DictReader(f2)191 current_desp = ''192 for row_id, (row1, row2) in enumerate(zip(reader1, reader2)):193 if row1['description']:194 current_desp = row1['description'].strip()195 model_name = row2['model']196 if not model_name in ['partial.token-copynet', 'tellina']:197 continue198 if row_id % 3 != 0:199 continue200 row1_template_eval = normalize_judgement(row1['correct template'].strip())201 row1_command_eval = normalize_judgement(row1['correct command'].strip())202 row2_template_eval = normalize_judgement(row2['correct template'].strip())203 row2_command_eval = normalize_judgement(row2['correct command'].strip())204 sc_key = get_example_nl_key(current_desp)205 pred_cmd = row1['prediction'].strip()206 if not pred_cmd:207 row1_template_eval, row1_command_eval = 'n', 'n'208 row2_template_eval, row2_command_eval = 'n', 'n'209 pred_temp = data_tools.cmd2template(pred_cmd, loose_constraints=True)210 structure_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_temp)211 command_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_cmd)212 row3_template_eval, row3_command_eval = None, None213 if structure_example_key in sup_structure_eval:214 row3_template_eval = sup_structure_eval[structure_example_key]215 if command_example_key in sup_command_eval:216 row3_command_eval = sup_command_eval[command_example_key]217 if row1_template_eval != row2_template_eval or row1_command_eval != row2_command_eval:218 if row1_template_eval != row2_template_eval:219 if row3_template_eval is None:220 print(pred_cmd, structure_example_key)221 assert (row3_template_eval is not None)222 template_eval = row3_template_eval223 else:224 template_eval = row1_template_eval225 if row1_command_eval != row2_command_eval:226 # if row3_command_eval is None:227 # print(command_example_key)228 assert (row3_command_eval is not None)229 command_eval = row3_command_eval230 else:231 command_eval = row1_command_eval232 else:233 template_eval = row1_template_eval234 command_eval = row1_command_eval235 if row3_template_eval is None:236 row3_template_eval = ''237 if row3_command_eval is None:238 row3_command_eval = ''239 o_f.write('"{}","{}","{}",{},{},{},{},{},{},{},{}\n'.format(240 current_desp.replace('"', '""'), model_name, pred_cmd.replace('"', '""'),241 template_eval, command_eval,242 row1_template_eval, row1_command_eval,243 row2_template_eval, row2_command_eval,244 row3_template_eval, row3_command_eval))245 o_f.close()246def compute_error_overlap():247 input_file = sys.argv[1]248 template_judgements = []249 command_judgements = []250 with open(input_file) as f:251 reader = csv.DictReader(f)252 for row in reader:253 template_eval = row['correct template']254 command_eval = row['correct command']255 if row['model'] == 'tellina':256 template_judgements.append([template_eval])257 command_judgements.append([command_eval])258 else:259 template_judgements[-1].append(template_eval)260 command_judgements[-1].append(command_eval)261 temp_error_hist = [0, 0, 0, 0]262 for t1, t2 in template_judgements:263 if t1 == 'y' and t2 == 'y':264 temp_error_hist[0] += 1265 elif t1 == 'y' and t2 == 'n':266 temp_error_hist[1] += 1267 elif t1 == 'n' and t2 == 'y':268 temp_error_hist[2] += 1269 elif t1 == 'n' and t2 == 'n':270 temp_error_hist[3] += 1271 print('Template Judgements:')272 print('\ty y: {}'.format(temp_error_hist[0]))273 print('\ty n: {}'.format(temp_error_hist[1]))274 print('\tn y: {}'.format(temp_error_hist[2]))275 print('\tn n: {}'.format(temp_error_hist[3]))276 command_error_hist = [0, 0, 0, 0]277 for c1, c2 in command_judgements:278 if c1 == 'y' and c2 == 'y':279 command_error_hist[0] += 1280 elif c1 == 'y' and c2 == 'n':281 command_error_hist[1] += 1282 elif c1 == 'n' and c2 == 'y':283 command_error_hist[2] += 1284 elif c1 == 'n' and c2 == 'n':285 command_error_hist[3] += 1286 print('Command Judgements"')287 print('\ty y: {}'.format(command_error_hist[0]))288 print('\ty n: {}'.format(command_error_hist[1]))289 print('\tn y: {}'.format(command_error_hist[2]))290 print('\tn n: {}'.format(command_error_hist[3]))291def compute_error_category():292 input_file = sys.argv[1]293 tellina_error_hist = collections.defaultdict(int)294 pc_error_hist = collections.defaultdict(int)295 with open(input_file) as f:296 reader = csv.DictReader(f)297 for row in reader:298 error_cat = row['error category']299 if error_cat:300 if row['model'] == 'tellina':301 tellina_error_hist[error_cat] += 1302 elif row['model'] == 'partial.token-copynet':303 pc_error_hist[error_cat] += 1304 else:305 raise ValueError306 print('Tellina errors:')307 for ec, freq in sorted(tellina_error_hist.items(), key=lambda x:x[1], reverse=True):308 print(ec, freq)309 print()310 print('Sub-token CopyNet errors:')311 for ec, freq in sorted(pc_error_hist.items(), key=lambda x:x[1], reverse=True):312 print(ec, freq)313def export_annotation_differences(input_file1, input_file2, output_file, command_header):314 o_f = open(output_file, 'w')315 o_f.write('description,{},correct template A,correct command A,correct template B,correct command B\n'.format(316 command_header))317 with open(input_file1) as f1:318 with open(input_file2) as f2:319 reader1 = csv.DictReader(f1)320 reader2 = csv.DictReader(f2)321 current_desp = ''322 desp_written = False323 for row1, row2 in zip(reader1, reader2):324 if row1['description']:325 current_desp = row1['description']326 desp_written = False327 if not row1[command_header]:328 continue329 row1_template_eval = normalize_judgement(row1['correct template'].strip())330 row1_command_eval = normalize_judgement(row1['correct command'].strip())331 row2_template_eval = normalize_judgement(row2['correct template'].strip())332 row2_command_eval = normalize_judgement(row2['correct command'].strip())333 if (row1_template_eval != row2_template_eval) or \334 (row1_command_eval != row2_command_eval):335 if not desp_written:336 o_f.write('"{}","{}",{},{},{},{}\n'.format(337 current_desp.replace('"', '""'), row1[command_header].replace('"', '""'),338 row1_template_eval, row1_command_eval, row2_template_eval, row2_command_eval))339 desp_written = True340 else:341 o_f.write(',"{}",{},{},{},{}\n'.format(row1[command_header].replace('"', '""'),342 row1_template_eval, row1_command_eval, row2_template_eval, row2_command_eval))343 o_f.close()344def main():345 # print_error_analysis_sheet()346 # combine_annotations_multi_annotators()347 # input_files1 = ['unreleased_files/manual.evaluations.test.stc.annotator.1.csv', 'unreleased_files/manual.evaluations.test.tellina.annotator.1.csv']348 # input_files2 = ['unreleased_files/manual.evaluations.test.stc.annotator.2.csv', 'unreleased_files/manual.evaluations.test.tellina.annotator.2.csv']349 # input_files1 = ['unreleased_files/NL-Cmd Judgement (Hamid) - pc.csv', 'unreleased_files/NL-Cmd Judgement (Hamid) - tellina.csv']350 # input_files2 = ['unreleased_files/NL-Cmd Judgement (Shridhar) - pc.csv', 'unreleased_files/NL-Cmd Judgement (Shridhar) - tellina.csv']351 # input_files1 = ['unreleased_files/manual.evaluations.dev.samples.annotator.1.csv']352 # input_files2 = ['unreleased_files/manual.evaluations.dev.samples.annotator.2.csv']353 # inter_annotator_agreement(input_files1, input_files2)354 # compute_error_overlap()355 # compute_error_category()356 combine_annotations_multi_files() 357if __name__ == '__main__':...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run play_selenium automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful