"""
=======
logo
=======
Core logo-drawing functions of `dmslogo`.
Some of this code is borrowed and modified from
`pyseqlogo <https://github.com/saketkc/pyseqlogo>`_.
"""
import glob
import os
import warnings
import matplotlib.font_manager
import matplotlib.patheffects
import matplotlib.pyplot as plt
import matplotlib.ticker
import matplotlib.transforms
import numpy
import pandas as pd
import pkg_resources
import dmslogo.colorschemes
import dmslogo.utils
# default font
_DEFAULT_FONT = "DejaVuSansMonoBold_SeqLogo"
# add fonts to font manager
_FONT_PATH = pkg_resources.resource_filename("dmslogo", "ttf_fonts/")
if not os.path.isdir(_FONT_PATH):
raise RuntimeError(f"Cannot find font directory {_FONT_PATH}")
for _fontfile in matplotlib.font_manager.findSystemFonts(_FONT_PATH):
matplotlib.font_manager.fontManager.addfont(_fontfile)
for _fontfile in matplotlib.font_manager.findSystemFonts(None):
try:
matplotlib.font_manager.fontManager.addfont(_fontfile)
except TypeError:
warnings.warn(f"Cannot load font {_fontfile}", RuntimeWarning)
except RuntimeError:
# problem with loading emoji fonts; solution here is just to
# skip any fonts that cause problems
pass
del _fontfile
_fontlist = {f.name for f in matplotlib.font_manager.fontManager.ttflist}
if _DEFAULT_FONT not in _fontlist:
raise RuntimeError(f"Could not find default font {_DEFAULT_FONT}")
for _fontfile in glob.glob(f"{_FONT_PATH}/*.ttf"):
_font = os.path.splitext(os.path.basename(_fontfile))[0]
if _font not in _fontlist:
raise RuntimeError(f"Could not find font {_font} in file {_fontfile}")
[docs]
class Scale(matplotlib.patheffects.RendererBase):
"""Scale letters using affine transformation.
From here: https://www.python-forum.de/viewtopic.php?t=30856
"""
def __init__(self, sx, sy=None):
"""See main class docstring."""
self._sx = sx
self._sy = sy
[docs]
def draw_path(self, renderer, gc, tpath, affine, rgbFace):
"""Draw the letters."""
affine = matplotlib.transforms.Affine2D().scale(self._sx, self._sy) + affine
renderer.draw_path(gc, tpath, affine, rgbFace)
[docs]
class Memoize:
"""Memoize function from https://stackoverflow.com/a/1988826"""
def __init__(self, f):
"""See main class docstring."""
self.f = f
self.memo = {}
def __call__(self, *args):
"""Call class."""
if args not in self.memo:
self.memo[args] = self.f(*args)
# Warning: You may wish to do a deepcopy here if returning objects
return self.memo[args]
@Memoize
def _setup_font(fontfamily, fontsize):
"""Get `FontProperties` for `fontfamily` and `fontsize`."""
font = matplotlib.font_manager.FontProperties(
family=fontfamily, size=fontsize, weight="bold"
)
return font
@Memoize
def _frac_above_baseline(font):
"""Fraction of font height that is above baseline.
Args:
`font` (FontProperties)
Font for which we are computing fraction.
"""
fig, ax = plt.subplots()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
txt_baseline = ax.text(
0, 0, "A", fontproperties=font, va="baseline", bbox={"pad": 0}
)
txt_bottom = ax.text(0, 0, "A", fontproperties=font, va="bottom", bbox={"pad": 0})
fig.canvas.draw()
bbox_baseline = txt_baseline.get_window_extent()
bbox_bottom = txt_bottom.get_window_extent()
height_baseline = bbox_baseline.y1 - bbox_baseline.y0
height_bottom = bbox_bottom.y1 - bbox_bottom.y0
assert numpy.allclose(height_baseline, height_bottom)
frac = (bbox_baseline.y1 - bbox_bottom.y0) / height_bottom
plt.close(fig)
return frac
def _draw_text_data_coord(
height_matrix,
ystarts,
ax,
fontfamily,
fontaspect,
letterpad,
letterheightscale,
xpad,
):
"""Draws logo letters.
Args:
`height_matrix` (list of lists)
Gives letter heights. In the main list, there is a list
for each site, with the entries being 3-tuples giving
the letter, its height, its color, and 'pad_below' or
'pad_above' indicating where vertical padding is added.
`ystarts` (list)
Gives y position of bottom of first letter for each site.
`ax` (matplotlib Axes)
Axis on which we draw logo letters.
`fontfamily` (str)
Name of font to use.
`fontaspect` (float)
Value to use for font aspect ratio (height to width).
`letterpad` (float)
Add this much vertical padding between letters.
`letterheightscale` (float)
Scale height of letters by this much.
`xpad` (float)
x-axis is padded by this many data units on each side.
"""
fig = ax.get_figure()
# get bbox in **inches**
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width = bbox.width * len(height_matrix) / (2 * xpad + len(height_matrix))
height = bbox.height
max_stack_height = max(sum(abs(tup[1]) for tup in row) for row in height_matrix)
if len(ystarts) != len(height_matrix):
raise ValueError(
"`ystarts` and `height_matrix` different lengths\n"
f"ystarts={ystarts}\nheight_matrix={height_matrix}"
)
ymin, ymax = ax.get_ylim()
yextent = ymax - ymin
if max_stack_height > yextent:
raise ValueError("`max_stack_height` exceeds `yextent`")
if ymin > 0:
raise ValueError("`ymin` > 0")
if min(ystarts) < ymin:
raise ValueError("`ymin` exceeds smallest `ystarts`")
letterpadheight = yextent * letterpad
fontsize = 72
font = _setup_font(fontfamily, fontsize)
frac_above_baseline = _frac_above_baseline(font)
fontwidthscale = width / (fontaspect * len(height_matrix))
for xindex, (xcol, ystart) in enumerate(zip(height_matrix, ystarts)):
ypos = ystart
for letter, letterheight, lettercolor, pad_loc in xcol:
adj_letterheight = letterheightscale * letterheight
padding = min(letterheight / 2, letterpadheight)
if pad_loc == "pad_below":
ypad = padding + letterheight - adj_letterheight
elif pad_loc == "pad_above":
ypad = 0
else:
raise ValueError(f"invalid `pad_loc` {pad_loc}")
txt = ax.text(
xindex,
ypos + ypad,
letter,
fontsize=fontsize,
color=lettercolor,
ha="left",
va="baseline",
fontproperties=font,
bbox={"pad": 0, "edgecolor": "none", "facecolor": "none"},
)
scaled_height = adj_letterheight / frac_above_baseline
scaled_padding = padding / frac_above_baseline
txt.set_path_effects(
[
Scale(
fontwidthscale,
((scaled_height - scaled_padding) * height / yextent),
)
]
)
ypos += letterheight
[docs]
def draw_logo(
data,
*,
x_col,
letter_col,
letter_height_col,
xtick_col=None,
color_col=None,
shade_color_col=None,
shade_alpha_col=None,
heatmap_overlays=None,
xlabel=None,
ylabel=None,
title=None,
colorscheme=dmslogo.colorschemes.AA_FUNCTIONAL_GROUP,
missing_color="gray",
addbreaks=True,
widthscale=1,
heightscale=1,
heatmap_overlay_height=0.15,
axisfontscale=1,
hide_axis=False,
fontfamily=_DEFAULT_FONT,
fontaspect=0.58,
letterpad=0.0105,
letterheightscale=0.96,
ax=None,
ylim_setter=None,
fixed_ymin=None,
fixed_ymax=None,
clip_negative_heights=False,
drop_na_letter_heights=True,
draw_line_at_zero="if_negative",
):
"""Draw sequence logo from specified letter heights.
Args:
`data` (pandas DataFrame)
Holds data to plot.
`letter_height_col` (str)
Column in `data` with letter heights.
`letter_col` (str)
Column in `data` with letter identities.
`x_col` (str)
Column in `data` with integer site numbers.
`xtick_col` (`None` or str)
Column in `data` used to label sites if not using `x_col`.
`color_col` (`None` or str)
Column in data with colors for each letter; set to `None`
to define colors via `colorscheme` and `missing_color`.
`shade_color_col` (`None` or str)
Column in `data` indicating color to shade each site.
Must be same color for all letters at site. If a site
should not be shaded, set to something that evaluates
to `False` or `NaN`.
`shade_alpha_col` (`None` or str)
Column in `data` giving transparency of shading at each site.
`heatmap_overlays` (`None` or list)
List of columns in `data` giving colors for each overlay.
`xlabel` (`None` or str)
Label for x-axis if not using `xtick_col` or `x_col`.
`ylabel` (`None` or str)
Label for y-axis if not using `letter_height_col`.
`title` (`None` or str)
Title to place above plot.
`colorscheme` (dict)
Color for each letter. Ignored if `color_col` is not `None`.
See :py:mod:`dmslogo.colorschemes` for some color schemes.
`missing_color` (`None` or str)
Color for letters not assigned in `colorscheme`,
or `None` to raise an error for unassigned letters.
`addbreaks` (bool)
Anywhere there is a gap in sequential numbering of
`x_col`, add break consisting of space and dashed line.
`widthscale` (float)
Scale width by this much.
`heightscale` (float)
Scale height by this much.
`heatmap_overlay_height` (float)
Height of heatmap overlays relative to logo.
`axisfontscale` (float)
Scale size of font for axis ticks and labels by this much.
`hide_axis` (bool)
Do we hide the axis and tick labels?
`fontfamily` (str)
Font to use (for logo letters).
`fontaspect` (float)
Aspect ratio of logo letter font (height to width). If letters are
too crowded, increase this.
`letterpad` (float)
Add this much fixed vertical padding between letters
as fraction of total stack height.
`letterheightscale` (float)
Scale height of all letters by this much.
`ax` (`None` or matplotlib axes.Axes object or list of Axes)
Use to plot on an existing axis. If using `heatmap_overlays`
then must be list of axes of correct length.
`ylim_setter` (`None` or :class:`dmslogo.utils.AxLimSetter`)
Object used to set y-limits. If `None`, a
:class:`dmslogo.utils.AxLimSetter` is created using
default parameters). If `fixed_ymin` and/or `fixed_ymax`
are set, they override the limits from this setter.
`fixed_ymin` (`None` or float)
If not `None`, then fixed y-axis minimum.
`fixed_ymax` (`None` or float)
If not `None`, then fixed y-axis maximum.
`clip_negative_heights` (bool)
Set to 0 any value in `letter_height_col` that is < 0.
`drop_na_letter_heights` (bool)
Drop any rows in `data` where `letter_height_col` is NaN.
`draw_line_at_zero` (str)
Draw a horizontal line at the value of zero? Can have following
values: 'if_negative' to only draw line if there are negative
letter heights, 'always' to always draw line, and 'never' to
never draw line.
Returns:
The 2-tuple `(fig, ax)` giving the figure and axis with the logo plots.
If using `heatmap_overlays`, then `ax` will be an array of all axes
(overlays and logo axes).
"""
# set default values of arguments that can be None
if xtick_col is None:
xtick_col = x_col
if xlabel is None:
xlabel = xtick_col
if ylabel is None:
ylabel = letter_height_col
# check letters are all upper case
letters = str(data[letter_col].unique())
if letters.upper() != letters:
raise ValueError("letters in `letter_col` must be uppercase")
# checks on input data
for col in [letter_height_col, letter_col, x_col, xtick_col]:
if col not in data.columns:
raise ValueError(f"`data` lacks column {col}")
if (color_col is not None) and (color_col not in data.columns):
raise ValueError(f"`data` lacks column {color_col}")
if drop_na_letter_heights:
data = data[-data[letter_height_col].isna()]
if len(data) == 0:
raise ValueError("no data after dropping nan heights")
if clip_negative_heights:
data = data.assign(
**{letter_height_col: lambda x: numpy.clip(x[letter_height_col], 0, None)}
)
if any(data[x_col] != data[x_col].astype(int)):
raise ValueError("`x_col` does not have integer values")
if any(len(set(g[xtick_col])) != 1 for _, g in data.groupby(x_col)):
raise ValueError("not unique mapping of `x_col` to `xtick_col`")
# construct height_matrix: list of lists of (letter, height, color)
height_matrix = []
min_by_site = []
max_by_site = []
min_by_site_nonempty = []
max_by_site_nonempty = []
xticklabels = []
xticks = []
lastx = None
breaks = []
x_to_xtick = {}
xtick = 0.5
for x, xdata in data.sort_values([x_col, letter_height_col]).groupby(x_col):
if addbreaks and (lastx is not None) and (x != lastx + 1):
breaks.append(len(height_matrix))
height_matrix.append([])
min_by_site.append(0)
max_by_site.append(0)
xtick += 1
lastx = x
if len(xdata[letter_col]) != xdata[letter_col].nunique():
raise ValueError(f"duplicate letters for `x_col` {x}")
min_by_site.append(xdata[letter_height_col].clip(None, 0).sum())
max_by_site.append(xdata[letter_height_col].clip(0, None).sum())
min_by_site_nonempty.append(min_by_site[-1])
max_by_site_nonempty.append(max_by_site[-1])
row = []
for tup in xdata.itertuples(index=False):
letter = getattr(tup, letter_col)
if not (isinstance(letter, str) and len(letter) == 1):
raise ValueError(f"invalid letter of {letter}")
letter_height = getattr(tup, letter_height_col)
if color_col is not None:
color = getattr(tup, color_col)
else:
try:
color = colorscheme[letter]
except KeyError:
if missing_color:
color = missing_color
else:
raise ValueError(f"no color for {letter}")
row.append(
(
letter,
abs(letter_height),
color,
"pad_below" if letter_height >= 0 else "pad_above",
)
)
height_matrix.append(row)
assert len(xdata[xtick_col].unique()) == 1
xticklabels.append(str(xdata[xtick_col].values[0]))
xticks.append(xtick)
x_to_xtick[x] = xtick
xtick += 1
assert len(xticklabels) == len(xticks)
if draw_line_at_zero == "always":
line_at_zero = True
elif draw_line_at_zero == "never":
line_at_zero = False
elif draw_line_at_zero == "if_negative":
if min(min_by_site) < 0:
line_at_zero = True
else:
line_at_zero = False
else:
raise ValueError(f"invalid `draw_line_at_zero` {draw_line_at_zero}")
# do we have overlays?
if heatmap_overlays:
noverlays = len(heatmap_overlays)
for overlay in heatmap_overlays:
if overlay not in data.columns:
raise ValueError(f"`data` lacks `heatmap_overlay` {overlay}")
else:
overlay_data = data[[x_col] + list(heatmap_overlays)].drop_duplicates()
if not all(
overlay_data.values
== (overlay_data.groupby(x_col, as_index=False).first())
):
raise ValueError("Overlay not unique per site:\n" + overlay_data)
else:
heatmap_overlays = []
noverlays = 0
# setup axis for plotting
if not ax:
fig, axes = plt.subplots(
nrows=1 + noverlays,
ncols=1,
sharex=True,
sharey=False,
squeeze=False,
gridspec_kw={
"height_ratios": [heatmap_overlay_height] * noverlays + [1],
},
)
axes = axes.ravel()
assert len(axes) == 1 + noverlays, axes
fig.set_size_inches(
(
widthscale * 0.35 * (len(height_matrix) + int(not hide_axis)),
heightscale
* (
2
+ 0.5 * int(not hide_axis)
+ 2 * noverlays * heatmap_overlay_height
+ 0.5 * int(bool(title))
),
)
)
ax = axes[-1]
else:
if noverlays:
if len(ax) != noverlays + 1:
raise ValueError(f"`ax` not axes for {noverlays} overlays")
axes = ax
ax = axes[0]
else:
if not isinstance(ax, plt.Axes):
raise TypeError(f"`ax` is not an Axis: {ax}")
axes = [ax]
fig = ax.get_figure()
# draw overlays
for overlay, overlay_ax in zip(heatmap_overlays, axes):
overlay_ax.set_yticks([0.5])
overlay_ax.set_yticklabels([overlay])
overlay_ax.tick_params("y", labelsize=14 * axisfontscale, length=0)
dmslogo.utils.despine(
ax=overlay_ax,
top=True,
right=True,
left=True,
bottom=True,
)
overlay_ax.get_xaxis().set_visible(False)
for x, color in overlay_data.set_index(x_col)[overlay].items():
xtick = x_to_xtick[x]
overlay_ax.add_patch(
plt.Rectangle(
xy=(xtick - 0.5, 0),
width=1,
height=1,
facecolor=color,
edgecolor="black",
linewidth=1,
clip_on=False,
)
)
if title:
axes[0].set_title(title, fontsize=17 * axisfontscale)
xpad = 0.2
ax.set_xlim(-xpad, len(height_matrix) + xpad)
# set y-limits
if ylim_setter is None:
ylim_setter = dmslogo.utils.AxLimSetter()
ymin1, ymax1 = ylim_setter.get_lims(min_by_site_nonempty)
ymin2, ymax2 = ylim_setter.get_lims(max_by_site_nonempty)
ymin = min(ymin1, ymin2)
ymax = max(ymax1, ymax2)
if fixed_ymin is not None:
ymin = fixed_ymin
if fixed_ymax is not None:
ymax = fixed_ymax
ax.set_ylim(ymin, ymax)
if not hide_axis:
ax.set_xticks(xticks)
ax.tick_params(length=5, width=1)
ax.set_xticklabels(xticklabels, rotation=90, ha="center", va="top")
ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(4))
ax.tick_params("both", labelsize=12 * axisfontscale)
ax.set_xlabel(xlabel, fontsize=17 * axisfontscale)
ax.set_ylabel(ylabel, fontsize=17 * axisfontscale)
dmslogo.utils.despine(ax=ax, trim=False, top=True, right=True)
else:
ax.axis("off")
# draw the letters
_draw_text_data_coord(
height_matrix,
min_by_site,
ax,
fontfamily,
fontaspect,
letterpad,
letterheightscale,
xpad,
)
# draw the breaks
for x in breaks:
# loosely dotted line:
# https://matplotlib.org/gallery/lines_bars_and_markers/linestyles.html
ax.axvline(x=x + 0.5, ls=(0, (2, 5)), color="black", lw=1)
# draw line at zero
if line_at_zero:
ax.axhline(y=0, ls="-", color="black", lw=1, zorder=4)
# draw the shading
if shade_color_col is not None:
if shade_alpha_col is None:
raise ValueError("`shade_color_col` without `shade_alpha_col`")
if shade_color_col not in data.columns:
raise ValueError(f"data lacks `shade_color_col` {shade_color_col}")
if shade_alpha_col not in data.columns:
raise ValueError(f"data lacks `shade_alpha_col` {shade_alpha_col}")
for x, xdata in data.groupby(x_col):
shade_color = xdata[shade_color_col].unique()
if len(shade_color) != 1:
raise ValueError(f"not exactly one shade color for {x}")
else:
shade_color = shade_color[0]
shade_alpha = xdata[shade_alpha_col].unique()
if len(shade_alpha) != 1:
raise ValueError(f"not exactly one shade alpha for {x}")
else:
shade_alpha = shade_alpha[0]
if pd.isnull(shade_color) or not shade_color:
continue
elif not (0 <= shade_alpha <= 1):
raise ValueError(f"shade alpha not between 0 and 1 for {x}")
xtick = x_to_xtick[x]
ax.axvspan(
xmin=xtick - 0.5,
xmax=xtick + 0.5,
edgecolor=None,
facecolor=shade_color,
alpha=shade_alpha,
)
elif shade_alpha_col is not None:
raise ValueError("`shade_alpha_col` without `shade_color_col`")
if len(axes) == 1:
return fig, ax
else:
return fig, axes
if __name__ == "__main__":
import doctest
doctest.testmod()