Source code for dmslogo.facet

"""
======
facet
======

Facet multiple plots on the same figure.
"""


import collections
import operator

import matplotlib.pyplot as plt

import numpy

import dmslogo


[docs] def facet_plot( data, *, x_col, show_col, height_per_ax=2.5, gridrow_col=None, gridcol_col=None, draw_line_kwargs=None, draw_logo_kwargs=None, line_titlesuffix="", logo_titlesuffix="", hspace=0.8, wspace=1.1, lmargin=1, rmargin=0.2, tmargin=0.4, bmargin=1.3, share_xlabel=False, share_ylabel=False, share_ylim_across_rows=True, set_ylims=False, ): """Facet together plots of different types on same figure. Useful for combining multiple instances of the plots you could create with :py:mod:`dmslogo.logo.draw_logo` and :py:mod:`dmslogo.line.draw_line`. Args: `data` (pandas DataFrame) The data to plot. `x_col` (str) Column in `data` with x-axis values, as for :py:mod:`dmslogo.logo.draw_logo` and :py:mod:`dmslogo.line.draw_line`. `show_col` (str or `None`) Column in `data` with x-axis values to highlight in line plots and to show in logo plots. `height_per_ax` (float) Height of each axis in the faceted plot. `gridrow_col` (str or `None`) Column in `data` to facet over for rows of plot. `gridcol_col` (str or `None`) Column in data to facet over for columns of plot. `draw_line_kwargs` (dict) All arguments to be passed to :py:mod:`dmslogo.line.draw_line` **except** `x_col`, `show_col`, and `title`. These are passed separately or come from faceting variables. `draw_logo_kwargs` Like `draw_line_kwargs` but for :py:mod:`dmslogo.logo.draw_logo`. `line_titlesuffix` (str or `None`) String suffixed to titles for line plots. `logo_titlesuffix` String suffixed to titles for logo plots. `hspace` (float) Vertical space between axes in same units as `height_per_ax`. `wspace` (float) Horizontal space between axes. `lmargin` (float) Left margin in same units as `height_per_ax`. `rmargin` (float) Right margin in same units as `height_per_ax`. `tmargin` (float) Top margin in same units as `height_per_ax`. `bmargin` (float) Bottom margin in same units as `height_per_ax`. `share_xlabel` (bool) Share the x-labels across the line and logo plots. `share_ylabel` (bool) Share the y-labels across the line and logo plots. `share_ylim_across_rows` (bool) Do we share y-limits across rows? `set_ylims` (`False` or 2-tuple or dict) To set y-limits for all axes, specify the 2-tuple `(ymin, ymax)`. To set y-limits differently for each row, specify a dict keyed by the possible values of `gridrow_col` with the values being 2-tuples `(ymin, ymax)`. Returns: The 2-tuple `fig, axes` where `fig` is the matplotlib Figure and `axes` is a numpy ndarray of the figure axes. `x_col` and `show_col` must have the same unique entries in `data` for all groups in being faceted over. """ if gridrow_col is None: gridrow_col = "_gridrow_col_" if gridrow_col in data.columns: raise ValueError(f"`data` already has column {gridrow_col}") data = data.assign(**{gridrow_col: ""}) if gridcol_col is None: gridcol_col = "_gridcol_col_" if gridcol_col in data.columns: raise ValueError(f"`data` already has column {gridcol_col}") data = data.assign(**{gridcol_col: ""}) cols = [gridrow_col, gridcol_col, x_col] if show_col is not None: cols.append(show_col) for col in cols: if col not in data.columns: raise ValueError(f"no {col} column in `data`") # make sure all groups have same x_col and show_col groups = ( data[cols] .drop_duplicates() .sort_values(x_col) .groupby([gridrow_col, gridcol_col]) ) firstgroupname, firstgroup = list(groups)[0] firstgroup = firstgroup.reset_index(drop=True) for groupname, group in groups: assert len(group), f"empty group {groupname}" group = group.reset_index(drop=True) for col in cols[2:]: if (len(firstgroup[col]) != len(group[col])) or any( firstgroup[col] != group[col] ): raise ValueError( f"different entries for {col} " f"in `data`, differs between {firstgroupname} " f"and {groupname}:\n" f"{firstgroup[col]}\n{group[col]}" ) # determine which draw_funcs are being used draw_funcs = collections.OrderedDict() possible_funcs = [ ("draw_line", draw_line_kwargs, line_titlesuffix), ("draw_logo", draw_logo_kwargs, logo_titlesuffix), ] for name, kwargs, titlesuffix in possible_funcs: if kwargs is not None: for col in ["ax", "title"]: if col in kwargs: raise ValueError(f"{name}_kwargs can't have {col}") if "heightscale" in kwargs: raise ValueError( f"do not set `heightscale` in " f"{name}_kwargs; use `height_per_ax`" ) draw_funcs[name] = { "kwargs": kwargs, "func": getattr(dmslogo, name), "data": data, "titlesuffix": titlesuffix, } if name == "draw_logo" and show_col is not None: draw_funcs[name]["data"] = data.query(show_col) if not len(draw_funcs[name]["data"]): raise ValueError( f"no data for {name}. You passed empty `data` " f"or `show_col` ({show_col}) is all False." ) if len(draw_funcs) < 1: raise ValueError( "set at least one of: " + ", ".join(tup[0] + "_kwargs" for tup in possible_funcs) ) # harmonize top-function arg columns with plotting kwargs for name, name_d in draw_funcs.items(): for colname, col in [("x_col", x_col), ("show_col", show_col)]: if (colname in name_d["kwargs"]) and (name_d["kwargs"][colname] != col): raise ValueError( f"{colname} is in {name}_kwargs; " "should only be specified via the top " f"function-level argument {colname}" ) if (colname != "show_col") or (name == "draw_line"): name_d["kwargs"][colname] = col nrows = len(data[gridrow_col].unique()) nfuncs = len(draw_funcs) ncols_per_func = len(data[gridcol_col].unique()) # get sizes of fig, axis limits of plots for each func fixed_ylims = {"min": {}, "max": {}} # keys 'min' / 'max', then row name for name, name_d in draw_funcs.items(): for (row, _), idata in name_d["data"].groupby([gridrow_col, gridcol_col]): fig, ax = name_d["func"](idata, **name_d["kwargs"]) fig.tight_layout() width = fig.get_size_inches()[0] xticks = list(ax.get_xticks()) xticklabels = [t.get_text() for t in ax.get_xticklabels()] ymin, ymax = ax.get_ylim() plt.close(fig) for key, val in [ ("width", width), ("xticks", xticks), ("xticklabels", xticklabels), ]: if key not in name_d: name_d[key] = val elif name_d[key] != val: raise ValueError( f"inconsistent {key} for {name}: " f"{val} {name_d[key]}" ) if "ymin" not in name_d: name_d["ymin"] = ymin name_d["ymin"] = min(name_d["ymin"], ymin) if "ymax" not in name_d: name_d["ymax"] = ymax name_d["ymax"] = max(name_d["ymax"], ymax) for ltype, lfunc, val in [("min", min, ymin), ("max", max, ymax)]: if row not in fixed_ylims[ltype]: fixed_ylims[ltype][row] = val else: fixed_ylims[ltype][row] = lfunc(val, fixed_ylims[ltype][row]) if share_ylim_across_rows: for ltype, lfunc in [("min", min), ("max", max)]: lim = lfunc(fixed_ylims[ltype].values()) for row in list(fixed_ylims[ltype].keys()): fixed_ylims[ltype][row] = lim if set_ylims: if isinstance(set_ylims, tuple): rows = set(fixed_ylims["min"].keys()) set_ylims = ( {row: set_ylims[0] for row in rows}, {row: set_ylims[1] for row in rows}, ) elif isinstance(set_ylims, dict): set_ylims = ( {row: ymin for row, (ymin, _) in set_ylims.items()}, {row: ymax for row, (_, ymax) in set_ylims.items()}, ) else: raise ValueError(f"invalid `set_ylims`: {set_ylims}") for ltype, op, setlim in [ ("min", operator.lt, set_ylims[0]), ("max", operator.gt, set_ylims[1]), ]: for row, lim in list(fixed_ylims[ltype].items()): if op(lim, setlim[row]): raise ValueError( f"invalid y{ltype} in `set_ylims`, must " f"be at least {lim}." ) else: fixed_ylims[ltype][row] = setlim[row] # make figure fig, axes = plt.subplots( nrows, nfuncs * ncols_per_func, squeeze=False, gridspec_kw={ "width_ratios": [ d["width"] for d in draw_funcs.values() for _ in range(ncols_per_func) ] }, ) width = ( lmargin + rmargin + ncols_per_func * sum(d["width"] for d in draw_funcs.values()) ) hparams = height_params(nrows, height_per_ax, hspace, tmargin, bmargin) fig.set_size_inches(width, hparams["height"]) fig.subplots_adjust( wspace=wspace * nfuncs * ncols_per_func / width, hspace=hparams["hspace"], top=hparams["top"], bottom=hparams["bottom"], right=1 - rmargin / width, left=lmargin / width, ) # Add plots, adjust to tight layout axes_has_plot = _draw_facet_plots( axes, draw_funcs, ncols_per_func, gridrow_col, gridcol_col, nrows, fixed_ylims ) fig.canvas.draw() # only show one label for aligned axes assert axes.shape == (nrows, nfuncs * ncols_per_func) if share_xlabel: _axes_to_centered_fig_label( fig, [ axes[nrows - 1, ifunc * ncols_per_func + icol] for icol in range(ncols_per_func) for ifunc in range(nfuncs) ], "x", ) else: for ifunc in range(nfuncs): _axes_to_centered_fig_label( fig, [ axes[nrows - 1, ifunc * ncols_per_func + icol] for icol in range(ncols_per_func) ], "x", ) if share_ylabel: _axes_to_centered_fig_label( fig, [ axes[irow, ifunc * ncols_per_func] for irow in range(nrows) for ifunc in range(nfuncs) ], "y", ) else: for ifunc in range(nfuncs): _axes_to_centered_fig_label( fig, [axes[irow, ifunc * ncols_per_func] for irow in range(nrows)], "y" ) # hide empty axes assert axes.shape == axes_has_plot.shape for ax, has_plot in zip(axes.ravel(), axes_has_plot.ravel()): if not has_plot: ax.clear() ax.set_axis_off() return fig, axes
[docs] def height_params(nrows, height_per_ax, hspace, tmargin, bmargin): """Values to set vertical figure subplots parameters. Args: `nrow` (int) Number of rows `height_per_ax`, `hspace`, `tmargin`, `bmargin` Same meaning as for :func:`facet_plot`. Returns: A dict keyed by: - `height`: height of figure; - `hspace`, `top`, `bottom`: values for `subplots_adjust`. """ height = nrows * height_per_ax + tmargin + bmargin return { "height": height, "top": 1 - tmargin / height, "bottom": bmargin / height, "hspace": hspace / height_per_ax, }
def _axes_to_centered_fig_label(fig, axlist, axistype): """Replace axes labels with one figure label. Args: `fig` (matplotlib Figure) The figure. `axlist` (list) List of the Axes with the labels to replace. `axistype` (str) Either 'x' or 'y'. """ if axistype not in ["x", "y"]: raise ValueError(f"invalid `axistype` of {axistype}") if not len(axlist): raise ValueError("empty `axlist`") loclists = collections.defaultdict(list) label_props = collections.defaultdict(set) for ax in axlist: axis = getattr(ax, axistype + "axis") label = axis.get_label() bbox = label.get_window_extent().transformed( transform=fig.transFigure.inverted() ) for loc in ["x0", "x1", "y0", "y1"]: loclists[loc].append(getattr(bbox, loc)) label_props["fontproperties"].add(label.get_fontproperties()) label_props["rotation"].add(label.get_rotation()) label_props["text"].add(label.get_text()) label.set_visible(False) for propname, propset in list(label_props.items()): if len(propset) != 1: raise ValueError(f"multiple {propname} among axes: " f"{propset}") label_props[propname] = list(propset)[0] if axistype == "x": x = (min(loclists["x0"]) + max(loclists["x1"])) / 2 y = (min(loclists["y0"]) + min(loclists["y1"])) / 2 elif axistype == "y": x = (min(loclists["x0"]) + min(loclists["x1"])) / 2 y = (min(loclists["y0"]) + max(loclists["y1"])) / 2 fig.text( x, y, label_props["text"], ha="center", va="center", rotation=label_props["rotation"], fontproperties=label_props["fontproperties"], ) def _draw_facet_plots( axes, draw_funcs, ncols_per_func, gridrow_col, gridcol_col, nrows, fixed_ylims ): """Draws plots on axes for :func:`facet_plots`. Returns array of same shape as axes indicating whether a plot was drawn on each axis. If the value in this array is `False`, then the plot should be hidden by the calling function as it may have incorrect data purely for the purpose of axes formatting. """ axes_has_plot = numpy.ndarray(axes.shape, dtype="bool") for ifunc, func_d in enumerate(draw_funcs.values()): groups = [ (row_name, row_data) for row_name, row_data in func_d["data"].groupby(gridrow_col) if len(row_data) ] assert len(groups) == nrows for irow, (row_name, row_data) in enumerate(groups): assert 0 <= irow < nrows, ( f"irow out of bound\nirow: {irow}\nnrows: {nrows}\n" f"row_name: {row_name}\nngroups: {len(groups)}" ) row_groups = [ (col_name, col_data) for col_name, col_data in row_data.groupby(gridcol_col) if len(col_data) ] assert len(row_groups) <= ncols_per_func if len(row_groups) == 0: raise ValueError(f"no data for row {row_name}") for icol in range(ncols_per_func): colnum = ifunc * ncols_per_func + icol ax = axes[irow, colnum] if icol < len(row_groups): col_name, col_data = row_groups[icol] axes_has_plot[irow, colnum] = True title = ( row_name + (" " if row_name and col_name else "") + col_name + (" " if func_d["titlesuffix"] else "") + func_d["titlesuffix"] ) else: col_name, col_data = row_groups[0] # dummy data axes_has_plot[irow, colnum] = False title = "dummy data (error if you see this)" func_d["func"]( col_data, ax=ax, title=title, fixed_ymin=fixed_ylims["min"][row_name], fixed_ymax=fixed_ylims["max"][row_name], **func_d["kwargs"], ) if irow != nrows - 1: ax.set_xlabel("") ax.set_xticklabels([]) if icol != 0: ax.set_ylabel("") ax.set_yticklabels([]) return axes_has_plot if __name__ == "__main__": import doctest doctest.testmod()