Best Python code snippet using yandex-tank
analysis.py
Source:analysis.py  
1"""Functions to plot and compute results for the literature review.2"""3import os4import logging5import logging.config6from collections import OrderedDict7import subprocess8import pandas as pd9import geopandas as gpd10import matplotlib.pyplot as plt11from matplotlib.ticker import MaxNLocator12import matplotlib.patches as patches13import matplotlib as mpl14import seaborn as sns15import numpy as np16from PIL import Image17from wordcloud import WordCloud, STOPWORDS18from graphviz import Digraph19import config as cfg20import utils as ut21# Set style, context and palette22sns.set_style(rc=cfg.axes_styles)23sns.set_context(rc=cfg.plotting_context)24sns.set_palette(cfg.palette)25for key, val in cfg.axes_styles.items():26    mpl.rcParams[key] = val27for key, val in cfg.plotting_context.items():28    mpl.rcParams[key] = val29# Initialize logger for saving results and stats. Use `logger.info('message')`30# to log results.31logging.config.dictConfig({32    'version': 1,33    'disable_existing_loggers': True,34})35logger = logging.getLogger()36log_savename = os.path.join(cfg.saving_config['savepath'], 'results') + '.log'37handler = logging.FileHandler(log_savename, mode='w')38formatter = logging.Formatter(39        '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')40handler.setFormatter(formatter)41logger.addHandler(handler)42logger.setLevel(logging.INFO)43def plot_prisma_diagram(save_cfg=cfg.saving_config):44    """Plot diagram showing the number of selected articles.45    TODO:46    - Use first two colors of colormap instead of gray47    - Reduce white space48    - Reduce arrow width49    """50    # save_format = save_cfg['format'] if isinstance(save_cfg, dict) else 'svg'51    save_format = 'pdf'52    # save_format = 'eps'53    size = '{},{}!'.format(0.5 * save_cfg['page_width'], 0.2 * save_cfg['page_height'])54    dot = Digraph(format=save_format)55    dot.attr('graph', rankdir='TB', overlap='false', size=size, margin='0')56    dot.attr('node', fontname='Liberation Sans', fontsize=str(9), shape='box', 57             style='filled', margin='0.15,0.07', penwidth='0.1')58    # dot.attr('edge', arrowsize=0.5)59    fillcolor = 'gray98'60    dot.node('A', 'PubMed (n=39)\nGoogle Scholar (n=409)\narXiv (n=105)', 61             fillcolor='gray95')62    dot.node('B', 'Articles identified\nthrough database\nsearching\n(n=553)', 63             fillcolor=fillcolor)64    # dot.node('B2', 'Excluded\n(n=446)', fillcolor=fillcolor)65    dot.node('C', 'Articles after content\nscreening and\nduplicate removal\n(n=105) ', 66             fillcolor=fillcolor)67    dot.node('D', 'Articles included in\nthe analysis\n(n=154)', 68             fillcolor=fillcolor)69    dot.node('E', 'Additional articles\nidentified through\nbibliography search\n(n=49)', 70             fillcolor=fillcolor)71    dot.edge('B', 'C')72    # dot.edge('B', 'B2')73    dot.edge('C', 'D')74    dot.edge('E', 'D')75    if save_cfg is not None:76        fname = os.path.join(save_cfg['savepath'], 'prisma_diagram')77        dot.render(filename=fname, view=False, cleanup=False)78                79    return dot80def plot_domain_tree(df, first_box='DL + EEG studies', min_font_size=10, 81                     max_font_size=14, max_char=16, min_n_items=2, 82                     postprocess=True, save_cfg=cfg.saving_config):83    """Plot tree graph showing the breakdown of study domains.84    Args:85        df (pd.DataFrame): data items table86    Keyword Args:87        first_box (str): text of the first box88        min_font_size (int): minimum font size89        max_font_size (int): maximum font size90        max_char (int): maximum number of characters per line91        min_n_items (int): if a node has less than this number of elements, 92            put it inside a node called "Others".93        postpocess (bool): if True, convert PNG to EPS using inkscape in a 94            system call.95        save_cfg (dict or None):96    97    Returns:98        (graphviz.Digraph): graphviz object99    NOTES:100    - To unflatten automatically, apply the following on the .dot file:101        >> unflatten -l 3 -c 10 dom_domains_tree | dot -Teps -o domains_unflattened.eps102    - To produce a circular version instead (uses space more efficiently):103        >> neato -Tpdf dom_domains_tree -o domains_neato.pdf104    """105    df = df[['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4']].copy()106    df = df[~df['Domain 1'].isnull()]107    df[df == ' '] = None108    n_samples, n_levels = df.shape109    format = save_cfg['format'] if isinstance(save_cfg, dict) else 'svg'110    size = '{},{}!'.format(save_cfg['page_width'], 0.7 * save_cfg['page_height'])111    112    dot = Digraph(format=format)113    dot.attr('graph', rankdir='TB', overlap='false', ratio='fill', size=size)  # LR (left to right), TB (top to bottom)114    dot.attr('node', fontname='Liberation Sans', fontsize=str(max_font_size), 115             shape='box', style='filled, rounded',  margin='0.2,0.01', 116             penwidth='0.5')117    dot.node('A', '{}\n({})'.format(first_box, len(df)), 118             fillcolor='azure')119    120    min_sat, max_sat = 0.05, 0.4121    122    sub_df = df['Domain 1'].value_counts()123    n_categories = len(sub_df)124    for i, (d1, count1) in enumerate(sub_df.iteritems()):125        node1, hue = ut.make_box(126            dot, d1, max_char, count1, n_samples, 0, n_levels, min_sat, max_sat, 127            min_font_size, max_font_size, 'A', counter=i, 128            n_categories=n_categories)129        130        for d2, count2 in df[df['Domain 1'] == d1]['Domain 2'].value_counts().iteritems():131            node2, _ = ut.make_box(132                dot, d2, max_char, count2, n_samples, 1, n_levels, min_sat, 133                max_sat, min_font_size, max_font_size, node1, hue=hue)134            135            n_others3 = 0136            for d3, count3 in df[df['Domain 2'] == d2]['Domain 3'].value_counts().iteritems():137                if isinstance(d3, str) and d3 != 'TBD':138                    if count3 < min_n_items:139                        n_others3 += 1140                    else:141                        node3, _ = ut.make_box(142                            dot, d3, max_char, count3, n_samples, 2, n_levels,143                            min_sat, max_sat, min_font_size, max_font_size, 144                            node2, hue=hue)145                        n_others4 = 0146                        for d4, count4 in df[df['Domain 3'] == d3]['Domain 4'].value_counts().iteritems():147                            if isinstance(d4, str) and d4 != 'TBD':148                                if count4 < min_n_items:149                                    n_others4 += 1150                                else:151                                    ut.make_box(152                                        dot, d4, max_char, count4, n_samples, 3, 153                                        n_levels, min_sat, max_sat, min_font_size, 154                                        max_font_size, node3, hue=hue)155                        if n_others4 > 0:156                            ut.make_box(157                                dot, 'Others', max_char, n_others4, n_samples, 158                                3, n_levels, min_sat, max_sat, min_font_size, 159                                max_font_size, node3, hue=hue, 160                                node_name=node3+'others')161            if n_others3 > 0:162                ut.make_box(163                    dot, 'Others', max_char, n_others3, n_samples, 2, n_levels,164                    min_sat, max_sat, min_font_size, max_font_size, node2, hue=hue, 165                    node_name=node2+'others')166    if save_cfg is not None:167        fname = os.path.join(save_cfg['savepath'], 'dom_domains_tree')168        dot.render(filename=fname, cleanup=False)169        if postprocess:170            subprocess.call(171                ['neato', '-Tpdf', fname, '-o', fname + '.pdf'])172                173    return dot174def plot_model_comparison(df, save_cfg=cfg.saving_config):175    """Plot bar graph showing the types of baseline models used.176    """177    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 2, 178                                    save_cfg['text_height'] / 5))179    sns.countplot(y=df['Baseline model type'].dropna(axis=0), ax=ax)180    ax.set_xlabel('Number of papers')181    ax.set_ylabel('')182    plt.tight_layout()183    model_prcts = df['Baseline model type'].value_counts() / df.shape[0] * 100184    logger.info('% of studies that used at least one traditional baseline: {}'.format(185        model_prcts['Traditional pipeline'] + model_prcts['DL & Trad.']))186    logger.info('% of studies that used at least one deep learning baseline: {}'.format(187        model_prcts['DL'] + model_prcts['DL & Trad.']))188    logger.info('% of studies that did not report baseline comparisons: {}'.format(189        model_prcts['None']))190    if save_cfg is not None:191        fname = os.path.join(save_cfg['savepath'], 'model_comparison')192        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)193    return ax194def plot_performance_metrics(df, cutoff=3, eeg_clf=None, 195                             save_cfg=cfg.saving_config):196    """Plot bar graph showing the types of performance metrics used.197    Args:198        df (DataFrame)199    Keyword Args:200        cutoff (int): Metrics with less than this number of papers will be cut201            off from the bar graph.202        eeg_clf (bool): If True, only use studies that focus on EEG 203            classification. If False, only use studies that did not focus on 204            EEG classification. If None, use all studies.205        save_cfg (dict)206    Assumptions, simplifications:207    - Rates have been simplified (e.g., "false positive rate" -> "false positives")208    - RMSE and MSE have been merged under MSE209    - Training/testing times have been simplified to "time"210    - Macro f1-score === f1=score211    """212    if eeg_clf is True:213        metrics = df[df['Domain 1'] == 'Classification of EEG signals'][214            'Performance metrics (clean)']215    elif eeg_clf is False:216        metrics = df[df['Domain 1'] != 'Classification of EEG signals'][217            'Performance metrics (clean)']218    elif eeg_clf is None:219        metrics = df['Performance metrics (clean)']220    metrics = metrics.str.split(',').apply(ut.lstrip)221    metric_per_article = list()222    for i, metric_list in metrics.iteritems():223        for m in metric_list:224            metric_per_article.append([i, m])225    metrics_df = pd.DataFrame(metric_per_article, columns=['paper nb', 'metric'])226    # Replace equivalent terms by standardized term227    equivalences = {'selectivity': 'specificity',228                    'true negative rate': 'specificity',229                    'sensitivitiy': 'sensitivity',230                    'sensitivy': 'sensitivity',231                    'recall': 'sensitivity',232                    'hit rate': 'sensitivity', 233                    'true positive rate': 'sensitivity',234                    'sensibility': 'sensitivity',235                    'positive predictive value': 'precision',236                    'f-measure': 'f1-score',237                    'f-score': 'f1-score',238                    'f1-measure': 'f1-score',239                    'macro f1-score': 'f1-score',240                    'macro-averaging f1-score': 'f1-score',241                    'kappa': 'cohen\'s kappa',242                    'mae': 'mean absolute error',243                    'false negative rate': 'false negatives',244                    'fpr': 'false positives',245                    'false positive rate': 'false positives',246                    'false prediction rate': 'false positives',247                    'roc': 'ROC curves',248                    'roc auc': 'ROC AUC',249                    'rmse': 'mean squared error',250                    'mse': 'mean squared error',251                    'training time': 'time',252                    'testing time': 'time',253                    'test error': 'error'}254    metrics_df = metrics_df.replace(equivalences)255    metrics_df['metric'] = metrics_df['metric'].apply(lambda x: x[0].upper() + x[1:])256    # Removing low count categories257    metrics_counts = metrics_df['metric'].value_counts()258    metrics_df = metrics_df[metrics_df['metric'].isin(259        metrics_counts[(metrics_counts >= cutoff)].index)]260    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 2, 261                                    save_cfg['text_height'] / 5))262    ax = sns.countplot(y='metric', data=metrics_df, 263                       order=metrics_df['metric'].value_counts().index)264    ax.set_xlabel('Number of papers')265    ax.set_ylabel('')266    plt.tight_layout()267    if save_cfg is not None:268        savename = 'performance_metrics'269        if eeg_clf is True:270            savename += '_eeg_clf'271        elif eeg_clf is False:272            savename += '_not_eeg_clf'273        fname = os.path.join(save_cfg['savepath'], savename)274        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)275    return ax276def plot_reported_results(df, data_items_df=None, save_cfg=cfg.saving_config):277    """Plot figures to described the reported results in the studies.278    Args:279        df (DataFrame): contains reported results (second tab in spreadsheet)280    Keyword Args:281        data_items_df (DataFrame): contains data items (first tab in spreadsheet)282        save_cfg (dict)283    Returns:284        (list): list of axes to created figures285    TODO:286    - This function is starting to be a bit too big. Should probably split it up.287    """288    acc_df = df[df['Metric'] == 'accuracy']  # Extract accuracy rows only289    # Create new column that contains both citation and task information290    acc_df.loc[:, 'citation_task'] = acc_df[['Citation', 'Task']].apply(291        lambda x: ' ['.join(x) + ']', axis=1)292    # Create a new column with the year293    acc_df.loc[:, 'year'] = acc_df['Citation'].apply(294        lambda x: int(x[x.find('2'):x.find('2') + 4]))295    # Order by average proposed model accuracy296    acc_ind = acc_df[acc_df['model_type'] == 'Proposed'].groupby(297        'Citation').mean().sort_values(by='Result').index298    acc_df.loc[:, 'Citation'] = acc_df['Citation'].astype('category')299    acc_df['Citation'].cat.set_categories(acc_ind, inplace=True)300    acc_df = acc_df.sort_values(['Citation'])301    # Only keep 2 best per task and model type302    acc2_df = acc_df.sort_values(303        ['Citation', 'Task', 'model_type', 'Result'], ascending=True).groupby(304            ['Citation', 'Task', 'model_type']).tail(2)305    axes = list()306    axes.append(_plot_results_per_citation_task(acc2_df, save_cfg))307    # Only keep the maximum accuracy per citation & task308    best_df = acc_df.groupby(309        ['Citation', 'Task', 'model_type'])[310            'Result'].max().reset_index()311    # Only keep citations/tasks that have a traditional baseline312    best_df = best_df.groupby(['Citation', 'Task']).filter(313        lambda x: 'Baseline (traditional)' in x.values).reset_index()314    # Add back architecture315    best_df = pd.merge(316        best_df, acc_df[['Citation', 'Task', 'model_type', 'Result', 'Architecture']], 317        how='inner').drop_duplicates()  # XXX: why are there duplicates?318    # Compute difference between proposed and traditional baseline319    def acc_diff_and_arch(x):320        diff = x[x['model_type'] == 'Proposed']['Result'].iloc[0] - \321               x[x['model_type'] == 'Baseline (traditional)']['Result'].iloc[0]322        arch = x[x['model_type'] == 'Proposed']['Architecture']323        return pd.Series(diff, arch)324    diff_df = best_df.groupby(['Citation', 'Task']).apply(325        acc_diff_and_arch).reset_index()326    diff_df = diff_df.rename(columns={0: 'acc_diff'})327    axes.append(_plot_results_accuracy_diff_scatter(diff_df, save_cfg))328    axes.append(_plot_results_accuracy_diff_distr(diff_df, save_cfg))329    # Pivot dataframe to plot proposed vs. baseline accuracy as a scatterplot330    best_df['citation_task'] = best_df[['Citation', 'Task']].apply(331        lambda x: ' ['.join(x) + ']', axis=1)332    acc_comparison_df = best_df.pivot_table(333        index='citation_task', columns='model_type', values='Result')334    axes.append(_plot_results_accuracy_comparison(acc_comparison_df, save_cfg))335    if data_items_df is not None:336        domains_df = data_items_df.filter(337            regex='(?=Domain*|Citation|Main domain|Journal / Origin|Dataset name|'338                    'Data - samples|Data - time|Data - subjects|Preprocessing \(clean\)|'339                    'Artefact handling \(clean\)|Features \(clean\)|Architecture \(clean\)|'340                    'Layers \(clean\)|Regularization \(clean\)|Optimizer \(clean\)|'341                    'Intra/Inter subject|Training procedure)')342        # Concatenate domains into one string343        def concat_domains(x):344            domain = ''345            for i in x[1:]:346                if isinstance(i, str):347                    domain += i + '/'348            return domain[:-1]349        domains_df.loc[:, 'domain'] = data_items_df.filter(350            regex='(?=Domain*)').apply(concat_domains, axis=1)351        diff_domain_df = diff_df.merge(domains_df, on='Citation', how='left')352        diff_domain_df = diff_domain_df.sort_values(by='domain')353        diff_domain_df.loc[:, 'arxiv'] = diff_domain_df['Journal / Origin'] == 'Arxiv'354        axes.append(_plot_results_accuracy_per_domain(355            diff_domain_df, diff_df, save_cfg))356        axes.append(_plot_results_stats_impact_on_acc_diff(357            diff_domain_df, save_cfg))358        axes.append(_compute_acc_diff_for_preprints(diff_domain_df, save_cfg))359        360    return axes361def _plot_results_per_citation_task(results_df, save_cfg):362    """Plot scatter plot of accuracy for each condition and task.363    """364    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 365                                    save_cfg['text_height'] * 1.3))366    # figsize = plt.rcParams.get('figure.figsize')367    # fig, ax = plt.subplots(figsize=(figsize[0], figsize[1] * 4))368    # Need to make the graph taller otherwise the y axis labels are on top of369    # each other.370    sns.catplot(y='citation_task', x='Result', hue='model_type', data=results_df, 371                ax=ax)372    ax.set_xlabel('accuracy')373    ax.set_ylabel('')374    plt.tight_layout()375    if save_cfg is not None:376        savename = 'reported_results'377        fname = os.path.join(save_cfg['savepath'], savename)378        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)379    return ax380def _plot_results_accuracy_diff_scatter(results_df, save_cfg):381    """Plot difference in accuracy for each condition/task as a scatter plot.382    """383    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 384                                    save_cfg['text_height'] * 1.3))385    # figsize = plt.rcParams.get('figure.figsize')386    # fig, ax = plt.subplots(figsize=(figsize[0], figsize[1] * 2))387    sns.catplot(y='Task', x='acc_diff', data=results_df, ax=ax)388    ax.set_xlabel('Accuracy difference')389    ax.set_ylabel('')390    ax.axvline(0, c='k', alpha=0.2)391    plt.tight_layout()392    if save_cfg is not None:393        savename = 'reported_accuracy_diff_scatter'394        fname = os.path.join(save_cfg['savepath'], savename)395        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)396    return ax397def _plot_results_accuracy_diff_distr(results_df, save_cfg):398    """Plot the distribution of difference in accuracy.399    """400    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 401                                    save_cfg['text_height'] * 0.5))402    sns.distplot(results_df['acc_diff'], kde=False, rug=True, ax=ax)403    ax.set_xlabel('Accuracy difference')404    ax.set_ylabel('Number of studies')405    plt.tight_layout()406    if save_cfg is not None:407        savename = 'reported_accuracy_diff_distr'408        fname = os.path.join(save_cfg['savepath'], savename)409        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)410    return ax411def _plot_results_accuracy_comparison(results_df, save_cfg):412    """Plot the comparison between the best model and best baseline.413    """414    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 415                                    save_cfg['text_height'] * 0.5))416    sns.scatterplot(data=results_df, x='Baseline (traditional)', y='Proposed', 417                    ax=ax)418    ax.plot([0, 1.1], [0, 1.1], c='k', alpha=0.2)419    plt.axis('square')420    ax.set_xlim([0, 1.1])421    ax.set_ylim([0, 1.1])422    plt.tight_layout()423    if save_cfg is not None:424        savename = 'reported_accuracy_comparison'425        fname = os.path.join(save_cfg['savepath'], savename)426        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)427    return ax428def _plot_results_accuracy_per_domain(results_df, diff_df, save_cfg):429    """Make scatterplot + boxplot to show accuracy difference by domain.430    """431    fig, axes = plt.subplots(432        nrows=2, ncols=1, sharex=True, 433        figsize=(save_cfg['text_width'], save_cfg['text_height'] / 3), 434        gridspec_kw = {'height_ratios':[5, 1]})435    results_df['Main domain'] = results_df['Main domain'].apply(436        ut.wrap_text, max_char=20)437    sns.catplot(y='Main domain', x='acc_diff', s=3, jitter=True, 438                data=results_df, ax=axes[0])439    axes[0].set_xlabel('')440    axes[0].set_ylabel('')441    axes[0].axvline(0, c='k', alpha=0.2)442    sns.boxplot(x='acc_diff', data=diff_df, ax=axes[1])443    sns.swarmplot(x='acc_diff', data=diff_df, color="0", size=2, ax=axes[1])444    axes[1].axvline(0, c='k', alpha=0.2)445    axes[1].set_xlabel('Accuracy difference')446    fig.subplots_adjust(wspace=0, hspace=0.02)447    plt.tight_layout()448    logger.info('Number of studies included in the accuracy improvement analysis: {}'.format(449        results_df.shape[0]))450    median = diff_df['acc_diff'].median()451    iqr = diff_df['acc_diff'].quantile(.75) - diff_df['acc_diff'].quantile(.25)452    logger.info('Median gain in accuracy: {:.6f}'.format(median))453    logger.info('Interquartile range of the gain in accuracy: {:.6f}'.format(iqr))454    best_improvement = diff_df.nlargest(3, 'acc_diff')455    logger.info('Best improvement in accuracy: {}, in {}'.format(456        best_improvement['acc_diff'].values[0], 457        best_improvement['Citation'].values[0]))458    logger.info('Second best improvement in accuracy: {}, in {}'.format(459        best_improvement['acc_diff'].values[1], 460        best_improvement['Citation'].values[1]))461    logger.info('Third best improvement in accuracy: {}, in {}'.format(462        best_improvement['acc_diff'].values[2], 463        best_improvement['Citation'].values[2]))464    if save_cfg is not None:465        savename = 'reported_accuracy_per_domain'466        fname = os.path.join(save_cfg['savepath'], savename)467        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)468    return axes469def _plot_results_stats_impact_on_acc_diff(results_df, save_cfg):470    """Run statistical analysis to see which data items correlate with acc diff.471    NOTE: This analysis is not perfectly accurate as there are several papers 472        which contrasted results based on data items (e.g., testing the impact473        of number of layers on performance), but our summaries are not at this474        level of granularity. Therefore the results are not to be taken at face475        value.476    """477    binary_data_items = {'Preprocessing (clean)': ['Yes', 'No'],478                         'Artefact handling (clean)': ['Yes', 'No'],479                         'Features (clean)': ['Raw EEG', 'Frequency-domain'],480                         'Regularization (clean)': ['Yes', 'N/M'],481                         'Intra/Inter subject': ['Intra', 'Inter']}482    multiclass_data_items = ['Architecture',  # Architecture (clean)',483                             'Optimizer (clean)']484    continuous_data_items = {'Layers (clean)': False,485                             'Data - subjects': True,486                             'Data - time': True,487                             'Data - samples': True}488    results = dict()489    for key, val in binary_data_items.items():490        results[key] = ut.run_mannwhitneyu(results_df, key, val, plot=True)491    for i in multiclass_data_items:492        results[i] = ut.run_kruskal(results_df, i, plot=True)493    for i in continuous_data_items:494        single_df = ut.keep_single_valued_rows(results_df, i)495        single_df = single_df[single_df[i] != 'N/M']496        single_df[i] = single_df[i].astype(float)497        results[i] = ut.run_spearmanr(single_df, i, log=val, plot=True)498    499    stats_df =  pd.DataFrame(results).T500    logger.info('Results of statistical tests on impact of data items:\n{}'.format(501        stats_df))502    # Categorical plot for each "significant" data item503    significant_items = stats_df[stats_df['pvalue'] < 0.05].index504    if save_cfg is not None and len(significant_items) > 0:505        for i in significant_items:506            savename = 'stat_impact_{}_on_acc_diff'.format(507                i.replace(' ', '_').replace('/', '_'))508            fname = os.path.join(save_cfg['savepath'], savename)509            stats_df.loc[i, 'fig'].savefig(510                fname + '.' + save_cfg['format'], **save_cfg)511    return None512def _compute_acc_diff_for_preprints(results_df, save_cfg):513    """Analyze the acc diff for preprints vs. peer-reviewed articles.514    """515    results_df['preprint'] = results_df['Journal / Origin'].isin(['Arxiv', 'BioarXiv'])516    preprints = results_df.groupby('Citation').first()['preprint'].value_counts()517    logger.info(518        'Number of preprints included in the accuracy difference comparison: '519        '{}/{} papers'.format(preprints[True], sum(preprints)))520    logger.info('Median acc diff for preprints vs. non-preprint:\n{}'.format(521        results_df.groupby('preprint').median()['acc_diff']))522    results = ut.run_mannwhitneyu(results_df, 'preprint', [True, False])523    logger.info('Mann-Whitney test on preprint vs. not preprint: {:0.3f}'.format(524        results['pvalue']))525    return results526def generate_wordcloud(df, save_cfg=cfg.saving_config):527    brain_mask = np.array(Image.open("./img/brain_stencil.png"))528    def transform_format(val):529        if val == 0:530            return 255531        else:532            return val533    text = (df['Title']).to_string()534    stopwords = set(STOPWORDS)535    stopwords.add("using")536    stopwords.add("based")537    wc = WordCloud(538        background_color="white", max_words=2000, max_font_size=50, mask=brain_mask,539        stopwords=stopwords, contour_width=1, contour_color='steelblue')540    wc.generate(text)541    # store to file542    if save_cfg is not None:543        fname = os.path.join(save_cfg['savepath'], 'DL-EEG_WordCloud')544        wc.to_file(fname + '.' + save_cfg['format']) #, **save_cfg)545def plot_model_inspection_and_table(df, cutoff=1, save_cfg=cfg.saving_config):546    """Make bar graph and table listing method inspection techniques.547    Args:548        df (DataFrame)549    Keyword Args:550        cutoff (int): Metrics with less than this number of papers will be cut551            off from the bar graph.552        save_cfg (dict)553    """554    df['inspection_list'] = df[555        'Model inspection (clean)'].str.split(',').apply(ut.lstrip)556    inspection_per_article = list()557    for i, items in df[['Citation', 'inspection_list']].iterrows():558        for m in items['inspection_list']:559            inspection_per_article.append([i, items['Citation'], m])560            561    inspection_df = pd.DataFrame(562        inspection_per_article, 563        columns=['paper nb', 'Citation', 'inspection method'])564    # Remove "no" entries, because they make it really hard to see the 565    # actual distribution566    n_nos = inspection_df['inspection method'].value_counts()['no']567    n_papers = inspection_df.shape[0]568    logger.info('Number of papers without model inspection method: {}'.format(n_nos))569    inspection_df = inspection_df[inspection_df['inspection method'] != 'no']570    # # Replace "no" by "None"571    # inspection_df['inspection method'][572    #     inspection_df['inspection method'] == 'no'] = 'None'573    # Removing low count categories574    inspection_counts = inspection_df['inspection method'].value_counts()575    inspection_df = inspection_df[inspection_df['inspection method'].isin(576        inspection_counts[(inspection_counts >= cutoff)].index)]577    578    inspection_df['inspection method'] = inspection_df['inspection method'].apply(579        lambda x: x.capitalize())580    print(inspection_df['inspection method'])581    # Making table582    inspection_table = inspection_df.groupby(['inspection method'])[583        'Citation'].apply(list)584    order = inspection_df['inspection method'].value_counts().index585    inspection_table = inspection_table.reindex(order)586    inspection_table = inspection_table.apply(lambda x: r'\cite{' + ', '.join(x) + '}')587    with open(os.path.join(save_cfg['table_savepath'], 'inspection_methods.tex'), 'w') as f:588        with pd.option_context("max_colwidth", 1000):589            f.write(inspection_table.to_latex(escape=False))590    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 3, 591                                    save_cfg['text_height'] / 2))592    ax = sns.countplot(y='inspection method', data=inspection_df, 593                    order=inspection_df['inspection method'].value_counts().index)594    ax.set_xlabel('Number of papers')595    ax.set_ylabel('')596    ax.xaxis.set_major_locator(MaxNLocator(integer=True))597    plt.tight_layout()598    logger.info('% of studies that used model inspection techniques: {}'.format(599        100 - 100 * (n_nos / n_papers)))600    if save_cfg is not None:601        savename = 'model_inspection'602        fname = os.path.join(save_cfg['savepath'], savename)603        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)604    return ax605def plot_type_of_paper(df, save_cfg=cfg.saving_config):606    """Plot bar graph showing the type of each paper (journal, conference, etc.).607    """608    # Move supplements to journal paper category for the plot (a value of one is609    # not visible on a bar graph).610    df_plot = df.copy()611    df_plot.loc[df['Type of paper'] == 'Supplement', :] = 'Journal'612    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4, 613                                    save_cfg['text_height'] / 5))614    sns.countplot(x=df_plot['Type of paper'], ax=ax)615    ax.set_xlabel('')616    ax.set_ylabel('Number of papers')617    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)618    plt.tight_layout()619    counts = df['Type of paper'].value_counts()620    logger.info('Number of journal papers: {}'.format(counts['Journal']))621    logger.info('Number of conference papers: {}'.format(counts['Conference']))622    logger.info('Number of preprints: {}'.format(counts['Preprint']))623    logger.info('Number of papers that were initially published as preprints: '624                '{}'.format(df[df['Type of paper'] != 'Preprint'][625                    'Preprint first'].value_counts()['Yes']))626    if save_cfg is not None:627        fname = os.path.join(save_cfg['savepath'], 'type_of_paper')628        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)629    return ax630def plot_country(df, save_cfg=cfg.saving_config):631    """Plot bar graph showing the country of the first author's affiliation.632    """633    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 3, 634                                    save_cfg['text_height'] / 5))635    sns.countplot(x=df['Country'], ax=ax,636                order=df['Country'].value_counts().index)637    ax.set_ylabel('Number of papers')638    ax.set_xlabel('')639    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)640    plt.tight_layout()641    top3 = df['Country'].value_counts().index[:3]642    logger.info('Top 3 countries of first author affiliation: {}'.format(top3.values))643    if save_cfg is not None:644        fname = os.path.join(save_cfg['savepath'], 'country')645        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)646    return ax647def plot_countrymap(dfx, postprocess=True, save_cfg=cfg.saving_config):648    """Plot world map with colour indicating number of papers.649    Plot a world map where the colour of each country indicates how many papers650    were published in which the first author's affiliation was from that country.651    When saved as .eps this figure is well over the 6 MB limit allowed by arXiv.652    To solve this, we first save it as a .png (with high enough dpi), then use653    inkscape to convert it to .eps (leading to a file of ~1.6 MB):654    >> inkscape countrymap.png --export-eps=countrymap.eps655    Keyword Args:656        postpocess (bool): if True, convert PNG to EPS using inkscape in a 657            system call.658    """659    dirname = os.path.dirname(__file__)660    shapefile = os.path.join(dirname, '../img/countries/ne_10m_admin_0_countries.shp')661    gdf = gpd.read_file(shapefile)[['ADMIN', 'geometry']] #.to_crs('+proj=robin')662    # gdf = gdf.to_crs(epsg=4326)663    gdf.crs = '+init=epsg:4326'664    dfx = dfx.Country.value_counts().reset_index().rename(665        columns={'index': 'Country', 'Country': 'Count'})666    #print("Renaming Exceptions!")667    #print(dfx.loc[~dfx['Country'].isin(gdf['ADMIN'])])668    # Exception #1 - USA: United States of America669    dfx.loc[dfx['Country'] == 'USA', 'Country'] = 'United States of America'670    # Exception #2 - UK: United Kingdom671    dfx.loc[dfx['Country'] == 'UK', 'Country'] = 'United Kingdom'672    # Exception #3 - Bosnia: Bosnia and Herzegovina673    dfx.loc[dfx['Country'] == 'Bosnia', 'Country'] = 'Bosnia and Herzegovina'674    if len(dfx.loc[~dfx['Country'].isin(gdf['ADMIN'])]) > 0:675        print("## ERROR ## - Unhandled Countries!")676    # Adding 0 to all other countries!677    gdf['Count'] = 0678    for c in gdf['ADMIN']:679        if any(dfx['Country'].str.contains(c)):680            gdf.loc[gdf['ADMIN'] == c, 'Count'] = dfx[681                dfx['Country'].str.contains(c)]['Count'].values[0]682        else:683            gdf.loc[gdf['ADMIN'] == c, 'Count'] = 0684    # figsize = (16, 10)685    figsize = (save_cfg['text_width'], save_cfg['text_height'] / 2)686    ax = gdf.plot(column='Count', figsize=figsize, cmap='Blues', 687                  scheme='Fisher_Jenks', k=10, legend=True, edgecolor='k',688                  linewidth=0.3, categorical=False, vmin=0,689                  legend_kwds={'loc': 'lower left', 'title': 'Number of studies',690                               'framealpha': 1},691                  rasterized=False)692    # Remove floating points in legend693    leg = ax.get_legend()694    for t in leg.get_texts():695        t.set_text(t.get_text().replace('.00', ''))696    ax.set_axis_off()697    fig = ax.get_figure()698    plt.tight_layout()699    700    if save_cfg is not None:701        fname = os.path.join(save_cfg['savepath'], 'countrymap')702        save_cfg2 = save_cfg.copy()703        save_cfg2['dpi'] = 1000704        save_cfg2['format'] = 'png'705        fig.savefig(fname + '.png', **save_cfg2)706        if postprocess:707            subprocess.call(708                ['inkscape', fname + '.png', '--export-eps=' + fname + '.eps'])709    return ax710def compute_prct_statistical_tests(df):711    """Compute the number of studies that used statistical tests.712    """713    prct = 100 - 100 * df['Statistical analysis of performance'].value_counts(714        )['No'] / df.shape[0]715    logger.info('% of studies that used statistical test: {}'.format(prct))716def make_domain_table(df, save_cfg=cfg.saving_config):717    """Make domain table that contains every reference.718    """719    # Replace NaNs by ' ' in 'Domain 3' and 'Domain 4' columns720    df = ut.replace_nans_in_column(df, 'Domain 3', replace_by=' ')721    df = ut.replace_nans_in_column(df, 'Domain 4', replace_by=' ')722    cols = ['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4', 'Architecture (clean)']723    df[cols] = df[cols].applymap(ut.tex_escape)724    # Make tuple of first 2 domain levels725    domains_df = df.groupby(cols)['Citation'].apply(list).apply(726        lambda x: '\cite{' + ', '.join(x) + '}').unstack()727    domains_df = domains_df.applymap(728        lambda x: ' ' if isinstance(x, float) and np.isnan(x) else x)729    fname = os.path.join(save_cfg['table_savepath'], 'domains_architecture_table.tex')730    with open(fname, 'w') as f:731        with pd.option_context("max_colwidth", 1000):732            f.write(domains_df.to_latex(733                escape=False, 734                column_format='p{1.5cm}' * 4 + 'p{0.6cm}' * domains_df.shape[1]))735def plot_preprocessing_proportions(df, save_cfg=cfg.saving_config):736    """Plot proportions for preprocessing-related data items.737    """738    data = dict()739    data['(a) Preprocessing of EEG data'] = df[740         'Preprocessing (clean)'].value_counts().to_dict()741    data['(b) Artifact handling'] = df[742         'Artefact handling (clean)'].value_counts().to_dict()743    data['(c) Extracted features'] = df[744         'Features (clean)'].value_counts().to_dict()745    fig, ax = ut.plot_multiple_proportions(746        data, print_count=5, respect_order=['Yes', 'No', 'Other', 'N/M'],747        figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] / 7 * 2))748    749    if save_cfg is not None:750        fname = os.path.join(save_cfg['savepath'], 'preprocessing')751        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)752    return ax753def plot_hyperparams_proportions(df, save_cfg=cfg.saving_config):754    """Plot proportions for hyperparameter-related data items.755    """756    data = dict()757    data['(a) Training procedure'] = df[758         'Training procedure (clean)'].value_counts().to_dict()759    data['(b) Regularization'] = df[760         'Regularization (clean)'].value_counts().to_dict()761    data['(c) Optimizer'] = df[762         'Optimizer (clean)'].value_counts().to_dict()763    fig, ax = ut.plot_multiple_proportions(764        data, print_count=5, respect_order=['Yes', 'No', 'Other', 'N/M'],765        figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] / 7 * 2))766    767    if save_cfg is not None:768        fname = os.path.join(save_cfg['savepath'], 'hyperparams')769        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)770    return ax771def plot_reproducibility_proportions(df, save_cfg=cfg.saving_config):772    """Plot proportions for reproducibility-related data items.773    """774    df['Code hosted on'] = df['Code hosted on'].replace(np.nan, 'N/M', regex=True)775    df['Limited data'] = df['Limited data'].replace(np.nan, 'N/M', regex=True)776    df['Code available'] = df['Code available'].replace(np.nan, 'N/M', regex=True)777    data = dict()778    data['(a) Dataset availability'] = df[779         'Dataset accessibility'].value_counts().to_dict()780    data['(b) Code availability'] = df[781         'Code hosted on'].value_counts().to_dict()782    data['(c) Type of baseline'] = df[783         'Baseline model type'].value_counts().to_dict()784    df['reproducibility'] = 'Hard'785    df.loc[(df['Code available'] == 'Yes') & 786           (df['Dataset accessibility'] == 'Public'), 'reproducibility'] = 'Easy' 787    df.loc[(df['Code available'] == 'Yes') & 788           (df['Dataset accessibility'] == 'Both'), 'reproducibility'] = 'Medium' 789    df.loc[(df['Code available'] == 'No') & 790           (df['Dataset accessibility'] == 'Private'), 'reproducibility'] = 'Impossible' 791    data['(d) Reproducibility'] = df[792         'reproducibility'].value_counts().to_dict()793    logger.info('Stats on reproducibility - Dataset Accessibility: {}'.format(794        data['(a) Dataset availability']))795    logger.info('Stats on reproducibility - Code Accessibility: {}'.format(796        df['Code available'].value_counts().to_dict()))797    logger.info('Stats on reproducibility - Code Hosted On: {}'.format(798        data['(b) Code availability']))799    logger.info('Stats on reproducibility - Baseline: {}'.format(800        data['(c) Type of baseline']))801    logger.info('Stats on reproducibility - Reproducibility Level: {}'.format(802        data['(d) Reproducibility']))803    logger.info('Stats on reproducibility - Limited data: {}'.format(804        df['Limited data'].value_counts().to_dict()))805    logger.info('Stats on reproducibility - Shared their Code: {}'.format(806        df[df['Code available'] == 'Yes']['Citation'].to_dict()))807    fig, ax = ut.plot_multiple_proportions(808        data, print_count=5, respect_order=['Easy', 'Medium', 'Hard', 'Impossible'],809        figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] * 0.4))810    811    if save_cfg is not None:812        fname = os.path.join(save_cfg['savepath'], 'reproducibility')813        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)814    return ax815def plot_domains_per_year(df, save_cfg=cfg.saving_config):816    """Plot stacked bar graph of domains per year.817    """818    fig, ax = plt.subplots(819        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 4))820    df['Year'] = df['Year'].astype('int32')821    main_domains = ['Epilepsy', 'Sleep', 'BCI', 'Affective', 'Cognitive', 822                    'Improvement of processing tools', 'Generation of data']823    domains_df = df[['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4']]824    df['Main domain'] = [row[row.isin(main_domains)].values[0] 825        if any(row.isin(main_domains)) else 'Others' 826        for ind, row in domains_df.iterrows()]827    df.groupby(['Year', 'Main domain']).size().unstack('Main domain').plot(828        kind='bar', stacked=True, title='', ax=ax)829    ax.set_ylabel('Number of papers')830    ax.set_xlabel('')831    legend = plt.legend()832    for l in legend.get_texts():833        l.set_text(ut.wrap_text(l.get_text(), max_char=14))834    if save_cfg is not None:835        fname = os.path.join(save_cfg['savepath'], 'domains_per_year')836        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)837    return ax838def plot_hardware(df, save_cfg=cfg.saving_config):839    """Plot bar graph showing the hardware used in the study.840    """841    col = 'EEG Hardware'842    hardware_df = ut.split_column_with_multiple_entries(843        df, col, ref_col='Citation', sep=',', lower=False)844    # Remove N/Ms because they make it hard to see anything845    hardware_df = hardware_df[hardware_df[col] != 'N/M']846    847    # Add low cost column848    hardware_df['Low-cost'] = False849    low_cost_devices = ['EPOC (Emotiv)', 'OpenBCI (OpenBCI)', 'Muse (InteraXon)', 850                        'Mindwave Mobile (Neurosky)', 'Mindset (NeuroSky)']851    hardware_df.loc[hardware_df[col].isin(low_cost_devices), 852                    'Low-cost'] = True853    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 2, 854                                    save_cfg['text_height'] / 5 * 2))855    sns.countplot(hue=hardware_df['Low-cost'], y=hardware_df[col], ax=ax,856                  order=hardware_df[col].value_counts().index, 857                  dodge=False)858    # sns.catplot(row=hardware_df['low_cost'], y=hardware_df['hardware'])859    ax.set_xlabel('Number of papers')860    ax.set_ylabel('')861    plt.tight_layout()862    if save_cfg is not None:863        fname = os.path.join(save_cfg['savepath'], 'hardware')864        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)865    return ax866def plot_architectures(df, save_cfg=cfg.saving_config):867    """Plot bar graph showing the architectures used in the study.868    """869    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 3, 870                                    save_cfg['text_width'] / 3))871    colors = sns.color_palette()872    counts = df['Architecture (clean)'].value_counts()873    _, _, pct = ax.pie(counts.values, labels=counts.index, autopct='%1.1f%%',874           wedgeprops=dict(width=0.3, edgecolor='w'), colors=colors,875           pctdistance=0.55)876    for i in pct:877        i.set_fontsize(5)878    ax.axis('equal')879    plt.tight_layout()880    if save_cfg is not None:881        fname = os.path.join(save_cfg['savepath'], 'architectures')882        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)883    return ax884    885def plot_architectures_per_year(df, save_cfg=cfg.saving_config):886    """Plot stacked bar graph of architectures per year.887    """888    fig, ax = plt.subplots(889        figsize=(save_cfg['text_width'] / 3 * 2, save_cfg['text_width'] / 3))890    colors = sns.color_palette()891    df['Year'] = df['Year'].astype('int32')892    col_name = 'Architecture (clean)'893    df['Arch'] = df[col_name]894    order = df[col_name].value_counts().index895    counts = df.groupby(['Year', 'Arch']).size().unstack('Arch')896    counts = counts[order]897    counts.plot(kind='bar', stacked=True, title='', ax=ax, color=colors)898    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))899    ax.set_ylabel('Number of papers')900    ax.set_xlabel('')901    plt.tight_layout()902    if save_cfg is not None:903        fname = os.path.join(save_cfg['savepath'], 'architectures_per_year')904        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)905    return ax906def plot_architectures_vs_input(df, save_cfg=cfg.saving_config):907    """Plot stacked bar graph of architectures vs input type.908    """909    fig, ax = plt.subplots(910        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 3))911    df['Input'] = df['Features (clean)']912    col_name = 'Architecture (clean)'913    df['Arch'] = df[col_name]914    order = df[col_name].value_counts().index915    counts = df.groupby(['Input', 'Arch']).size().unstack('Input')916    counts = counts.loc[order, :]917    # To reduce the height of the figure, wrap long xticklabels918    counts = counts.rename({'CNN+RNN': 'CNN+\nRNN'}, axis='index')919    counts.plot(kind='bar', stacked=True, title='', ax=ax)920    # ax.legend(loc='upper left', bbox_to_anchor=(1, 1))921    ax.set_ylabel('Number of papers')922    ax.set_xlabel('')923    plt.tight_layout()924    if save_cfg is not None:925        fname = os.path.join(save_cfg['savepath'], 'architectures_vs_input')926        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)927        save_cfg2 = save_cfg.copy()928        save_cfg2['format'] = 'png'929        fig.savefig(fname + '.png', **save_cfg2)930    return ax931def plot_optimizers_per_year(df, save_cfg=cfg.saving_config):932    """Plot stacked bar graph of optimizers per year.933    """934    fig, ax = plt.subplots(935        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 5 * 2))936    df['Input'] = df['Features (clean)']937    col_name = 'Optimizer (clean)'938    df['Opt'] = df[col_name]939    order = df[col_name].value_counts().index940    counts = df.groupby(['Year', 'Opt']).size().unstack('Opt')941    counts = counts[order]942    counts.plot(kind='bar', stacked=True, title='', ax=ax)943    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))944    ax.set_ylabel('Number of papers')945    ax.set_xlabel('')946    plt.tight_layout()947    if save_cfg is not None:948        fname = os.path.join(save_cfg['savepath'], 'optimizers_per_year')949        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)950    return ax951def plot_intra_inter_per_year(df, save_cfg=cfg.saving_config):952    """Plot stacked bar graph of intra-/intersubject studies per year.953    """954    fig, ax = plt.subplots(955        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 4))956    df['Year'] = df['Year'].astype(int)957    col_name = 'Intra/Inter subject'958    order = df[col_name].value_counts().index959    counts = df.groupby(['Year', col_name]).size().unstack(col_name)960    counts = counts[order]961    logger.info('Stats on inter/intra subjects: {}'.format(962        df[col_name].value_counts() / df.shape[0] * 100))963    counts.plot(kind='bar', stacked=True, title='', ax=ax)964    # ax.legend(loc='upper left', bbox_to_anchor=(1, 1))965    ax.set_ylabel('Number of papers')966    ax.set_xlabel('')967    plt.tight_layout()968    if save_cfg is not None:969        fname = os.path.join(save_cfg['savepath'], 'intra_inter_per_year')970        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)971    return ax972def plot_number_layers(df, save_cfg=cfg.saving_config):973    """Plot histogram of number of layers.974    """975    fig, ax = plt.subplots(976        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 3))977    n_layers_df = df['Layers (clean)'].value_counts().reindex(978        [str(i) for i in range(1, 32)] + ['N/M'])979    n_layers_df = n_layers_df.dropna().astype(int)980    from matplotlib.colors import ListedColormap981    cmap = ListedColormap(sns.color_palette(None).as_hex())982    n_layers_df.plot(kind='bar', width=0.8, rot=0, colormap=cmap, ax=ax)983    ax.set_xlabel('Number of layers')984    ax.set_ylabel('Number of papers')985    plt.tight_layout()986    if save_cfg is not None:987        fname = os.path.join(save_cfg['savepath'], 'number_layers')988        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)989        save_cfg2 = save_cfg.copy()990        save_cfg2['format'] = 'png'991        save_cfg2['dpi'] = 300992        fig.savefig(fname + '.png', **save_cfg2)993    return ax   994def plot_number_subjects_by_domain(df, save_cfg=cfg.saving_config):995    """Plot number of subjects in studies by domain.996    """997    # Split values into separate rows and remove invalid values998    col = 'Data - subjects'999    nb_subj_df = ut.split_column_with_multiple_entries(1000        df, col, ref_col='Main domain')1001    nb_subj_df = nb_subj_df.loc[~nb_subj_df[col].isin(['n/m', 'tbd'])]1002    nb_subj_df[col] = nb_subj_df[col].astype(int)1003    nb_subj_df = nb_subj_df.loc[nb_subj_df[col] > 0, :]1004    nb_subj_df['Main domain'] = nb_subj_df['Main domain'].apply(1005        ut.wrap_text, max_char=13)1006    fig, ax = plt.subplots(1007        figsize=(save_cfg['text_width'] / 3 * 2, save_cfg['text_height'] / 3))1008    ax.set(xscale='log', yscale='linear')1009    sns.swarmplot(1010        y='Main domain', x=col, data=nb_subj_df, 1011        ax=ax, size=3, order=nb_subj_df.groupby(['Main domain'])[1012            col].median().sort_values().index)1013    ax.set_xlabel('Number of subjects')1014    ax.set_ylabel('')1015    1016    logger.info('Stats on number of subjects per model: {}'.format(1017        nb_subj_df[col].describe()))1018    plt.tight_layout()1019    if save_cfg is not None:1020        fname = os.path.join(save_cfg['savepath'], 'nb_subject_per_domain')1021        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)1022    return ax 1023def plot_number_channels(df, save_cfg=cfg.saving_config):1024    """Plot histogram of number of channels.1025    """1026    nb_channels_df = ut.split_column_with_multiple_entries(1027        df, 'Nb Channels', ref_col='Citation', sep=';\n', lower=False)1028    nb_channels_df['Nb Channels'] = nb_channels_df['Nb Channels'].astype(int)1029    nb_channels_df = nb_channels_df.loc[nb_channels_df['Nb Channels'] > 0, :]1030    fig, ax = plt.subplots(1031        figsize=(save_cfg['text_width'] / 2, save_cfg['text_height'] / 4))1032    sns.distplot(nb_channels_df['Nb Channels'], kde=False, norm_hist=False, ax=ax)1033    ax.set_xlabel('Number of EEG channels')1034    ax.set_ylabel('Number of papers')1035    logger.info('Stats on number of channels per model: {}'.format(1036        nb_channels_df['Nb Channels'].describe()))1037    plt.tight_layout()1038    if save_cfg is not None:1039        fname = os.path.join(save_cfg['savepath'], 'nb_channels')1040        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)1041    return ax1042def compute_stats_sampling_rate(df):1043    """Compute the statistics for hardware sampling rate.1044    """1045    fs_df = ut.split_column_with_multiple_entries(1046        df, 'Sampling rate', ref_col='Citation', sep=';\n', lower=False)1047    fs_df['Sampling rate'] = fs_df['Sampling rate'].astype(float)1048    fs_df = fs_df.loc[fs_df['Sampling rate'] > 0, :]1049    logger.info('Stats on sampling rate per model: {}'.format(1050        fs_df['Sampling rate'].describe()))1051def plot_cross_validation(df, save_cfg=cfg.saving_config):1052    """Plot bar graph of cross validation approaches.1053    """1054    col = 'Cross validation (clean)'1055    df[col] = df[col].fillna('N/M')1056    cv_df = ut.split_column_with_multiple_entries(1057        df, col, ref_col='Citation', sep=';\n', lower=False)1058    1059    fig, ax = plt.subplots(1060        figsize=(save_cfg['text_width'] / 2, save_cfg['text_height'] / 5))1061    sns.countplot(y=cv_df[col], order=cv_df[col].value_counts().index, ax=ax)1062    ax.set_xlabel('Number of papers')1063    ax.set_ylabel('')1064    1065    plt.tight_layout()1066    if save_cfg is not None:1067        fname = os.path.join(save_cfg['savepath'], 'cross_validation')1068        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)1069    return ax1070def make_dataset_table(df, min_n_articles=2, save_cfg=cfg.saving_config):1071    """Make table that reports most used datasets.1072    Args:1073        df1074    Keyword Args:1075        min_n_articles (int): minimum number of times a dataset must have been1076            used to be listed in the table. If under that number, will appear as1077            'Other' in the table.1078        save_cfg (dict)1079    """1080    def merge_dataset_names(s):1081        if 'bci comp' in s.lower():1082            s = 'BCI Competition'1083        elif 'tuh' in s.lower():1084            s = 'TUH'1085        elif 'mahnob' in s.lower():1086            s = 'MAHNOB'1087        return s1088    col = 'Dataset name'1089    datasets_df = ut.split_column_with_multiple_entries(1090        df, col, ref_col=['Main domain', 'Citation'], sep=';\n', lower=False)1091    # Remove not mentioned and internal recordings, as readers won't be able to 1092    # use these datasets anyway1093    datasets_df = datasets_df.loc[~datasets_df[col].isin(1094        ['N/M', 'Internal Recordings', 'TBD'])]1095    datasets_df['Dataset'] = datasets_df[col].apply(merge_dataset_names).apply(1096        ut.tex_escape)1097    # Replace datasets that were used rarely by 'Other'1098    counts = datasets_df['Dataset'].value_counts()1099    datasets_df.loc[datasets_df['Dataset'].isin(1100        counts[counts < min_n_articles].index), 'Dataset'] = 'Other'1101    # Remove duplicates (due to grouping of Others and BCI Comp)1102    datasets_df = datasets_df.drop(labels=col, axis=1)1103    datasets_df = datasets_df.drop_duplicates()1104    # Group by dataset and order by number of articles1105    dataset_table = datasets_df.groupby(1106        ['Main domain', 'Dataset'], as_index=True)['Citation'].apply(list)1107    dataset_table = pd.concat([dataset_table.apply(len), dataset_table], axis=1)1108    dataset_table.columns = [r'\# articles', 'References']1109    dataset_table = dataset_table.sort_values(1110        by=['Main domain', r'\# articles'], ascending=[True, False])1111    dataset_table['References'] = dataset_table['References'].apply(1112        lambda x: r'\cite{' + ', '.join(x) + '}')1113    with open(os.path.join(save_cfg['table_savepath'], 'dataset_table.tex'), 'w') as f:1114        with pd.option_context("max_colwidth", 1000):1115            f.write(dataset_table.to_latex(escape=False, multicolumn=False))1116def plot_data_quantity(df, save_cfg=cfg.saving_config):1117    """Plot the quantity of data used by domain.1118    """1119    data_df = ut.split_column_with_multiple_entries(1120        df, ['Data - samples', 'Data - time'], ref_col=['Citation', 'Main domain'], 1121        sep=';\n', lower=False)1122    # Remove N/M and TBD1123    col = 'Data - samples'1124    data_df.loc[data_df[col].isin(['N/M', 'TBD', '[TBD]']), col] = np.nan1125    data_df[col] = data_df[col].astype(float)1126    col2 = 'Data - time'1127    data_df.loc[data_df[col2].isin(['N/M', 'TBD', '[TBD]']), col2] = np.nan1128    data_df[col2] = data_df[col2].astype(float)1129    # Wrap main domain text1130    data_df['Main domain'] = data_df['Main domain'].apply(1131        ut.wrap_text, max_char=13)1132    # Extract ratio1133    data_df['data_ratio'] = data_df['Data - samples'] / data_df['Data - time']1134    data_df = data_df.sort_values(['Main domain', 'data_ratio'])1135    # Plot1136    fig, axes = plt.subplots(1137        ncols=3, 1138        figsize=(save_cfg['text_width'], save_cfg['text_height'] / 3))1139    axes[0].set(xscale='log', yscale='linear')1140    sns.swarmplot(y='Main domain', x=col2, data=data_df, ax=axes[0], size=3)1141    axes[0].set_xlabel('Recording time (min)')1142    axes[0].set_ylabel('')1143    max_val = int(np.ceil(np.log10(data_df[col2].max())))1144    axes[0].set_xticks(np.power(10, range(0, max_val + 1)))1145    axes[1].set(xscale='log', yscale='linear')1146    sns.swarmplot(y='Main domain', x=col, data=data_df, ax=axes[1], size=3)1147    axes[1].set_xlabel('Number of examples')1148    axes[1].set_yticklabels('')1149    axes[1].set_ylabel('')1150    min_val = int(np.floor(np.log10(data_df[col].min())))1151    max_val = int(np.ceil(np.log10(data_df[col].max())))1152    axes[1].set_xticks(np.power(10, range(min_val, max_val + 1)))1153    axes[2].set(xscale='log', yscale='linear')1154    sns.swarmplot(y='Main domain', x='data_ratio', data=data_df, ax=axes[2], 1155                  size=3)1156    axes[2].set_xlabel('Ratio (examples/min)')1157    axes[2].set_ylabel('')1158    axes[2].set_yticklabels('')1159    min_val = int(np.floor(np.log10(data_df['data_ratio'].min())))1160    max_val = int(np.ceil(np.log10(data_df['data_ratio'].max())))1161    axes[2].set_xticks(np.power(10, np.arange(min_val, max_val + 1, dtype=float)))1162    plt.tight_layout()1163    if save_cfg is not None:1164        fname = os.path.join(save_cfg['savepath'], 'data_quantity')1165        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)1166    return axes1167def plot_eeg_intro(save_cfg=cfg.saving_config):1168    """Plot a figure that shows basic EEG notions such as epochs and samples.1169    """1170    # Visualization parameters1171    win_len = 1  # in s1172    step = 0.5  # in s1173    first_epoch = 11174    data, t, fs = ut.get_real_eeg_data(start=30, stop=34, chans=[0, 10, 20, 30])1175    t = t - t[0]1176    # Offset data for visualization1177    data -= data.mean(axis=0)1178    max_std = np.max(data.std(axis=0))1179    offsets = np.arange(data.shape[1])[::-1] * 4 * max_std1180    data += offsets1181    rect_y_border = 0.6 * max_std1182    min_y = data.min() - rect_y_border1183    max_y = data.max() + rect_y_border1184    # Make figure1185    fig, ax = plt.subplots(1186        figsize=(save_cfg['text_width'] / 4 * 3, save_cfg['text_height'] / 3))1187    ax.plot(t, data)1188    ax.set_xlabel('Time (s)')1189    ax.set_ylabel(r'Amplitude (e.g., $\mu$V)')1190    ax.set_yticks(offsets)1191    ax.set_yticklabels(['channel {}'.format(i + 1) for i in range(data.shape[1])])1192    ax.spines['top'].set_visible(False)1193    ax.spines['right'].set_visible(False)1194    # Display epochs as dashed line rectangles1195    rect1 = patches.Rectangle((first_epoch, min_y + rect_y_border / 4), 1196                            win_len, max_y - min_y, 1197                            linewidth=1, linestyle='--', edgecolor='k',1198                            facecolor='none')1199    rect2 = patches.Rectangle((first_epoch + step, min_y - rect_y_border / 4), 1200                            win_len, max_y - min_y, 1201                            linewidth=1, linestyle='--', edgecolor='k',1202                            facecolor='none')1203    ax.add_patch(rect1)1204    ax.add_patch(rect2)1205    # Annotate epochs1206    ax.annotate(1207        r'$\bf{Window}$ or $\bf{epoch}$ or $\bf{trial}$' +1208        '\n({:.0f} points in a \n1-s window at {:.0f} Hz)'.format(fs, fs), #fontsize=14, 1209        xy=(first_epoch, min_y), 1210        arrowprops=dict(facecolor='black', shrink=0.05, width=2, headwidth=6),1211        xytext=(0, min_y - 3.5 * max_std),1212        xycoords='data', ha='center', va='top')1213    1214    # Annotate input1215    ax.annotate(r'Neural network input' + '\n'1216        r'$X_i \in \mathbb{R}^{c \times l}$', #fontsize=14,1217        xy=(first_epoch+1.5, min_y),1218        arrowprops=dict(facecolor='black', shrink=0.05, width=2),1219        xytext=(4, min_y - 5.3 * max_std),1220        xycoords='data', ha='right', va='bottom')1221    # Annotate sample1222    special_ind = np.where((t >= 2.4) & (t < 2.5))[0][0]1223    special_point = data[special_ind, 0]1224    ax.plot(t[special_ind], special_point, '.', c='k')1225    ax.annotate(1226        r'$\bf{Point}$ or $\bf{sample}$', #fontsize=14, 1227        xy=(t[special_ind], special_point), 1228        arrowprops=dict(facecolor='black', shrink=0.05, width=2, headwidth=6),1229        xytext=(3, max_y),1230        xycoords='data', ha='left', va='bottom')1231    # Annotate overlap1232    ut.draw_brace(ax, (first_epoch + step, first_epoch + step * 2), 1233            r'0.5-s $\bf{overlap}$' + '\nbetween windows', 1234            beta_factor=300, y_offset=max_y)1235    plt.tight_layout()1236    if save_cfg is not None:1237        fname = os.path.join(save_cfg['savepath'], 'eeg_intro')1238        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)...utils.py
Source:utils.py  
...119        }120        'optimizer_state_dict': xxx121    }122    Args:123        save_cfg(namedtuple): The configuration object of the saving process. It must contain the follows:124            save_cfg.output_dir: the output directory of the model or checkpoint125            save_cfg.low: the lowest value of the metric to save126            save_cfg.cfg_name: the name of the configuration127            save_cfg.dataset_name: the name of the dataset128            save_cfg.model_name: the name of the model129            save_cfg.time_stamp: the time when the engine start130        cfg_name: the name of the configuration file  # will be deprecated in the future131        current_step(int): Which step stores the ckpt or model132        saved_stuff(dict): The saving things incules the model, epoch, loss, optimizer. At least, it contains the model. For example:133            {134                'step': 0,135                'loss':0,136                $model_name: xxxx,137                $optimizer_name: xxx
...save_img.py
Source:save_img.py  
1import glob2import os3import SimpleITK as sitk4import imageio5import matplotlib6import matplotlib.pyplot as plt7import numpy as np8matplotlib.use('Agg')9import torch10def plot_save_img(input_img, pred_img, target_img, save_cfg):11    input_img = input_img.squeeze()12    pred_img = pred_img.squeeze()13    target_img = target_img.squeeze()14    input_img = torch.clamp(input_img, min=0, max=1.)15    pred_img = torch.clamp(pred_img, min=0, max=1.)16    target_img = torch.clamp(target_img, min=0, max=1.)17    if len(input_img) == 3:18        input_img = input_img.squeeze().permute(1, 2, 0).cpu().numpy()19        pred_img = pred_img.squeeze().permute(1, 2, 0).cpu().numpy()20        target_img = target_img.squeeze().permute(1, 2, 0).cpu().numpy()21    else:22        input_img = input_img.squeeze().cpu().numpy()23        pred_img = pred_img.squeeze().cpu().numpy()24        target_img = target_img.squeeze().cpu().numpy()25    fig, ax = plt.subplots(1, 3, figsize=(30, 10))26    fig.suptitle(save_cfg['name'], fontsize=20)27    ax[0].imshow(input_img, cmap='gray')28    ax[0].set_title("Input", fontsize=20)29    ax[0].grid(False)30    ax[0].set_xlabel("PSNR: {:.4f}\nSSIM: {:.4f}\n".format(save_cfg['org_psnr'], save_cfg['org_ssim']), fontsize=20)31    ax[1].imshow(pred_img, cmap='gray')32    ax[1].set_title("Prediction", fontsize=20)33    ax[1].grid(False)34    ax[1].set_xlabel("PSNR: {:.4f}\nSSIM: {:.4f}\n".format(save_cfg['pred_psnr'], save_cfg['pred_ssim']), fontsize=20)35    ax[2].imshow(target_img, cmap='gray')36    ax[2].set_title("Target", fontsize=20)37    ax[2].grid(False)38    fig.savefig(os.path.join(save_cfg['save_path'], 'result_{}.png'.format(save_cfg['name'])))39    plt.close()40def plot_save_ct(input_img, pred_img, target_img, save_cfg):41    input_img = torch.clamp(input_img, min=0, max=1.)42    pred_img = torch.clamp(pred_img, min=0, max=1.)43    target_img = torch.clamp(target_img, min=0, max=1.)44    if len(input_img) == 3:45        input_img = input_img.squeeze().permute(1, 2, 0).cpu().numpy()46        pred_img = pred_img.squeeze().permute(1, 2, 0).cpu().numpy()47        target_img = target_img.squeeze().permute(1, 2, 0).cpu().numpy()48    else:49        input_img = input_img.squeeze().cpu().numpy()50        pred_img = pred_img.squeeze().cpu().numpy()51        target_img = target_img.squeeze().cpu().numpy()52    input_img = trunc_denorm(input_img, trunc_max=240, trunc_min=-160)53    pred_img = trunc_denorm(pred_img, trunc_max=240, trunc_min=-160)54    target_img = trunc_denorm(target_img, trunc_max=240, trunc_min=-160)55    fig, ax = plt.subplots(1, 3, figsize=(30, 10))56    fig.suptitle(save_cfg['name'], fontsize=20)57    ax[0].imshow(input_img, cmap='gray')58    ax[0].set_title("Input", fontsize=20)59    ax[0].grid(False)60    ax[0].set_xlabel("PSNR: {:.4f}\nSSIM: {:.4f}\n".format(save_cfg['org_psnr'], save_cfg['org_ssim']), fontsize=20)61    ax[1].imshow(pred_img, cmap='gray')62    ax[1].set_title("Prediction", fontsize=20)63    ax[1].grid(False)64    ax[1].set_xlabel("PSNR: {:.4f}\nSSIM: {:.4f}\n".format(save_cfg['pred_psnr'], save_cfg['pred_ssim']), fontsize=20)65    ax[2].imshow(target_img, cmap='gray')66    ax[2].set_title("Target", fontsize=20)67    ax[2].grid(False)68    fig.savefig(os.path.join(save_cfg['save_path'], 'plot_{}.png'.format(save_cfg['name'])))69    plt.close()70def save_single_slice(pred_img, save_cfg):71    pred_img = torch.clamp(pred_img, min=0, max=1.)72    if len(pred_img) == 3:73        pred_img = pred_img.squeeze().permute(1, 2, 0).cpu().numpy()74    else:75        pred_img = pred_img.squeeze().cpu().numpy()76    pred_img = trunc_denorm(pred_img) + 102477    pred_img = ((pred_img / 4096) * 65535).astype(np.uint16)78    imageio.imwrite(os.path.join(save_cfg['save_path'], 'single_{}.png'.format(save_cfg['name'])), pred_img)79def trunc_denorm(image, trunc_max=3072.0, trunc_min=-1024.0, norm_range_max=3072.0, norm_range_min=-1024.0):80    image = denormalize(image, norm_range_max, norm_range_min)81    image = trunc(image, trunc_max, trunc_min)82    return image83def denormalize(image, norm_range_max=3072.0, norm_range_min=-1024.0):84    image = image * (norm_range_max - norm_range_min) + norm_range_min85    return image86def trunc(image, trunc_max, trunc_min):87    image[image <= trunc_min] = trunc_min88    image[image >= trunc_max] = trunc_max89    return image90def save_nifti(arrays, save_cfg, params):91    nifti_org = sorted(glob.glob(os.path.join(params.data_dir, params.subject, '*.nii.gz')))92    image_org = sitk.ReadImage(nifti_org[0])93    images = sitk.GetImageFromArray(arrays)94    images.CopyInformation(image_org)...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
