In [1]:
import pandas as pd
import numpy as np
import altair as alt
import theme
from natsort import natsorted, natsort_keygen
from scipy.stats import false_discovery_control
alt.data_transformers.disable_max_rows()
Out[1]:
DataTransformerRegistry.enable('default')
In [2]:
def read_and_filter_data(
path,
effect_std_filter=2,
times_seen_filter=2,
n_selections_filter=2,
clip_effect=-5
):
print(f'Reading data from {path}')
print(
f"Filtering for:\n"
f" effect_std <= {effect_std_filter}\n"
f" times_seen >= {times_seen_filter}\n"
f" n_selections >= {n_selections_filter}"
)
print(f"Clipping effect values at {clip_effect}")
df = pd.read_csv(path).query(
'effect_std <= @effect_std_filter and \
times_seen >= @times_seen_filter and \
n_selections >= @n_selections_filter'
).query(
'mutant not in ["*", "-"]' # don't want stop codons/indels
)
df['site'] = df['site'].astype(str)
df['effect'] = df['effect'].clip(-5)
df = pd.concat([
df,
df[['site', 'wildtype']].drop_duplicates().assign(
mutant=lambda x: x['wildtype'],
effect=0.0,
effect_std=0.0,
times_seen=np.nan,
n_selections=np.nan
) # add wildtype sites with zero effect
], ignore_index=True).sort_values(['site', 'mutant']).reset_index(drop=True)
return df
In [3]:
def kl_divergence(p, q):
return np.sum(p * np.log(p / q))
def compute_js_divergence_per_site(df, ha_x, ha_y, site_col="struct_site", min_mutations=15):
"""Compute JS divergence at each site and merge it back to the dataframe."""
js_per_site = {}
for site, group in df.groupby(site_col):
valid = group.dropna(subset=[f'{ha_x}_effect', f'{ha_y}_effect'])
js_div = np.nan
if len(valid) >= min_mutations:
p = np.exp(valid[f'{ha_x}_effect'].values)
q = np.exp(valid[f'{ha_y}_effect'].values)
p /= p.sum()
q /= q.sum()
m = 0.5 * (p + q)
js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
js_per_site[site] = js_div
# Create a column with the JS divergence duplicated across each row at the same site
df = df.copy()
col_name = f"JS_{ha_x}_vs_{ha_y}"
df[col_name] = df[site_col].map(js_per_site)
return df
In [4]:
def compute_jsd_with_null(
df,
ha_x,
ha_y,
site_col="struct_site",
min_mutations=15,
n_bootstrap=1000,
random_seed=42,
jsd_threshold=0.02,
use_pooled_std=False
):
"""
Compute JS divergence with simulated null distribution for significance testing.
The null distribution represents: "What JSD would I observe from measurement noise alone?"
The null is generated by simulating data under Gaussian measurement noise:
1. ha_x null: Generate two replicate datasets from ha_x effects and stds, compute JSD
2. ha_y null: Generate two replicate datasets from ha_y effects and stds, compute JSD
3. Take the mean of the two null distributions
This assumes measurement noise is normally distributed. A significant result means the
observed JSD is larger than expected from measurement noise alone.
Only sites with observed JSD > jsd_threshold are tested for significance.
Parameters
----------
df : pd.DataFrame
Dataframe with mutation effects and effect_std columns
ha_x, ha_y : str
HA subtype names (e.g., 'h3', 'h5')
site_col : str
Column name for site identifier
min_mutations : int
Minimum number of mutations required at a site
n_bootstrap : int
Number of simulations for null distribution
random_seed : int
Random seed for reproducibility
jsd_threshold : float
Minimum JSD value for a site to be tested for significance.
Sites with observed JSD <= threshold will have p_value = NaN.
Default is 0.02.
use_pooled_std : bool
If True, use a per-site pooled standard deviation (RMS of per-mutation stds)
instead of each mutation's individual effect_std when generating the null
distribution. This gives a more stable variance estimate by pooling ~20
estimates rather than relying on each mutation's 2-4 replicate estimate.
Default is False (use per-mutation stds).
Returns
-------
pd.DataFrame
DataFrame with columns:
- struct_site: site identifier
- JS_observed: observed JSD value
- JS_null_mean: mean of null distribution (NaN if below threshold)
- JS_null_std: standard deviation of null distribution (NaN if below threshold)
- p_value: empirical p-value (NaN if below threshold)
- n_mutations: number of mutations at site
Sorted by struct_site using natural sorting.
"""
np.random.seed(random_seed)
def compute_jsd_vectorized(effects, std, n_bootstrap):
"""Vectorized computation of null JSD distribution."""
n_mutations = len(effects)
# Generate all simulated samples: shape (n_bootstrap, n_mutations)
effects_1 = np.random.normal(
loc=effects[np.newaxis, :], # broadcast to (1, n_mutations)
scale=std[np.newaxis, :], # broadcast to (1, n_mutations)
size=(n_bootstrap, n_mutations)
)
effects_2 = np.random.normal(
loc=effects[np.newaxis, :],
scale=std[np.newaxis, :],
size=(n_bootstrap, n_mutations)
)
# Compute probabilities for all simulations
p1 = np.exp(effects_1)
p2 = np.exp(effects_2)
# Normalize: divide each row by its sum
p1 = p1 / p1.sum(axis=1, keepdims=True)
p2 = p2 / p2.sum(axis=1, keepdims=True)
# Compute mixture distribution
m = 0.5 * (p1 + p2)
# Compute KL divergences
# KL(p||m) = sum(p * log(p/m))
kl_p_m = np.sum(p1 * np.log(p1 / m), axis=1)
kl_q_m = np.sum(p2 * np.log(p2 / m), axis=1)
# JSD = 0.5 * (KL(p||m) + KL(q||m))
jsd = 0.5 * (kl_p_m + kl_q_m)
return jsd
results = []
for site, group in df.groupby(site_col):
# Filter to valid mutations with both effects and stds
valid = group.dropna(subset=[
f'{ha_x}_effect', f'{ha_y}_effect',
f'{ha_x}_effect_std', f'{ha_y}_effect_std'
])
if len(valid) < min_mutations:
continue
# Get observed effects
effects_x = valid[f'{ha_x}_effect'].values
effects_y = valid[f'{ha_y}_effect'].values
# Get standard deviations
std_x = valid[f'{ha_x}_effect_std'].values
std_y = valid[f'{ha_y}_effect_std'].values
# Compute observed JSD between ha_x and ha_y
p_obs = np.exp(effects_x)
q_obs = np.exp(effects_y)
p_obs /= p_obs.sum()
q_obs /= q_obs.sum()
m_obs = 0.5 * (p_obs + q_obs)
jsd_obs = 0.5 * (kl_divergence(p_obs, m_obs) + kl_divergence(q_obs, m_obs))
# Only compute null distribution if JSD exceeds threshold
if jsd_obs <= jsd_threshold:
results.append({
'struct_site': site,
'JS_observed': jsd_obs,
'JS_null_mean': np.nan,
'JS_null_std': np.nan,
'p_value': np.nan,
'n_mutations': len(valid),
'null_distribution': None
})
continue
# Replace per-mutation stds with the per-site std (RMS)
if use_pooled_std:
pooled_std_x = np.sqrt(np.mean(std_x ** 2))
pooled_std_y = np.sqrt(np.mean(std_y ** 2))
std_x = np.full_like(std_x, pooled_std_x)
std_y = np.full_like(std_y, pooled_std_y)
# Simulated null distributions
jsd_null_x = compute_jsd_vectorized(effects_x, std_x, n_bootstrap)
jsd_null_y = compute_jsd_vectorized(effects_y, std_y, n_bootstrap)
# Take the mean of the two nulls (balanced approach)
jsd_null = (jsd_null_x + jsd_null_y) / 2
# Compute empirical p-value (one-tailed test: is observed JSD greater than null?)
p_value = np.mean(jsd_null >= jsd_obs)
results.append({
'struct_site': site,
'JS_observed': jsd_obs,
'JS_null_mean': jsd_null.mean(),
'JS_null_std': jsd_null.std(),
'p_value': p_value,
'n_mutations': len(valid),
'null_distribution': jsd_null
})
# Convert to DataFrame and sort by struct_site using natural sorting
results_df = pd.DataFrame(results)
results_df = results_df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
return results_df
In [5]:
def apply_fdr_with_threshold(df):
"""Apply FDR correction only to non-NaN p-values."""
# Initialize q_value column with NaN
df['q_value'] = np.nan
# Get indices of non-NaN p-values
tested_mask = df['p_value'].notna()
if tested_mask.sum() > 0:
# Apply FDR correction only to tested sites
df.loc[tested_mask, 'q_value'] = false_discovery_control(df.loc[tested_mask, 'p_value'])
return df
In [6]:
def plot_jsd(df, jsd_pvals_df, ha_x, ha_y, identity_df=None, alpha=0.1, only_lineplot=False, variant_selector=None):
"""
Plot JSD values with significance coloring.
Parameters
----------
df : pd.DataFrame
Main dataframe with mutation effects
jsd_pvals_df : pd.DataFrame
DataFrame with JSD p-values and q-values from compute_jsd_with_null
identity_df : pd.DataFrame
DataFrame with sequence identity information
ha_x, ha_y : str
HA subtype names
alpha : float
Significance threshold for q-value (default 0.1)
variant_selector : alt.selection_point, optional
Shared selection object for synchronized highlighting across plots
"""
if identity_df is not None:
result = identity_df.query(
f'ha_x=="{ha_x.upper()}" and ha_y=="{ha_y.upper()}"'
)
shared_aai = result['percent_identity'].values[0] if len(result) > 0 else None
else:
shared_aai = None
amino_acid_classification = {
'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
'D': 'Negative', 'E': 'Negative',
'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
'C': 'Special', 'G': 'Special', 'P': 'Special'
}
df['struct_site'] = df['struct_site'].astype(str)
df = df.assign(
mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
)
# Merge significance data with site-level JSD data
site_jsd_df = df[[
'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa',
f'JS_{ha_x}_vs_{ha_y}', f'rmsd_{ha_x}{ha_y}'
]].dropna().drop_duplicates()
# Merge q-values
site_jsd_df = site_jsd_df.merge(
jsd_pvals_df[['struct_site', 'q_value']],
on='struct_site',
how='left'
)
# Add significance flag
site_jsd_df = site_jsd_df.assign(
significant=lambda x: x['q_value'] < alpha
)
# Use provided selector or create a new one
if variant_selector is None:
variant_selector = alt.selection_point(
on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
)
sorted_sites = natsorted(df['struct_site'].unique())
base = alt.Chart(site_jsd_df).encode(
alt.X(
"struct_site:O",
sort=sorted_sites,
title='Site',
axis=alt.Axis(
labelAngle=0,
values=['1', '50', '100', '150', '200', '250', '300', '350', '400', '450', '500'],
tickCount=11,
)
),
alt.Y(
f'JS_{ha_x}_vs_{ha_y}:Q',
title=['Divergence in amino-acid', 'preferences'],
axis=alt.Axis(
grid=False
),
scale=alt.Scale(domain=[0, 0.7])
),
tooltip=[
'struct_site',
f'{ha_x}_wt_aa',
f'{ha_y}_wt_aa',
alt.Tooltip(f'JS_{ha_x}_vs_{ha_y}', format='.4f'),
alt.Tooltip(f'rmsd_{ha_x}{ha_y}', format='.2f'),
alt.Tooltip('q_value', format='.4f'),
'significant'
],
).properties(
width=800,
height=150
)
line = base.mark_line(opacity=0.5, stroke='#999999', size=1)
# Points layer with conditional formatting based on hover and click
points = base.mark_circle(filled=True).encode(
size=alt.condition(
variant_selector,
alt.value(75), # when selected
alt.value(40) # default
),
color=alt.Color(
'significant:N',
title=['Significant', f'(FDR < {alpha})'],
scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
legend=alt.Legend(
titleFontSize=14,
labelFontSize=12
)
),
stroke=alt.condition(
variant_selector,
alt.value('black'),
alt.value(None)
),
strokeWidth=alt.condition(
variant_selector,
alt.value(1),
alt.value(0)
),
opacity=alt.condition(
variant_selector,
alt.value(1),
alt.value(0.75)
)
).add_params(
variant_selector
)
# Correlation between cell entry effects plot
base_corr_chart = (alt.Chart(df)
.mark_text(size=20)
.encode(
alt.X(
f"{ha_x}_effect",
title=["Effect on cell entry", f"in {ha_x.upper()}"],
scale=alt.Scale(domain=[-6,1.5])
),
alt.Y(
f"{ha_y}_effect",
title=["Effect on cell entry", f"in {ha_y.upper()}"],
scale=alt.Scale(domain=[-6,1.5])
),
alt.Text('mutant'),
alt.Color('mutant_type',
title='Mutant type',
scale=alt.Scale(
domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
),
legend=alt.Legend(
titleFontSize=16,
labelFontSize=13
)
),
tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa',
f'{ha_x}_effect', f'{ha_x}_effect_std',
f'{ha_y}_effect', f'{ha_y}_effect_std',
f'JS_{ha_x}_vs_{ha_y}'],
)
.transform_filter(
variant_selector
)
.properties(
height=150,
width=150,
)
)
# Vertical line at x = 0
vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
# Horizontal line at y = 0
hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
corr_chart = vline + hline + base_corr_chart
# density plot
density = alt.Chart(
site_jsd_df
).transform_density(
density=f'JS_{ha_x}_vs_{ha_y}',
bandwidth=0.02,
extent=[0,1],
counts=True,
steps=200
).mark_area(opacity=1, color='#CCEBC5', stroke='black', strokeWidth=1).encode(
alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
alt.Y('density:Q', title='Density').stack(None),
).properties(
width=200,
height=60
)
if shared_aai is not None:
title_text = f'{ha_x.upper()} vs. {ha_y.upper()} ({shared_aai:.1f}% Amino Acid Identity)'
else:
title_text = f'{ha_x.upper()} vs. {ha_y.upper()}'
# Combine the bar and heatmaps
if only_lineplot is False:
combined_chart = alt.vconcat(
(line + points), corr_chart, density
).resolve_scale(
y='independent',
x='independent',
color='independent'
)
else:
combined_chart = line + points
combined_chart = combined_chart.properties(
title=alt.Title(title_text,
offset=0,
fontSize=18,
subtitleFontSize=16,
anchor='middle'
)
)
return combined_chart
H7 2'6 vs. H7 2'3¶
In [7]:
h7_23_df = read_and_filter_data('../data/cell_entry_effects/293_2-3_entry_func_effects.csv')[[
'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
columns={
'site': 'struct_site',
'wildtype': 'h7_2-3_wt_aa',
'mutant': 'mutant',
'effect': 'h7_2-3_effect',
'effect_std': 'h7_2-3_effect_std'
}
)
h7_26_df = read_and_filter_data('../data/cell_entry_effects/293_2-6_entry_func_effects.csv')[[
'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
columns={
'site': 'struct_site',
'wildtype': 'h7_2-6_wt_aa',
'mutant': 'mutant',
'effect': 'h7_2-6_effect',
'effect_std': 'h7_2-6_effect_std'
}
)
h7_23_26_df = pd.merge(
h7_23_df,
h7_26_df,
left_on=['struct_site', 'h7_2-3_wt_aa', 'mutant'],
right_on=['struct_site', 'h7_2-6_wt_aa', 'mutant'],
).assign(
**{'rmsd_h7_2-3h7_2-6': 0}
)
h7_23_26_df.head()
Reading data from ../data/cell_entry_effects/293_2-3_entry_func_effects.csv Filtering for: effect_std <= 2 times_seen >= 2 n_selections >= 2 Clipping effect values at -5 Reading data from ../data/cell_entry_effects/293_2-6_entry_func_effects.csv Filtering for: effect_std <= 2 times_seen >= 2 n_selections >= 2 Clipping effect values at -5
Out[7]:
| struct_site | h7_2-3_wt_aa | mutant | h7_2-3_effect | h7_2-3_effect_std | h7_2-6_wt_aa | h7_2-6_effect | h7_2-6_effect_std | rmsd_h7_2-3h7_2-6 | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 100 | G | A | -0.00515 | 0.85400 | G | -1.276 | 0.719 | 0 |
| 1 | 100 | G | C | -3.90900 | 0.01169 | G | -4.422 | 0.000 | 0 |
| 2 | 100 | G | D | -4.78700 | 0.00000 | G | -4.936 | 0.000 | 0 |
| 3 | 100 | G | G | 0.00000 | 0.00000 | G | 0.000 | 0.000 | 0 |
| 4 | 100 | G | H | -4.63900 | 0.00000 | G | -4.796 | 0.000 | 0 |
In [8]:
js_df_h7_23_26 = compute_js_divergence_per_site(h7_23_26_df, 'h7_2-3', 'h7_2-6', min_mutations=10)
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h7_23_26 = compute_jsd_with_null(
js_df_h7_23_26,
'h7_2-3', 'h7_2-6',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
jsd_with_pvals_h7_23_26 = apply_fdr_with_threshold(jsd_with_pvals_h7_23_26)
# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H7 2-3 vs H7 2-6, q < 0.1):")
total_h7_23_26 = len(jsd_with_pvals_h7_23_26)
sig_h7_23_26 = (jsd_with_pvals_h7_23_26['q_value'] < 0.1).sum()
print(f" {sig_h7_23_26} / {total_h7_23_26} sites ({sig_h7_23_26/total_h7_23_26:.2%})")
Significant sites (H7 2-3 vs H7 2-6, q < 0.1): 0 / 492 sites (0.00%)
In [9]:
chart = plot_jsd(
js_df_h7_23_26,
jsd_with_pvals_h7_23_26,
'h7_2-3', 'h7_2-6'
)
chart.display()
H5 2'6 vs. H5 2'3¶
In [10]:
h5_23_df = read_and_filter_data('../data/cell_entry_effects/293_SA23_entry_func_effects.csv')[[
'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
columns={
'site': 'struct_site',
'wildtype': 'h5_2-3_wt_aa',
'mutant': 'mutant',
'effect': 'h5_2-3_effect',
'effect_std': 'h5_2-3_effect_std'
}
)
h5_26_df = read_and_filter_data('../data/cell_entry_effects/293_SA26_entry_func_effects.csv')[[
'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
columns={
'site': 'struct_site',
'wildtype': 'h5_2-6_wt_aa',
'mutant': 'mutant',
'effect': 'h5_2-6_effect',
'effect_std': 'h5_2-6_effect_std'
}
)
h5_23_26_df = pd.merge(
h5_23_df,
h5_26_df,
left_on=['struct_site', 'h5_2-3_wt_aa', 'mutant'],
right_on=['struct_site', 'h5_2-6_wt_aa', 'mutant'],
).assign(
**{'rmsd_h5_2-3h5_2-6': 0}
)
h5_23_26_df.head()
Reading data from ../data/cell_entry_effects/293_SA23_entry_func_effects.csv Filtering for: effect_std <= 2 times_seen >= 2 n_selections >= 2 Clipping effect values at -5 Reading data from ../data/cell_entry_effects/293_SA26_entry_func_effects.csv Filtering for: effect_std <= 2 times_seen >= 2 n_selections >= 2 Clipping effect values at -5
Out[10]:
| struct_site | h5_2-3_wt_aa | mutant | h5_2-3_effect | h5_2-3_effect_std | h5_2-6_wt_aa | h5_2-6_effect | h5_2-6_effect_std | rmsd_h5_2-3h5_2-6 | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | -1 | L | A | -2.5730 | 1.0460 | L | -1.2890 | 1.0410 | 0 |
| 1 | -1 | L | C | 0.6595 | 0.1659 | L | 0.4317 | 0.3278 | 0 |
| 2 | -1 | L | D | -5.0000 | 0.0000 | L | -2.8930 | 1.6370 | 0 |
| 3 | -1 | L | E | -2.2900 | 0.2044 | L | -1.9370 | 0.7707 | 0 |
| 4 | -1 | L | F | 0.2525 | 0.2299 | L | -0.9021 | 1.4380 | 0 |
In [11]:
js_df_h5_23_26 = compute_js_divergence_per_site(h5_23_26_df, 'h5_2-3', 'h5_2-6', min_mutations=10)
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h5_23_26 = compute_jsd_with_null(
js_df_h5_23_26,
'h5_2-3', 'h5_2-6',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
jsd_with_pvals_h5_23_26 = apply_fdr_with_threshold(jsd_with_pvals_h5_23_26)
# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H5 2-3 vs H5 2-6, q < 0.1):")
total_h5_23_26 = len(jsd_with_pvals_h5_23_26)
sig_h5_23_26 = (jsd_with_pvals_h5_23_26['q_value'] < 0.1).sum()
print(f" {sig_h5_23_26} / {total_h5_23_26} sites ({sig_h5_23_26/total_h5_23_26:.2%})")
Significant sites (H5 2-3 vs H5 2-6, q < 0.1): 15 / 549 sites (2.73%)
In [12]:
chart = plot_jsd(
js_df_h5_23_26,
jsd_with_pvals_h5_23_26,
'h5_2-3', 'h5_2-6'
)
chart.display()
In [13]:
js_df_h3_h5 = pd.read_csv('../results/divergence/h3_h5_divergence.csv')
js_df_h3_h7 = pd.read_csv('../results/divergence/h3_h7_divergence.csv')
js_df_h5_h7 = pd.read_csv('../results/divergence/h5_h7_divergence.csv')
In [14]:
def plot_ridgeline_density(dfs_dict, x_col_template='JS_{ha_x}_vs_{ha_y}',
bandwidth=0.02, extent=[0,0.65],
colors=None, width=200, height=400,
overlap=2.5, label_mapping=None):
"""
Plot ridgeline (joyplot) density plots for multiple dataframes.
Parameters:
-----------
dfs_dict : dict
Dictionary where keys are comparison labels (e.g., 'h3-h5', 'h3-h7')
and values are tuples of (df, ha_x, ha_y)
Example: {'h3-h5': (js_df_h3_h5, 'h3', 'h5'),
'h3-h7': (js_df_h3_h7, 'h3', 'h7')}
x_col_template : str
Template for column name with {ha_x} and {ha_y} placeholders
bandwidth : float
Bandwidth for density estimation
extent : list
[min, max] for density calculation
colors : list or None
List of colors for each comparison. If None, uses default color scheme
width, height : int
Dimensions of the plot
overlap : float
How much the ridges overlap (higher = more overlap)
Returns:
--------
alt.Chart : Ridgeline density plot
"""
import pandas as pd
import altair as alt
# Default color scheme if none provided
if colors is None:
colors = ['#8DD3C7', '#FFFFB3', '#BEBADA', '#FB8072', '#80B1D3', '#FDB462']
# Combine all dataframes with a comparison label
combined_data = []
for i, (comparison, (df, ha_x, ha_y)) in enumerate(dfs_dict.items()):
col_name = x_col_template.format(ha_x=ha_x, ha_y=ha_y)
temp_df = df[[col_name]].copy()
temp_df['comparison'] = comparison
temp_df['value'] = temp_df[col_name]
combined_data.append(temp_df[['value', 'comparison']])
combined_df = pd.concat(combined_data, ignore_index=True)
if label_mapping is not None:
combined_df['comparison'] = combined_df['comparison'].map(label_mapping)
# Calculate step size for ridgeline spacing
step = height / (len(dfs_dict) * overlap)
# Create the ridgeline plot
ridgeline = alt.Chart(combined_df).transform_density(
density='value',
bandwidth=bandwidth,
extent=extent,
groupby=['comparison'],
steps=200
).transform_calculate(
# Offset each comparison vertically based on its order
yvalue='datum.density'
).mark_area(
opacity=1,
stroke='black',
strokeWidth=1,
interpolate='monotone'
).encode(
alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
alt.Y('density:Q',
title='Density',
axis=None),
alt.Row('comparison:N',
title=None,
header=alt.Header(labelAngle=0, labelAlign='left')),
alt.Fill('comparison:N',
legend=None,
scale=alt.Scale(range=colors[:len(dfs_dict)]))
).properties(
width=width,
height=step,
bounds='flush'
).configure_facet(
spacing=-(step * (overlap - 1))
).configure_view(
stroke=None
).configure_header(
labelFontSize=14
)
return ridgeline
# Example usage:
dfs_to_plot = {
'h3-h5': (js_df_h3_h5, 'h3', 'h5'),
'h3-h7': (js_df_h3_h7, 'h3', 'h7'),
'h5-h7': (js_df_h5_h7, 'h5', 'h7'),
'h7_2-3-h7_2-6': (js_df_h7_23_26, 'h7_2-3', 'h7_2-6')
}
plot_ridgeline_density(
dfs_to_plot,
label_mapping={
'h3-h5': 'H3 vs. H5',
'h3-h7': 'H3 vs. H7',
'h5-h7': 'H5 vs. H7',
'h7_2-3-h7_2-6': ['H7 (a2,3) vs.', 'H7 (a2,6)']
}
).display()