In [1]:
import pandas as pd
import numpy as np
import altair as alt
import theme
from natsort import natsorted, natsort_keygen
alt.data_transformers.disable_max_rows()
Out[1]:
DataTransformerRegistry.enable('default')
In [2]:
mutation_effects = pd.read_csv('../results/combined_effects/combined_mutation_effects.csv')
mutation_effects.head()
Out[2]:
| mutant | struct_site | h3_wt_aa | h5_wt_aa | h7_wt_aa | rmsd_h3h5 | rmsd_h3h7 | rmsd_h5h7 | 4o5n_aa_RSA | 4kwm_aa_RSA | 6ii9_aa_RSA | h3_effect | h3_effect_std | h5_effect | h5_effect_std | h7_effect | h7_effect_std | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | A | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | 0.0151 | 0.7225 | 0.2049 | 0.2627 | NaN | NaN |
| 1 | C | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | -0.4080 | 0.3850 | -0.3977 | 0.1072 | NaN | NaN |
| 2 | D | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | 0.2361 | 0.2740 | 0.2383 | 0.2087 | NaN | NaN |
| 3 | E | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | -0.2463 | 0.8478 | 0.3120 | 0.2815 | NaN | NaN |
| 4 | F | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | 0.2061 | 0.3214 | -0.8917 | 1.2020 | NaN | NaN |
In [3]:
site_effects = pd.read_csv('../results/combined_effects/combined_site_effects.csv')
site_effects.head()
Out[3]:
| struct_site | h3_wt_aa | h5_wt_aa | h7_wt_aa | rmsd_h3h5 | rmsd_h3h7 | rmsd_h5h7 | 4o5n_aa_RSA | 4kwm_aa_RSA | 6ii9_aa_RSA | avg_h3_effect | avg_h5_effect | avg_h7_effect | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9 | S | K | NaN | 9.167400 | NaN | NaN | 1.084277 | 1.140252 | NaN | -0.050776 | -0.998095 | NaN |
| 1 | 10 | T | S | NaN | 8.157247 | NaN | NaN | 0.150962 | 0.175962 | NaN | -0.697911 | -3.348267 | NaN |
| 2 | 11 | A | D | D | 5.040040 | 2.984626 | 2.886615 | 0.050388 | 0.097927 | 0.624352 | -3.138280 | -3.951383 | -2.962194 |
| 3 | 12 | T | Q | K | 3.937602 | 1.626754 | 3.384350 | 0.268605 | 0.216889 | 0.368644 | -1.036219 | -0.342761 | -1.705403 |
| 4 | 13 | L | I | I | 3.687798 | 1.734039 | 2.549524 | 0.000000 | 0.000000 | 0.000000 | -3.941050 | -3.827571 | -3.829644 |
In [4]:
# Read in protein sequence identities
seq_identity = pd.read_csv('../results/sequence_identity/ha_sequence_identity.csv')
seq_identity.head()
Out[4]:
| ha_x | ha_y | matches | alignable_residues | percent_identity | |
|---|---|---|---|---|---|
| 0 | H3 | H5 | 192.0 | 479.0 | 40.083507 |
| 1 | H3 | H7 | 229.0 | 483.0 | 47.412008 |
| 2 | H5 | H7 | 202.0 | 473.0 | 42.706131 |
In [5]:
h3_h7_scatter = alt.Chart(mutation_effects).mark_circle(
size=25, opacity=0.3, color='#767676'
).encode(
x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h7_wt_aa', 'h3_effect', 'h7_effect']
).properties(
width=200,
height=200,
title='H3 vs. H7'
)
h3_h5_scatter = alt.Chart(mutation_effects).mark_circle(
size=25, opacity=0.3, color='#767676'
).encode(
x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
y=alt.Y('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h5_wt_aa', 'h3_effect', 'h5_effect']
).properties(
width=200,
height=200,
title='H3 vs. H5'
)
h5_h7_scatter = alt.Chart(mutation_effects).mark_circle(
size=25, opacity=0.3, color='#767676'
).encode(
x=alt.X('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
tooltip=['struct_site', 'mutant', 'h5_wt_aa', 'h7_wt_aa', 'h5_effect', 'h7_effect']
).properties(
width=200,
height=200,
title='H5 vs. H7'
)
h3_h7_scatter | h3_h5_scatter | h5_h7_scatter
Out[5]:
In [6]:
def scatter_and_density_plot(df, ha_x, ha_y, colors):
r_value = df[f'avg_{ha_x}_effect'].corr(df[f'avg_{ha_y}_effect'])
r_text = f"r = {r_value:.2f}"
identity_line = alt.Chart(pd.DataFrame({'x': [-5, 0.3], 'y': [-5, 0.3]})).mark_line(
strokeDash=[6, 6],
color='black'
).encode(
x='x',
y='y'
)
df = df.assign(
same_wildtype= lambda x: np.where(
x[f'{ha_x}_wt_aa'] == x[f'{ha_y}_wt_aa'],
'Amino acid conserved',
'Amino acid changed'
),
)
scatter = alt.Chart(df).mark_circle(
size=35, opacity=1, stroke='black', strokeWidth=0.5
).encode(
x=alt.X(f'avg_{ha_x}_effect', title=['Mean effect on cell entry', f'in {ha_x.upper()} background']),
y=alt.Y(f'avg_{ha_y}_effect', title=['Mean effect on cell entry', f'in {ha_y.upper()} background']),
color=alt.Color(
'same_wildtype:N',
scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
),
tooltip=['struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', f'avg_{ha_x}_effect', f'avg_{ha_y}_effect']
).properties(
width=175,
height=175,
)
r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
align='left',
baseline='top',
fontSize=16,
fontWeight='normal',
color='black'
).encode(
text='text:N',
x=alt.value(5),
y=alt.value(5)
)
x_density = alt.Chart(df).transform_density(
density=f'avg_{ha_x}_effect',
bandwidth=0.3,
groupby=['same_wildtype'],
extent=[df[f'avg_{ha_x}_effect'].min(), df[f'avg_{ha_x}_effect'].max()],
counts=True,
steps=200
).mark_area(opacity=0.6, color='black', strokeWidth=1).encode(
alt.X('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
alt.Y('density:Q', title='Density').stack(None),
color=alt.Color(
'same_wildtype:N',
title=None,
scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
),
).properties(
width=175,
height=50
)
y_density = alt.Chart(df).transform_density(
density=f'avg_{ha_y}_effect',
bandwidth=0.3,
groupby=['same_wildtype'],
extent=[df[f'avg_{ha_y}_effect'].min(), df[f'avg_{ha_y}_effect'].max()],
counts=True,
steps=200
).mark_area(opacity=0.6, color='black', strokeWidth=1, orient='horizontal').encode(
alt.Y('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
alt.X('density:Q', title='Density').stack(None),
color=alt.Color(
'same_wildtype:N',
title=None,
scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
),
).properties(
width=50,
height=175
)
marginal_plot = alt.vconcat(
x_density,
alt.hconcat(
(scatter + identity_line + r_label),
y_density
)
)
return marginal_plot
colors = {
'Amino acid changed' : '#5484AF',
'Amino acid conserved' : '#E04948'
}
p1 = scatter_and_density_plot(site_effects, 'h3', 'h5', colors=colors)
p2 = scatter_and_density_plot(site_effects, 'h3', 'h7', colors=colors)
p3 = scatter_and_density_plot(site_effects, 'h5', 'h7', colors=colors)
p1 | p2 | p3
Out[6]:
Calculate Jensen-Shannon Divergence¶
In [7]:
def kl_divergence(p, q):
return np.sum(p * np.log(p / q))
def compute_js_divergence_per_site(df, ha_x, ha_y, site_col="struct_site", min_mutations=15):
"""Compute JS divergence at each site and merge it back to the dataframe."""
js_per_site = {}
for site, group in df.groupby(site_col):
valid = group.dropna(subset=[f'{ha_x}_effect', f'{ha_y}_effect'])
js_div = np.nan
if len(valid) >= min_mutations:
p = np.exp(valid[f'{ha_x}_effect'].values)
q = np.exp(valid[f'{ha_y}_effect'].values)
p /= p.sum()
q /= q.sum()
m = 0.5 * (p + q)
js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
js_per_site[site] = js_div
# Create a column with the JS divergence duplicated across each row at the same site
df = df.copy()
col_name = f"JS_{ha_x}_vs_{ha_y}"
df[col_name] = df[site_col].map(js_per_site)
return df
js_df_h3_h7 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h7', min_mutations=10)
js_df_h3_h5 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h5', min_mutations=10)
js_df_h5_h7 = compute_js_divergence_per_site(mutation_effects, 'h5', 'h7', min_mutations=10)
Are epistatic shifts significant?¶
In [8]:
def compute_jsd_with_null(
df,
ha_x,
ha_y,
site_col="struct_site",
min_mutations=15,
n_bootstrap=1000,
random_seed=42,
jsd_threshold=0.02,
use_pooled_std=False
):
"""
Compute JS divergence with simulated null distribution for significance testing.
The null distribution represents: "What JSD would I observe from measurement noise alone?"
The null is generated by simulating data under Gaussian measurement noise:
1. ha_x null: Generate two replicate datasets from ha_x effects and stds, compute JSD
2. ha_y null: Generate two replicate datasets from ha_y effects and stds, compute JSD
3. Take the mean of the two null distributions
This assumes measurement noise is normally distributed. A significant result means the
observed JSD is larger than expected from measurement noise alone.
Only sites with observed JSD > jsd_threshold are tested for significance.
Parameters
----------
df : pd.DataFrame
Dataframe with mutation effects and effect_std columns
ha_x, ha_y : str
HA subtype names (e.g., 'h3', 'h5')
site_col : str
Column name for site identifier
min_mutations : int
Minimum number of mutations required at a site
n_bootstrap : int
Number of simulations for null distribution
random_seed : int
Random seed for reproducibility
jsd_threshold : float
Minimum JSD value for a site to be tested for significance.
Sites with observed JSD <= threshold will have p_value = NaN.
Default is 0.02.
use_pooled_std : bool
If True, use a per-site pooled standard deviation (RMS of per-mutation stds)
instead of each mutation's individual effect_std when generating the null
distribution. This gives a more stable variance estimate by pooling ~20
estimates rather than relying on each mutation's 2-4 replicate estimate.
Default is False (use per-mutation stds).
Returns
-------
pd.DataFrame
DataFrame with columns:
- struct_site: site identifier
- JS_observed: observed JSD value
- JS_null_mean: mean of null distribution (NaN if below threshold)
- JS_null_std: standard deviation of null distribution (NaN if below threshold)
- p_value: empirical p-value (NaN if below threshold)
- n_mutations: number of mutations at site
Sorted by struct_site using natural sorting.
"""
np.random.seed(random_seed)
def compute_jsd_vectorized(effects, std, n_bootstrap):
"""Vectorized computation of null JSD distribution."""
n_mutations = len(effects)
# Generate all simulated samples: shape (n_bootstrap, n_mutations)
effects_1 = np.random.normal(
loc=effects[np.newaxis, :], # broadcast to (1, n_mutations)
scale=std[np.newaxis, :], # broadcast to (1, n_mutations)
size=(n_bootstrap, n_mutations)
)
effects_2 = np.random.normal(
loc=effects[np.newaxis, :],
scale=std[np.newaxis, :],
size=(n_bootstrap, n_mutations)
)
# Compute probabilities for all simulations
p1 = np.exp(effects_1)
p2 = np.exp(effects_2)
# Normalize: divide each row by its sum
p1 = p1 / p1.sum(axis=1, keepdims=True)
p2 = p2 / p2.sum(axis=1, keepdims=True)
# Compute mixture distribution
m = 0.5 * (p1 + p2)
# Compute KL divergences
# KL(p||m) = sum(p * log(p/m))
kl_p_m = np.sum(p1 * np.log(p1 / m), axis=1)
kl_q_m = np.sum(p2 * np.log(p2 / m), axis=1)
# JSD = 0.5 * (KL(p||m) + KL(q||m))
jsd = 0.5 * (kl_p_m + kl_q_m)
return jsd
results = []
for site, group in df.groupby(site_col):
# Filter to valid mutations with both effects and stds
valid = group.dropna(subset=[
f'{ha_x}_effect', f'{ha_y}_effect',
f'{ha_x}_effect_std', f'{ha_y}_effect_std'
])
if len(valid) < min_mutations:
continue
# Get observed effects
effects_x = valid[f'{ha_x}_effect'].values
effects_y = valid[f'{ha_y}_effect'].values
# Get standard deviations
std_x = valid[f'{ha_x}_effect_std'].values
std_y = valid[f'{ha_y}_effect_std'].values
# Compute observed JSD between ha_x and ha_y
p_obs = np.exp(effects_x)
q_obs = np.exp(effects_y)
p_obs /= p_obs.sum()
q_obs /= q_obs.sum()
m_obs = 0.5 * (p_obs + q_obs)
jsd_obs = 0.5 * (kl_divergence(p_obs, m_obs) + kl_divergence(q_obs, m_obs))
# Only compute null distribution if JSD exceeds threshold
if jsd_obs <= jsd_threshold:
results.append({
'struct_site': site,
'JS_observed': jsd_obs,
'JS_null_mean': np.nan,
'JS_null_std': np.nan,
'p_value': np.nan,
'n_mutations': len(valid),
'null_distribution': None
})
continue
# Replace per-mutation stds with the per-site std (RMS)
if use_pooled_std:
pooled_std_x = np.sqrt(np.mean(std_x ** 2))
pooled_std_y = np.sqrt(np.mean(std_y ** 2))
std_x = np.full_like(std_x, pooled_std_x)
std_y = np.full_like(std_y, pooled_std_y)
# Simulated null distributions
jsd_null_x = compute_jsd_vectorized(effects_x, std_x, n_bootstrap)
jsd_null_y = compute_jsd_vectorized(effects_y, std_y, n_bootstrap)
# Take the mean of the two nulls (balanced approach)
jsd_null = (jsd_null_x + jsd_null_y) / 2
# Compute empirical p-value (one-tailed test: is observed JSD greater than null?)
p_value = np.mean(jsd_null >= jsd_obs)
results.append({
'struct_site': site,
'JS_observed': jsd_obs,
'JS_null_mean': jsd_null.mean(),
'JS_null_std': jsd_null.std(),
'p_value': p_value,
'n_mutations': len(valid),
'null_distribution': jsd_null
})
# Convert to DataFrame and sort by struct_site using natural sorting
results_df = pd.DataFrame(results)
results_df = results_df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
return results_df
In [9]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h3_h5 = compute_jsd_with_null(
js_df_h3_h5,
'h3', 'h5',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
jsd_with_pvals_h3_h7 = compute_jsd_with_null(
js_df_h3_h7,
'h3', 'h7',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
jsd_with_pvals_h5_h7 = compute_jsd_with_null(
js_df_h5_h7,
'h5', 'h7',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
# Apply multiple testing correction (Benjamini-Hochberg FDR)
# Only apply FDR to sites that were tested (non-NaN p-values)
from scipy.stats import false_discovery_control
def apply_fdr_with_threshold(df):
"""Apply FDR correction only to non-NaN p-values."""
# Initialize q_value column with NaN
df['q_value'] = np.nan
# Get indices of non-NaN p-values
tested_mask = df['p_value'].notna()
if tested_mask.sum() > 0:
# Apply FDR correction only to tested sites
df.loc[tested_mask, 'q_value'] = false_discovery_control(df.loc[tested_mask, 'p_value'])
return df
jsd_with_pvals_h3_h5 = apply_fdr_with_threshold(jsd_with_pvals_h3_h5)
jsd_with_pvals_h3_h7 = apply_fdr_with_threshold(jsd_with_pvals_h3_h7)
jsd_with_pvals_h5_h7 = apply_fdr_with_threshold(jsd_with_pvals_h5_h7)
# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H3 vs H5, q < 0.1):")
total_h3h5 = len(jsd_with_pvals_h3_h5)
sig_h3h5 = (jsd_with_pvals_h3_h5['q_value'] < 0.1).sum()
print(f" {sig_h3h5} / {total_h3h5} sites ({sig_h3h5/total_h3h5:.2%})")
print("\nSignificant sites (H3 vs H7, q < 0.1):")
total_h3h7 = len(jsd_with_pvals_h3_h7)
sig_h3h7 = (jsd_with_pvals_h3_h7['q_value'] < 0.1).sum()
print(f" {sig_h3h7} / {total_h3h7} sites ({sig_h3h7/total_h3h7:.2%})")
print("\nSignificant sites (H5 vs H7, q < 0.1):")
total_h5h7 = len(jsd_with_pvals_h5_h7)
sig_h5h7 = (jsd_with_pvals_h5_h7['q_value'] < 0.1).sum()
print(f" {sig_h5h7} / {total_h5h7} sites ({sig_h5h7/total_h5h7:.2%})")
Significant sites (H3 vs H5, q < 0.1): 270 / 468 sites (57.69%) Significant sites (H3 vs H7, q < 0.1): 253 / 467 sites (54.18%) Significant sites (H5 vs H7, q < 0.1): 212 / 432 sites (49.07%)
In [10]:
def plot_jsd_ecdf(jsd_pvals_df, ha_x, ha_y, alpha=0.1):
"""
Plot empirical cumulative distribution of JSD values, separated by significance.
Parameters
----------
jsd_pvals_df : pd.DataFrame
DataFrame with JSD values and q-values from compute_jsd_with_null
ha_x, ha_y : str
HA subtype names for title
alpha : float
Significance threshold for q-value (default 0.1)
Returns
-------
alt.Chart : eCDF plot
"""
# Prepare data with significance flag
plot_df = jsd_pvals_df.copy()
plot_df['significant'] = plot_df['q_value'] < alpha
# Count sites in each category
n_sig = plot_df['significant'].sum()
n_nonsig = (~plot_df['significant']).sum()
# Create eCDF plot
ecdf = alt.Chart(plot_df).transform_window(
cumulative_count='count()',
sort=[{'field': 'JS_observed'}],
groupby=['significant']
).transform_joinaggregate(
total='count()',
groupby=['significant']
).transform_calculate(
ecdf='datum.cumulative_count / datum.total'
).mark_line(size=3).encode(
x=alt.X('JS_observed:Q',
title=['Divergence in amino-acid', 'preferences'],
scale=alt.Scale(domain=[0, 0.7])),
y=alt.Y('ecdf:Q',
title='Cumulative probability',
scale=alt.Scale(domain=[0, 1])),
color=alt.Color('significant:N',
title=['Significant', f'(FDR < {alpha})'],
scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
legend=alt.Legend(
titleFontSize=14,
labelFontSize=12
))
).properties(
width=175,
height=175,
title=alt.Title(
f'{ha_x.upper()} vs. {ha_y.upper()}',
subtitle=[f'N = {n_sig} significant, {n_nonsig} not significant'],
fontSize=16,
subtitleFontSize=12
)
)
return ecdf
# Create eCDF plots for all comparisons
ecdf_h3_h5 = plot_jsd_ecdf(jsd_with_pvals_h3_h5, 'h3', 'h5')
ecdf_h3_h7 = plot_jsd_ecdf(jsd_with_pvals_h3_h7, 'h3', 'h7')
ecdf_h5_h7 = plot_jsd_ecdf(jsd_with_pvals_h5_h7, 'h5', 'h7')
# Display side by side
(ecdf_h3_h5 | ecdf_h3_h7 | ecdf_h5_h7).display()