In [1]:
import pandas as pd
import numpy as np
import altair as alt

import theme
from natsort import natsorted, natsort_keygen
from scipy.stats import false_discovery_control

alt.data_transformers.disable_max_rows()
Out[1]:
DataTransformerRegistry.enable('default')
In [2]:
def read_and_filter_data(
    path, 
    effect_std_filter=2,
    times_seen_filter=2,
    n_selections_filter=2,
    clip_effect=-5 
):
    print(f'Reading data from {path}')
    print(
        f"Filtering for:\n"
        f"  effect_std <= {effect_std_filter}\n"
        f"  times_seen >= {times_seen_filter}\n"
        f"  n_selections >= {n_selections_filter}"
    )
    print(f"Clipping effect values at {clip_effect}")

    df = pd.read_csv(path).query(
        'effect_std <= @effect_std_filter and \
        times_seen >= @times_seen_filter and \
        n_selections >= @n_selections_filter'
    ).query(
        'mutant not in ["*", "-"]' # don't want stop codons/indels
    )

    df['site'] = df['site'].astype(str)
    df['effect'] = df['effect'].clip(-5)

    df = pd.concat([
        df,
        df[['site', 'wildtype']].drop_duplicates().assign(
            mutant=lambda x: x['wildtype'],
            effect=0.0,
            effect_std=0.0,
            times_seen=np.nan,
            n_selections=np.nan
        ) # add wildtype sites with zero effect
    ], ignore_index=True).sort_values(['site', 'mutant']).reset_index(drop=True)
    
    return df
In [3]:
def kl_divergence(p, q):
    return np.sum(p * np.log(p / q))

def compute_js_divergence_per_site(df, ha_x, ha_y, site_col="struct_site", min_mutations=15):
    """Compute JS divergence at each site and merge it back to the dataframe."""
    js_per_site = {}

    for site, group in df.groupby(site_col):
        valid = group.dropna(subset=[f'{ha_x}_effect', f'{ha_y}_effect'])
        js_div = np.nan

        if len(valid) >= min_mutations:
            p = np.exp(valid[f'{ha_x}_effect'].values)
            q = np.exp(valid[f'{ha_y}_effect'].values)

            p /= p.sum()
            q /= q.sum()

            m = 0.5 * (p + q)
            js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))

        js_per_site[site] = js_div

    # Create a column with the JS divergence duplicated across each row at the same site
    df = df.copy()
    col_name = f"JS_{ha_x}_vs_{ha_y}"
    df[col_name] = df[site_col].map(js_per_site)

    return df
In [4]:
def compute_jsd_with_null(
    df,
    ha_x,
    ha_y,
    site_col="struct_site",
    min_mutations=15,
    n_bootstrap=1000,
    random_seed=42,
    jsd_threshold=0.02,
    use_pooled_std=False
):
    """
    Compute JS divergence with simulated null distribution for significance testing.

    The null distribution represents: "What JSD would I observe from measurement noise alone?"

    The null is generated by simulating data under Gaussian measurement noise:
    1. ha_x null: Generate two replicate datasets from ha_x effects and stds, compute JSD
    2. ha_y null: Generate two replicate datasets from ha_y effects and stds, compute JSD
    3. Take the mean of the two null distributions

    This assumes measurement noise is normally distributed. A significant result means the
    observed JSD is larger than expected from measurement noise alone.

    Only sites with observed JSD > jsd_threshold are tested for significance.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe with mutation effects and effect_std columns
    ha_x, ha_y : str
        HA subtype names (e.g., 'h3', 'h5')
    site_col : str
        Column name for site identifier
    min_mutations : int
        Minimum number of mutations required at a site
    n_bootstrap : int
        Number of simulations for null distribution
    random_seed : int
        Random seed for reproducibility
    jsd_threshold : float
        Minimum JSD value for a site to be tested for significance.
        Sites with observed JSD <= threshold will have p_value = NaN.
        Default is 0.02.
    use_pooled_std : bool
        If True, use a per-site pooled standard deviation (RMS of per-mutation stds)
        instead of each mutation's individual effect_std when generating the null
        distribution. This gives a more stable variance estimate by pooling ~20
        estimates rather than relying on each mutation's 2-4 replicate estimate.
        Default is False (use per-mutation stds).

    Returns
    -------
    pd.DataFrame
        DataFrame with columns:
        - struct_site: site identifier
        - JS_observed: observed JSD value
        - JS_null_mean: mean of null distribution (NaN if below threshold)
        - JS_null_std: standard deviation of null distribution (NaN if below threshold)
        - p_value: empirical p-value (NaN if below threshold)
        - n_mutations: number of mutations at site
        Sorted by struct_site using natural sorting.
    """
    np.random.seed(random_seed)

    def compute_jsd_vectorized(effects, std, n_bootstrap):
        """Vectorized computation of null JSD distribution."""
        n_mutations = len(effects)
        
        # Generate all simulated samples: shape (n_bootstrap, n_mutations)
        effects_1 = np.random.normal(
            loc=effects[np.newaxis, :],  # broadcast to (1, n_mutations)
            scale=std[np.newaxis, :],     # broadcast to (1, n_mutations)
            size=(n_bootstrap, n_mutations)
        )
        effects_2 = np.random.normal(
            loc=effects[np.newaxis, :],
            scale=std[np.newaxis, :],
            size=(n_bootstrap, n_mutations)
        )
        
        # Compute probabilities for all simulations
        p1 = np.exp(effects_1)
        p2 = np.exp(effects_2)
        
        # Normalize: divide each row by its sum
        p1 = p1 / p1.sum(axis=1, keepdims=True)
        p2 = p2 / p2.sum(axis=1, keepdims=True)
        
        # Compute mixture distribution
        m = 0.5 * (p1 + p2)
        
        # Compute KL divergences
        # KL(p||m) = sum(p * log(p/m))
        kl_p_m = np.sum(p1 * np.log(p1 / m), axis=1)
        kl_q_m = np.sum(p2 * np.log(p2 / m), axis=1)
        
        # JSD = 0.5 * (KL(p||m) + KL(q||m))
        jsd = 0.5 * (kl_p_m + kl_q_m)
        
        return jsd

    results = []

    for site, group in df.groupby(site_col):
        # Filter to valid mutations with both effects and stds
        valid = group.dropna(subset=[
            f'{ha_x}_effect', f'{ha_y}_effect',
            f'{ha_x}_effect_std', f'{ha_y}_effect_std'
        ])

        if len(valid) < min_mutations:
            continue

        # Get observed effects
        effects_x = valid[f'{ha_x}_effect'].values
        effects_y = valid[f'{ha_y}_effect'].values

        # Get standard deviations
        std_x = valid[f'{ha_x}_effect_std'].values
        std_y = valid[f'{ha_y}_effect_std'].values

        # Compute observed JSD between ha_x and ha_y
        p_obs = np.exp(effects_x)
        q_obs = np.exp(effects_y)
        p_obs /= p_obs.sum()
        q_obs /= q_obs.sum()
        m_obs = 0.5 * (p_obs + q_obs)
        jsd_obs = 0.5 * (kl_divergence(p_obs, m_obs) + kl_divergence(q_obs, m_obs))

        # Only compute null distribution if JSD exceeds threshold
        if jsd_obs <= jsd_threshold:
            results.append({
                'struct_site': site,
                'JS_observed': jsd_obs,
                'JS_null_mean': np.nan,
                'JS_null_std': np.nan,
                'p_value': np.nan,
                'n_mutations': len(valid),
                'null_distribution': None
            })
            continue

        # Replace per-mutation stds with the per-site std (RMS)
        if use_pooled_std:
            pooled_std_x = np.sqrt(np.mean(std_x ** 2))
            pooled_std_y = np.sqrt(np.mean(std_y ** 2))
            std_x = np.full_like(std_x, pooled_std_x)
            std_y = np.full_like(std_y, pooled_std_y)

        # Simulated null distributions
        jsd_null_x = compute_jsd_vectorized(effects_x, std_x, n_bootstrap)
        jsd_null_y = compute_jsd_vectorized(effects_y, std_y, n_bootstrap)
        
        # Take the mean of the two nulls (balanced approach)
        jsd_null = (jsd_null_x + jsd_null_y) / 2

        # Compute empirical p-value (one-tailed test: is observed JSD greater than null?)
        p_value = np.mean(jsd_null >= jsd_obs)

        results.append({
            'struct_site': site,
            'JS_observed': jsd_obs,
            'JS_null_mean': jsd_null.mean(),
            'JS_null_std': jsd_null.std(),
            'p_value': p_value,
            'n_mutations': len(valid),
            'null_distribution': jsd_null
        })

    # Convert to DataFrame and sort by struct_site using natural sorting
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
    
    return results_df
In [5]:
def apply_fdr_with_threshold(df):
    """Apply FDR correction only to non-NaN p-values."""
    # Initialize q_value column with NaN
    df['q_value'] = np.nan
    
    # Get indices of non-NaN p-values
    tested_mask = df['p_value'].notna()
    
    if tested_mask.sum() > 0:
        # Apply FDR correction only to tested sites
        df.loc[tested_mask, 'q_value'] = false_discovery_control(df.loc[tested_mask, 'p_value'])
    
    return df
In [6]:
def plot_jsd(df, jsd_pvals_df, ha_x, ha_y, identity_df=None, alpha=0.1, only_lineplot=False, variant_selector=None): 
    """
    Plot JSD values with significance coloring.
    
    Parameters
    ----------
    df : pd.DataFrame
        Main dataframe with mutation effects
    jsd_pvals_df : pd.DataFrame
        DataFrame with JSD p-values and q-values from compute_jsd_with_null
    identity_df : pd.DataFrame
        DataFrame with sequence identity information
    ha_x, ha_y : str
        HA subtype names
    alpha : float
        Significance threshold for q-value (default 0.1)
    variant_selector : alt.selection_point, optional
        Shared selection object for synchronized highlighting across plots
    """
    if identity_df is not None:
        result = identity_df.query(
            f'ha_x=="{ha_x.upper()}" and ha_y=="{ha_y.upper()}"'
        )
        shared_aai = result['percent_identity'].values[0] if len(result) > 0 else None
    else:
        shared_aai = None

    amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }
    df['struct_site'] = df['struct_site'].astype(str)

    df = df.assign(
        mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
    )

    # Merge significance data with site-level JSD data
    site_jsd_df = df[[
        'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
        f'JS_{ha_x}_vs_{ha_y}', f'rmsd_{ha_x}{ha_y}'
    ]].dropna().drop_duplicates()
    
    # Merge q-values
    site_jsd_df = site_jsd_df.merge(
        jsd_pvals_df[['struct_site', 'q_value']], 
        on='struct_site', 
        how='left'
    )
    
    # Add significance flag
    site_jsd_df = site_jsd_df.assign(
        significant=lambda x: x['q_value'] < alpha
    )

    # Use provided selector or create a new one
    if variant_selector is None:
        variant_selector = alt.selection_point(
            on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
        )

    sorted_sites = natsorted(df['struct_site'].unique())
    base = alt.Chart(site_jsd_df).encode(
        alt.X(
            "struct_site:O",
            sort=sorted_sites, 
            title='Site',
            axis=alt.Axis(
                labelAngle=0,
                values=['1', '50', '100', '150', '200', '250', '300', '350', '400', '450', '500'],
                tickCount=11,
            )
        ),
        alt.Y(
            f'JS_{ha_x}_vs_{ha_y}:Q', 
            title=['Divergence in amino-acid', 'preferences'],
            axis=alt.Axis(
                grid=False
            ),
            scale=alt.Scale(domain=[0, 0.7])
        ),
        tooltip=[
            'struct_site', 
            f'{ha_x}_wt_aa', 
            f'{ha_y}_wt_aa', 
            alt.Tooltip(f'JS_{ha_x}_vs_{ha_y}', format='.4f'),
            alt.Tooltip(f'rmsd_{ha_x}{ha_y}', format='.2f'),
            alt.Tooltip('q_value', format='.4f'),
            'significant'
        ],
    ).properties(
        width=800,
        height=150
    )

    line = base.mark_line(opacity=0.5, stroke='#999999', size=1)
    
    # Points layer with conditional formatting based on hover and click
    points = base.mark_circle(filled=True).encode(
        size=alt.condition(
            variant_selector,
            alt.value(75),  # when selected
            alt.value(40)  # default
        ),
        color=alt.Color(
            'significant:N',
            title=['Significant', f'(FDR < {alpha})'],
            scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
            legend=alt.Legend(
                titleFontSize=14,
                labelFontSize=12
            )
        ),
        stroke=alt.condition(
            variant_selector,
            alt.value('black'),
            alt.value(None)
        ),
        strokeWidth=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0)
        ),
        opacity=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0.75)
        )
    ).add_params(
        variant_selector
    )

    # Correlation between cell entry effects plot
    base_corr_chart = (alt.Chart(df)
        .mark_text(size=20)
        .encode(
            alt.X(
                f"{ha_x}_effect", 
                title=["Effect on cell entry", f"in {ha_x.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Y(
                f"{ha_y}_effect", 
                title=["Effect on cell entry", f"in {ha_y.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Text('mutant'),
            alt.Color('mutant_type',
                    title='Mutant type',
                    scale=alt.Scale(
                        domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                        range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
                    ),
                    legend=alt.Legend(
                        titleFontSize=16,
                        labelFontSize=13
                    )
            ),
            tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
                     f'{ha_x}_effect', f'{ha_x}_effect_std', 
                     f'{ha_y}_effect', f'{ha_y}_effect_std',
                    f'JS_{ha_x}_vs_{ha_y}'],  
        )
        .transform_filter(
            variant_selector
        )
        .properties(
            height=150,
            width=150,
        )
    )

    # Vertical line at x = 0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
    
    # Horizontal line at y = 0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
    
    corr_chart = vline + hline + base_corr_chart

    # density plot
    density = alt.Chart(
        site_jsd_df
    ).transform_density(
        density=f'JS_{ha_x}_vs_{ha_y}',
        bandwidth=0.02,
        extent=[0,1],
        counts=True,
        steps=200
    ).mark_area(opacity=1, color='#CCEBC5', stroke='black', strokeWidth=1).encode(
        alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
        alt.Y('density:Q', title='Density').stack(None),
    ).properties(
        width=200,
        height=60
    )

    if shared_aai is not None:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()} ({shared_aai:.1f}% Amino Acid Identity)'
    else:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()}'

    # Combine the bar and heatmaps
    if only_lineplot is False:
        combined_chart = alt.vconcat(
            (line + points), corr_chart, density
        ).resolve_scale(
            y='independent', 
            x='independent', 
            color='independent'
        )
    else:
        combined_chart = line + points
    
    combined_chart = combined_chart.properties(
        title=alt.Title(title_text, 
        offset=0,
        fontSize=18,
        subtitleFontSize=16,
        anchor='middle'
        )
    )

    return combined_chart

H7 2'6 vs. H7 2'3¶

In [7]:
h7_23_df = read_and_filter_data('../data/cell_entry_effects/293_2-3_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h7_2-3_wt_aa',
        'mutant': 'mutant',
        'effect': 'h7_2-3_effect',
        'effect_std': 'h7_2-3_effect_std'
    }
)
h7_26_df = read_and_filter_data('../data/cell_entry_effects/293_2-6_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h7_2-6_wt_aa',
        'mutant': 'mutant',
        'effect': 'h7_2-6_effect',
        'effect_std': 'h7_2-6_effect_std'
    }
)

h7_23_26_df = pd.merge(
    h7_23_df,
    h7_26_df,
    left_on=['struct_site', 'h7_2-3_wt_aa', 'mutant'],
    right_on=['struct_site', 'h7_2-6_wt_aa', 'mutant'],
).assign(
    **{'rmsd_h7_2-3h7_2-6': 0}
)

h7_23_26_df.head()
Reading data from ../data/cell_entry_effects/293_2-3_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Reading data from ../data/cell_entry_effects/293_2-6_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Out[7]:
struct_site h7_2-3_wt_aa mutant h7_2-3_effect h7_2-3_effect_std h7_2-6_wt_aa h7_2-6_effect h7_2-6_effect_std rmsd_h7_2-3h7_2-6
0 100 G A -0.00515 0.85400 G -1.276 0.719 0
1 100 G C -3.90900 0.01169 G -4.422 0.000 0
2 100 G D -4.78700 0.00000 G -4.936 0.000 0
3 100 G G 0.00000 0.00000 G 0.000 0.000 0
4 100 G H -4.63900 0.00000 G -4.796 0.000 0
In [8]:
js_df_h7_23_26 = compute_js_divergence_per_site(h7_23_26_df, 'h7_2-3', 'h7_2-6', min_mutations=10)

# Compute JSD with null distributions for each comparison
jsd_with_pvals_h7_23_26 = compute_jsd_with_null(
    js_df_h7_23_26,
    'h7_2-3', 'h7_2-6',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h7_23_26 = apply_fdr_with_threshold(jsd_with_pvals_h7_23_26)

# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H7 2-3 vs H7 2-6, q < 0.1):")
total_h7_23_26 = len(jsd_with_pvals_h7_23_26)
sig_h7_23_26 = (jsd_with_pvals_h7_23_26['q_value'] < 0.1).sum()
print(f"  {sig_h7_23_26} / {total_h7_23_26} sites ({sig_h7_23_26/total_h7_23_26:.2%})")
Significant sites (H7 2-3 vs H7 2-6, q < 0.1):
  0 / 492 sites (0.00%)
In [9]:
chart = plot_jsd(
    js_df_h7_23_26,
    jsd_with_pvals_h7_23_26,
    'h7_2-3', 'h7_2-6'
)
chart.display()

H5 2'6 vs. H5 2'3¶

In [10]:
h5_23_df = read_and_filter_data('../data/cell_entry_effects/293_SA23_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h5_2-3_wt_aa',
        'mutant': 'mutant',
        'effect': 'h5_2-3_effect',
        'effect_std': 'h5_2-3_effect_std'
    }
)
h5_26_df = read_and_filter_data('../data/cell_entry_effects/293_SA26_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h5_2-6_wt_aa',
        'mutant': 'mutant',
        'effect': 'h5_2-6_effect',
        'effect_std': 'h5_2-6_effect_std'
    }
)

h5_23_26_df = pd.merge(
    h5_23_df,
    h5_26_df,
    left_on=['struct_site', 'h5_2-3_wt_aa', 'mutant'],
    right_on=['struct_site', 'h5_2-6_wt_aa', 'mutant'],
).assign(
    **{'rmsd_h5_2-3h5_2-6': 0}
)

h5_23_26_df.head()
Reading data from ../data/cell_entry_effects/293_SA23_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Reading data from ../data/cell_entry_effects/293_SA26_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Out[10]:
struct_site h5_2-3_wt_aa mutant h5_2-3_effect h5_2-3_effect_std h5_2-6_wt_aa h5_2-6_effect h5_2-6_effect_std rmsd_h5_2-3h5_2-6
0 -1 L A -2.5730 1.0460 L -1.2890 1.0410 0
1 -1 L C 0.6595 0.1659 L 0.4317 0.3278 0
2 -1 L D -5.0000 0.0000 L -2.8930 1.6370 0
3 -1 L E -2.2900 0.2044 L -1.9370 0.7707 0
4 -1 L F 0.2525 0.2299 L -0.9021 1.4380 0
In [11]:
js_df_h5_23_26 = compute_js_divergence_per_site(h5_23_26_df, 'h5_2-3', 'h5_2-6', min_mutations=10)

# Compute JSD with null distributions for each comparison
jsd_with_pvals_h5_23_26 = compute_jsd_with_null(
    js_df_h5_23_26,
    'h5_2-3', 'h5_2-6',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h5_23_26 = apply_fdr_with_threshold(jsd_with_pvals_h5_23_26)

# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H5 2-3 vs H5 2-6, q < 0.1):")
total_h5_23_26 = len(jsd_with_pvals_h5_23_26)
sig_h5_23_26 = (jsd_with_pvals_h5_23_26['q_value'] < 0.1).sum()
print(f"  {sig_h5_23_26} / {total_h5_23_26} sites ({sig_h5_23_26/total_h5_23_26:.2%})")
Significant sites (H5 2-3 vs H5 2-6, q < 0.1):
  15 / 549 sites (2.73%)
In [12]:
chart = plot_jsd(
    js_df_h5_23_26,
    jsd_with_pvals_h5_23_26,
    'h5_2-3', 'h5_2-6'
)
chart.display()
In [13]:
js_df_h3_h5 = pd.read_csv('../results/divergence/h3_h5_divergence.csv')
js_df_h3_h7 = pd.read_csv('../results/divergence/h3_h7_divergence.csv')
js_df_h5_h7 = pd.read_csv('../results/divergence/h5_h7_divergence.csv')
In [14]:
def plot_ridgeline_density(dfs_dict, x_col_template='JS_{ha_x}_vs_{ha_y}', 
                           bandwidth=0.02, extent=[0,0.65], 
                           colors=None, width=200, height=400,
                           overlap=2.5, label_mapping=None):
    """
    Plot ridgeline (joyplot) density plots for multiple dataframes.
    
    Parameters:
    -----------
    dfs_dict : dict
        Dictionary where keys are comparison labels (e.g., 'h3-h5', 'h3-h7') 
        and values are tuples of (df, ha_x, ha_y)
        Example: {'h3-h5': (js_df_h3_h5, 'h3', 'h5'), 
                  'h3-h7': (js_df_h3_h7, 'h3', 'h7')}
    x_col_template : str
        Template for column name with {ha_x} and {ha_y} placeholders
    bandwidth : float
        Bandwidth for density estimation
    extent : list
        [min, max] for density calculation
    colors : list or None
        List of colors for each comparison. If None, uses default color scheme
    width, height : int
        Dimensions of the plot
    overlap : float
        How much the ridges overlap (higher = more overlap)
    
    Returns:
    --------
    alt.Chart : Ridgeline density plot
    """
    import pandas as pd
    import altair as alt
    
    # Default color scheme if none provided
    if colors is None:
        colors = ['#8DD3C7', '#FFFFB3', '#BEBADA', '#FB8072', '#80B1D3', '#FDB462']
    
    # Combine all dataframes with a comparison label
    combined_data = []
    for i, (comparison, (df, ha_x, ha_y)) in enumerate(dfs_dict.items()):
        col_name = x_col_template.format(ha_x=ha_x, ha_y=ha_y)
        temp_df = df[[col_name]].copy()
        temp_df['comparison'] = comparison
        temp_df['value'] = temp_df[col_name]
        combined_data.append(temp_df[['value', 'comparison']])
    
    combined_df = pd.concat(combined_data, ignore_index=True)
    
    if label_mapping is not None:
            combined_df['comparison'] = combined_df['comparison'].map(label_mapping)
        
    # Calculate step size for ridgeline spacing
    step = height / (len(dfs_dict) * overlap)
    
    # Create the ridgeline plot
    ridgeline = alt.Chart(combined_df).transform_density(
        density='value',
        bandwidth=bandwidth,
        extent=extent,
        groupby=['comparison'],
        steps=200
    ).transform_calculate(
        # Offset each comparison vertically based on its order
        yvalue='datum.density'
    ).mark_area(
        opacity=1,
        stroke='black',
        strokeWidth=1,
        interpolate='monotone'
    ).encode(
        alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
        alt.Y('density:Q', 
              title='Density',
              axis=None),
        alt.Row('comparison:N',
                title=None,
                header=alt.Header(labelAngle=0, labelAlign='left')),
        alt.Fill('comparison:N',
                 legend=None,
                 scale=alt.Scale(range=colors[:len(dfs_dict)]))
    ).properties(
        width=width,
        height=step,
        bounds='flush'
    ).configure_facet(
        spacing=-(step * (overlap - 1))
    ).configure_view(
        stroke=None
    ).configure_header(
    labelFontSize=14
)
    
    return ridgeline

# Example usage:
dfs_to_plot = {
    'h3-h5': (js_df_h3_h5, 'h3', 'h5'),
    'h3-h7': (js_df_h3_h7, 'h3', 'h7'),
    'h5-h7': (js_df_h5_h7, 'h5', 'h7'),
    'h7_2-3-h7_2-6': (js_df_h7_23_26, 'h7_2-3', 'h7_2-6')
}

plot_ridgeline_density(
    dfs_to_plot, 
    label_mapping={
        'h3-h5': 'H3 vs. H5',
        'h3-h7': 'H3 vs. H7',
        'h5-h7': 'H5 vs. H7',
        'h7_2-3-h7_2-6': ['H7 (a2,3) vs.', 'H7 (a2,6)']
    }
).display()