In [1]:
import pandas as pd
import numpy as np
import altair as alt

import theme
from natsort import natsorted, natsort_keygen

alt.data_transformers.disable_max_rows()
Out[1]:
DataTransformerRegistry.enable('default')
In [2]:
mutation_effects = pd.read_csv('../results/combined_effects/combined_mutation_effects.csv')
mutation_effects.head()
Out[2]:
mutant struct_site h3_wt_aa h5_wt_aa h7_wt_aa rmsd_h3h5 rmsd_h3h7 rmsd_h5h7 4o5n_aa_RSA 4kwm_aa_RSA 6ii9_aa_RSA h3_effect h3_effect_std h5_effect h5_effect_std h7_effect h7_effect_std
0 A 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN 0.0151 0.7225 0.2049 0.2627 NaN NaN
1 C 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN -0.4080 0.3850 -0.3977 0.1072 NaN NaN
2 D 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN 0.2361 0.2740 0.2383 0.2087 NaN NaN
3 E 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN -0.2463 0.8478 0.3120 0.2815 NaN NaN
4 F 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN 0.2061 0.3214 -0.8917 1.2020 NaN NaN
In [3]:
site_effects = pd.read_csv('../results/combined_effects/combined_site_effects.csv')
site_effects.head()
Out[3]:
struct_site h3_wt_aa h5_wt_aa h7_wt_aa rmsd_h3h5 rmsd_h3h7 rmsd_h5h7 4o5n_aa_RSA 4kwm_aa_RSA 6ii9_aa_RSA avg_h3_effect avg_h5_effect avg_h7_effect
0 9 S K NaN 9.167400 NaN NaN 1.084277 1.140252 NaN -0.050776 -0.998095 NaN
1 10 T S NaN 8.157247 NaN NaN 0.150962 0.175962 NaN -0.697911 -3.348267 NaN
2 11 A D D 5.040040 2.984626 2.886615 0.050388 0.097927 0.624352 -3.138280 -3.951383 -2.962194
3 12 T Q K 3.937602 1.626754 3.384350 0.268605 0.216889 0.368644 -1.036219 -0.342761 -1.705403
4 13 L I I 3.687798 1.734039 2.549524 0.000000 0.000000 0.000000 -3.941050 -3.827571 -3.829644
In [4]:
# Read in protein sequence identities
seq_identity = pd.read_csv('../results/sequence_identity/ha_sequence_identity.csv')
seq_identity.head()
Out[4]:
ha_x ha_y matches alignable_residues percent_identity
0 H3 H5 192.0 479.0 40.083507
1 H3 H7 229.0 483.0 47.412008
2 H5 H7 202.0 473.0 42.706131
In [5]:
h3_h7_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
    y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
    tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h7_wt_aa', 'h3_effect', 'h7_effect']
).properties(
    width=200,
    height=200,
    title='H3 vs. H7'
)

h3_h5_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
    y=alt.Y('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
    tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h5_wt_aa', 'h3_effect', 'h5_effect']
).properties(
    width=200,
    height=200,
    title='H3 vs. H5'
)

h5_h7_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
    y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
    tooltip=['struct_site', 'mutant', 'h5_wt_aa', 'h7_wt_aa', 'h5_effect', 'h7_effect']
).properties(
    width=200,
    height=200,
    title='H5 vs. H7'
)

h3_h7_scatter | h3_h5_scatter | h5_h7_scatter
Out[5]:
In [6]:
def scatter_and_density_plot(df, ha_x, ha_y, colors):
    r_value = df[f'avg_{ha_x}_effect'].corr(df[f'avg_{ha_y}_effect'])
    r_text = f"r = {r_value:.2f}"

    identity_line = alt.Chart(pd.DataFrame({'x': [-5, 0.3], 'y': [-5, 0.3]})).mark_line(
        strokeDash=[6, 6],
        color='black'
    ).encode(
        x='x',
        y='y'
    )

    df = df.assign(
        same_wildtype= lambda x: np.where(
            x[f'{ha_x}_wt_aa'] == x[f'{ha_y}_wt_aa'],
            'Amino acid conserved',
            'Amino acid changed'
        ),
    )

    scatter = alt.Chart(df).mark_circle(
        size=35, opacity=1, stroke='black', strokeWidth=0.5
    ).encode(
        x=alt.X(f'avg_{ha_x}_effect', title=['Mean effect on cell entry', f'in {ha_x.upper()} background']),
        y=alt.Y(f'avg_{ha_y}_effect', title=['Mean effect on cell entry', f'in {ha_y.upper()} background']),
        color=alt.Color(
            'same_wildtype:N', 
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
        tooltip=['struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', f'avg_{ha_x}_effect', f'avg_{ha_y}_effect']
    ).properties(
        width=175,
        height=175,
    )

    r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
        align='left',
        baseline='top',
        fontSize=16,
        fontWeight='normal',
        color='black'
    ).encode(
        text='text:N',
        x=alt.value(5), 
        y=alt.value(5)
    )

    x_density = alt.Chart(df).transform_density(
        density=f'avg_{ha_x}_effect',
        bandwidth=0.3,
        groupby=['same_wildtype'],
        extent=[df[f'avg_{ha_x}_effect'].min(), df[f'avg_{ha_x}_effect'].max()],
        counts=True,
        steps=200
    ).mark_area(opacity=0.6, color='black', strokeWidth=1).encode(
        alt.X('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
        alt.Y('density:Q', title='Density').stack(None),
        color=alt.Color(
            'same_wildtype:N', 
            title=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
    ).properties(
        width=175,
        height=50
    )

    y_density = alt.Chart(df).transform_density(
        density=f'avg_{ha_y}_effect',
        bandwidth=0.3,
        groupby=['same_wildtype'],
        extent=[df[f'avg_{ha_y}_effect'].min(), df[f'avg_{ha_y}_effect'].max()],
        counts=True,
        steps=200
    ).mark_area(opacity=0.6, color='black', strokeWidth=1, orient='horizontal').encode(
        alt.Y('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
        alt.X('density:Q', title='Density').stack(None),
        color=alt.Color(
            'same_wildtype:N', 
            title=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
    ).properties(
        width=50,
        height=175
    )

    marginal_plot = alt.vconcat(
        x_density,
        alt.hconcat(
            (scatter + identity_line + r_label),
            y_density
        )
    )
    return marginal_plot

colors = {
    'Amino acid changed' : '#5484AF',
    'Amino acid conserved' : '#E04948'
}
p1 = scatter_and_density_plot(site_effects, 'h3', 'h5', colors=colors)
p2 = scatter_and_density_plot(site_effects, 'h3', 'h7', colors=colors)
p3 = scatter_and_density_plot(site_effects, 'h5', 'h7', colors=colors)

p1 | p2 | p3
Out[6]:

Calculate Jensen-Shannon Divergence¶

In [7]:
def kl_divergence(p, q):
    return np.sum(p * np.log(p / q))

def compute_js_divergence_per_site(df, ha_x, ha_y, site_col="struct_site", min_mutations=15):
    """Compute JS divergence at each site and merge it back to the dataframe."""
    js_per_site = {}

    for site, group in df.groupby(site_col):
        valid = group.dropna(subset=[f'{ha_x}_effect', f'{ha_y}_effect'])
        js_div = np.nan

        if len(valid) >= min_mutations:
            p = np.exp(valid[f'{ha_x}_effect'].values)
            q = np.exp(valid[f'{ha_y}_effect'].values)

            p /= p.sum()
            q /= q.sum()

            m = 0.5 * (p + q)
            js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))

        js_per_site[site] = js_div

    # Create a column with the JS divergence duplicated across each row at the same site
    df = df.copy()
    col_name = f"JS_{ha_x}_vs_{ha_y}"
    df[col_name] = df[site_col].map(js_per_site)

    return df

js_df_h3_h7 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h7', min_mutations=10)
js_df_h3_h5 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h5', min_mutations=10)
js_df_h5_h7 = compute_js_divergence_per_site(mutation_effects, 'h5', 'h7', min_mutations=10)

Are epistatic shifts significant?¶

In [8]:
def compute_jsd_with_null(
    df,
    ha_x,
    ha_y,
    site_col="struct_site",
    min_mutations=15,
    n_bootstrap=1000,
    random_seed=42,
    jsd_threshold=0.02,
    use_pooled_std=False
):
    """
    Compute JS divergence with simulated null distribution for significance testing.

    The null distribution represents: "What JSD would I observe from measurement noise alone?"

    The null is generated by simulating data under Gaussian measurement noise:
    1. ha_x null: Generate two replicate datasets from ha_x effects and stds, compute JSD
    2. ha_y null: Generate two replicate datasets from ha_y effects and stds, compute JSD
    3. Take the mean of the two null distributions

    This assumes measurement noise is normally distributed. A significant result means the
    observed JSD is larger than expected from measurement noise alone.

    Only sites with observed JSD > jsd_threshold are tested for significance.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe with mutation effects and effect_std columns
    ha_x, ha_y : str
        HA subtype names (e.g., 'h3', 'h5')
    site_col : str
        Column name for site identifier
    min_mutations : int
        Minimum number of mutations required at a site
    n_bootstrap : int
        Number of simulations for null distribution
    random_seed : int
        Random seed for reproducibility
    jsd_threshold : float
        Minimum JSD value for a site to be tested for significance.
        Sites with observed JSD <= threshold will have p_value = NaN.
        Default is 0.02.
    use_pooled_std : bool
        If True, use a per-site pooled standard deviation (RMS of per-mutation stds)
        instead of each mutation's individual effect_std when generating the null
        distribution. This gives a more stable variance estimate by pooling ~20
        estimates rather than relying on each mutation's 2-4 replicate estimate.
        Default is False (use per-mutation stds).

    Returns
    -------
    pd.DataFrame
        DataFrame with columns:
        - struct_site: site identifier
        - JS_observed: observed JSD value
        - JS_null_mean: mean of null distribution (NaN if below threshold)
        - JS_null_std: standard deviation of null distribution (NaN if below threshold)
        - p_value: empirical p-value (NaN if below threshold)
        - n_mutations: number of mutations at site
        Sorted by struct_site using natural sorting.
    """
    np.random.seed(random_seed)

    def compute_jsd_vectorized(effects, std, n_bootstrap):
        """Vectorized computation of null JSD distribution."""
        n_mutations = len(effects)
        
        # Generate all simulated samples: shape (n_bootstrap, n_mutations)
        effects_1 = np.random.normal(
            loc=effects[np.newaxis, :],  # broadcast to (1, n_mutations)
            scale=std[np.newaxis, :],     # broadcast to (1, n_mutations)
            size=(n_bootstrap, n_mutations)
        )
        effects_2 = np.random.normal(
            loc=effects[np.newaxis, :],
            scale=std[np.newaxis, :],
            size=(n_bootstrap, n_mutations)
        )
        
        # Compute probabilities for all simulations
        p1 = np.exp(effects_1)
        p2 = np.exp(effects_2)
        
        # Normalize: divide each row by its sum
        p1 = p1 / p1.sum(axis=1, keepdims=True)
        p2 = p2 / p2.sum(axis=1, keepdims=True)
        
        # Compute mixture distribution
        m = 0.5 * (p1 + p2)
        
        # Compute KL divergences
        # KL(p||m) = sum(p * log(p/m))
        kl_p_m = np.sum(p1 * np.log(p1 / m), axis=1)
        kl_q_m = np.sum(p2 * np.log(p2 / m), axis=1)
        
        # JSD = 0.5 * (KL(p||m) + KL(q||m))
        jsd = 0.5 * (kl_p_m + kl_q_m)
        
        return jsd

    results = []

    for site, group in df.groupby(site_col):
        # Filter to valid mutations with both effects and stds
        valid = group.dropna(subset=[
            f'{ha_x}_effect', f'{ha_y}_effect',
            f'{ha_x}_effect_std', f'{ha_y}_effect_std'
        ])

        if len(valid) < min_mutations:
            continue

        # Get observed effects
        effects_x = valid[f'{ha_x}_effect'].values
        effects_y = valid[f'{ha_y}_effect'].values

        # Get standard deviations
        std_x = valid[f'{ha_x}_effect_std'].values
        std_y = valid[f'{ha_y}_effect_std'].values

        # Compute observed JSD between ha_x and ha_y
        p_obs = np.exp(effects_x)
        q_obs = np.exp(effects_y)
        p_obs /= p_obs.sum()
        q_obs /= q_obs.sum()
        m_obs = 0.5 * (p_obs + q_obs)
        jsd_obs = 0.5 * (kl_divergence(p_obs, m_obs) + kl_divergence(q_obs, m_obs))

        # Only compute null distribution if JSD exceeds threshold
        if jsd_obs <= jsd_threshold:
            results.append({
                'struct_site': site,
                'JS_observed': jsd_obs,
                'JS_null_mean': np.nan,
                'JS_null_std': np.nan,
                'p_value': np.nan,
                'n_mutations': len(valid),
                'null_distribution': None
            })
            continue

        # Replace per-mutation stds with the per-site std (RMS)
        if use_pooled_std:
            pooled_std_x = np.sqrt(np.mean(std_x ** 2))
            pooled_std_y = np.sqrt(np.mean(std_y ** 2))
            std_x = np.full_like(std_x, pooled_std_x)
            std_y = np.full_like(std_y, pooled_std_y)

        # Simulated null distributions
        jsd_null_x = compute_jsd_vectorized(effects_x, std_x, n_bootstrap)
        jsd_null_y = compute_jsd_vectorized(effects_y, std_y, n_bootstrap)
        
        # Take the mean of the two nulls (balanced approach)
        jsd_null = (jsd_null_x + jsd_null_y) / 2

        # Compute empirical p-value (one-tailed test: is observed JSD greater than null?)
        p_value = np.mean(jsd_null >= jsd_obs)

        results.append({
            'struct_site': site,
            'JS_observed': jsd_obs,
            'JS_null_mean': jsd_null.mean(),
            'JS_null_std': jsd_null.std(),
            'p_value': p_value,
            'n_mutations': len(valid),
            'null_distribution': jsd_null
        })

    # Convert to DataFrame and sort by struct_site using natural sorting
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
    
    return results_df
In [9]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h3_h5 = compute_jsd_with_null(
    js_df_h3_h5,
    'h3', 'h5',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h3_h7 = compute_jsd_with_null(
    js_df_h3_h7,
    'h3', 'h7',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h5_h7 = compute_jsd_with_null(
    js_df_h5_h7,
    'h5', 'h7',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

# Apply multiple testing correction (Benjamini-Hochberg FDR)
# Only apply FDR to sites that were tested (non-NaN p-values)
from scipy.stats import false_discovery_control

def apply_fdr_with_threshold(df):
    """Apply FDR correction only to non-NaN p-values."""
    # Initialize q_value column with NaN
    df['q_value'] = np.nan
    
    # Get indices of non-NaN p-values
    tested_mask = df['p_value'].notna()
    
    if tested_mask.sum() > 0:
        # Apply FDR correction only to tested sites
        df.loc[tested_mask, 'q_value'] = false_discovery_control(df.loc[tested_mask, 'p_value'])
    
    return df

jsd_with_pvals_h3_h5 = apply_fdr_with_threshold(jsd_with_pvals_h3_h5)
jsd_with_pvals_h3_h7 = apply_fdr_with_threshold(jsd_with_pvals_h3_h7)
jsd_with_pvals_h5_h7 = apply_fdr_with_threshold(jsd_with_pvals_h5_h7)

# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H3 vs H5, q < 0.1):")
total_h3h5 = len(jsd_with_pvals_h3_h5)
sig_h3h5 = (jsd_with_pvals_h3_h5['q_value'] < 0.1).sum()
print(f"  {sig_h3h5} / {total_h3h5} sites ({sig_h3h5/total_h3h5:.2%})")

print("\nSignificant sites (H3 vs H7, q < 0.1):")
total_h3h7 = len(jsd_with_pvals_h3_h7)
sig_h3h7 = (jsd_with_pvals_h3_h7['q_value'] < 0.1).sum()
print(f"  {sig_h3h7} / {total_h3h7} sites ({sig_h3h7/total_h3h7:.2%})")

print("\nSignificant sites (H5 vs H7, q < 0.1):")
total_h5h7 = len(jsd_with_pvals_h5_h7)
sig_h5h7 = (jsd_with_pvals_h5_h7['q_value'] < 0.1).sum()
print(f"  {sig_h5h7} / {total_h5h7} sites ({sig_h5h7/total_h5h7:.2%})")
Significant sites (H3 vs H5, q < 0.1):
  270 / 468 sites (57.69%)

Significant sites (H3 vs H7, q < 0.1):
  253 / 467 sites (54.18%)

Significant sites (H5 vs H7, q < 0.1):
  212 / 432 sites (49.07%)
In [10]:
def plot_jsd_ecdf(jsd_pvals_df, ha_x, ha_y, alpha=0.1):
    """
    Plot empirical cumulative distribution of JSD values, separated by significance.
    
    Parameters
    ----------
    jsd_pvals_df : pd.DataFrame
        DataFrame with JSD values and q-values from compute_jsd_with_null
    ha_x, ha_y : str
        HA subtype names for title
    alpha : float
        Significance threshold for q-value (default 0.1)
    
    Returns
    -------
    alt.Chart : eCDF plot
    """
    # Prepare data with significance flag
    plot_df = jsd_pvals_df.copy()
    plot_df['significant'] = plot_df['q_value'] < alpha
    
    # Count sites in each category
    n_sig = plot_df['significant'].sum()
    n_nonsig = (~plot_df['significant']).sum()
    
    # Create eCDF plot
    ecdf = alt.Chart(plot_df).transform_window(
        cumulative_count='count()',
        sort=[{'field': 'JS_observed'}],
        groupby=['significant']
    ).transform_joinaggregate(
        total='count()',
        groupby=['significant']
    ).transform_calculate(
        ecdf='datum.cumulative_count / datum.total'
    ).mark_line(size=3).encode(
        x=alt.X('JS_observed:Q', 
                title=['Divergence in amino-acid', 'preferences'],
                scale=alt.Scale(domain=[0, 0.7])),
        y=alt.Y('ecdf:Q', 
                title='Cumulative probability',
                scale=alt.Scale(domain=[0, 1])),
        color=alt.Color('significant:N',
                       title=['Significant', f'(FDR < {alpha})'],
                       scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
                       legend=alt.Legend(
                           titleFontSize=14,
                           labelFontSize=12
                       ))
    ).properties(
        width=175,
        height=175,
        title=alt.Title(
            f'{ha_x.upper()} vs. {ha_y.upper()}',
            subtitle=[f'N = {n_sig} significant, {n_nonsig} not significant'],
            fontSize=16,
            subtitleFontSize=12
        )
    )
    
    return ecdf

# Create eCDF plots for all comparisons
ecdf_h3_h5 = plot_jsd_ecdf(jsd_with_pvals_h3_h5, 'h3', 'h5')
ecdf_h3_h7 = plot_jsd_ecdf(jsd_with_pvals_h3_h7, 'h3', 'h7')
ecdf_h5_h7 = plot_jsd_ecdf(jsd_with_pvals_h5_h7, 'h5', 'h7')

# Display side by side
(ecdf_h3_h5 | ecdf_h3_h7 | ecdf_h5_h7).display()
In [11]:
def plot_jsd(df, jsd_pvals_df, ha_x, ha_y, identity_df=None, alpha=0.1, only_lineplot=False, variant_selector=None): 
    """
    Plot JSD values with significance coloring.
    
    Parameters
    ----------
    df : pd.DataFrame
        Main dataframe with mutation effects
    jsd_pvals_df : pd.DataFrame
        DataFrame with JSD p-values and q-values from compute_jsd_with_null
    identity_df : pd.DataFrame
        DataFrame with sequence identity information
    ha_x, ha_y : str
        HA subtype names
    alpha : float
        Significance threshold for q-value (default 0.1)
    variant_selector : alt.selection_point, optional
        Shared selection object for synchronized highlighting across plots
    """
    if identity_df is not None:
        result = identity_df.query(
            f'ha_x=="{ha_x.upper()}" and ha_y=="{ha_y.upper()}"'
        )
        shared_aai = result['percent_identity'].values[0] if len(result) > 0 else None
    else:
        shared_aai = None

    amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }
    df['struct_site'] = df['struct_site'].astype(str)

    df = df.assign(
        mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
    )

    # Merge significance data with site-level JSD data
    site_jsd_df = df[[
        'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
        f'JS_{ha_x}_vs_{ha_y}', f'rmsd_{ha_x}{ha_y}'
    ]].dropna().drop_duplicates()
    
    # Merge q-values
    site_jsd_df = site_jsd_df.merge(
        jsd_pvals_df[['struct_site', 'q_value']], 
        on='struct_site', 
        how='left'
    )
    
    # Add significance flag
    site_jsd_df = site_jsd_df.assign(
        significant=lambda x: x['q_value'] < alpha
    )

    # Use provided selector or create a new one
    if variant_selector is None:
        variant_selector = alt.selection_point(
            on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
        )

    sorted_sites = natsorted(df['struct_site'].unique())
    base = alt.Chart(site_jsd_df).encode(
        alt.X(
            "struct_site:O",
            sort=sorted_sites, 
            title='Site',
            axis=alt.Axis(
                labelAngle=0,
                values=['1', '50', '100', '150', '200', '250', '300', '350', '400', '450', '500'],
                tickCount=11,
            )
        ),
        alt.Y(
            f'JS_{ha_x}_vs_{ha_y}:Q', 
            title=['Divergence in amino-acid', 'preferences'],
            axis=alt.Axis(
                grid=False
            ),
            scale=alt.Scale(domain=[0, 0.7])
        ),
        tooltip=[
            'struct_site', 
            f'{ha_x}_wt_aa', 
            f'{ha_y}_wt_aa', 
            alt.Tooltip(f'JS_{ha_x}_vs_{ha_y}', format='.4f'),
            alt.Tooltip(f'rmsd_{ha_x}{ha_y}', format='.2f'),
            alt.Tooltip('q_value', format='.4f'),
            'significant'
        ],
    ).properties(
        width=800,
        height=150
    )

    line = base.mark_line(opacity=0.5, stroke='#999999', size=1)
    
    # Points layer with conditional formatting based on hover and click
    points = base.mark_circle(filled=True).encode(
        size=alt.condition(
            variant_selector,
            alt.value(75),  # when selected
            alt.value(40)  # default
        ),
        color=alt.Color(
            'significant:N',
            title=['Significant', f'(FDR < {alpha})'],
            scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
            legend=alt.Legend(
                titleFontSize=14,
                labelFontSize=12
            )
        ),
        stroke=alt.condition(
            variant_selector,
            alt.value('black'),
            alt.value(None)
        ),
        strokeWidth=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0)
        ),
        opacity=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0.75)
        )
    ).add_params(
        variant_selector
    )

    # Correlation between cell entry effects plot
    base_corr_chart = (alt.Chart(df)
        .mark_text(size=20)
        .encode(
            alt.X(
                f"{ha_x}_effect", 
                title=["Effect on cell entry", f"in {ha_x.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Y(
                f"{ha_y}_effect", 
                title=["Effect on cell entry", f"in {ha_y.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Text('mutant'),
            alt.Color('mutant_type',
                    title='Mutant type',
                    scale=alt.Scale(
                        domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                        range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
                    ),
                    legend=alt.Legend(
                        titleFontSize=16,
                        labelFontSize=13
                    )
            ),
            tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
                     f'{ha_x}_effect', f'{ha_x}_effect_std', 
                     f'{ha_y}_effect', f'{ha_y}_effect_std',
                    f'JS_{ha_x}_vs_{ha_y}'],  
        )
        .transform_filter(
            variant_selector
        )
        .properties(
            height=150,
            width=150,
        )
    )

    # Vertical line at x = 0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
    
    # Horizontal line at y = 0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
    
    corr_chart = vline + hline + base_corr_chart

    # density plot
    density = alt.Chart(
        site_jsd_df
    ).transform_density(
        density=f'JS_{ha_x}_vs_{ha_y}',
        bandwidth=0.02,
        extent=[0,1],
        counts=True,
        steps=200
    ).mark_area(opacity=1, color='#CCEBC5', stroke='black', strokeWidth=1).encode(
        alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
        alt.Y('density:Q', title='Density').stack(None),
    ).properties(
        width=200,
        height=60
    )

    if shared_aai is not None:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()} ({shared_aai:.1f}% Amino Acid Identity)'
    else:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()}'

    # Combine the bar and heatmaps
    if only_lineplot is False:
        combined_chart = alt.vconcat(
            (line + points), corr_chart, density
        ).resolve_scale(
            y='independent', 
            x='independent', 
            color='independent'
        )
    else:
        combined_chart = line + points
    
    combined_chart = combined_chart.properties(
        title=alt.Title(title_text, 
        offset=0,
        fontSize=18,
        subtitleFontSize=16,
        anchor='middle'
        )
    )

    return combined_chart

chart = plot_jsd(
    js_df_h3_h5,
    jsd_with_pvals_h3_h5,
    'h3', 'h5',
    seq_identity
)
chart.display()
In [12]:
chart = plot_jsd(
    js_df_h3_h7,
    jsd_with_pvals_h3_h7,
    'h3', 'h7',
    seq_identity
)
chart.display()
In [13]:
chart = plot_jsd(
    js_df_h5_h7,
    jsd_with_pvals_h5_h7,
    'h5', 'h7',
    seq_identity
)
chart.display()
In [14]:
# Create a shared selection for synchronized hovering across all plots
shared_selection = alt.selection_point(
    on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
)

combined_interactive_lineplots = (
    plot_jsd(
        js_df_h3_h5, jsd_with_pvals_h3_h5,
        'h3', 'h5', seq_identity, only_lineplot=True,
        variant_selector=shared_selection
    ) & plot_jsd(
        js_df_h5_h7, jsd_with_pvals_h5_h7,
        'h5', 'h7', seq_identity, only_lineplot=True,
        variant_selector=shared_selection
    ) & plot_jsd(
        js_df_h3_h7, jsd_with_pvals_h3_h7,
        'h3', 'h7', seq_identity, only_lineplot=True,
        variant_selector=shared_selection
    )
)
combined_interactive_lineplots.save('combined_interactive_lineplots.html')
combined_interactive_lineplots.display()
In [15]:
js_df_h3_h5[[
    'struct_site', 'h3_wt_aa', 'h5_wt_aa', 'rmsd_h3h5', '4o5n_aa_RSA', 'JS_h3_vs_h5'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h3_h5_divergence.csv', index=False
)

js_df_h3_h7[[
    'struct_site', 'h3_wt_aa', 'h7_wt_aa', 'rmsd_h3h7', '4o5n_aa_RSA', 'JS_h3_vs_h7'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h3_h7_divergence.csv', index=False
)

js_df_h5_h7[[
    'struct_site', 'h5_wt_aa', 'h7_wt_aa', 'rmsd_h5h7', '4o5n_aa_RSA', 'JS_h5_vs_h7'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h5_h7_divergence.csv', index=False
)
In [16]:
# Compare distribution of per-mutation vs per-site std for H3, H5, H7
std_rows = []

for ha, std_col in [('H3', 'h3_effect_std'), ('H5', 'h5_effect_std'), ('H7', 'h7_effect_std')]:
    valid = mutation_effects[['struct_site', std_col]].dropna()

    # Per-mutation stds (exclude wildtype std=0)
    per_mut_stds = valid.loc[valid[std_col] > 0, std_col]
    for v in per_mut_stds:
        std_rows.append({'std': v, 'type': 'Per-mutation', 'ha': ha})

    # Per-site stds (RMS across all mutations)
    per_site_pooled = valid.groupby('struct_site')[std_col].apply(
        lambda x: np.sqrt(np.mean(x ** 2))
    )
    for v in per_site_pooled:
        std_rows.append({'std': v, 'type': 'Per-site', 'ha': ha})

std_df = pd.DataFrame(std_rows)

std_chart = alt.Chart(std_df).transform_density(
    density='std',
    bandwidth=0.03,
    groupby=['ha', 'type'],
    extent=[0, 1.5],
    steps=300
).mark_area(opacity=0.5, strokeWidth=1.5).encode(
    alt.X('value:Q', title='Standard deviation of effect'),
    alt.Y('density:Q', title='Density').stack(None),
    alt.Color('type:N',
              title=None,
              scale=alt.Scale(
                  domain=['Per-mutation', 'Per-site'],
                  range=['#4C78A8', '#E45756']
              ),
              legend=alt.Legend(labelFontSize=14, titleFontSize=16)),
    alt.Stroke('type:N',
               scale=alt.Scale(
                   domain=['Per-mutation', 'Per-site'],
                   range=['#4C78A8', '#E45756']
               ),
               legend=None),
    alt.Column('ha:N', title=None, header=alt.Header(labelFontSize=16))
).properties(
    width=200,
    height=150
).resolve_scale(y='independent')

std_chart.display()
In [17]:
# Run significance analysis with per-site pooled std for all four comparisons
jsd_pooled_h3_h5 = compute_jsd_with_null(
    js_df_h3_h5, 'h3', 'h5',
    min_mutations=10, n_bootstrap=1000, jsd_threshold=0.02,
    use_pooled_std=True
)
jsd_pooled_h3_h7 = compute_jsd_with_null(
    js_df_h3_h7, 'h3', 'h7',
    min_mutations=10, n_bootstrap=1000, jsd_threshold=0.02,
    use_pooled_std=True
)
jsd_pooled_h5_h7 = compute_jsd_with_null(
    js_df_h5_h7, 'h5', 'h7',
    min_mutations=10, n_bootstrap=1000, jsd_threshold=0.02,
    use_pooled_std=True
)

jsd_pooled_h3_h5 = apply_fdr_with_threshold(jsd_pooled_h3_h5)
jsd_pooled_h3_h7 = apply_fdr_with_threshold(jsd_pooled_h3_h7)
jsd_pooled_h5_h7 = apply_fdr_with_threshold(jsd_pooled_h5_h7)

for label, orig, pooled in [
    ("H3 vs H5",       jsd_with_pvals_h3_h5,    jsd_pooled_h3_h5),
    ("H3 vs H7",       jsd_with_pvals_h3_h7,    jsd_pooled_h3_h7),
    ("H5 vs H7",       jsd_with_pvals_h5_h7,    jsd_pooled_h5_h7),
]:
    n = len(orig)
    sig_orig   = (orig['q_value']   < 0.1).sum()
    sig_pooled = (pooled['q_value'] < 0.1).sum()
    print(f"{label} (N={n} sites):")
    print(f"  Per-mutation std:    {sig_orig}   significant ({sig_orig/n:.1%})")
    print(f"  Per-site std:  {sig_pooled} significant ({sig_pooled/n:.1%})")
H3 vs H5 (N=468 sites):
  Per-mutation std:    270   significant (57.7%)
  Per-site std:  317 significant (67.7%)
H3 vs H7 (N=467 sites):
  Per-mutation std:    253   significant (54.2%)
  Per-site std:  301 significant (64.5%)
H5 vs H7 (N=432 sites):
  Per-mutation std:    212   significant (49.1%)
  Per-site std:  261 significant (60.4%)
In [18]:
# Per-site lineplot comparisons (pooled std)
for (js_df, pvals_orig, pvals_pooled, ha_x, ha_y) in [
    (js_df_h3_h5,    jsd_with_pvals_h3_h5,    jsd_pooled_h3_h5,    'h3',     'h5'),
    (js_df_h3_h7,    jsd_with_pvals_h3_h7,    jsd_pooled_h3_h7,    'h3',     'h7'),
    (js_df_h5_h7,    jsd_with_pvals_h5_h7,    jsd_pooled_h5_h7,    'h5',     'h7'),
]:
    orig_chart   = plot_jsd(js_df, pvals_orig,   ha_x, ha_y, only_lineplot=True)
    pooled_chart = plot_jsd(js_df, pvals_pooled, ha_x, ha_y, only_lineplot=True)
    (
        orig_chart.properties(title=f'{ha_x.upper()} vs {ha_y.upper()} – Per-mutation std') &
        pooled_chart.properties(title=f'{ha_x.upper()} vs {ha_y.upper()} – Per-site std')
    ).display()

Examples of mutation effect correlations¶

In [19]:
def plot_correlation(df, ha_x, ha_y, site, decimal_places=2):
    amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }
    df['struct_site'] = df['struct_site'].astype(str)

    df = df.assign(
        mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
    ).query(f'struct_site == "{site}"')

    jsd = df[f'JS_{ha_x}_vs_{ha_y}'].unique()[0]

    base_corr_chart = (alt.Chart(df.query(f'struct_site == "{site}"'))
        .mark_text(size=20)
        .encode(
            alt.X(
                f"{ha_x}_effect", 
                title=["Effect on cell entry", f"in {ha_x.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Y(
                f"{ha_y}_effect", 
                title=["Effect on cell entry", f"in {ha_y.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Text('mutant'),
            alt.Color('mutant_type',
                    title='Mutant type',
                    scale=alt.Scale(
                        domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                        range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
                    ),
                    legend=alt.Legend(
                        titleFontSize=16,
                        labelFontSize=13
                    )
            ),
            tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
                     f'{ha_x}_effect', f'{ha_x}_effect_std', 
                     f'{ha_y}_effect', f'{ha_y}_effect_std',
                    f'JS_{ha_x}_vs_{ha_y}'],  
        ).properties(
            height=125,
            width=125,
        )
    )

    # Vertical line at x = 0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
    
    # Horizontal line at y = 0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
    
    corr_chart = (vline + hline + base_corr_chart).properties(
        title=alt.Title([f'Site {site}', f'Divergence = {jsd:.{decimal_places}f}'], 
        offset=0,
        fontSize=16,
        anchor='middle'
        )
    )
    return corr_chart
In [20]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='97', decimal_places=3) |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='198') |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='241')
).display()
In [21]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='86') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='86')
).display()
In [22]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='173') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='173') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='173')
).display()
In [23]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='178') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='178') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='178')
).display()

(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='123') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='123') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='123')
).display()

(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='176') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='176') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='176')
).display()

# H3 forms H bonds at 178, 123, 176, and 211. 
# H5 and H7 do not form any H bonds in this region, and therefore tolerate many more amino acids.

Logo plots¶

In [24]:
from IPython.display import display, Image
import matplotlib.pyplot as plt
import dmslogo
from dmslogo.colorschemes import CBPALETTE, ValueToColorMap
/home/tyu2/.conda/envs/ha-epistasis/lib/python3.12/site-packages/dmslogo/logo.py:27: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources
In [25]:
amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }

color_map = {
    'Aromatic': '#4e79a7', 'Hydrophilic': '#f28e2c', 'Hydrophobic': '#e15759',
    'Negative': '#76b7b2', 'Positive': '#59a14f', 'Special': '#edc949'
}

def compute_aa_prefs(df, ha, site_col='struct_site'):
    """Compute per-site amino acid preferences via softmax of effects."""
    effect_col = f'{ha}_effect'
    pref_col = f'{ha}_aa_pref'
    df = df.copy()
    df[pref_col] = np.nan
    for site, group in df.groupby(site_col):
        valid_mask = group[effect_col].notna()
        if valid_mask.sum() == 0:
            continue
        effects = group.loc[valid_mask, effect_col].values
        prefs = np.exp(effects)
        prefs /= prefs.sum()
        df.loc[group.index[valid_mask], pref_col] = prefs
    return df

site_prefs = compute_aa_prefs(mutation_effects.query('struct_site in ["123", "176", "178"]'), 'h3')
site_prefs = compute_aa_prefs(site_prefs, 'h5')
site_prefs = compute_aa_prefs(site_prefs, 'h7')[['struct_site', 'mutant', 'h3_wt_aa', 'h5_wt_aa', 'h7_wt_aa', 'h3_aa_pref', 'h5_aa_pref', 'h7_aa_pref']]

site_prefs = site_prefs.melt(
    id_vars=['struct_site', 'mutant', 'h3_wt_aa', 'h5_wt_aa', 'h7_wt_aa'],
    value_vars=['h3_aa_pref', 'h5_aa_pref', 'h7_aa_pref'],
    var_name='ha',
    value_name='aa_pref'
).assign(
    ha=lambda x: x['ha'].str.replace('_aa_pref', '', regex=False)
).assign(
    site_label=lambda x: x.apply(lambda r: r['struct_site'], axis=1)
).assign(
    mutant_type=lambda x: x['mutant'].map(amino_acid_classification),
    color=lambda x: x['mutant_type'].map(color_map),
    show_site=True
).drop(columns=['h3_wt_aa', 'h5_wt_aa', 'h7_wt_aa'])

site_prefs['struct_site'] = site_prefs['struct_site'].astype(int)
site_prefs.head()
Out[25]:
struct_site mutant ha aa_pref site_label mutant_type color show_site
0 123 A h3 0.005977 123 Hydrophobic #e15759 True
1 123 C h3 0.005888 123 Special #edc949 True
2 123 D h3 0.006110 123 Negative #76b7b2 True
3 123 E h3 0.805119 123 Negative #76b7b2 True
4 123 F h3 0.006514 123 Aromatic #4e79a7 True
In [27]:
def generate_facet_logo_plot(df, output_file_name, clip_negative=True):
    """Generate logo plot and save as a file."""

    draw_logo_kwargs = {
        "letter_col": "mutant",
        "color_col": "color",
        "xtick_col": "site_label",
        "letter_height_col": "aa_pref",
        "xlabel": "",
        "clip_negative_heights": clip_negative,
    }
    fig, ax = dmslogo.facet_plot(
        data=df,
        x_col="struct_site",
        gridrow_col="ha",
        share_ylim_across_rows=True,
        show_col="show_site",
        draw_logo_kwargs=draw_logo_kwargs,
    )
    fig
    fig.savefig(output_file_name, bbox_inches="tight", format="svg")

generate_facet_logo_plot(site_prefs, 'figures/rewired_site.svg')
No description has been provided for this image