InĀ [1]:
import pandas as pd
import numpy as np
import altair as alt

import theme
from natsort import natsorted, natsort_keygen
from scipy.stats import mannwhitneyu

alt.data_transformers.disable_max_rows()
Out[1]:
DataTransformerRegistry.enable('default')
InĀ [2]:
sitemap = pd.read_csv('../data/h3_site_numbering_map.csv')
InĀ [3]:
js_df_h3_h5 = pd.read_csv('../results/divergence/h3_h5_divergence.csv')
js_df_h3_h7 = pd.read_csv('../results/divergence/h3_h7_divergence.csv')
js_df_h5_h7 = pd.read_csv('../results/divergence/h5_h7_divergence.csv')
InĀ [4]:
def annotate_jsd_df(jsd_df, ha_x, ha_y):
    aa_class = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }

    nat_key = natsort_keygen()
    return pd.merge(
        jsd_df.drop_duplicates().reset_index(drop=True),
        sitemap[['reference_site', 'region', 'rbs_region']].assign(
            reference_site=lambda x: x['reference_site'].astype(str)
        ),
        left_on=['struct_site'],
        right_on=['reference_site'],
        validate='one_to_one'
    ).assign(
        ha_region=lambda x: pd.Categorical(
            np.where(
                x['struct_site'].map(lambda s: nat_key(str(s))) <= nat_key('329'),
                'HA1',
                'HA2'
            ),
            categories=['HA1', 'HA2']
        )
    ).assign(
        JSD=lambda x: x[f'JS_{ha_x}_vs_{ha_y}'],
        same_wildtype= lambda x: np.where(
            x[f'{ha_x}_wt_aa'] == x[f'{ha_y}_wt_aa'],
            'Amino acid conserved',
            'Amino acid changed'
        ),
        mutant_type_changed=lambda x: np.where(x.apply(
            lambda row: aa_class.get(row[f'{ha_x}_wt_aa']) != aa_class.get(row[f'{ha_y}_wt_aa']),
            axis=1),
            'AA type changed',
            'AA type conserved'
        ),
        surface_exposed=lambda x: np.where(
            x['4o5n_aa_RSA'] > 0.2,
            'Exposed',
            'Buried'
        )        
    )

js_df_h3_h5_ann = annotate_jsd_df(js_df_h3_h5, 'h3', 'h5')
js_df_h3_h7_ann = annotate_jsd_df(js_df_h3_h7, 'h3', 'h7')
js_df_h5_h7_ann = annotate_jsd_df(js_df_h5_h7, 'h5', 'h7')
InĀ [5]:
def format_pvalue(p, decimals=2, threshold=1e-3):
    """Format p-values nicely."""
    superscripts = str.maketrans("0123456789-", "⁰¹²³⁓⁵⁶⁷⁸⁹⁻")
    
    if p < 1e-5:
        return "p < 1 Ɨ 10⁻⁵"
    elif p < threshold:  # p < 1e-3
        # Scientific notation format like "1 Ɨ 10⁻⁓"
        base, exp = f"{p:.0e}".split("e")
        exp_str = str(int(exp)).translate(superscripts)
        return f"p = {base} Ɨ 10{exp_str}"
    else:  # p >= 1e-3
        # Regular decimal format with one sig fig
        return f"p = {p:.1g}"

def plot_ha_region_jsd_boxplot(jsd_df, col, ha_x, ha_y, colors, box_size=30, label_angle=0, stats=True):
    """Plot boxplot of JS divergence split by HA region."""
    df = jsd_df[[
        'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', f'JS_{ha_x}_vs_{ha_y}', 'ha_region', 'rbs_region', 'region'
    ]].dropna().reset_index(drop=True)

    boxplot = alt.Chart(df).mark_boxplot(
        size=box_size,
        outliers=False
    ).encode(
        x=alt.X(col, title=None, axis=alt.Axis(labelAngle=label_angle)),
        y=alt.Y(f'JS_{ha_x}_vs_{ha_y}:Q', title=['Divergence in amino-acid', 'preferences']),
        color=alt.Color(
            col,
            legend=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        )
    ).properties(
        width=125,
        height=150,
        title=alt.Title(f'{ha_x.upper()} vs. {ha_y.upper()}', fontSize=16, anchor='middle')
    )

    if stats is True:
        stat, p = mannwhitneyu(
            *df.groupby(col, observed=True)[f'JS_{ha_x}_vs_{ha_y}'].apply(list),
            alternative='two-sided'
        )
        p_value = format_pvalue(p)
        p_label = alt.Chart(pd.DataFrame({'text': [p_value]})).mark_text(
            align='left',
            baseline='top',
            fontSize=16,
            fontWeight='normal',
            color='black'
        ).encode(
            text='text:N',
            x=alt.value(40), 
            y=alt.value(-15)
        )
        chart = boxplot + p_label
    else:
        chart = boxplot

    return chart

colors = {
    'HA1' : '#8DAB8E',
    'HA2' : '#E6C069'
}

(
    plot_ha_region_jsd_boxplot(js_df_h3_h5_ann, 'ha_region', 'h3', 'h5', colors) |
    plot_ha_region_jsd_boxplot(js_df_h3_h7_ann, 'ha_region', 'h3', 'h7', colors) |
    plot_ha_region_jsd_boxplot(js_df_h5_h7_ann, 'ha_region', 'h5', 'h7', colors)
)
Out[5]:
InĀ [6]:
def plot_combined_ha_region_jsd_boxplot(jsd_dfs, col, box_size=10, width=400):
    """
    Plot combined boxplot with all comparisons for each HA region.
    
    Parameters
    ----------
    jsd_dfs : list of tuples
        List of (dataframe, ha_x, ha_y) tuples for each comparison
    col : str
        Column name to group by (e.g., 'rbs_region', 'ha_region', 'region')
    colors : dict
        Dictionary mapping categories to colors (for HA regions)
    box_size : int
        Size of boxplot boxes
    """
    # Combine all dataframes with comparison labels
    combined_data = []
    for df, ha_x, ha_y in jsd_dfs:
        df_copy = df.copy()
        df_copy['comparison'] = f'{ha_x.upper()} vs. {ha_y.upper()}'
        df_copy['jsd'] = df_copy[f'JS_{ha_x}_vs_{ha_y}']
        combined_data.append(df_copy)
    
    combined_df = pd.concat(combined_data, ignore_index=True)
    
    # Define comparison colors
    comparison_colors = {
        'H3 vs. H5': '#155F83',
        'H3 vs. H7': '#FFA319',
        'H5 vs. H7': '#767676'
    }
    
    # Create boxplot
    boxplot = alt.Chart(combined_df).mark_boxplot(
        size=box_size,
        outliers=False
    ).encode(
        x=alt.X(f'{col}:N', title=None, axis=alt.Axis(labelAngle=-45)),
        y=alt.Y('jsd:Q', title=['Divergence in amino-acid', 'preferences']),
        color=alt.Color(
            'comparison:N',
            title='Comparison',
            scale=alt.Scale(
                domain=list(comparison_colors.keys()),
                range=list(comparison_colors.values())
            ),
            legend=alt.Legend(orient='top')
        ),
        xOffset='comparison:N'
    ).properties(
        width=width,
        height=200
    )
    
    return boxplot
InĀ [7]:
rbs_combined_plot = plot_combined_ha_region_jsd_boxplot(
    [
        (js_df_h3_h5_ann.query('rbs_region != "RBS other"'), 'h3', 'h5'),
        (js_df_h3_h7_ann.query('rbs_region != "RBS other"'), 'h3', 'h7'),
        (js_df_h5_h7_ann.query('rbs_region != "RBS other"'), 'h5', 'h7')
    ],
    col='rbs_region',
    box_size=10,
)

epitope_combined_plot = plot_combined_ha_region_jsd_boxplot(
    [
        (js_df_h3_h5_ann, 'h3', 'h5'),
        (js_df_h3_h7_ann, 'h3', 'h7'),
        (js_df_h5_h7_ann, 'h5', 'h7')
    ],
    col='region',
    box_size=10,
    width=475
)

rbs_combined_plot | epitope_combined_plot
Out[7]:

Changed sites?¶

InĀ [8]:
def plot_jsd_boxplot(jsd_df, ha_x, ha_y, box_col, box_col_title, colors, box_size=30, stats=True, decimals=2, label_map=None):
    """Plot boxplot of JS divergence.
    
    Parameters
    ----------
    label_map : dict, optional
        Dictionary mapping category values to multi-line labels (as lists of strings).
        Example: {'Amino acid changed': ['Amino acid', 'changed'], 
                  'Amino acid conserved': ['Amino acid', 'conserved']}
    """
    df = jsd_df[[
        'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 'JSD', box_col
    ]].dropna().reset_index(drop=True)

    # Build axis configuration
    axis_config = alt.Axis(labelAngle=0)
    
    if label_map is not None:
        # Create labelExpr for multi-line labels
        label_expr_parts = []
        for i, (key, label_lines) in enumerate(label_map.items()):
            condition = f"datum.label == '{key}'"
            label_array = str(label_lines)  # Converts to string representation of list
            if i == 0:
                label_expr_parts.append(f"{condition} ? {label_array}")
            else:
                label_expr_parts.append(f" : {condition} ? {label_array}")
        label_expr_parts.append(f" : datum.label")  # fallback
        axis_config = alt.Axis(labelAngle=0, labelExpr="".join(label_expr_parts))

    boxplot = alt.Chart(df).mark_boxplot(
        size=box_size,
        outliers=False
    ).encode(
        x=alt.X(box_col, title=box_col_title, axis=axis_config),
        y=alt.Y(f'JSD:Q', title=['Divergence in amino-acid', 'preferences']),
        color=alt.Color(
            box_col,
            legend=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        )
    ).properties(
        width=175,
        height=150,
        title=alt.Title(f'{ha_x.upper()} vs. {ha_y.upper()}', fontSize=16, anchor='middle')
    )
    if stats is True:
        stat, p = mannwhitneyu(
            *df.groupby(box_col)['JSD'].apply(list),
            alternative='two-sided'
        )
        p_value = format_pvalue(p, decimals=decimals)

        p_label = alt.Chart(pd.DataFrame({'text': [p_value]})).mark_text(
            align='center',
            baseline='top',
            fontSize=16,
            fontWeight='normal',
            color='black'
        ).encode(
            text='text:N',
            #x=alt.value(55), 
            y=alt.value(-15)
        )
        chart = boxplot + p_label
    else:
        chart = boxplot

    return chart

colors = {
    'Amino acid changed' : '#5484AF',
    'Amino acid conserved' : '#E04948'
}

label_map = {
    'Amino acid changed': ['Amino acid', 'changed'],
    'Amino acid conserved': ['Amino acid', 'conserved']
}

(
    plot_jsd_boxplot(js_df_h3_h5_ann, 'h3', 'h5', 'same_wildtype', None, colors, label_map=label_map) |
    plot_jsd_boxplot(js_df_h3_h7_ann, 'h3', 'h7', 'same_wildtype', None, colors, label_map=label_map) |
    plot_jsd_boxplot(js_df_h5_h7_ann, 'h5', 'h7', 'same_wildtype', None, colors, label_map=label_map)
).resolve_scale(y='shared')
Out[8]:

RMSD?¶

InĀ [9]:
def plot_jsd_vs_rmsd(jsd_df, ha_x, ha_y, colors): 
    df = jsd_df[[
        'struct_site', f'rmsd_{ha_x}{ha_y}', 'JSD', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 'same_wildtype'
    ]].dropna().reset_index(drop=True)

    r_value = df[f'JSD'].corr(df[f'rmsd_{ha_x}{ha_y}'])
    r_text = f"r = {r_value:.2f}"

    chart = alt.Chart(df).mark_circle(
        size=35, opacity=1, color='#899DA4', strokeWidth=0.5, stroke='black'
    ).encode(
        y=alt.Y(f'JSD', title=['Divergence in amino-acid', 'preferences']),
        x=alt.X(f'rmsd_{ha_x}{ha_y}', title='Sitewise Cα RMSD (ƅ)'),
        color=alt.Color(
            'same_wildtype:N',
            title=None, 
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
        tooltip=['struct_site', f'rmsd_{ha_x}{ha_y}', 'JSD', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa']
    ).properties(
        width=150,
        height=150,
        title=alt.Title(f'{ha_x.upper()} vs. {ha_y.upper()}', fontSize=16, anchor='middle')
    )

    r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
        align='left',
        baseline='top',
        fontSize=16,
        fontWeight='normal',
        color='black'
    ).encode(
        text='text:N',
        x=alt.value(100), 
        y=alt.value(5)
    )

    marginal_plot = alt.vconcat(
        chart + r_label,
    )

    return marginal_plot

colors = {
    'Amino acid changed' : '#5484AF',
    'Amino acid conserved' : '#E04948'
}

(
    plot_jsd_vs_rmsd(js_df_h3_h5_ann, 'h3', 'h5', colors) |
    plot_jsd_vs_rmsd(js_df_h3_h7_ann, 'h3', 'h7', colors) |
    plot_jsd_vs_rmsd(js_df_h5_h7_ann, 'h5', 'h7', colors)
)
Out[9]:

Surface accessibility?¶

InĀ [10]:
site_effects = pd.read_csv('../results/combined_effects/combined_site_effects.csv')
site_effects.head()
Out[10]:
struct_site h3_wt_aa h5_wt_aa h7_wt_aa rmsd_h3h5 rmsd_h3h7 rmsd_h5h7 4o5n_aa_RSA 4kwm_aa_RSA 6ii9_aa_RSA avg_h3_effect avg_h5_effect avg_h7_effect
0 9 S K NaN 9.167400 NaN NaN 1.084277 1.140252 NaN -0.050776 -0.998095 NaN
1 10 T S NaN 8.157247 NaN NaN 0.150962 0.175962 NaN -0.697911 -3.348267 NaN
2 11 A D D 5.040040 2.984626 2.886615 0.050388 0.097927 0.624352 -3.138280 -3.951383 -2.962194
3 12 T Q K 3.937602 1.626754 3.384350 0.268605 0.216889 0.368644 -1.036219 -0.342761 -1.705403
4 13 L I I 3.687798 1.734039 2.549524 0.000000 0.000000 0.000000 -3.941050 -3.827571 -3.829644
InĀ [11]:
def plot_jsd_vs_rsa(jsd_df, ha_x, ha_y, rsa_col, effect_col='h3'):
    r_value = jsd_df['JSD'].corr(jsd_df[f'{rsa_col}_aa_RSA'])
    r_text = f"r = {r_value:.2f}"

    chart = alt.Chart(
        jsd_df
    ).mark_circle(
        size=30, opacity=1, color='#899DA4', strokeWidth=0.5, stroke='black'
    ).encode(
        y=alt.Y(f'JSD', title=['Divergence in amino-acid', 'preferences']),
        x=alt.X(f'{rsa_col}_aa_RSA', title=['Relative solvent', f'accessibility ({rsa_col.upper()})']),
        color=alt.Color(
            f'avg_{effect_col}_effect',
            title=['Mean mutation', f'effect ({effect_col.upper()})'], 
            scale=alt.Scale(scheme='viridis')
        ),
        tooltip=['struct_site', f'{rsa_col}_aa_RSA', f'JSD', f'avg_{effect_col}_effect']
    ).properties(
        width=150,
        height=150,
        title=alt.Title(f'{ha_x.upper()} vs. {ha_y.upper()}', fontSize=16, anchor='middle')
    )

    r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
        align='left',
        baseline='top',
        fontSize=16,
        fontWeight='normal',
        color='black'
    ).encode(
        text='text:N',
        x=alt.value(100), 
        y=alt.value(5)
    )

    return chart + r_label

(
    plot_jsd_vs_rsa(
        pd.merge(
            js_df_h3_h5_ann,
            site_effects[['struct_site', 'avg_h3_effect']].drop_duplicates(),
            on='struct_site'
        ).query('same_wildtype == "Amino acid changed"'), 'h3', 'h5', '4o5n', 'h3'
    ) |
    plot_jsd_vs_rsa(
        pd.merge(
            js_df_h3_h7_ann,
            site_effects[['struct_site', 'avg_h3_effect']].drop_duplicates(),
            on='struct_site'
        ).query('same_wildtype == "Amino acid changed"'), 'h3', 'h7', '4o5n', 'h3'
    ) |
    plot_jsd_vs_rsa(
        pd.merge(
            js_df_h5_h7_ann,
            site_effects[['struct_site', 'avg_h5_effect']].drop_duplicates(),
            on='struct_site'
        ).query('same_wildtype == "Amino acid changed"'), 'h5', 'h7', '4o5n', 'h5'
    )
).resolve_scale(
    color='independent' 
)
Out[11]:
InĀ [12]:
colors = {
    'Buried' : '#2E5A87',
    'Exposed' : '#A1C8E3'
}

(
    plot_jsd_boxplot(js_df_h3_h5_ann.query('same_wildtype == "Amino acid changed"'), 'h3', 'h5', 'surface_exposed', None, colors, decimals=2) | 
    plot_jsd_boxplot(js_df_h3_h7_ann.query('same_wildtype == "Amino acid changed"'), 'h3', 'h7', 'surface_exposed', None, colors, decimals=2) |
    plot_jsd_boxplot(js_df_h5_h7_ann.query('same_wildtype == "Amino acid changed"'), 'h5', 'h7', 'surface_exposed', None, colors)
)
Out[12]:
InĀ [13]:
colors = {
    'Buried' : '#A90C38',
    'Exposed' : '#FFB0A1'
}

(
    plot_jsd_boxplot(js_df_h3_h5_ann.query('same_wildtype == "Amino acid conserved"'), 'h3', 'h5', 'surface_exposed', None, colors, decimals=2) | 
    plot_jsd_boxplot(js_df_h3_h7_ann.query('same_wildtype == "Amino acid conserved"'), 'h3', 'h7', 'surface_exposed', None, colors, decimals=2) |
    plot_jsd_boxplot(js_df_h5_h7_ann.query('same_wildtype == "Amino acid conserved"'), 'h5', 'h7', 'surface_exposed', None, colors)
)
Out[13]:
InĀ [14]:
colors = {
    'AA type changed' : '#2E5A87',
    'AA type conserved' : '#2E5A87'
}


label_map = {
    'AA type changed': ['AA type', 'changed'],
    'AA type conserved': ['AA type', 'conserved']
}

(
    plot_jsd_boxplot(
        js_df_h3_h5_ann.query('same_wildtype == "Amino acid changed" and surface_exposed == "Buried"'), 
        'h3', 'h5', 'mutant_type_changed', None, colors, label_map=label_map
    ) | 
    plot_jsd_boxplot(
        js_df_h3_h7_ann.query('same_wildtype == "Amino acid changed" and surface_exposed == "Buried"'), 
        'h3', 'h7', 'mutant_type_changed', None, colors, label_map=label_map
    ) |
    plot_jsd_boxplot(
        js_df_h5_h7_ann.query('same_wildtype == "Amino acid changed" and surface_exposed == "Buried"'), 
        'h5', 'h7', 'mutant_type_changed', None, colors, decimals=3, label_map=label_map
    )
)
Out[14]: