Source code for dms_tools2.plot

"""
===================
plot
===================

Plotting related functions for ``dms_tools2``.

Uses `plotnine <https://plotnine.readthedocs.io/en/stable>`_
and `seaborn <https://seaborn.pydata.org/index.html>`_.
"""


import re
import os
import math
import numbers
import random
import collections

import natsort
import pandas
import numpy
import scipy.stats
import scipy.optimize
from statsmodels.sandbox.stats.multicomp import multipletests

# complicated backend setting: we typically want PDF,
# but this causes problem when loading from iPython / Jupyter
import matplotlib
backend = matplotlib.get_backend()
try:
    matplotlib.use('pdf')
    import matplotlib.pyplot as plt
except:
    matplotlib.use(backend, warn=False, force=True)
    import matplotlib.pyplot as plt

from plotnine import *
# set ggplot theme
theme_set(theme_bw(base_size=12)) 

import seaborn
seaborn.set(context='talk',
            style='white',
            rc={
                'xtick.labelsize':15,
                'ytick.labelsize':15,
                'axes.labelsize':19,
                'legend.fontsize':17,
                'font.family':'sans-serif',
                'font.sans-serif':['DejaVu Sans'],
                }
           )

from dms_tools2 import CODONS, AAS, AAS_WITHSTOP, NTS
import dms_tools2.utils

#: `color-blind safe palette <http://bconnelly.net/2013/10/creating-colorblind-friendly-figures/>`_
#: use by adding to your plots the following
#: `scale_fill_manual(COLOR_BLIND_PALETTE)` or
#: `scale_color_manual(COLOR_BLIND_PALETTE)`.
COLOR_BLIND_PALETTE = ["#000000", "#E69F00", "#56B4E9", "#009E73",
                       "#F0E442", "#0072B2", "#D55E00", "#CC79A7"]

#: `color-blind safe palette <http://bconnelly.net/2013/10/creating-colorblind-friendly-figures/>`_
#: that differs from `COLOR_BLIND_PALETTE` in that first 
#: color is gray rather than black.
COLOR_BLIND_PALETTE_GRAY = ["#999999", "#E69F00", "#56B4E9", "#009E73",
                            "#F0E442", "#0072B2", "#D55E00", "#CC79A7"]


[docs]def breaksAndLabels(xi, x, n): """Get breaks and labels for an axis. Useful when you would like to re-label a numeric x-axis with string labels. Uses `matplotlib.ticker.MaxNLocator` to choose pretty breaks. Args: `xi` (list or array) Integer values actually assigned to axis points. `x` (list) Strings corresponding to each numeric value in `xi`. `n` (int) Approximate number of ticks to use. Returns: The tuple `(breaks, labels)` where `breaks` gives the locations of breaks taken from `xi`, and `labels` is the label for each break. >>> xi = list(range(213)) >>> x = [str(i + 1) for i in xi] >>> (breaks, labels) = breaksAndLabels(xi, x, 5) >>> breaks [0, 50, 100, 150, 200] >>> labels ['1', '51', '101', '151', '201'] """ assert len(xi) == len(x) assert all([isinstance(i, (int, numpy.integer)) for i in xi]), \ "xi not integer values:\n{0}".format(xi) xi = list(xi) assert sorted(set(xi)) == xi, "xi not unique and ordered" breaks = matplotlib.ticker.MaxNLocator(n).tick_values(xi[0], xi[-1]) breaks = [int(i) for i in breaks if xi[0] <= i <= xi[-1]] labels = [x[xi.index(i)] for i in breaks] return (breaks, labels)
[docs]def latexSciNot(xlist): """Converts list of numbers to LaTex scientific notation. Useful for nice axis-tick formatting. Args: `xlist` (list or number) Numbers to format. Returns: List of latex scientific notation formatted strings, or single string if `xlist` is a number. >>> latexSciNot([0, 3, 3120, -0.0000927]) ['$0$', '$3$', '$3.1 \\\\times 10^{3}$', '$-9.3 \\\\times 10^{-5}$'] >>> latexSciNot([0.001, 1, 1000, 1e6]) ['$0.001$', '$1$', '$10^{3}$', '$10^{6}$'] >>> latexSciNot([-0.002, 0.003, 0.000011]) ['$-0.002$', '$0.003$', '$1.1 \\\\times 10^{-5}$'] >>> latexSciNot([-0.1, 0.0, 0.1, 0.2]) ['$-0.1$', '$0$', '$0.1$', '$0.2$'] >>> latexSciNot([0, 1, 2]) ['$0$', '$1$', '$2$'] """ if isinstance(xlist, numbers.Number): isnum = True xlist = [xlist] else: isnum = False formatlist = [] for x in xlist: xf = "{0:.2g}".format(x) if xf[ : 2] == '1e': xf = "$10^{{{0}}}$".format(int(xf[2 : ])) elif xf[ : 3] == '-1e': xf = "$-10^{{{0}}}$".format(int(xf[3 : ])) elif 'e' in xf: (d, exp) = xf.split('e') xf = '${0} \\times 10^{{{1}}}$'.format(d, int(exp)) else: xf = '${0}$'.format(xf) formatlist.append(xf) if isnum: assert len(formatlist) == 1 formatlist = formatlist[0] return formatlist
[docs]def plotReadStats(names, readstatfiles, plotfile): """Plots ``dms2_bcsubamp`` read statistics for a set of samples. Args: `names` (list or series) Names of the samples for which we are plotting statistics. `readstatfiles` (list or series) Names of ``*_readstats.csv`` files created by ``dms2_bcsubamp``. `plotfile` (str) Name of PDF plot file to create. """ assert len(names) == len(readstatfiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' readstats = pandas.concat([pandas.read_csv(f).assign(name=name) for (name, f) in zip(names, readstatfiles)], ignore_index=True) readstats['retained'] = (readstats['total'] - readstats['fail filter'] - readstats['low Q barcode']) readstats_melt = readstats.melt(id_vars='name', value_vars=['retained', 'fail filter', 'low Q barcode'], value_name='number of reads', var_name='read fate') p = (ggplot(readstats_melt) + geom_col(aes(x='name', y='number of reads', fill='read fate'), position='stack') + theme(axis_text_x=element_text(angle=90, vjust=1, hjust=0.5), axis_title_x=element_blank()) + scale_y_continuous(labels=latexSciNot) + scale_fill_manual(COLOR_BLIND_PALETTE) ) p.save(plotfile, height=2.7, width=(1.2 + 0.25 * len(names)), verbose=False, limitsize=False) plt.close()
[docs]def plotBCStats(names, bcstatsfiles, plotfile): """Plots ``dms2_bcsubamp`` barcode statistics for set of samples. Args: `names` (list or series) Names of the samples for which we are plotting statistics. `bcstatsfiles` (list or series) Names of ``*_bcstats.csv`` files created by ``dms2_bcsubamp``. `plotfile` (str) Name of PDF plot file to create. """ assert len(names) == len(bcstatsfiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' bcstats = pandas.concat([pandas.read_csv(f).assign(name=name) for (name, f) in zip(names, bcstatsfiles)], ignore_index=True) bcstats_melt = bcstats.melt(id_vars='name', value_vars=['too few reads', 'not alignable', 'aligned'], value_name='number of barcodes', var_name='barcode fate') p = (ggplot(bcstats_melt) + geom_col(aes(x='name', y='number of barcodes', fill='barcode fate'), position=position_stack(reverse=True)) + theme(axis_text_x=element_text(angle=90, vjust=1, hjust=0.5), axis_title_x=element_blank()) + scale_y_continuous(labels=latexSciNot) + scale_fill_manual(COLOR_BLIND_PALETTE) ) p.save(plotfile, height=2.7, width=(1.2 + 0.25 * len(names)), verbose=False, limitsize=False) plt.close()
[docs]def plotReadsPerBC(names, readsperbcfiles, plotfile, maxreads=10, maxcol=6): """Plots ``dms2_bcsubamp`` reads-per-barcode stats for set of samples. Args: `names` (list or series) Names of samples for which we plot statistics. `readsperbcfiles` (list or series) Names of ``*_readsperbc.csv`` files created by ``dms2_bcsubamp``. `plotfile` (str) Name of PDF plot file to create. `maxreads` (int) For any barcodes with > this many reads, just make a category of >= this. `maxcol` (int) Number of columns in faceted plot. """ assert len(names) == len(readsperbcfiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' # read data frames, ensure 'number of reads' from 1 to >= maxreads dfs = [] for (name, f) in zip(names, readsperbcfiles): df = pandas.read_csv(f) # make 'number of reads' maxreads hold number >= maxreads barcodes n_ge = df[df['number of reads'] >= maxreads]['number of reads'].sum() df = df.append(pandas.DataFrame({'number of reads':[maxreads], 'number of barcodes':[n_ge]})) for nreads in range(1, maxreads): if nreads not in df['number of reads']: df.append(pandas.DataFrame({'number of reads':[nreads], 'number of barcodes':[0]})) df = df[df['number of reads'] <= maxreads] df = df.assign(name=name) dfs.append(df) df = pandas.concat(dfs, ignore_index=True) # make name a category to preserve order df['name'] = df['name'].astype( pandas.api.types.CategoricalDtype(categories=names)) ncol = min(maxcol, len(names)) nrow = math.ceil(len(names) / float(ncol)) p = (ggplot(df) + geom_col(aes(x='number of reads', y='number of barcodes'), position='stack') + scale_x_continuous(breaks=[1, maxreads // 2, maxreads], labels=['$1$', '${0}$'.format(maxreads // 2), r'$\geq {0}$'.format(maxreads)]) + scale_y_continuous(labels=latexSciNot) + facet_wrap('~name', ncol=ncol) + theme(figure_size=(1.9 * (0.8 + ncol), 1.3 * (0.4 + nrow))) ) p.save(plotfile, verbose=False, limitsize=False) plt.close()
[docs]def plotDepth(names, countsfiles, plotfile, maxcol=4, charlist=CODONS): """Plot sequencing depth along primary sequence. Args: `names` (list or series) Names of samples for which we plot statistics. `countsfiles` (list or series) Files containing character counts at each site. Should have column named `site` and a column for each character in `charlist`. `plotfile` (str) Name of created PDF plot file containing count depth. `maxcol` (int) Number of columns in faceted plot. `charlist` (list) Characters contained in `countsfiles`. For instance, list of codons or amino acids. """ assert len(names) == len(countsfiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' counts = pandas.concat( [pandas.read_csv(f) .assign(name=name) .assign(ncounts=lambda x: x[charlist].sum(axis=1)) .rename(columns={'ncounts':'number of counts'}) for (name, f) in zip(names, countsfiles)], ignore_index=True) ncol = min(maxcol, len(names)) nrow = math.ceil(len(names) / float(ncol)) # make name a category to preserve order counts['name'] = counts['name'].astype( pandas.api.types.CategoricalDtype(categories=names)) p = (ggplot(counts, aes(x='site', y='number of counts')) + geom_step(size=0.4) + scale_y_continuous(labels=latexSciNot, limits=(0, counts['number of counts'].max())) + scale_x_continuous(limits=(counts['site'].min(), counts['site'].max())) + facet_wrap('~name', ncol=ncol) + theme(figure_size=(2.25 * (0.6 + ncol), 1.3 * (0.3 + nrow))) ) p.save(plotfile, verbose=False, limitsize=False) plt.close()
[docs]def plotMutFreq(names, countsfiles, plotfile, maxcol=4): """Plot mutation frequency along primary sequence. Args: `names` (list or series) Names of samples for which we plot statistics. `countsfiles` (list or series) ``*_codoncounts.csv`` files of type created by ``dms2_bcsubamp``. `plotfile` (str) Name of created PDF plot file. `maxcol` (int) Number of columns in faceted plot. """ assert len(names) == len(countsfiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' counts = pandas.concat([dms_tools2.utils.annotateCodonCounts(f).assign( name=name) for (name, f) in zip(names, countsfiles)], ignore_index=True).rename(columns={'mutfreq':'mutation frequency'}) ncol = min(maxcol, len(names)) nrow = math.ceil(len(names) / float(ncol)) # make name a category to preserve order counts['name'] = counts['name'].astype( pandas.api.types.CategoricalDtype(categories=names)) p = (ggplot(counts, aes(x='site', y='mutation frequency')) + geom_step(size=0.4) + scale_y_continuous(labels=latexSciNot, limits=(0, counts['mutation frequency'].max())) + scale_x_continuous(limits=(counts['site'].min(), counts['site'].max())) + facet_wrap('~name', ncol=ncol) + theme(figure_size=(2.25 * (0.6 + ncol), 1.3 * (0.3 + nrow))) ) p.save(plotfile, verbose=False, limitsize=False) plt.close()
[docs]def plotCumulMutCounts(names, countsfiles, plotfile, chartype, nmax=15, maxcol=4): """Plot fraction of mutations seen <= some number of times. For each set of counts in `countsfiles`, plot the fraction of mutations seen greater than or equal to some number of times. This is essentially a cumulative fraction plot. Args: `names` (list or series) Names of samples for which we plot statistics. `countsfiles` (list or series) ``*_codoncounts.csv`` files of type created by ``dms2_bcsubamp``. `plotfile` (str) Name of created PDF plot file. `chartype` (str) The type of character in `countsfiles`. - `codon` `nmax` (int) Plot out to this number of mutation occurrences. `maxcol` (int) Number of columns in faceted plot. """ assert len(names) == len(countsfiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' counts = pandas.concat([pandas.read_csv(f).assign(name=name) for (name, f) in zip(names, countsfiles)], ignore_index=True) if chartype != 'codon': raise ValueError("invalid chartype of {0}".format(chartype)) codoncounts = pandas.concat([pandas.read_csv(f).assign(name=name) for (name, f) in zip(names, countsfiles)], ignore_index=True).assign(character='codons') assert set(CODONS) <= set(codoncounts.columns) codonmelt = codoncounts.melt(id_vars=['name', 'wildtype', 'character'], value_vars=CODONS, value_name='counts', var_name='codon') codonmelt = codonmelt[codonmelt['codon'] != codonmelt['wildtype']] aacounts = pandas.concat([dms_tools2.utils.codonToAACounts( pandas.read_csv(f)).assign(name=name) for (name, f) in zip(names, countsfiles)], ignore_index=True).assign(character='amino acids') assert set(AAS_WITHSTOP) <= set(aacounts.columns) aamelt = aacounts.melt(id_vars=['name', 'character', 'wildtype'], value_vars=AAS_WITHSTOP, value_name='counts', var_name='aa') aamelt = aamelt[aamelt['aa'] != aamelt['wildtype']] df = pandas.concat([codonmelt, aamelt], ignore_index=True, sort=True) # make name a category to preserve order df['name'] = df['name'].astype( pandas.api.types.CategoricalDtype(categories=names)) ncol = min(maxcol, len(names)) nrow = math.ceil(len(names) / float(ncol)) p = (ggplot(df, aes('counts', color='character', linestyle='character')) + stat_ecdf(geom='step', size=1) + coord_cartesian(xlim=(0, nmax)) + facet_wrap('~name', ncol=ncol) + theme(figure_size=(2.25 * (0.6 + ncol), 1.3 * (0.5 + nrow)), legend_position='top', legend_direction='horizontal') + labs(color="") + guides(color=guide_legend(title_position='left')) + ylab(r'fraction $\leq$ this many counts') + scale_color_manual(COLOR_BLIND_PALETTE) ) p.save(plotfile, verbose=False, limitsize=False) plt.close()
[docs]def plotCodonMutTypes(names, countsfiles, plotfile, classification='aachange', csvfile=None): """Plot average frequency codon mutation types. The averages are determined by summing counts for all sites. Args: `names` (list or series) Names of samples for which we plot statistics. `countsfiles` (list or series) ``*_codoncounts.csv`` files of type created by ``dms2_bcsubamp``. `plotfile` (str) Name of created PDF plot file. `classification` (str) The method used to classify the mutation types. Can be: `aachange` : stop, synonymous, nonsynonymous `n_ntchanges` : number of nucleotide changes per codon `singlentchanges` : nucleotide change in 1-nt mutations `csvfile` (str or `None`) `None` or name of CSV file to which numerical data are written. """ assert len(names) == len(countsfiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' counts = pandas.concat([dms_tools2.utils.annotateCodonCounts(f).assign( name=name) for (name, f) in zip(names, countsfiles)], ignore_index=True) # make name a category to preserve order counts['name'] = counts['name'].astype( pandas.api.types.CategoricalDtype(categories=names)) if classification == 'aachange': muttypes = {'stop':'nstop', 'synonymous':'nsyn', 'nonsynonymous':'nnonsyn'} elif classification == 'n_ntchanges': muttypes = dict([('{0} nucleotide'.format(n), 'n{0}nt'.format(n)) for n in [1, 2, 3]]) elif classification == 'singlentchanges': muttypes = dict([(ntchange, ntchange) for ntchange in [ '{0}to{1}'.format(nt1, nt2) for nt1 in NTS for nt2 in NTS if nt1 != nt2]]) else: raise ValueError("Invalid classification {0}".format(classification)) df = (counts[list(muttypes.values()) + ['ncounts', 'name']] .groupby('name', as_index=False) .sum(axis=1) .assign(ncounts=lambda x: x['ncounts'].astype('float')) ) for (newcol, n) in muttypes.items(): df[newcol] = (df[n] / df['ncounts']).fillna(0) if csvfile: df[['name'] + list(muttypes.keys())].to_csv(csvfile, index=False) df = df.melt(id_vars='name', var_name='mutation type', value_vars=list(muttypes.keys()), value_name='per-codon frequency') p = (ggplot(df) + geom_col(aes(x='name', y='per-codon frequency', fill='mutation type'), position='stack') + theme(axis_text_x=element_text(angle=90, vjust=1, hjust=0.5), axis_title_x=element_blank()) + scale_y_continuous(labels=latexSciNot) ) if len(muttypes) <= len(COLOR_BLIND_PALETTE): p = p + scale_fill_manual(COLOR_BLIND_PALETTE) else: p = p + guides(fill=guide_legend(ncol=2)) p.save(plotfile, height=2.7, width=(1.2 + 0.25 * len(names)), verbose=False, limitsize=False) plt.close()
[docs]def plotCorrMatrix(names, infiles, plotfile, datatype, trim_unshared=True, title='', colors='black', contour=False, ncontours=10): """Plots correlations among replicates. Args: `names` (list or series) Names of samples for which we plot statistics. `infiles` (list or series) CSV files containing data. Format depends on `datatype`. `plotfile` (str) Name of created PDF plot file. `datatype` (str) Type of data for which we are plotting correlations: - `prefs`: in format returned by ``dms2_prefs`` - `mutdiffsel`: mutdiffsel from ``dms2_diffsel`` - `abs_diffsel`: sitediffsel from ``dms2_diffsel`` - `positive_diffsel`: sitediffsel from ``dms2_diffsel`` - `max_diffsel`: sitediffsel from ``dms2_diffsel`` - `mutfracsurvive`: from ``dms2_fracsurvive`` `trim_unshared` (bool) What if files in `infiles` don't have same sites / mutations? If `True`, trim unshared one and just analyze ones shared among all files. If `False`, raise an error. `title` (str) Title to place above plot. `colors` (str or list) Color(s) to color scatter points. If a string, should specify one color for all plots. Otherwise should be list of length `len(names) * (len(names) - 1) // 2` giving lists of colors for plots from top to bottom, left to right. `contour` (bool) Show contour lines from KDE rather than points. `ncontours` (int) Number of contour lines if using `contour`. """ assert len(names) == len(infiles) == len(set(names)) assert os.path.splitext(plotfile)[1].lower() == '.pdf' if datatype == 'prefs': # read prefs into dataframe, ensuring all have same characters prefs = [pandas.read_csv(f).assign(name=name) for (name, f) in zip(names, infiles)] chars = set(prefs[0].columns) sites = set(prefs[0]['site'].values) for p in prefs: unsharedchars = chars.symmetric_difference(set(p.columns)) if unsharedchars: raise ValueError("infiles don't have same characters: {0}" .format(unsharedchars)) unsharedsites = sites.symmetric_difference(set(p['site'])) if trim_unshared: sites -= unsharedsites elif unsharedsites: raise ValueError("infiles don't have same sites: {0}". format(unsharedsites)) assert {'site', 'name'} < chars chars = list(chars - {'site', 'name'}) # get measurements for each replicate in its own column df = (pandas.concat(prefs, ignore_index=True) .query('site in @sites') # only keep shared sites .melt(id_vars=['name', 'site'], var_name='char') .pivot_table(index=['site', 'char'], columns='name') ) df.columns = df.columns.get_level_values(1) elif datatype in ['mutdiffsel', 'mutfracsurvive']: mut_df = [pandas.read_csv(f) .assign(name=name) .assign(mutname=lambda x: x.wildtype + x.site.map(str) + x.mutation) .sort_values('mutname') [['name', 'mutname', datatype]] for (name, f) in zip(names, infiles)] muts = set(mut_df[0]['mutname'].values) for m in mut_df: unsharedmuts = muts.symmetric_difference(set(m['mutname'])) if trim_unshared: muts -= unsharedmuts elif unsharedmuts: raise ValueError("infiles don't have same muts: {0}". format(unsharedmuts)) df = (pandas.concat(mut_df, ignore_index=True) .query('mutname in @muts') # only keep shared muts .pivot_table(index='mutname', columns='name') .dropna() ) df.columns = df.columns.get_level_values(1) elif datatype in ['abs_diffsel', 'positive_diffsel', 'max_diffsel', 'avgfracsurvive', 'maxfracsurvive']: site_df = [pandas.read_csv(f) .assign(name=name) .sort_values('site') [['name', 'site', datatype]] for (name, f) in zip(names, infiles)] sites = set(site_df[0]['site'].values) for s in site_df: unsharedsites = sites.symmetric_difference(set(s['site'])) if trim_unshared: sites -= unsharedsites elif unsharedsites: raise ValueError("infiles don't have same sites: {0}". format(unsharedsites)) df = (pandas.concat(site_df, ignore_index=True) .query('site in @sites') # only keep shared sites .pivot_table(index='site', columns='name') .dropna() ) df.columns = df.columns.get_level_values(1) else: raise ValueError("Invalid datatype {0}".format(datatype)) ncolors = len(names) * (len(names) - 1) // 2 if isinstance(colors, str): colors = [colors] * ncolors else: assert len(colors) == ncolors, "not {0} colors".format(ncolors) def corrfunc(x, y, contour, **kws): r, _ = scipy.stats.pearsonr(x, y) ax = plt.gca() ax.annotate('R = {0:.2f}'.format(r), xy=(0.05, 0.9), xycoords=ax.transAxes, fontsize=19, fontstyle='oblique') color = colors.pop(0) if contour: seaborn.kdeplot(x, y, shade=True, n_levels=ncontours) else: plt.scatter(x, y, s=22, alpha=0.35, color=color, marker='o', edgecolor='none', rasterized=True) # map lower / upper / diagonal as here: # https://stackoverflow.com/a/30942817 for plot p = seaborn.PairGrid(df) p.map_lower(corrfunc, colors=colors, contour=contour) if datatype == 'prefs': p.set( xlim=(0, 1), ylim=(0, 1), xticks=[0, 0.5, 1], yticks=[0, 0.5, 1], xticklabels=['0', '0.5', '1'], yticklabels=['0', '0.5', '1'], ) # hide upper, diag: https://stackoverflow.com/a/34091733 for (i, j) in zip(*numpy.triu_indices_from(p.axes, 1)): p.axes[i, j].set_visible(False) for (i, j) in zip(*numpy.diag_indices_from(p.axes)): p.axes[i, j].set_visible(False) elif datatype in ['mutdiffsel', 'abs_diffsel', 'positive_diffsel', 'max_diffsel', 'mutfracsurvive', 'avgfracsurvive', 'maxfracsurvive']: p.map_diag(seaborn.distplot, color='black', kde=True, hist=False) (lowlim, highlim) = (df.values.min(), df.values.max()) highlim += (highlim - lowlim) * 0.05 lowlim -= (highlim - lowlim) * 0.05 p.set(xlim=(lowlim, highlim), ylim=(lowlim, highlim)) # hide upper for (i, j) in zip(*numpy.triu_indices_from(p.axes, 1)): p.axes[i, j].set_visible(False) else: raise ValueError("invalid datatype") p.map_upper(seaborn.kdeplot, n_levels=20, cmap='Blues_d') if title: # following here: https://stackoverflow.com/a/29814281 p.fig.suptitle(title) p.savefig(plotfile) plt.close()
[docs]def plotSiteDiffSel(names, diffselfiles, plotfile, diffseltype, maxcol=2, white_bg=False, highlighted_sites=[]): """Plot site diffsel or fracsurvive along sequence. Despite the function name, this function can be used to plot either differential selection or fraction surviving. Args: `names` (list or series) Names of samples for which we plot statistics. `diffselfiles` (list or series) ``*sitediffsel.csv`` files from ``dms2_diffsel`` or ``*sitefracsurvive.csv`` files from ``dms2_fracsurvive``. `plotfile` (str) Name of created PDF plot file. `diffseltype` (str) Type of diffsel or fracsurvive to plot: - `positive`: positive sitediffsel - `total`: positive and negative sitediffsel - `max`: maximum mutdiffsel - `minmax`: minimum and maximum mutdiffsel - `avgfracsurvive`: total site fracsurvive - `maxfracsurvive`: max mutfracsurvive at site `maxcol` (int) Number of columns in faceted plot. `white_bg` (bool) Plots will have a white background with limited other formatting. `highlighted_sites` (list) Highlight sites of interest (passed in string format) in grey. """ assert len(names) == len(diffselfiles) == len(set(names)) > 0 assert os.path.splitext(plotfile)[1].lower() == '.pdf' diffsels = [pandas.read_csv(f).assign(name=name) for (name, f) in zip(names, diffselfiles)] assert all([set(diffsels[0]['site']) == set(df['site']) for df in diffsels]), "diffselfiles not all for same sites" diffsel = pandas.concat(diffsels, ignore_index=True) ylabel = 'differential selection' if diffseltype == 'positive': rename = {'positive_diffsel':'above'} elif diffseltype == 'total': rename = {'positive_diffsel':'above', 'negative_diffsel':'below'} elif diffseltype == 'max': rename = {'max_diffsel':'above'} elif diffseltype == 'minmax': rename = {'max_diffsel':'above', 'min_diffsel':'below'} elif diffseltype in ['avgfracsurvive', 'maxfracsurvive']: ylabel = 'fraction surviving' rename = {diffseltype:'above'} else: raise ValueError("invalid diffseltype {0}".format(diffseltype)) diffsel = (diffsel.rename(columns=rename) .melt(id_vars=['site', 'name'], value_vars=list(rename.values()), value_name='diffsel', var_name='direction') ) y_lim = diffsel['diffsel'].max() #get max value used to plot the overlay # natural sort by site: https://stackoverflow.com/a/29582718 diffsel = diffsel.reindex(index=natsort.order_by_index( diffsel.index, natsort.index_realsorted(diffsel.site))) # now some manipulations to make site str while siteindex is int diffsel['site'] = diffsel['site'].apply(str) diffsel['siteindex'] = pandas.Categorical(diffsel['site'], diffsel['site'].unique()).codes ncol = min(maxcol, len(names)) nrow = math.ceil(len(names) / float(ncol)) # make name a category to preserve order diffsel['name'] = diffsel['name'].astype( pandas.api.types.CategoricalDtype(categories=names)) (xbreaks, xlabels) = breaksAndLabels(diffsel['siteindex'].unique(), diffsel['site'].unique(), n=6) if highlighted_sites is None: highlighted_sites = [] diffsel['highlight'] = diffsel['site'].isin(highlighted_sites) diffsel['highlight'] = numpy.where(diffsel['highlight'] == True, y_lim, 0) if white_bg: p = (ggplot(diffsel, aes(x='siteindex', y='diffsel', color='direction', fill='direction')) + geom_bar(aes(y='highlight'), alpha=0.5, stat="identity", color="#d9d9d9", size=0.3, show_legend=False) + geom_step(size=0.3) + xlab('site') + ylab(ylabel) + scale_x_continuous(breaks=xbreaks, labels=xlabels) + scale_color_manual(COLOR_BLIND_PALETTE) + scale_fill_manual(COLOR_BLIND_PALETTE) + guides(color=False) + theme(panel_background=element_rect(fill='white'), axis_line_x=element_line(color='black'), axis_line_y=element_line(color='black'), panel_grid=element_blank(), panel_border=element_blank(), strip_background=element_blank() ) ) else: p = (ggplot(diffsel, aes(x='siteindex', y='diffsel', color='direction')) + geom_bar(aes(y='highlight'), alpha=0.5, stat="identity", color="#d9d9d9", size=0.3, show_legend=False) + geom_step(size=0.4) + xlab('site') + ylab(ylabel) + scale_x_continuous(breaks=xbreaks, labels=xlabels) + scale_color_manual(COLOR_BLIND_PALETTE) + guides(color=False) ) if not ((len(names) == 1) and ((not names[0]) or names[0].isspace())): p = p + facet_wrap('~name', ncol=ncol) p = p + theme(figure_size=(4.6 * (0.3 + ncol), 1.9 * (0.2 + nrow))) p.save(plotfile, verbose=False, limitsize=False) plt.close()
[docs]def plotFacetedNeutCurves( neutdata, plotfile, xlabel, ylabel, maxcol=3): """Faceted neutralization curves with points and fit line. Args: `neutdata` (pandas DataFrame) Should have the following columns: `concentration`, `sample`, `fit`, `points`. The plot is faceted on `sample`. The line smoothly connects all points in column `fit`, and points are drawn anywhere that `points` is not `NaN`. `plotfile` (str) Name of created plot. `xlabel` (str) x-axis label `ylabel` (str) y-axis label `maxcol` (int) Number of columns in facets. """ cols = {'concentration', 'sample', 'fit', 'points'} assert set(neutdata.columns) >= cols, ("missing cols:\n" "required: {0}\nactual: {1}".format(cols, neutdata.columns)) # make sample a category to preserve order neutdata = neutdata.copy() samples = neutdata['sample'].unique() neutdata['sample'] = neutdata['sample'].astype( pandas.api.types.CategoricalDtype(categories=samples)) ncol = min(maxcol, len(samples)) nrow = math.ceil(len(samples) / float(ncol)) ymin = min(neutdata['fit'].min(), neutdata['points'].min(), 0) ymax = max(neutdata['fit'].max(), neutdata['points'].max(), 0) p = (ggplot(neutdata) + geom_point(aes(x='concentration', y='points')) + geom_line(aes(x='concentration', y='fit')) + scale_x_log10(labels=latexSciNot) + scale_y_continuous(limits=(ymin, ymax)) + xlab(xlabel) + ylab(ylabel) + facet_wrap('~sample', ncol=ncol) + theme(figure_size=(2.4 * (0.25 + ncol), 1.45 * (0.25 + nrow))) ) p.save(plotfile, verbose=False) plt.close()
[docs]def findSigSel(df, valcol, plotfile, fdr=0.05, title=None, method='robust_hist', mle_frac_censor=0.005, returnplot=False): """Finds "significant" selection at sites / mutations. Designed for the case where most sites / mutations are not under selection, but a few may be. It tries to find those few that are under selection. It does not use a mechanistic statistical model, but rather fits a gamma distribution. The rationale for a gamma distribution is that it is negative binomial's continuous `analog <http://www.nehalemlabs.net/prototype/blog/2013/12/01/gamma-distribution-approximation-to-the-negative-binomial-distribution/>`_. It then identifies sites that clearly have **larger** values than expected under this distribution. It currently does not identify sites with **smaller** (or more negative) than expected values. Args: `df` (pandas DataFrame) Contains data to analyze `valcol` (string) Column in `df` with values (e.g., `fracsurvive`) `plotfile` (string) Name of file to which we plot fit. `fdr` (float) Find sites that are significant at this `fdr` given fitted distribution. `title` (string or `None`) Title for plot. `method` (str) Specifies how to fit gamma distribution, can have following values: - 'robust_hist': bin the data, then use `robust regression <http://scipy-cookbook.readthedocs.io/items/robust_regression.html>`_ (soft L1 loss) to fit a gamma distribution to the histogram. - 'mle': fit the gamma distribution to the points by maximum likelihood; see also `mle_frac_censor`. `mle_frac_censor` (float) Only meaningful if `method` is 'mle'. In this case, before fitting, censor the data by setting the top `mle_frac_censor` largest values to the `mle_frac_censor` largest value. This shrinks very large outliers that affect fit. `returnplot` (bool) Return the matplotlib figure. Returns: Creates the plot in `plotfile`. Also returns the 3-tuple `(df_sigsel, cutoff, gamma_fit)` where: - `df_sigsel` is copy of `df` with new columns `P`, `Q`, and `sig`. These give P value, Q value, and whether site meets `fdr` cutoff for significance. - `cutoff` is the maximum value that is **not** called significant. Because FDR is a property of a distribution, this value cannot be interpreted as meaning a new data point would be called based on this cutoff, as the cutoff would change. But `cutoff` is useful for plotting to see significant / non-significant. - `gamma_params` is a `numpy.ndarray` that of length 3 that gives the shape, scale, and location parameter of the fit gamma distribution. If `returnplot` is `True`, then return `((df_sigsel, cutoff, gamma_fit), fig)` where `fig` is the matplotlib figure. An example: First, simulate points from a gamma distribution: .. nbplot:: >>> import pandas >>> import scipy >>> from dms_tools2.plot import findSigSel >>> >>> shape_sim = 1.5 >>> scale_sim = 0.005 >>> loc_sim = 0.0 >>> gamma_sim = scipy.stats.gamma(shape_sim, scale=scale_sim, ... loc=loc_sim) >>> nsites = 1000 >>> scipy.random.seed(0) >>> df = pandas.DataFrame.from_dict({ ... 'site':[r for r in range(nsites)], ... 'fracsurvive':gamma_sim.rvs(nsites)}) Now make two sites have "significantly" higher values: .. nbplot:: >>> sigsites = [100, 200] >>> df.loc[sigsites, 'fracsurvive'] = 0.08 Now find the significant sites: .. nbplot:: >>> plotfile = '_findSigSel.png' >>> (df_sigsel, cutoff, gamma_params) = findSigSel( ... df, 'fracsurvive', plotfile, title='example') Make sure the fitted params are close to the ones used to simulate the data: .. nbplot:: >>> numpy.allclose(shape_sim, gamma_params[0], rtol=0.1, atol=1e-3) True >>> numpy.allclose(scale_sim, gamma_params[1], rtol=0.1, atol=1e-3) True >>> numpy.allclose(loc_sim, gamma_params[2], rtol=0.1, atol=1e-3) True Check that we find the correct significant sites: .. nbplot:: >>> set(sigsites) == set(df_sigsel.query('sig').site) True Make sure that sites above cutoff are significant: .. nbplot:: >>> df_sigsel.query('sig').equals(df_sigsel.query('fracsurvive > @cutoff')) True Now repeat, getting and showing the plot: .. nbplot:: >>> _, fig = findSigSel( ... df, 'fracsurvive', plotfile, title='example', returnplot=True) Now use the 'mle' `method`: .. nbplot:: >>> (df_sigsel_mle, _, _), fig_mle = findSigSel( ... df, 'fracsurvive', plotfile, title='example', ... method='mle', returnplot=True) >>> set(df_sigsel_mle.query('sig').site) == set(sigsites) True """ assert valcol in df.columns, "no `valcol` {0}".format(valcol) newcols = {'P', 'Q', 'sig'} assert not (newcols & set(df.columns)), \ "`df` already has {0}".format(newcols) # We fit curves to histogram. First we need to get bins. try: # try with Freedman Diaconis Estimator binedges = numpy.histogram(df[valcol], bins='fd')[1] except ValueError: # fd will fail of lots of identical points binedges = numpy.histogram(df[valcol], bins='doane')[1] # get bin centers bins = (binedges[ : -1] + binedges[1 : ]) / 2 # plot the histogram plt.figure(figsize=(5.5, 4)) (heights, binedges, patches) = plt.hist(df[valcol], bins=binedges, density=True, histtype='stepfilled', color=COLOR_BLIND_PALETTE[2]) # initial guess gives correct mean and variance for # gamma distribution with loc of 0 scale = df[valcol].var() / df[valcol].mean() shape = df[valcol].mean() / scale x0 = numpy.array([shape, scale, 0.0]) if method == 'robust_hist': def _f(x, bins, heights): """Gamma distribution least squares fitting function. Zero when distribution perfectly fits histogram. `x` is `(shape, scale, loc)`. """ return (scipy.stats.gamma.pdf(bins, x[0], scale=x[1], loc=x[2]) - heights) # fit using soft L1 loss for robust regression # http://scipy-cookbook.readthedocs.io/items/robust_regression.html fit = scipy.optimize.least_squares(_f, x0, args=(bins, heights), loss='soft_l1') gamma_params = fit.x gamma_fit = scipy.stats.gamma(fit.x[0], scale=fit.x[1], loc=fit.x[2]) elif method == 'mle': vals = numpy.sort(df[valcol].values) mle_n_censor = round(len(vals) * mle_frac_censor) maxval = vals[-1 - mle_n_censor] vals[vals > maxval] = maxval fit_shape, fit_loc, fit_scale = scipy.stats.gamma.fit(vals, shape, scale=scale, loc=0) gamma_params = numpy.array([fit_shape, fit_scale, fit_loc]) gamma_fit = scipy.stats.gamma(fit_shape, scale=fit_scale, loc=fit_loc) else: raise ValueError(f"invalid `method` {method}") # add fit gamma distribution to plot nfitbins = 500 if nfitbins > len(bins): fitbins = numpy.linspace(bins[0], bins[-1], nfitbins) else: fitbins = bins plt.plot(fitbins, gamma_fit.pdf(fitbins), color=COLOR_BLIND_PALETTE[1]) # compute P and Q values df_sigsel = (df.assign(P=lambda x: gamma_fit.sf(x[valcol])) .assign(Q=lambda x: multipletests(x.P, fdr, 'fdr_bh')[1]) .assign(sig=lambda x: multipletests(x.P, fdr, 'fdr_bh')[0]) ) # compute cutoff cutoff = df_sigsel.query('not sig')[valcol].max() # plot cutoff # find first bin boundary greater than cutoff if (binedges > cutoff).any(): bincutoff = binedges[binedges > cutoff][0] else: bincutoff = binedges[-1] # now annotate plot plt.axvline(bincutoff, color=COLOR_BLIND_PALETTE[3], ls='--', lw=0.75) text_y = 0.95 * plt.ylim()[1] (xmin, xmax) = plt.xlim() if (bincutoff - xmin) < 0.75 * (xmax - xmin): text_x = bincutoff + 0.01 * (xmax - xmin) ha = 'left' else: text_x = bincutoff - 0.01 * (xmax - xmin) ha = 'right' if len(df_sigsel.query('sig')): text = '{0} values\nsignificant\n($>${1})'.format( len(df_sigsel.query('sig')), latexSciNot([cutoff])[0]) else: text = 'no values\nsignificant' plt.text(text_x, text_y, text, horizontalalignment=ha, verticalalignment='top', color=COLOR_BLIND_PALETTE[3], size='small') # put labels on plot plt.xlabel(valcol.replace('_', ' ')) plt.ylabel('density') if title: plt.title(title.replace('_', ' ')) # save plot plt.tight_layout() plt.savefig(plotfile) if returnplot: return ((df_sigsel, cutoff, gamma_params), plt.gcf()) else: plt.close() return (df_sigsel, cutoff, gamma_params)
[docs]def plotColCorrs(df, plotfile, cols, *, lower_filter=None, title=None, shrink_threshold=25): """Plots correlation among columns in pandas Data Frame. Plots distribution of each variable and pairwise correlations. Args: `df` (pandas DataFrame) Data frame with data to plot. `plotfile` (str or `None`) Name of created plot, or `None` if you want plot returned. `cols` (list) List of columns in `df` to plot. `lower_filter` (`None` or str) Can be any string that can passed to the `query` function of `df`. In this case, on the lower diagonal only plot data for which this query is `True`. `title` (`None` or str) Title of plot. `shrink_threshold` (float) See argument of same name to :meth:`hist_bins_intsafe`. Returns: If `plotfile` is a string, makes the plot and does not return anything. If `plotfile` is `None`, returns the plot. """ if not set(cols).issubset(set(df.columns)): raise ValueError("`cols` specifies columns not in `df`") if lower_filter is not None: filter_indices = df.query(lower_filter).index color_all = COLOR_BLIND_PALETTE_GRAY[0] color_filter = COLOR_BLIND_PALETTE_GRAY[1] else: color_all = COLOR_BLIND_PALETTE[0] def hist1d(x, color, **kwargs): """1D histogram for diagonal elements.""" bins=dms_tools2.plot.hist_bins_intsafe(x, shrink_threshold=shrink_threshold) plt.hist(x, color=color_all, bins=bins, **kwargs) if lower_filter: plt.hist(x.ix[filter_indices], color=color_filter, bins=bins, **kwargs) def hist2d(x, y, color, filterdata, **kwargs): """2D histogram for off-diagonal elements.""" bins = [dms_tools2.plot.hist_bins_intsafe(a, shrink_threshold=shrink_threshold) for a in [x, y]] if filterdata: color = color_filter x = x.ix[filter_indices] y = y.ix[filter_indices] else: color = color_all cmap = dms_tools2.plot.from_white_cmap(color) plt.hist2d(x, y, bins=bins, cmap=cmap, **kwargs) g = (dms_tools2.plot.AugmentedPairGrid(df, vars=cols, diag_sharey=False, height=3) .map_diag(hist1d) .map_upper(hist2d, filterdata=False) .map_lower(hist2d, filterdata=(lower_filter is not None)) .ax_lims_clip_outliers() ) if lower_filter is not None: label_order = ['all\n({0})'.format( dms_tools2.plot.latexSciNot(len(df))), '{0}\n({1})'.format(lower_filter, dms_tools2.plot.latexSciNot( len(df.query(lower_filter)))), ] label_data = {lab:plt.Line2D([0], [0], color=c, lw=10, solid_capstyle='butt') for (lab, c) in zip(label_order, [color_all, color_filter]) } g.add_legend(label_data, label_order=label_order, labelspacing=2, handlelength=1.5) if title is not None: g.fig.suptitle(title, va='bottom') if plotfile is None: return g else: g.savefig(plotfile) plt.close()
[docs]def plotRarefactionCurves(df, rarefy_col, plotfile, *, facet_col=None, nrow=1, xlabel='reads', ylabel=None, facet_scales='free'): """Plots rarefaction curves. The rarefaction curves are calculated analytically using :py:mod:`dms_tools2.utils.rarefactionCurve`. Args: `df` (pandas DataFrame) Data frame containing data. In tidy form if faceting. `rarefy_col` (str) Name of column in `df` that contains the variable that we rarify. For instance, these might be strings giving barcodes. `plotfile` (str) Name of created plot. `facet_col` (str or `None`) If not `None`, should be name of a column in `df` that contains a variable we facet in the plot. `nrow` (int) If faceting, the number of rows. `xlabel` (str) X-axis label. `ylabel` (str or `None`) Y-axis label. If `None`, defaults to value of `rarefy_col`. `facet_scales` (str`) Scales for faceting. Can be "free", "free_x", "free_y", or "fixed" Here is an example. First, we simulate two sets of barcodes. For ease of fast simulation, the barcodes are just numbers here. One samples a large set, the other a set a quarter that size with half as many reads: >>> nbc = 40000 >>> bclen = 10 >>> nreads = 200000 >>> barcodes = list(range(nbc)) >>> numpy.random.seed(1) >>> large_set = numpy.random.choice(barcodes, size=nreads) >>> small_set = numpy.random.choice(barcodes[ : nbc // 4], size=nreads // 2) Now we put these in a tidy data frame where one column is named "barcodes" and the other is named "sample": >>> df = pandas.DataFrame({ ... "barcodes":list(small_set) + list(large_set), ... "sample":['small_set'] * len(small_set) + ... ['large_set'] * len(large_set)}) Finally, plot the rarefaction curves: >>> plotfile = '_plotRarefactionCurves.png' >>> plotRarefactionCurves(df, 'barcodes', plotfile, facet_col='sample') Here is the resulting plot: .. image:: _static/_plotRarefactionCurves.png :width: 6in :align: center """ if rarefy_col not in df.columns: raise ValueError("`df` does not have `rarefy_col` {0}" .format(rarefy_col)) if (facet_col is not None) and (facet_col not in df.columns): raise ValueError("`df` does not have `facet_col` {0}" .format(facet_col)) if ylabel is None: ylabel = rarefy_col # get iterator over groups or dummy iterator categories = None if facet_col is not None: if df[facet_col].dtype.name == 'category': categories = df[facet_col].cat.categories df_iterator = df.groupby(facet_col)[rarefy_col] nfacets = len(df[facet_col].unique()) else: df_iterator = ('dummy', df[rarefy_col]) nfacets = 1 d = collections.defaultdict(list) for name, group in df_iterator: xs, ys = dms_tools2.utils.rarefactionCurve(group) assert len(xs) == len(ys) d[ylabel] += ys d[xlabel] += xs d['_facet_var'] += [name] * len(xs) rarefied = pandas.DataFrame(d) if categories is not None: rarefied['_facet_var'] = pandas.Categorical( rarefied['_facet_var'], categories) ident = lambda x: x.astype('int') if all(x.astype('int') == x) else x if rarefied[xlabel].max() >= 1e4: xlabeler = latexSciNot else: xlabeler = ident if rarefied[ylabel].max() >= 1e4: ylabeler = latexSciNot else: ylabeler = ident p = (ggplot(rarefied, aes(xlabel, ylabel)) + geom_line() + xlab(xlabel) + ylab(ylabel) + scale_x_continuous(labels=xlabeler, breaks=lambda x: matplotlib.ticker.MaxNLocator(3) .tick_values(min(x), max(x))) + scale_y_continuous(labels=ylabeler) ) x_panel_spacing = {"free":0.75, "free_x":0.1, "fixed":0.1, "free_y":0.75}[facet_scales] y_panel_spacing = {"free":0.4, "free_x":0.4, "fixed":0.1, "free_y":0.1}[facet_scales] if facet_col is not None: p = p + facet_wrap('~ _facet_var', nrow=nrow, scales=facet_scales) p = p + theme(panel_spacing_x=x_panel_spacing, panel_spacing_y=y_panel_spacing) ncol = math.ceil(nfacets / nrow) p.save(plotfile, height=0.5 + 1.75 * nrow + y_panel_spacing * (nrow - 1), width=(1.25 + 2 * ncol + x_panel_spacing * (ncol - 1)), verbose=False, limitsize=False) plt.close()
[docs]def hist_bins_intsafe(x, method='fd', shrink_threshold=None, maxbins=100): """Histogram bins that work for integer data. You can auto-choose bins using `numpy.histogram`. However, if the data are integer, these bins may be non-integer and so some bins will capture more integers. This function fixes that. Args: `x` (numpy array) The data to bin. `method` (str) The binning method. Can be anything acceptable to `numpy.histogram` as a `bins` argument. `shrink_threshold` (`None` or int) If set to a value other than `None`, apply a heuristic threshold to slow the growth in number of bins if they exceed this number. `maxbins` (int) Maximum number of bins. Returns: The bin edges as returned in the second element of `numpy.histogram`, but adjusted to be of integer width if the data are all integers. Just like `numpy.histogram` for non-int data: >>> numpy.random.seed(1) >>> x = 100 * numpy.random.random(500) >>> bin_edges = numpy.histogram(x, bins='fd')[1] >>> bin_edges_intsafe = hist_bins_intsafe(x)[ : len(bin_edges)] >>> numpy.allclose(bin_edges, bin_edges_intsafe) True >>> numpy.allclose(bin_edges, bin_edges.astype('int')) False But gives integer bins for int data: >>> x = x.astype('int') >>> bin_edges = numpy.histogram(x, bins='fd')[1] >>> bin_edges_intsafe = hist_bins_intsafe(x) >>> numpy.allclose(bin_edges, bin_edges_intsafe) False >>> numpy.allclose(bin_edges, bin_edges.astype('int')) False >>> numpy.allclose(bin_edges_intsafe, bin_edges_intsafe.astype('int')) True """ bin_edges = numpy.histogram(x, bins='fd')[1] if len(bin_edges) > maxbins: bin_edges = numpy.histogram(x, min(maxbins, len(bin_edges) - 1))[1] if shrink_threshold is None: corr = 1 else: assert shrink_threshold > 1 corr = max(1, math.sqrt(len(bin_edges) / float(shrink_threshold))) binwidth = (bin_edges[1] - bin_edges[0]) * corr if (x.astype('int') == x).all(): binwidth = math.ceil(binwidth) bin_edges = numpy.arange(x.min(), x.max() + binwidth, binwidth) return bin_edges
[docs]def from_white_cmap(color): """Get matplotlib color map from white to `color`.""" light = seaborn.set_hls_values(color, l=1) return seaborn.blend_palette([light, color], None, True)
[docs]class AugmentedPairGrid(seaborn.PairGrid): """Augmented version of `seaborn.PairGrid`."""
[docs] def ax_lims_clip_outliers(self, frac_clip=0.001, extend=0.03): """Sets axis limits to clip outliers in data. Useful if there are a few data points far outside the range of most of the data. Args: `frac_clip` (float) Set upper and lower limits so that this fraction of data is outside limits at both ends. Done **before** adding `extend`. `extend` (float) Extend the limits determined by `frac_clip` by this fraction of the data range. """ assert 0 <= frac_clip < 0.5 def _get_lims(s): """Gets limits for data in pandas.Series `s`.""" s_range = s.max() - s.min() if s_range == 0: s_extend = max(extend, extend * s.max()) else: s_extend = s_range * extend s_min = s.quantile(frac_clip) - s_extend s_max = s.quantile(1 - frac_clip) + s_extend return (s_min, s_max) xlims = [_get_lims(self.data[x]) for x in self.x_vars] ylims = [_get_lims(self.data[y]) for y in self.y_vars] for icol, xlim in enumerate(xlims): for irow, ylim in enumerate(ylims): self.axes[irow, icol].set_xlim(*xlim) if icol != irow: self.axes[irow, icol].set_ylim(*ylim) return self
if __name__ == '__main__': import doctest doctest.testmod()