Averaging models¶
Probably the best way to ensure robust inferences and estimate errors is to have multiple experimental replicates, ideally on different libraries.
Here we describe how to average model fits across libraries and/or replicates.
Split data into replicates¶
We will use our data for the RBD as an earlier examples, but split it into several libraries / replicates.
Specifically, we will fit two different libraries: avg2muts
and avg3muts
, which have different barcodes and also different mutation rates (although of course in real life you might sometimes want to average results from different libraries with the same mutation rates). We will also simulate having two replicates for each library just by sampling each library. To make this example faster, we’ll just use one concentration:
[1]:
import pandas as pd
import polyclonal.polyclonal
import polyclonal.polyclonal_collection
# read data
all_data = pd.read_csv("RBD_variants_escape_noisy.csv", na_filter=None)
# split by library and replicates
libraries = ["avg2muts", "avg3muts"] # the two libraries to use
concentrations = [1] # use just use this concentration
n_replicates = 2 # number of replicates per library
data_by_replicate = {
(library, replicate + 1): (
all_data.query("library == @library")
.query("concentration in @concentrations")
.sample(frac=0.3, random_state=replicate)
)
for library in libraries
for replicate in range(n_replicates)
}
Fit models to each replicate¶
We now fit a Polyclonal
model to each replicate using just 2 epitopes, as the data don’t seem sufficient to accurately fit all three epitopes. Then we arrange the models in a data frame:
[2]:
# first create a data frame with all the models
models_by_replicate = {}
for (library, replicate), data in data_by_replicate.items():
model = polyclonal.Polyclonal(data_to_fit=data, n_epitopes=2)
models_by_replicate[(library, replicate)] = model
models_df = (
pd.Series(models_by_replicate, name="model")
.rename_axis(["library", "replicate"])
.reset_index()
)
# now fit the models
n_fit, n_failed, models_df["model"] = polyclonal.polyclonal_collection.fit_models(
models_df["model"],
n_threads=2,
reg_escape_weight=0.01,
reg_uniqueness2_weight=0,
)
Note how the models are arranged in a data frame:
[3]:
# NBVAL_IGNORE_OUTPUT
models_df
[3]:
library | replicate | model | |
---|---|---|---|
0 | avg2muts | 1 | <polyclonal.polyclonal.Polyclonal object at 0x... |
1 | avg2muts | 2 | <polyclonal.polyclonal.Polyclonal object at 0x... |
2 | avg3muts | 1 | <polyclonal.polyclonal.Polyclonal object at 0x... |
3 | avg3muts | 2 | <polyclonal.polyclonal.Polyclonal object at 0x... |
Average the models¶
Now we create a PolyclonalAverage
model with the models to average. Note that by default the “average” used by PolyclonalAverage
is the median rather than the mean between epitopes, although this is a parameter that can also be set to mean.
If your epitopes are too different or poorly defined (e.g., you are trying to fit more epitopes than can be consistently inferred from the data), then you may get an epitope harmonization error:
[4]:
model_avg = polyclonal.PolyclonalAverage(models_df)
Let’s look at the correlation among the escape at each epitope across models:
[5]:
# NBVAL_IGNORE_OUTPUT
model_avg.mut_escape_corr_heatmap()
[5]:
Look at the activities of the epitopes and the rest of the curves. Note how a dark line is shown for the average, and thin lines for individual replicates. It should generally be the case that the epitope with greater activity (more left shifted in plot below) should also be better correlated among replicates (heatmap above) as it can be inferred more reliably:
[6]:
# NBVAL_IGNORE_OUTPUT
model_avg.curves_plot()
[6]:
We can access the average escape values. Note that there is also a column escape_min_magnitude that gives the lowest magnitude value:
[7]:
# NBVAL_IGNORE_OUTPUT
model_avg.mut_escape_df
[7]:
epitope | site | wildtype | mutant | mutation | escape_mean | escape_median | escape_min_magnitude | escape_std | n_models | times_seen | frac_models | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 331 | N | A | N331A | -0.070506 | -0.029126 | 0.025048 | 0.330401 | 4 | 17.75 | 1.0 |
1 | 1 | 331 | N | D | N331D | -0.106079 | -0.106459 | -0.074615 | 0.025421 | 4 | 11.25 | 1.0 |
2 | 1 | 331 | N | E | N331E | -0.075338 | -0.008291 | 0.002352 | 0.141307 | 4 | 10.25 | 1.0 |
3 | 1 | 331 | N | F | N331F | 0.278933 | 0.173117 | 0.030457 | 0.334194 | 4 | 10.00 | 1.0 |
4 | 1 | 331 | N | G | N331G | 0.187362 | 0.161299 | -0.039176 | 0.228411 | 4 | 25.00 | 1.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
3859 | 2 | 531 | T | R | T531R | -0.180296 | -0.114042 | -0.041709 | 0.187469 | 4 | 27.00 | 1.0 |
3860 | 2 | 531 | T | S | T531S | -0.077768 | -0.020680 | 0.012301 | 0.137118 | 4 | 31.75 | 1.0 |
3861 | 2 | 531 | T | V | T531V | -0.037153 | -0.009067 | 0.011169 | 0.095150 | 4 | 19.50 | 1.0 |
3862 | 2 | 531 | T | W | T531W | 0.171903 | 0.101605 | -0.004711 | 0.276231 | 4 | 5.25 | 1.0 |
3863 | 2 | 531 | T | Y | T531Y | -0.058269 | -0.055309 | -0.038848 | 0.021822 | 4 | 11.75 | 1.0 |
3864 rows × 12 columns
We can do the same for the ICXX values:
[8]:
# NBVAL_IGNORE_OUTPUT
model_avg.mut_icXX_df(
x=0.9,
icXX_col="IC90",
log_fold_change_icXX_col="log2 fold change IC90",
)
[8]:
site | wildtype | mutant | log2 fold change IC90 mean | log2 fold change IC90 median | log2 fold change IC90 min_magnitude | log2 fold change IC90 std | n_models | times_seen | frac_models | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 331 | N | A | -0.158794 | -0.192897 | -0.115146 | 0.270364 | 4 | 17.75 | 1.0 |
1 | 331 | N | D | -0.096712 | -0.058956 | -0.016521 | 0.106236 | 4 | 11.25 | 1.0 |
2 | 331 | N | E | 0.036998 | 0.018060 | 0.017496 | 0.073099 | 4 | 10.25 | 1.0 |
3 | 331 | N | F | 0.104497 | 0.142641 | 0.135784 | 0.288862 | 4 | 10.00 | 1.0 |
4 | 331 | N | G | 0.286305 | 0.279946 | 0.103283 | 0.160085 | 4 | 25.00 | 1.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2100 | 531 | T | S | -0.072452 | -0.022248 | -0.012782 | 0.131153 | 4 | 31.75 | 1.0 |
2101 | 531 | T | T | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 4 | NaN | 1.0 |
2102 | 531 | T | V | 0.011190 | -0.008491 | 0.035284 | 0.230686 | 4 | 19.50 | 1.0 |
2103 | 531 | T | W | 0.198453 | 0.121785 | -0.009031 | 0.290432 | 4 | 5.25 | 1.0 |
2104 | 531 | T | Y | -0.022972 | -0.054006 | -0.049928 | 0.135754 | 4 | 11.75 | 1.0 |
2105 rows × 10 columns
Or the per-replicate escape values:
[9]:
# NBVAL_IGNORE_OUTPUT
model_avg.mut_escape_df_replicates
[9]:
epitope | site | wildtype | mutant | mutation | escape | times_seen | library | replicate | |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 331 | N | A | N331A | -0.083300 | 19 | avg2muts | 1 |
1 | 1 | 331 | N | D | N331D | -0.104757 | 10 | avg2muts | 1 |
2 | 1 | 331 | N | E | N331E | -0.004766 | 11 | avg2muts | 1 |
3 | 1 | 331 | N | F | N331F | 0.033077 | 10 | avg2muts | 1 |
4 | 1 | 331 | N | G | N331G | 0.046883 | 18 | avg2muts | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
15429 | 2 | 531 | T | R | T531R | -0.041709 | 24 | avg3muts | 2 |
15430 | 2 | 531 | T | S | T531S | -0.015296 | 42 | avg3muts | 2 |
15431 | 2 | 531 | T | V | T531V | -0.029302 | 24 | avg3muts | 2 |
15432 | 2 | 531 | T | W | T531W | 0.547600 | 6 | avg3muts | 2 |
15433 | 2 | 531 | T | Y | T531Y | -0.083607 | 9 | avg3muts | 2 |
15434 rows × 9 columns
Now let’s plot the escape. See how you can select mutations based not only on how many times they are seen (averaged over all models in average), but also the number of models in which they are seen.
[10]:
# NBVAL_IGNORE_OUTPUT
model_avg.mut_escape_plot(addtl_slider_stats={"times_seen": 2})
[10]:
Here is the same plot plotting the lowest magnitude escape value across libraries for each mutation:
[11]:
# NBVAL_IGNORE_OUTPUT
model_avg.mut_escape_plot(
addtl_slider_stats={"times_seen": 2}, avg_type="min_magnitude"
)
[11]:
We can also plot the log fold change in IC90 caused by each mutation:
[12]:
# NBVAL_IGNORE_OUTPUT
model_avg.mut_icXX_plot(
addtl_slider_stats={"times_seen": 2},
avg_type="median",
)
[12]:
Escape values by region¶
In some cases, you may want to only get the escape values for a specific region of the protein for each model being averaged. For instance, this may be the case if you covered half the protein in one library and the other half in the other library.
In this case, you should initialize with a column in the models data with a region_col
that specifies the sites for each model:
[14]:
# assign all sites to the regions for the `avg2muts` library, but only
# some sites for the `avg3muts` library
regions_df = pd.DataFrame(
[
("avg2muts", 1, model_avg.sites),
("avg2muts", 2, model_avg.sites),
("avg3muts", 1, [r for r in model_avg.sites if r <= 460]),
("avg3muts", 2, [r for r in model_avg.sites if r >= 450]),
],
columns=["library", "replicate", "sites_to_keep"],
)
models_region_df = models_df.merge(regions_df)
print("Here is the input dataframe specifying sites to keep for each model:")
display(models_region_df)
model_region_avg = polyclonal.PolyclonalAverage(
models_region_df, region_col="sites_to_keep"
)
Here is the input dataframe specifying sites to keep for each model:
library | replicate | model | sites_to_keep | |
---|---|---|---|---|
0 | avg2muts | 1 | <polyclonal.polyclonal.Polyclonal object at 0x... | (331, 332, 333, 334, 335, 336, 337, 338, 339, ... |
1 | avg2muts | 2 | <polyclonal.polyclonal.Polyclonal object at 0x... | (331, 332, 333, 334, 335, 336, 337, 338, 339, ... |
2 | avg3muts | 1 | <polyclonal.polyclonal.Polyclonal object at 0x... | [331, 332, 333, 334, 335, 336, 337, 338, 339, ... |
3 | avg3muts | 2 | <polyclonal.polyclonal.Polyclonal object at 0x... | [450, 451, 452, 453, 455, 456, 458, 459, 460, ... |
We can see the number of sites in each region:
[15]:
for desc, sites in zip(model_region_avg.model_descriptors, model_region_avg.regions):
print(f"{desc=}, {len(sites)=}, {min(sites)=}, {max(sites)=}")
desc={'library': 'avg2muts', 'replicate': 1}, len(sites)=173, min(sites)=331, max(sites)=531
desc={'library': 'avg2muts', 'replicate': 2}, len(sites)=173, min(sites)=331, max(sites)=531
desc={'library': 'avg3muts', 'replicate': 1}, len(sites)=112, min(sites)=331, max(sites)=460
desc={'library': 'avg3muts', 'replicate': 2}, len(sites)=70, min(sites)=450, max(sites)=531
We can also get the number of models per site. Based on how we initialized, this is 3 for all sites except those between 450 and 460 where it is 4:
[16]:
assert model_region_avg.sites == model_avg.sites
assert (
model_region_avg.n_models_by_site[r] == 3 + (450 <= r <= 460)
for r in model_region_avg.sites
)
Now look at the mutation-escape data frame:
[17]:
# NBVAL_IGNORE_OUTPUT
model_region_avg.mut_escape_df
[17]:
epitope | site | wildtype | mutant | mutation | escape_mean | escape_median | escape_min_magnitude | escape_std | n_models | times_seen | frac_models | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 331 | N | A | N331A | -0.188953 | -0.083300 | 0.025048 | 0.282079 | 3 | 16.666667 | 1.0 |
1 | 1 | 331 | N | D | N331D | -0.095844 | -0.104757 | -0.074615 | 0.018464 | 3 | 10.666667 | 1.0 |
2 | 1 | 331 | N | E | N331E | -0.004743 | -0.004766 | 0.002352 | 0.007084 | 3 | 9.333333 | 1.0 |
3 | 1 | 331 | N | F | N331F | 0.361758 | 0.313157 | 0.033077 | 0.355483 | 3 | 9.000000 | 1.0 |
4 | 1 | 331 | N | G | N331G | 0.094474 | 0.046883 | -0.039176 | 0.162750 | 3 | 25.666667 | 1.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
3859 | 2 | 531 | T | R | T531R | -0.089931 | -0.069566 | -0.041709 | 0.061009 | 3 | 25.666667 | 1.0 |
3860 | 2 | 531 | T | S | T531S | -0.009686 | -0.015296 | 0.012301 | 0.019788 | 3 | 29.000000 | 1.0 |
3861 | 2 | 531 | T | V | T531V | -0.063681 | -0.029302 | 0.011169 | 0.096735 | 3 | 19.333333 | 1.0 |
3862 | 2 | 531 | T | W | T531W | 0.250270 | 0.207922 | -0.004711 | 0.278580 | 3 | 4.333333 | 1.0 |
3863 | 2 | 531 | T | Y | T531Y | -0.054589 | -0.041310 | -0.038848 | 0.025161 | 3 | 11.333333 | 1.0 |
3864 rows × 12 columns
For the sites where all four models are active, this will be the same as the model without regions:
[18]:
assert (
model_avg.mut_escape_df.query("(site >= 450) and (site <= 460)").equals(
model_region_avg.mut_escape_df.query("(site >= 450) and (site <= 460)")
)
is True
)
But they differ at other sites:
[19]:
assert model_avg.mut_escape_df.equals(model_region_avg.mut_escape_df) is False
Same for the mutation IC50 values:
[20]:
# NBVAL_IGNORE_OUTPUT
region_ic50_df = model_region_avg.mut_icXX_df(
x=0.5, icXX_col="IC50", log_fold_change_icXX_col="log_fold_change_IC50"
)
region_ic50_df
[20]:
site | wildtype | mutant | log_fold_change_IC50 mean | log_fold_change_IC50 median | log_fold_change_IC50 min_magnitude | log_fold_change_IC50 std | n_models | times_seen | frac_models | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 331 | N | A | -0.278667 | -0.192977 | -0.142477 | 0.193805 | 3 | 16.666667 | 1.0 |
1 | 331 | N | D | -0.110483 | -0.066345 | 0.005980 | 0.143709 | 3 | 10.666667 | 1.0 |
2 | 331 | N | E | 0.003540 | 0.019666 | 0.019666 | 0.032347 | 3 | 9.333333 | 1.0 |
3 | 331 | N | F | 0.191809 | 0.168699 | 0.076206 | 0.128724 | 3 | 9.000000 | 1.0 |
4 | 331 | N | G | 0.333271 | 0.393855 | 0.109961 | 0.200021 | 3 | 25.666667 | 1.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2100 | 531 | T | S | -0.009580 | 0.002936 | 0.002936 | 0.022277 | 3 | 29.000000 | 1.0 |
2101 | 531 | T | T | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 3 | NaN | 1.0 |
2102 | 531 | T | V | -0.090553 | -0.048369 | 0.025148 | 0.141588 | 3 | 19.333333 | 1.0 |
2103 | 531 | T | W | 0.309571 | 0.280769 | -0.007863 | 0.332770 | 3 | 4.333333 | 1.0 |
2104 | 531 | T | Y | -0.084084 | -0.059005 | -0.053140 | 0.048605 | 3 | 11.333333 | 1.0 |
2105 rows × 10 columns
[21]:
ic50_df = model_avg.mut_icXX_df(
x=0.5, icXX_col="IC50", log_fold_change_icXX_col="log_fold_change_IC50"
)
assert (
region_ic50_df.query("(site >= 450) and (site <= 460)").equals(
ic50_df.query("(site >= 450) and (site <= 460)")
)
is True
)
assert region_ic50_df.equals(ic50_df) is False
[23]:
# NBVAL_IGNORE_OUTPUT
model_region_avg.mut_escape_plot(addtl_slider_stats={"times_seen": 2})
[23]: