"""
==================
curvefits
==================
Defines :class:`CurveFits` to fit curves and display / plot results.
"""
import collections
import copy
import itertools
import math
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd
import neutcurve
from neutcurve.colorschemes import CBMARKERS, CBPALETTE
[docs]
class CurveFits:
"""Fit and display :class:`neutcurve.hillcurve.HillCurve` curves.
Args:
`data` (pandas DataFrame)
Tidy dataframe with data.
`conc_col` (str)
Column in `data` with concentrations of serum.
`fracinf_col` (str)
Column in `data` with fraction infectivity.
`serum_col` (str)
Column in `data` with serum name.
`virus_col` (str)
Column in `data` with name of virus being neutralized.
`replicate_col` (str`)
Column in data with name of replicate of this measurement. Replicates can
**not** be named 'average' as we compute the average from the replicates.
`fixbottom` (`False` or float or 2-tuple)
Same meaning as for :class:`neutcurve.hillcurve.HillCurve`.
`fixtop` (`False` or float or 2-tuple)
Same meaning as for :class:`neutcurve.hillcurve.HillCurve`.
`fixslope` (`False` or float or 2-tuple)
Same meaning as for :class:`neutcurve.hillcurve.HillCurve`.
`infectivity_or_neutralized` ({'infectivity', 'neutralized'})
Same meaning as for :class:`neutcurve.hillcurve.HillCurve`.
`fix_slope_first` (bool)
Same meaning as for :class:`neutcurve.hillcurve.HillCurve`.
`init_slope` (float)
Same meaning as for :class:`neutcurve.hillcurve.HillCurve`.
`allow_reps_unequal_conc` (bool)
Allow replicates for the same serum/virus to have unequal concentrations;
otherwise all replicates for a serum/virus must have measurements at same
concentrations.
Attributes of a :class:`CurveFits` include all args except `data` plus:
`df` (pandas DataFrame)
Copy of `data` that only has relevant columns, has additional rows
with `replicate_col` of 'average' that hold replicate averages, and
added columns 'stderr' (standard error of fraction infectivity
for 'average' if multiple replicates, otherwise `nan`).
`sera` (list)
List of all serum names in `serum_col` of `data`, in order
they occur in `data`.
`viruses` (dict)
For each serum in `sera`, `viruses[serum]` gives all viruses
for that serum in the order they occur in `data`.
`replicates` (dict)
`replicates[(serum, virus)]` is list of all replicates for
that serum and virus in the order they occur in `data`.
`allviruses` (list)
List of all viruses.
"""
# names commonly used for wildtype virus
_WILDTYPE_NAMES = ("WT", "wt", "wildtype", "Wildtype", "wild type", "Wild type")
[docs]
@staticmethod
def combineCurveFits(
curvefits_list,
*,
sera=None,
viruses=None,
serum_virus_replicates_to_drop=None,
):
"""
Args:
`curvesfit_list` (list)
List of :class:`CurveFits` objects that are identical other than the
data they contain and have unique virus/serum/replicate combinations.
They can differ in `fixtop` and `fixbottom`, but then those
will be set to `None` in the returned object.
`sera` (None or list)
Only keep fits for sera in this list, or keep all sera if `None`.
`viruses` (None or list)
Only keep fits for viruses in this list, or keep all sera if `None`.
`serum_virus_replicates_to_drop` (None or list)
If a list, should specify `(serum, virus, replicates)` tuples, and those
particular fits are dropped.
Returns:
combined_fits (:class:`CurveFits`)
A :class:`CurveFits` object that combines all the virus/serum/replicate
combinations in `curvefits_list`.
"""
if not (
len(curvefits_list) >= 1
and all(isinstance(c, CurveFits) for c in curvefits_list)
):
raise ValueError(f"{curvefits_list=} not list of at least one `CurveFits`")
attrs_must_be_same = [ # attributes that must be same among all objects
"conc_col",
"fracinf_col",
"serum_col",
"virus_col",
"replicate_col",
"_infectivity_or_neutralized",
"_fix_slope_first",
"_init_slope",
]
attrs_can_differ = [ # attributes that can differ among objects
"fixbottom",
"fixtop",
"fixslope",
"df",
"sera",
"allviruses",
"viruses",
"replicates",
"_hillcurves",
"_fitparams",
]
all_attrs = set(attrs_must_be_same + attrs_can_differ)
assert all(set(c.__dict__) == all_attrs for c in curvefits_list)
# make new object we will then adjust attributes
combined_fits = copy.deepcopy(curvefits_list[0])
for attr in attrs_must_be_same:
if any(
getattr(combined_fits, attr) != getattr(c, attr) for c in curvefits_list
):
raise ValueError(f"objects in `curvefits_list` differ in {attr}")
for attr in attrs_can_differ:
delattr(combined_fits, attr)
# fixtop, fixbottom, fixslope are kept at shared value or None
for attr in ["fixtop", "fixbottom", "fixslope"]:
if any(
getattr(curvefits_list[0], attr) != getattr(c, attr)
for c in curvefits_list
):
setattr(combined_fits, attr, None)
else:
setattr(combined_fits, attr, getattr(curvefits_list[0], attr))
# combine df
assert all(
all(getattr(curvefits_list[0], "df").columns == getattr(c, "df").columns)
for c in curvefits_list
)
combined_fits.df = pd.concat(
[getattr(c, "df") for c in curvefits_list],
ignore_index=False,
sort=False,
).drop(columns="stderr")
combined_fits.df = combined_fits.df[
combined_fits.df[combined_fits.replicate_col] != "average"
]
for col, keeplist in [
(combined_fits.serum_col, sera),
(combined_fits.virus_col, viruses),
]:
if keeplist is not None:
combined_fits.df = combined_fits.df[
combined_fits.df[col].isin(keeplist)
]
if serum_virus_replicates_to_drop:
assert "tup" not in set(combined_fits.df.columns)
combined_fits.df = (
combined_fits.df.assign(
tup=lambda x: list(
x[
[
combined_fits.serum_col,
combined_fits.virus_col,
combined_fits.replicate_col,
]
].itertuples(index=False, name=None)
),
)
.query("tup not in @serum_virus_replicates_to_drop")
.drop(columns="tup")
)
combined_fits.df = combined_fits._get_avg_and_stderr_df(combined_fits.df)
if len(combined_fits.df) != len(
combined_fits.df.groupby(
[
combined_fits.serum_col,
combined_fits.virus_col,
combined_fits.replicate_col,
combined_fits.conc_col,
]
)
):
raise ValueError("duplicated sera/virus/replicate in `curvefits_list`")
# combine sera
combined_fits.sera = combined_fits.df[combined_fits.serum_col].unique().tolist()
# combine allviruses
combined_fits.allviruses = (
combined_fits.df[combined_fits.virus_col].unique().tolist()
)
# combine viruses and replicates
assert combined_fits.serum_col != "viruses"
combined_fits.viruses = (
combined_fits.df.groupby(combined_fits.serum_col, sort=False)
.aggregate(
viruses=pd.NamedAgg(
combined_fits.virus_col,
lambda s: s.unique().tolist(),
),
)["viruses"]
.to_dict()
)
assert combined_fits.serum_col != "replicate"
assert combined_fits.virus_col != "replicate"
combined_fits.replicates = (
combined_fits.df[combined_fits.df[combined_fits.replicate_col] != "average"]
.groupby([combined_fits.serum_col, combined_fits.virus_col], sort=False)
.aggregate(
replicates=pd.NamedAgg(
combined_fits.replicate_col,
lambda s: s.unique().tolist(),
),
)["replicates"]
.to_dict()
)
for serum, virus in combined_fits.replicates:
combined_fits.replicates[(serum, virus)].append("average")
serum_virus_rep_tups = [
(serum, virus, rep)
for (serum, virus), reps in combined_fits.replicates.items()
for rep in reps
]
assert len(serum_virus_rep_tups) == len(set(serum_virus_rep_tups))
combined_fits_tups = set(
combined_fits.df[
[
combined_fits.serum_col,
combined_fits.virus_col,
combined_fits.replicate_col,
]
].itertuples(index=False, name=None)
)
assert (
set(serum_virus_rep_tups) == combined_fits_tups
), f"{combined_fits_tups=}\n\n{serum_virus_rep_tups=}"
assert set(combined_fits.allviruses).issubset(
v for f in curvefits_list for s in f.viruses.values() for v in s
)
# get all the HillCurve objects that have been pre-computed from prior
# objects being concatenated, except any "average" ones
combined_fits._hillcurves = {}
for c in curvefits_list:
for (serum, virus, replicate), curve in c._hillcurves.items():
if (
(serum in combined_fits.sera)
and (virus in combined_fits.allviruses)
and (replicate in combined_fits.replicates[(serum, virus)])
and (replicate != "average")
):
combined_fits._hillcurves[(serum, virus, replicate)] = curve
combined_fits._fitparams = {} # clear this cache
assert set(combined_fits.__dict__) == all_attrs
return combined_fits
def __init__(
self,
data,
*,
conc_col="concentration",
fracinf_col="fraction infectivity",
serum_col="serum",
virus_col="virus",
replicate_col="replicate",
infectivity_or_neutralized="infectivity",
fix_slope_first=True,
init_slope=1.5,
fixbottom=0,
fixtop=1,
fixslope=False,
allow_reps_unequal_conc=False,
):
"""See main class docstring."""
# make args into attributes
self.conc_col = conc_col
self.fracinf_col = fracinf_col
self.serum_col = serum_col
self.virus_col = virus_col
self.replicate_col = replicate_col
self.fixbottom = fixbottom
self.fixtop = fixtop
self.fixslope = fixslope
self._infectivity_or_neutralized = infectivity_or_neutralized
self._fix_slope_first = fix_slope_first
self._init_slope = init_slope
# check for required columns
cols = [
self.serum_col,
self.virus_col,
self.replicate_col,
self.conc_col,
self.fracinf_col,
]
if len(cols) != len(set(cols)):
raise ValueError("duplicate column names:\n\t" + "\n\t".join(cols))
if not (set(cols) <= set(data.columns)):
raise ValueError(
"`data` lacks required columns, which are:\n\t" + "\n\t".join(cols)
)
# create `self.df`, ensure that replicates are str rather than number
self.df = data[cols].assign(
**{replicate_col: lambda x: (x[replicate_col].astype(str))}
)
# create sera / viruses / replicates attributes, error check them
self.sera = self.df[self.serum_col].unique().tolist()
self.viruses = {}
self.replicates = {}
for serum in self.sera:
serum_data = self.df.query(f"{self.serum_col} == @serum")
serum_viruses = serum_data[self.virus_col].unique().tolist()
self.viruses[serum] = serum_viruses
for virus in serum_viruses:
virus_data = serum_data.query(f"{self.virus_col} == @virus")
virus_reps = virus_data[self.replicate_col].unique().tolist()
if "average" in virus_reps:
raise ValueError(
'A replicate is named "average". This is '
"not allowed as that name is used for "
"replicate averages."
)
self.replicates[(serum, virus)] = virus_reps + ["average"]
for i, rep1 in enumerate(virus_reps):
conc1 = (
virus_data.query(f"{self.replicate_col} == @rep1")[
self.conc_col
]
.sort_values()
.tolist()
)
if len(conc1) != len(set(conc1)):
raise ValueError(
f"duplicate concentrations for {serum=}, {virus=}, {rep1=}"
)
if not allow_reps_unequal_conc:
for rep2 in virus_reps[i + 1 :]:
conc2 = (
virus_data.query(f"{self.replicate_col} == @rep1")[
self.conc_col
]
.sort_values()
.tolist()
)
if conc1 != conc2:
raise ValueError(
f"{rep1=}, {rep2=} differ conc {serum=} {virus=}\n"
"Replicates for serum/virus must have same "
"concentrations unless allow_reps_unequal_conc=True"
)
self.allviruses = collections.OrderedDict()
for serum in self.sera:
for virus in self.viruses[serum]:
self.allviruses[virus] = True
self.allviruses = list(self.allviruses.keys())
if pd.isnull(self.allviruses).any():
raise ValueError(f"a virus has name NaN:\n{self.allviruses}")
if pd.isnull(self.sera).any():
raise ValueError(f"a serum has name NaN:\n{self.sera}")
# compute replicate average and add 'stderr'
self.df = self._get_avg_and_stderr_df(self.df)
self._hillcurves = {} # curves computed by `getCurve` cached here
self._fitparams = {} # cache data frame computed by `fitParams`
def _get_avg_and_stderr_df(self, df):
"""Adds average rows and stderr column."""
if "stderr" in df.columns:
raise ValueError('`data` has column "stderr"')
avg_df = (
df.groupby([self.serum_col, self.virus_col, self.conc_col], observed=True)[
self.fracinf_col
]
# sem is sample stderr, evaluates to NaN when just 1 rep
.aggregate(["mean", "sem"])
.rename(
columns={
"mean": self.fracinf_col,
"sem": "stderr",
}
)
.reset_index()
.assign(**{self.replicate_col: "average"})
)
df = pd.concat(
[df, avg_df],
ignore_index=True,
sort=False,
)
return df
[docs]
def getCurve(self, *, serum, virus, replicate):
"""Get the fitted curve for this sample.
Args:
`serum` (str)
Name of a valid serum.
`virus` (str)
Name of a valid virus for `serum`.
`replicate` (str)
Name of a valid replicate for `serum` and `virus`, or
'average' for the average of all replicates.
Returns:
A :class:`neutcurve.hillcurve.HillCurve`.
"""
key = (serum, virus, replicate)
if key not in self._hillcurves:
if serum not in self.sera:
raise ValueError(f"invalid {serum=}")
if virus not in self.viruses[serum]:
raise ValueError(f"invalid {virus=} for {serum=}")
if replicate not in self.replicates[(serum, virus)]:
raise ValueError(f"invalid {replicate=} for {serum=} {virus=}")
idata = self.df.query(
f"({self.serum_col} == @serum) & "
f"({self.virus_col} == @virus) & "
f"({self.replicate_col} == @replicate)"
)
if len(idata) < 1:
raise ValueError(f"no data for {serum=} {virus=}")
if idata["stderr"].isna().any():
fs_stderr = None # cannot use stderr if any concentrations lack it
else:
fs_stderr = idata["stderr"]
try:
curve = neutcurve.HillCurve(
cs=idata[self.conc_col],
fs=idata[self.fracinf_col],
fs_stderr=fs_stderr,
fixbottom=self.fixbottom,
fixtop=self.fixtop,
fixslope=self.fixslope,
infectivity_or_neutralized=self._infectivity_or_neutralized,
fix_slope_first=self._fix_slope_first,
init_slope=self._init_slope,
)
except neutcurve.HillCurveFittingError as e:
idata.to_csv("_temp.csv", index=False)
# following here: https://stackoverflow.com/a/46091127
raise neutcurve.HillCurveFittingError(
f"Error fitting HillCurve for {serum=} {virus=} {replicate=}\n"
f"Data are:\n{idata}"
) from e
self._hillcurves[key] = curve
return self._hillcurves[key]
[docs]
def fitParams(
self,
*,
average_only=True,
no_average=False,
ics=(50,),
ics_precision=0,
ic50_error=None,
):
"""Get data frame with curve fitting parameters.
Args:
`average_only` (bool)
If `True`, only get parameters for average across replicates.
`no_average` (bool)
Do not include average across replicates. Mutually incompatible
with `average_only`.
`ics` (iterable)
Include ICXX for each number in this list, where the number
is the percent neutralized. So if `ics` only contains 50,
we include the IC50. If it includes 95, we include the IC95.
`ics_precision` (int)
Include this many digits after decimal when creating the
ICXX columns.
ic50_error {`None`, 'fit_stdev'}
Include estimated error on IC50 as standard deviation of fit
parameter; note that we recommend instead just taking standard
error of replicate IC50s.
Returns:
A pandas DataFrame with fit parameters for each serum / virus /
replicate as defined for a :mod:`neutcurve.hillcurve.HillCurve`.
Columns:
- 'serum'
- 'virus'
- 'replicate'
- 'nreplicates': number of replicates for average, NaN otherwise.
- 'icXX': ICXX or its bound as a number, where `XX` is each
number in `ics`.
- 'icXX_bound': string indicating if ICXX interpolated from data,
or is an upper or lower bound.
- 'icXX_str': ICXX represented as string, with > or < indicating
if it is an upper or lower bound.
- 'midpoint': midpoint of curve, same as IC50 only if bottom
and top are 0 and 1.
- 'midpoint_bound': midpoint bounded by range of fit concentrations
- 'midpoint_bound_type': string indicating if midpoint is interpolated
from data or is an upper or lower bound.
- 'slope': Hill slope of curve.
- 'top': top of curve.
- 'bottom': bottom of curve.
- 'r2': coefficient of determination of fit
- 'rmsd': root-mean square deviation of fits
"""
if ic50_error not in {None, "fit_stdev"}:
raise ValueError(f"invalid {ic50_error=}")
ics = tuple(ics)
ic_colprefixes = [f"ic{{:.{ics_precision}f}}".format(ic) for ic in ics]
if len(ic_colprefixes) != len(set(ic_colprefixes)):
raise ValueError(
"column names for ICXX not unique.\n"
"Either you have duplicate entries in `ics` "
"or you need to increase `ics_precision`."
)
if average_only and no_average:
raise ValueError("both `average_only` and `no_average` are `True`")
key = (average_only, no_average, ics, ics_precision, ic50_error)
if key not in self._fitparams:
d = collections.defaultdict(list)
params = [
"midpoint",
"midpoint_bound",
"midpoint_bound_type",
"slope",
"top",
"bottom",
"r2",
"rmsd",
]
for serum in self.sera:
for virus in self.viruses[serum]:
replicates = self.replicates[(serum, virus)]
nreplicates = sum(r != "average" for r in replicates)
assert nreplicates == len(replicates) - 1
if no_average:
replicates = [r for r in replicates if r != "average"]
if average_only:
replicates = ["average"]
for replicate in replicates:
curve = self.getCurve(
serum=serum, virus=virus, replicate=replicate
)
d["serum"].append(serum)
d["virus"].append(virus)
d["replicate"].append(replicate)
if replicate == "average":
d["nreplicates"].append(nreplicates)
else:
d["nreplicates"].append(float("nan"))
for ic, colprefix in zip(ics, ic_colprefixes):
f = ic / 100
d[colprefix].append(curve.icXX(f, method="bound"))
d[f"{colprefix}_bound"].append(curve.icXX_bound(f))
d[f"{colprefix}_str"].append(curve.icXX_str(f))
if ic50_error == "fit_stdev":
d["ic50_error"].append(curve.ic50_stdev())
for param in params:
d[param].append(getattr(curve, param))
ic_cols = []
for prefix in ic_colprefixes:
ic_cols += [prefix, f"{prefix}_bound", f"{prefix}_str"]
if ic50_error == "fit_stdev":
ic_cols.append("ic50_error")
if len(d):
self._fitparams[key] = pd.DataFrame(d)[
["serum", "virus", "replicate", "nreplicates"] + ic_cols + params
].assign(nreplicates=lambda x: (x["nreplicates"].astype("Int64")))
else:
self._fitparams[key] = pd.DataFrame(
columns=["serum", "virus", "replicate", "nreplicates"]
+ ic_cols
+ params,
)
return self._fitparams[key]
[docs]
def plotSera(
self,
*,
ncol=4,
nrow=None,
sera="all",
viruses="all",
ignore_serum_virus=None,
colors=CBPALETTE,
markers=CBMARKERS,
virus_to_color_marker=None,
max_viruses_per_subplot=5,
multi_serum_subplots=True,
all_subplots=_WILDTYPE_NAMES,
titles=None,
vlines=None,
**kwargs,
):
"""Plot grid with replicate-average of viruses for each serum.
Args:
`ncol`, `nrow` (int or `None`)
Specify one of these to set number of columns or rows,
other should be `None`.
`sera` ('all' or list)
Sera to include on plot, in this order.
`viruses` ('all' or list)
Viruses to include on plot, in this order unless one
is specified in `all_subplots`.
`ignore_serum_virus` (`None` or dict)
Specific serum / virus combinations to ignore (not plot). Key
by serum, and then list viruses to ignore.
`colors` (iterable)
List of colors for different viruses.
`markers` (iterable)
List of markers for different viruses.
`virus_to_color_marker` (dict or `None`)
Optionally specify a specific color and for each virus as
2-tuples `(color, marker)`. If you use this option, `colors`
and `markers` are ignored.
`max_viruses_per_subplot` (int)
Maximum number of viruses to show on any subplot.
`multi_serum_subplots` (bool)
If a serum has more than `max_virus_per_subplot` viruses,
do we make multiple subplots for it or raise an error?
`all_subplots` (iterable)
If making multiple subplots for serum, which viruses
do we show on all subplots? These are also shown first.
`titles` (`None` or list)
Specify custom titles for each subplot different than `sera`.
`vlines` (`None` or dict)
Add vertical lines to plots. Keyed by serum name, values
are lists of dicts with a key 'x' giving x-location of vertical
line, and optional keys 'linewidth', 'color', and 'linestyle'.
`**kwargs`
Other keyword arguments that can be passed to
:meth:`CurveFits.plotGrid`.
Returns:
The 2-tuple `(fig, axes)` of matplotlib figure and 2D axes array.
"""
sera, viruses = self._sera_viruses_lists(sera, viruses)
viruses = list(collections.OrderedDict.fromkeys(viruses))
if titles is None:
titles = sera
elif len(sera) != len(titles):
raise ValueError(f"`titles`, `sera` != length:\n{titles=}\n{sera=}")
if max_viruses_per_subplot < 1:
raise ValueError(f"{max_viruses_per_subplot=} must be at least 1")
# get color scheme for viruses
if virus_to_color_marker:
extra_viruses = set(viruses) - set(virus_to_color_marker.keys())
if extra_viruses:
raise ValueError(
"viruses not in `virus_to_color_marker`: " + str(extra_viruses)
)
elif len(viruses) <= min(len(colors), len(markers)):
# can share scheme among subplots
ordered_viruses = [v for v in viruses if v in all_subplots] + [
v for v in viruses if v not in all_subplots
]
virus_to_color_marker = {
v: (c, m) for (v, c, m) in zip(ordered_viruses, colors, markers)
}
elif min(len(colors), len(markers)) < max_viruses_per_subplot:
raise ValueError(
"`max_viruses_per_subplot` larger than " "number of colors or markers"
)
else:
virus_to_color_marker = None
# Build a list of plots appropriate for `plotGrid`.
# Code is complicated because we could have several curve
# per serum, and in that case need to share viruses in
# `all_subplots` among curves.
plotlist = []
vlines_list = []
for serum, title in zip(sera, titles):
if ignore_serum_virus and serum in ignore_serum_virus:
ignore_virus = ignore_serum_virus[serum]
else:
ignore_virus = {}
curvelist = []
ivirus = 0
serum_shared_viruses = [
v
for v in self.viruses[serum]
if (v in viruses) and (v in all_subplots) and (v not in ignore_virus)
]
serum_unshared_viruses = [
v
for v in self.viruses[serum]
if (v in viruses)
and (v not in all_subplots)
and (v not in ignore_virus)
]
unshared = int(bool(len(serum_unshared_viruses)))
if len(serum_shared_viruses) > max_viruses_per_subplot - unshared:
raise ValueError(
f"{serum=} has too many subplot-shared "
"viruses (in `all_subplots`) relative to "
"value of `max_viruses_per_subplot`:\n"
f"{serum_shared_viruses=} is more than "
f"{max_viruses_per_subplot=} viruses."
)
shared_curvelist = []
for virus in serum_shared_viruses + serum_unshared_viruses:
if ivirus >= max_viruses_per_subplot:
if multi_serum_subplots:
plotlist.append((title, curvelist))
if vlines and (serum in vlines):
vlines_list.append(vlines[serum])
else:
vlines_list.append(None)
curvelist = list(shared_curvelist)
ivirus = len(curvelist)
assert ivirus < max_viruses_per_subplot
else:
raise ValueError(
f"{serum=} has more than "
"`max_viruses_per_subplot` viruses "
"and `multi_serum_subplots` is False"
)
if virus_to_color_marker:
color, marker = virus_to_color_marker[virus]
else:
color = colors[ivirus]
marker = markers[ivirus]
curvelist.append(
{
"serum": serum,
"virus": virus,
"replicate": "average",
"label": virus,
"color": color,
"marker": marker,
}
)
if virus in serum_shared_viruses:
shared_curvelist.append(curvelist[-1])
ivirus += 1
if curvelist:
plotlist.append((title, curvelist))
if vlines and (serum in vlines):
vlines_list.append(vlines[serum])
else:
vlines_list.append(None)
if not plotlist:
raise ValueError("no curves for these sera / viruses")
# get number of columns
if (nrow is not None) and (ncol is not None):
raise ValueError(f"either {ncol=} or {nrow=} must be `None`")
elif isinstance(nrow, int) and nrow > 0:
ncol = math.ceil(len(plotlist) / nrow)
elif not (isinstance(ncol, int) and ncol > 0):
raise ValueError(f"{nrow=} or {ncol=} must be integer > 0")
# convert plotlist to plots dict for `plotGrid`
plots = {}
vlines_axkey = {}
assert len(plotlist) == len(vlines_list)
for iplot, (plot, ivline) in enumerate(zip(plotlist, vlines_list)):
irow = iplot // ncol
icol = iplot % ncol
plots[(irow, icol)] = plot
if ivline:
vlines_axkey[(irow, icol)] = ivline
if virus_to_color_marker and "orderlegend" not in kwargs:
orderlegend = virus_to_color_marker.keys()
kwargs["orderlegend"] = orderlegend
return self.plotGrid(
plots,
vlines=vlines_axkey,
**kwargs,
)
[docs]
def plotViruses(
self,
*,
ncol=4,
nrow=None,
sera="all",
viruses="all",
ignore_virus_serum=None,
colors=CBPALETTE,
markers=CBMARKERS,
serum_to_color_marker=None,
max_sera_per_subplot=5,
multi_virus_subplots=True,
all_subplots=(),
titles=None,
vlines=None,
**kwargs,
):
"""Plot grid with replicate-average of sera for each virus.
Args:
`ncol`, `nrow` (int or `None`)
Specify one of these to set number of columns or rows,
other should be `None`.
`sera` ('all' or list)
Sera to include on plot, in this order, unless one is
specified in `all_subplots`.
`viruses` ('all' or list)
Viruses to include on plot, in this order.
`ignore_virus_serum` (`None` or dict)
Specific virus / serum combinations to ignore (not plot). Key
by virus, and then list sera to ignore.
`colors` (iterable)
List of colors for different sera.
`markers` (iterable)
List of markers for different sera.
`serum_to_color_marker` (dict or `None`)
Optionally specify a specific color and for each serum as
2-tuples `(color, marker)`. If you use this option, `colors`
and `markers` are ignored.
`max_sera_per_subplot` (int)
Maximum number of sera to show on any subplot.
`multi_virus_subplots` (bool)
If a virus has more than `max_sera_per_subplot` sera,
do we make multiple subplots for it or raise an error?
`all_subplots` (iterable)
If making multiple subplots for virus, which sera
do we show on all subplots? These are also shown first.
`titles` (`None` or list)
Specify custom titles for each subplot different than
`viruses`.
`vlines` (`None` or dict)
Add vertical lines to plots. Keyed by virus name, values
are lists of dicts with a key 'x' giving x-location of vertical
line, and optional keys 'linewidth', 'color', and 'linestyle'.
`**kwargs`
Other keyword arguments that can be passed to
:meth:`CurveFits.plotGrid`.
Returns:
The 2-tuple `(fig, axes)` of matplotlib figure and 2D axes array.
"""
sera, viruses = self._sera_viruses_lists(sera, viruses)
viruses = list(collections.OrderedDict.fromkeys(viruses))
if titles is None:
titles = viruses
elif len(viruses) != len(titles):
raise ValueError(f"`titles`, `viruses` != length:\n" f"{titles}\n{viruses}")
if max_sera_per_subplot < 1:
raise ValueError(f"{max_sera_per_subplot=} must be at least 1")
# get color scheme for sera
if serum_to_color_marker:
extra_sera = set(sera) - set(serum_to_color_marker.keys())
if extra_sera:
raise ValueError(
"sera not in `serum_to_color_marker`: " + str(extra_sera)
)
elif len(sera) <= min(len(colors), len(markers)):
# can share scheme among subplots
ordered_sera = [s for s in sera if s in all_subplots] + [
s for s in sera if s not in all_subplots
]
serum_to_color_marker = {
s: (c, m) for (s, c, m) in zip(ordered_sera, colors, markers)
}
elif min(len(colors), len(markers)) < max_sera_per_subplot:
raise ValueError(
"`max_sera_per_subplot` larger than " "number of colors or markers"
)
else:
serum_to_color_marker = None
# Build a list of plots appropriate for `plotGrid`.
# Code is complicated because we could have several curve
# per virus, and in that case need to share sera in
# `all_subplots` among curves.
virus_sera = {
v: [s for s in self.sera if v in self.viruses[s]] for v in self.allviruses
}
plotlist = []
vlines_list = []
for virus, title in zip(viruses, titles):
if ignore_virus_serum and virus in ignore_virus_serum:
ignore_serum = ignore_virus_serum[virus]
else:
ignore_serum = {}
curvelist = []
iserum = 0
virus_shared_sera = [
s
for s in virus_sera[virus]
if (s in sera) and (s in all_subplots) and (s not in ignore_serum)
]
virus_unshared_sera = [
s
for s in virus_sera[virus]
if (s in sera) and (s not in all_subplots) and (s not in ignore_serum)
]
unshared = int(bool(len(virus_unshared_sera)))
if len(virus_shared_sera) > max_sera_per_subplot - unshared:
raise ValueError(
f"{virus=} has too many subplot-shared "
"sera (in `all_subplots`) relative to "
"value of `max_sera_per_subplot`:\n"
f"{virus_shared_sera=} is more than "
f"{max_sera_per_subplot=} viruses."
)
shared_curvelist = []
for serum in virus_shared_sera + virus_unshared_sera:
if iserum >= max_sera_per_subplot:
if multi_virus_subplots:
plotlist.append((title, curvelist))
if vlines and (virus in vlines):
vlines_list.append(vlines[virus])
else:
vlines_list.append(None)
curvelist = list(shared_curvelist)
iserum = len(curvelist)
assert iserum < max_sera_per_subplot
else:
raise ValueError(
f"{virus=} has more than "
"`max_sera_per_subplot` viruses "
"and `multi_virus_subplots` is False"
)
if serum_to_color_marker:
color, marker = serum_to_color_marker[serum]
else:
color = colors[iserum]
marker = markers[iserum]
curvelist.append(
{
"serum": serum,
"virus": virus,
"replicate": "average",
"label": serum,
"color": color,
"marker": marker,
}
)
if serum in virus_shared_sera:
shared_curvelist.append(curvelist[-1])
iserum += 1
if curvelist:
plotlist.append((title, curvelist))
if vlines and (virus in vlines):
vlines_list.append(vlines[virus])
else:
vlines_list.append(None)
if not plotlist:
raise ValueError("no curves for these viruses / sera")
# get number of columns
if (nrow is not None) and (ncol is not None):
raise ValueError(f"either {ncol=} or {nrow=} must be `None`")
elif isinstance(nrow, int) and nrow > 0:
ncol = math.ceil(len(plotlist) / nrow)
elif not (isinstance(ncol, int) and ncol > 0):
raise ValueError(f"{nrow=} or {ncol=} must be integer > 0")
# convert plotlist to plots dict for `plotGrid`
plots = {}
vlines_axkey = {}
assert len(plotlist) == len(vlines_list)
for iplot, (plot, ivline) in enumerate(zip(plotlist, vlines_list)):
irow = iplot // ncol
icol = iplot % ncol
plots[(irow, icol)] = plot
if ivline:
vlines_axkey[(irow, icol)] = ivline
if serum_to_color_marker and "orderlegend" not in kwargs:
orderlegend = serum_to_color_marker.keys()
else:
orderlegend = None
return self.plotGrid(
plots,
orderlegend=orderlegend,
vlines=vlines_axkey,
**kwargs,
)
[docs]
def plotAverages(
self,
*,
color="black",
marker="o",
**kwargs,
):
"""Plot grid with a curve for each serum / virus pair.
Args:
`color` (str)
Color the curves.
`marker` (str)
Marker for the curves.
`**kwargs`
Other keyword arguments that can be passed to
:meth:`CurveFits.plotReplicates`.
Returns:
The 2-tuple `(fig, axes)` of matplotlib figure and 2D axes array.
"""
return self.plotReplicates(
average_only=True, colors=[color], markers=[marker], **kwargs
)
[docs]
def plotReplicates(
self,
*,
ncol=4,
nrow=None,
sera="all",
viruses="all",
colors=CBPALETTE,
markers=CBMARKERS,
subplot_titles="{serum} vs {virus}",
show_average=False,
average_only=False,
attempt_shared_legend=True,
**kwargs,
):
"""Plot grid with replicates for each serum / virus on same plot.
Args:
`ncol`, `nrow` (int or `None`)
Specify one of these to set number of columns or rows.
`sera` ('all' or list)
Sera to include on plot, in this order.
`viruses` ('all' or list)
Viruses to include on plot, in this order.
`colors` (iterable)
List of colors for different replicates.
`markers` (iterable)
List of markers for different replicates.
`subplot_titles` (str)
Format string to build subplot titles from *serum* and *virus*.
`show_average` (bool)
Include the replicate-average as a "replicate" in plots.
`average_only` (bool)
Show **only** the replicate-average on each plot. No
legend in this case.
`attempt_shared_legend` (bool)
Do we attempt to share the same replicate key for all panels or
give each its own?
`**kwargs`
Other keyword arguments that can be passed to
:meth:`CurveFits.plotGrid`.
Returns:
The 2-tuple `(fig, axes)` of matplotlib figure and 2D axes array.
"""
try:
subplot_titles.format(virus="dummy", serum="dummy")
except KeyError:
raise ValueError(
f"{subplot_titles=} invalid. Should have keys only for virus and serum"
)
sera, viruses = self._sera_viruses_lists(sera, viruses)
# get replicates and make sure there aren't too many
if average_only:
replicates = ["average"]
elif attempt_shared_legend:
replicates = collections.OrderedDict()
if show_average:
replicates["average"] = True
for serum, virus in itertools.product(sera, viruses):
if virus in self.viruses[serum]:
for replicate in self.replicates[(serum, virus)]:
if replicate != "average":
replicates[replicate] = True
replicates = list(collections.OrderedDict(replicates).keys())
else:
replicates_by_serum_virus = collections.defaultdict(list)
for serum, virus in itertools.product(sera, viruses):
key = (serum, virus)
if virus in self.viruses[serum]:
assert len(self.replicates[key])
if show_average:
replicates_by_serum_virus[key].append("average")
for replicate in self.replicates[(serum, virus)]:
if replicate != "average":
replicates_by_serum_virus[key].append(replicate)
# build list of plots appropriate for `plotGrid`
plotlist = []
for serum, virus in itertools.product(sera, viruses):
if virus in self.viruses[serum]:
title = subplot_titles.format(serum=serum, virus=virus)
curvelist = []
if attempt_shared_legend:
rep_list = replicates
else:
rep_list = replicates_by_serum_virus[(serum, virus)]
assert len(rep_list)
for i, replicate in enumerate(rep_list):
if replicate in self.replicates[(serum, virus)]:
curvelist.append(
{
"serum": serum,
"virus": virus,
"replicate": replicate,
"label": {False: replicate, True: None}[average_only],
"color": colors[i % len(colors)],
"marker": markers[i % len(markers)],
}
)
if curvelist:
plotlist.append((title, curvelist))
if not plotlist:
raise ValueError("no curves for these sera / viruses")
# get number of columns
if (nrow is not None) and (ncol is not None):
raise ValueError("either `ncol` or `nrow` must be `None`")
elif isinstance(nrow, int) and nrow > 0:
ncol = math.ceil(len(plotlist) / nrow)
elif not (isinstance(ncol, int) and ncol > 0):
raise ValueError(f"{nrow=} or {ncol=} must be integer > 0")
# convert plotlist to plots dict for `plotGrid`
plots = {}
for iplot, plot in enumerate(plotlist):
plots[(iplot // ncol, iplot % ncol)] = plot
return self.plotGrid(
plots,
attempt_shared_legend=attempt_shared_legend,
**kwargs,
)
def _sera_viruses_lists(self, sera, viruses):
"""Check and build lists of `sera` and their `viruses`.
Args:
`sera` ('all' or list)
`viruses` ('all' or list)
Returns:
The 2-tuple `(sera, viruses)` which are checked lists.
"""
if isinstance(sera, str) and sera == "all":
sera = self.sera
else:
extra_sera = set(sera) - set(self.sera)
if extra_sera:
raise ValueError(f"unrecognized sera: {extra_sera}")
allviruses = collections.OrderedDict()
for serum in sera:
for virus in self.viruses[serum]:
allviruses[virus] = True
allviruses = list(allviruses.keys())
if isinstance(viruses, str) and viruses == "all":
viruses = allviruses
else:
extra_viruses = set(viruses) - set(allviruses)
if extra_viruses:
raise ValueError(
"unrecognized viruses for specified " f"sera: {extra_viruses}"
)
return sera, viruses
[docs]
def plotGrid(
self,
plots,
*,
xlabel=None,
ylabel=None,
widthscale=1,
heightscale=1,
attempt_shared_legend=True,
fix_lims=None,
bound_ymin=0,
bound_ymax=1,
extend_lim=0.07,
markersize=6,
linewidth=1,
linestyle="-",
legendtitle=None,
orderlegend=None,
titlesize=14,
labelsize=15,
ticksize=12,
legendfontsize=12,
align_to_dmslogo_facet=False,
despine=False,
yticklocs=None,
sharex=True,
sharey=True,
vlines=None,
draw_in_bounds=False,
):
"""Plot arbitrary grid of curves.
Args:
`plots` (dict)
Plots to draw on grid. Keyed by 2-tuples `(irow, icol)`, which
give row and column (0, 1, ... numbering) where plot should be
drawn. Values are the 2-tuples `(title, curvelist)` where
`title` is title for this plot (or `None`) and `curvelist`
is a list of dicts keyed by:
- 'serum'
- 'virus'
- 'replicate'
- 'label': label for this curve in legend, or `None`
- 'color'
- 'marker': https://matplotlib.org/api/markers_api.html
`xlabel`, `ylabel` (`None`, str, or list)
Labels for x- and y-axes. If `None`, use `conc_col`
and `fracinf_col`, respectively. If str, use this shared
for all axes. If list, should be same length as `plots`
and gives axis label for each subplot.
`widthscale`, `heightscale` (float)
Scale width or height of figure by this much.
`attempt_shared_legend` (bool)
Share a single legend among plots if they all share
in common the same label assigned to the same color / marker.
`fix_lims` (dict or `None`)
To fix axis limits, specify any of 'xmin', 'xmax', 'ymin',
or 'ymax' with specified limit.
`bound_ymin`, `bound_ymax` (float or `None`)
Make y-axis min and max at least this small / large.
Ignored if using `fix_lims` for that axis limit.
`extend_lim` (float)
For all axis limits not in `fix_lims`, extend this fraction
of range above and below bounds / data limits.
`markersize` (float)
Size of point marker.
`linewidth` (float)
Width of line.
`linestyle` (str)
Line style.
`legendtitle` (str or `None`)
Title of legend.
`orderlegend` (`None` or list)
If specified, place legend labels in this order.
`titlesize` (float)
Size of subplot title font.
`labelsize` (float)
Size of axis label font.
`ticksize` (float)
Size of axis tick fonts.
`legendfontsize` (float)
Size of legend fonts.
`align_to_dmslogo_facet` (`False` or dict)
Make plot vertically alignable to ``dmslogo.facet_plot``
with same number of rows; dict should have keys for
`height_per_ax`, `hspace`, `tmargin`, and `bmargin` with
same meaning as ``dmslogo.facet_plot``. Also
`right` and `left` for passing to ``subplots_adjust``.
`despine` (bool)
Remove top and right spines from plots.
`yticklocs` (`None` or list)
Same meaning as for :meth:`neutcurve.hillcurve.HillCurve.plot`.
`sharex` (bool)
Share x-axis scale among plots.
`sharey` (bool)
Share y-axis scale among plots.
`vlines` (dict or `None`)
Vertical lines to draw. Keyed by 2-tuples `(irow, icol)`, which
give row and column of plot in grid (0, 1, ... numbering).
Values are lists of dicts with a key 'x' giving the x-location
of the vertical line, and optionally keys 'linewidth',
'color', and 'linestyle'.
`draw_in_bounds` (bool)
Same meaning as for meth:`neutcurve.hillcurve.HillCurve.plot`.
Returns:
The 2-tuple `(fig, axes)` of matplotlib figure and 2D axes array.
"""
vline_defaults = {
"linewidth": 1.5,
"color": "gray",
"linestyle": "--",
}
if not plots:
raise ValueError("empty `plots`")
# get number of rows / cols, curves, and data limits
nrows = ncols = None
if fix_lims is None:
fix_lims = {}
lims = {key: {} for key in plots.keys()}
for (irow, icol), (_title, curvelist) in plots.items():
if irow < 0:
raise ValueError("invalid row index `irow` < 0")
if icol < 0:
raise ValueError("invalid row index `icol` < 0")
if nrows is None:
nrows = irow + 1
else:
nrows = max(nrows, irow + 1)
if ncols is None:
ncols = icol + 1
else:
ncols = max(ncols, icol + 1)
for curvedict in curvelist:
curve = self.getCurve(
serum=curvedict["serum"],
virus=curvedict["virus"],
replicate=curvedict["replicate"],
)
curvedict["curve"] = curve
for lim, attr, f in [
("xmin", "cs", min),
("xmax", "cs", max),
("ymin", "fs", min),
("ymax", "fs", max),
]:
if lim in fix_lims:
lims[(irow, icol)][lim] = fix_lims[lim]
else:
val = f(getattr(curve, attr))
if lim in lims[(irow, icol)]:
val = f(val, lims[(irow, icol)][lim])
if lim == "ymin" and (bound_ymin is not None):
lims[(irow, icol)][lim] = min(val, bound_ymin)
elif lim == "ymax" and (bound_ymax is not None):
lims[(irow, icol)][lim] = max(val, bound_ymax)
else:
lims[(irow, icol)][lim] = val
for share, axtype in [(sharex, "x"), (sharey, "y")]:
if share:
for limtype, limfunc in [("min", min), ("max", max)]:
lim = limfunc(lims[key][axtype + limtype] for key in lims)
for key in lims.keys():
lims[key][axtype + limtype] = lim
# check and then extend limits
for key in plots.keys():
if lims[key]["xmin"] <= 0:
raise ValueError("xmin <= 0, which is not allowed")
yextent = lims[key]["ymax"] - lims[key]["ymin"]
if yextent <= 0:
raise ValueError("no positive extent for y-axis")
if "ymin" not in fix_lims:
lims[key]["ymin"] -= yextent * extend_lim
if "ymax" not in fix_lims:
lims[key]["ymax"] += yextent * extend_lim
xextent = math.log(lims[key]["xmax"]) - math.log(lims[key]["xmin"])
if xextent <= 0:
raise ValueError("no positive extent for x-axis")
if "xmin" not in fix_lims:
lims[key]["xmin"] = math.exp(
math.log(lims[key]["xmin"]) - xextent * extend_lim
)
if "xmax" not in fix_lims:
lims[key]["xmax"] = math.exp(
math.log(lims[key]["xmax"]) + xextent * extend_lim
)
if align_to_dmslogo_facet:
import dmslogo.facet
hparams = dmslogo.facet.height_params(
nrows,
align_to_dmslogo_facet["height_per_ax"],
align_to_dmslogo_facet["hspace"],
align_to_dmslogo_facet["tmargin"],
align_to_dmslogo_facet["bmargin"],
)
height = hparams["height"]
else:
height = (1 + 2.25 * nrows) * heightscale
width = (1 + 3 * ncols) * widthscale
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
sharex=sharex,
sharey=sharey,
squeeze=False,
figsize=(width, height),
)
# set limits for share axes
for irow, icol in plots.keys():
axes[irow, icol].set_xlim(
lims[irow, icol]["xmin"], lims[irow, icol]["xmax"]
)
axes[irow, icol].set_ylim(
lims[irow, icol]["ymin"], lims[irow, icol]["ymax"]
)
# make plots
shared_legend = attempt_shared_legend
kwargs_tup_to_label = {} # used to determine if shared legend
legend_handles = collections.defaultdict(list)
shared_legend_handles = [] # handles if using shared legend
for i, ((irow, icol), (title, curvelist)) in enumerate(plots.items()):
ax = axes[irow, icol]
ax.set_title(title, fontsize=titlesize)
for curvedict in curvelist:
kwargs = {
"color": curvedict["color"],
"marker": curvedict["marker"],
"linestyle": linestyle,
"linewidth": linewidth,
"markersize": markersize,
}
if isinstance(xlabel, list):
ixlabel = xlabel[i]
else:
ixlabel = None
if isinstance(ylabel, list):
iylabel = ylabel[i]
else:
iylabel = None
curvedict["curve"].plot(
ax=ax,
xlabel=ixlabel,
ylabel=iylabel,
yticklocs=yticklocs,
draw_in_bounds=draw_in_bounds,
**kwargs,
)
label = curvedict["label"]
if label:
handle = Line2D(
xdata=[],
ydata=[],
label=label,
**kwargs,
)
legend_handles[(irow, icol)].append(handle)
if shared_legend:
kwargs_tup = tuple(sorted(kwargs.items()))
if kwargs_tup in kwargs_tup_to_label:
if kwargs_tup_to_label[kwargs_tup] != label:
shared_legend = False
else:
kwargs_tup_to_label[kwargs_tup] = label
shared_legend_handles.append(handle)
ax.tick_params(
"both",
labelsize=ticksize,
bottom=True,
left=True,
right=False,
top=False,
)
if despine:
import dmslogo.utils
dmslogo.utils.despine(ax=ax)
if vlines and ((irow, icol) in vlines):
for vline in vlines[(irow, icol)]:
vline_d = vline_defaults.copy()
for key, val in vline.items():
vline_d[key] = val
ax.axvline(
vline_d["x"],
linestyle=vline_d["linestyle"],
linewidth=vline_d["linewidth"],
color=vline_d["color"],
)
# draw legend(s)
legend_kwargs = {
"fontsize": legendfontsize,
"numpoints": 1,
"markerscale": 1,
"handlelength": 1,
"labelspacing": 0.1,
"handletextpad": 0.4,
"frameon": True,
"borderaxespad": 0.1,
"borderpad": 0.2,
"title": legendtitle,
"title_fontsize": legendfontsize + 1,
"framealpha": 0.6,
}
def _ordered_legend(hs):
"""Get ordered legend handles."""
if not orderlegend:
return hs
else:
order_dict = {h: i for i, h in enumerate(orderlegend)}
h_labels = [h.get_label() for h in hs]
extra_hs = set(h_labels) - set(orderlegend)
if extra_hs:
raise ValueError(
"there are legend handles not in " f"`orderlegend`: {extra_hs}"
)
return [
h
for _, h in sorted(
zip(h_labels, hs), key=lambda x: order_dict[x[0]]
)
]
if shared_legend and shared_legend_handles:
if align_to_dmslogo_facet:
right = align_to_dmslogo_facet["right"]
ranchor = right + 0.15 * (1 - right)
else:
ranchor = 1
shared_legend_handles = _ordered_legend(shared_legend_handles)
# shared legend as here: https://stackoverflow.com/a/17328230
fig.legend(
handles=shared_legend_handles,
labels=[h.get_label() for h in shared_legend_handles],
loc="center left",
bbox_to_anchor=(ranchor, 0.5),
bbox_transform=fig.transFigure,
**legend_kwargs,
)
elif legend_handles:
for (irow, icol), handles in legend_handles.items():
ax = axes[irow, icol]
handles = _ordered_legend(handles)
ax.legend(
handles=handles,
labels=[h.get_label() for h in handles],
loc="lower left",
**legend_kwargs,
)
# hide unused axes
for irow, icol in itertools.product(range(nrows), range(ncols)):
if (irow, icol) not in plots:
axes[irow, icol].set_axis_off()
# common axis labels as here: https://stackoverflow.com/a/53172335
bigax = fig.add_subplot(111, frameon=False)
bigax.grid(False)
bigax.tick_params(
labelcolor="none",
top=False,
bottom=False,
left=False,
right=False,
which="both",
)
if xlabel is None:
bigax.set_xlabel(self.conc_col, fontsize=labelsize, labelpad=10)
elif not isinstance(xlabel, list):
bigax.set_xlabel(xlabel, fontsize=labelsize, labelpad=10)
if ylabel is None:
bigax.set_ylabel(self.fracinf_col, fontsize=labelsize, labelpad=10)
elif not isinstance(ylabel, list):
bigax.set_ylabel(ylabel, fontsize=labelsize, labelpad=10)
if align_to_dmslogo_facet:
fig.subplots_adjust(
hspace=hparams["hspace"],
top=hparams["top"],
bottom=hparams["bottom"],
left=align_to_dmslogo_facet["left"],
right=align_to_dmslogo_facet["right"],
)
else:
fig.tight_layout()
return fig, axes
if __name__ == "__main__":
import doctest
doctest.testmod()