InĀ [1]:
import pandas as pd
import numpy as np
import altair as alt

from Bio import AlignIO
from natsort import natsort_keygen

import theme
InĀ [2]:
def plot_identity_heatmap(df, title, width=100, height=100, order=None, highlight_pairs=None):
    """
    Plot a lower-triangle heatmap of sequence identity.
    
    Parameters:
    -----------
    df : DataFrame
        DataFrame with columns 'ha_x', 'ha_y', and 'percent_identity'
    title : str
        Title for the plot
    width : int
        Width of the plot
    height : int
        Height of the plot
    order : list, optional
        Order of HA subtypes for axes
    highlight_pairs : list of tuples, optional
        List of (ha_x, ha_y) pairs to highlight in red, e.g., [('H3', 'H5'), ('H3', 'H7')]
    """
    if order is None:
        order = sorted(df['ha_x'].unique(), key=natsort_keygen())

    # Filter to lower triangle only (excluding diagonal)
    # For lower triangle: index of ha_y > index of ha_x
    order_dict = {ha: i for i, ha in enumerate(order)}
    df_filtered = df[
        df.apply(lambda row: order_dict[row['ha_y']] > order_dict[row['ha_x']], axis=1)
    ].copy()
    
    # Add highlight column if highlight_pairs is specified
    if highlight_pairs is not None:
        # Normalize pairs to ensure they're in the right order (lower triangle)
        normalized_pairs = set()
        for ha_x, ha_y in highlight_pairs:
            # Make sure we're using the lower triangle ordering
            if order_dict[ha_y] > order_dict[ha_x]:
                normalized_pairs.add((ha_x, ha_y))
            else:
                normalized_pairs.add((ha_y, ha_x))
        
        df_filtered['highlighted'] = df_filtered.apply(
            lambda row: (row['ha_x'], row['ha_y']) in normalized_pairs,
            axis=1
        )
    else:
        df_filtered['highlighted'] = False

    # Split data into highlighted and non-highlighted
    df_normal = df_filtered[~df_filtered['highlighted']]
    df_highlight = df_filtered[df_filtered['highlighted']]

    # Create base heatmap for non-highlighted cells
    heatmap_normal = alt.Chart(df_normal).mark_rect(
        opacity=0.95, stroke='black', strokeWidth=1
    ).encode(
        x=alt.X('ha_x:N', title=None, axis=alt.Axis(labelAngle=-90), sort=order),
        y=alt.Y('ha_y:N', title=None, sort=order),
        color=alt.Color(
            'percent_identity:Q',
            scale=alt.Scale(scheme='blues'),
            title=['Amino Acid', 'Identity (%)']
        )
    ).properties(
        width=width,
        height=height,
        title=alt.Title(title, anchor='middle')
    )
    
    # Create heatmap for highlighted cells (drawn on top)
    heatmap_highlight = alt.Chart(df_highlight).mark_rect(
        opacity=0.95, stroke='red', strokeWidth=3
    ).encode(
        x=alt.X('ha_x:N', title=None, axis=alt.Axis(labelAngle=-90), sort=order),
        y=alt.Y('ha_y:N', title=None, sort=order),
        color=alt.Color(
            'percent_identity:Q',
            scale=alt.Scale(scheme='blues'),
            title=['Amino Acid', 'Identity (%)']
        )
    )

    # Add text labels on the heatmap
    text = alt.Chart(df_filtered).mark_text(baseline='middle').encode(
        x=alt.X('ha_x:N', sort=order),
        y=alt.Y('ha_y:N', sort=order),
        text=alt.Text('percent_identity:Q', format='.0f'),
        color=alt.condition(
            alt.datum.percent_identity > 55,
            alt.value('white'),
            alt.value('black')
        )
    )

    # Layer them: normal cells first, then highlighted cells, then text on top
    chart = (heatmap_normal + heatmap_highlight + text)
    return chart
InĀ [3]:
# Calculate pairwise amino acid identity for all 18 HA subtypes

# Read the alignment
alignment = AlignIO.read('../data/ha18_alignment_revised.fasta', 'fasta')

# Extract HA subtype names from sequence IDs
# Format: A/.../H#N#
ha_subtypes = []
for record in alignment:
    # Extract H number from sequence ID (e.g., H1N1 -> H1)
    subtype = record.id.split('/')[-1].split('N')[0]
    ha_subtypes.append(subtype)

print(f"Found {len(ha_subtypes)} HA subtypes: {', '.join(ha_subtypes)}")

# Calculate pairwise identity matrix
n_seqs = len(alignment)
results = []

for i in range(n_seqs):
    for j in range(n_seqs):
        seq1 = str(alignment[i].seq)
        seq2 = str(alignment[j].seq)
        
        if i == j:
            # Self-comparison
            results.append({
                'ha_x': ha_subtypes[i],
                'ha_y': ha_subtypes[j],
                'matches': np.nan,
                'alignable_residues': np.nan,
                'percent_identity': 100.0
            })
        else:
            # Count positions where both sequences have amino acids (not gaps)
            matches = 0
            alignable = 0
            
            for aa1, aa2 in zip(seq1, seq2):
                # Skip positions where either sequence has a gap
                if aa1 != '-' and aa2 != '-':
                    alignable += 1
                    if aa1 == aa2:
                        matches += 1
            
            pct_identity = (matches / alignable * 100) if alignable > 0 else 0
            
            results.append({
                'ha_x': ha_subtypes[i],
                'ha_y': ha_subtypes[j],
                'matches': matches,
                'alignable_residues': alignable,
                'percent_identity': pct_identity
            })

ha18_identity_df = pd.DataFrame(results)

print(f"\nCalculated {len(ha18_identity_df)} pairwise comparisons")
ha18_identity_df.head()
Found 18 HA subtypes: H3, H5, H7, H1, H2, H4, H6, H8, H9, H10, H11, H12, H13, H14, H15, H16, H17, H18

Calculated 324 pairwise comparisons
Out[3]:
ha_x ha_y matches alignable_residues percent_identity
0 H3 H3 NaN NaN 100.000000
1 H3 H5 192.0 479.0 40.083507
2 H3 H7 229.0 483.0 47.412008
3 H3 H1 207.0 487.0 42.505133
4 H3 H2 202.0 487.0 41.478439
InĀ [4]:
seqs_of_interest = ha18_identity_df.query(
    'ha_x == "H3" & ha_y == "H5" or ha_x == "H3" & ha_y == "H7" or ha_x == "H5" & ha_y == "H7"'
)
seqs_of_interest.to_csv('../results/sequence_identity/ha_sequence_identity.csv', index=False)
seqs_of_interest
Out[4]:
ha_x ha_y matches alignable_residues percent_identity
1 H3 H5 192.0 479.0 40.083507
2 H3 H7 229.0 483.0 47.412008
20 H5 H7 202.0 473.0 42.706131
InĀ [5]:
# Use hierarchical clustering to order HAs by similarity
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform

def cluster_ha(df):
    # Convert percent identity to distance matrix (100 - identity)
    identity_matrix = df.pivot(index='ha_y', columns='ha_x', values='percent_identity')
    distance_matrix = 100 - identity_matrix

    # Perform hierarchical clustering
    linkage_matrix = linkage(squareform(distance_matrix), method='average')

    # Get the order of HAs from clustering
    dendro = dendrogram(linkage_matrix, no_plot=True)
    clustered_order = [identity_matrix.index[i] for i in dendro['leaves']]
    return clustered_order

order = cluster_ha(ha18_identity_df)
plot_identity_heatmap(
    ha18_identity_df, '', width=400, height=400, order=order, 
    highlight_pairs=[('H3', 'H5'), ('H3', 'H7'), ('H5', 'H7')]
)
Out[5]:
InĀ [6]:
aln_df = pd.read_csv('../results/structural_alignment/structural_alignment.csv')[[
    'struct_site', 'h3_wt_aa', 'h5_wt_aa', 'h7_wt_aa'
]].assign(
    ha_region=lambda x: pd.Categorical(
        np.where(
            x['struct_site'].map(lambda s: natsort_keygen()(str(s))) <= natsort_keygen()('329'),
            'HA1',
            'HA2'
        ),
        categories=['HA1', 'HA2']
    )
)

aln_df.head()
Out[6]:
struct_site h3_wt_aa h5_wt_aa h7_wt_aa ha_region
0 9 S K NaN HA1
1 10 T S NaN HA1
2 11 A D D HA1
3 12 T Q K HA1
4 13 L I I HA1
InĀ [7]:
# Calculate sliding window conservation across sites
def calculate_sliding_window_conservation(df, window_size=30):
    """
    Calculate sliding window conservation for pairwise HA comparisons.
    
    Parameters:
    -----------
    df : DataFrame
        Alignment dataframe with columns 'struct_site', 'h3_wt_aa', 'h5_wt_aa', 'h7_wt_aa'
    window_size : int
        Size of the sliding window
        
    Returns:
    --------
    DataFrame with columns: site, comparison, percent_conserved
    """
    # Sort by struct_site using natural sorting
    df = df.copy()
    df = df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
    
    # Add a positional index for x-axis plotting
    df['site_index'] = range(len(df))
    
    # Define pairwise comparisons
    comparisons = [
        ('H3', 'H5', 'h3_wt_aa', 'h5_wt_aa'),
        ('H3', 'H7', 'h3_wt_aa', 'h7_wt_aa'),
        ('H5', 'H7', 'h5_wt_aa', 'h7_wt_aa')
    ]
    
    results = []
    
    for ha1, ha2, col1, col2 in comparisons:
        comparison_name = f'{ha1}-{ha2}'
        
        # Calculate conservation for each window
        for i in range(len(df)):
            # Define window boundaries
            start_idx = max(0, i - window_size // 2)
            end_idx = min(len(df), i + window_size // 2 + 1)
            
            # Get window data
            window_df = df.iloc[start_idx:end_idx]
            
            # Get alignable positions (both sequences have amino acids)
            alignable = window_df[[col1, col2]].dropna()
            
            if len(alignable) > 0:
                # Calculate percent conserved (matching amino acids)
                matches = (alignable[col1] == alignable[col2]).sum()
                pct_conserved = (matches / len(alignable)) * 100
            else:
                pct_conserved = np.nan
            
            results.append({
                'site': df.loc[i, 'struct_site'],
                'site_index': df.loc[i, 'site_index'],
                'comparison': comparison_name,
                'percent_conserved': pct_conserved,
                'window_size': end_idx - start_idx,
                'ha_region': df.loc[i, 'ha_region']
            })
    
    return pd.DataFrame(results)

window_df = calculate_sliding_window_conservation(aln_df, window_size=30)
window_df.head()
Out[7]:
site site_index comparison percent_conserved window_size ha_region
0 9 0 H3-H5 25.000000 16 HA1
1 10 1 H3-H5 23.529412 17 HA1
2 11 2 H3-H5 27.777778 18 HA1
3 12 3 H3-H5 26.315789 19 HA1
4 13 4 H3-H5 30.000000 20 HA1
InĀ [8]:
# Create sliding window conservation plot

# Define colors for each comparison
comparison_colors = {
    'H3-H5': '#0099B4',
    'H3-H7': '#ED0000',
    'H5-H7': '#FFDC91'
}

# Get the unique sites in natural sort order for the x-axis
site_order = window_df.drop_duplicates('site').sort_values('site_index')['site'].tolist()

# Find the HA1/HA2 boundary position
boundary_idx = window_df[window_df['ha_region'] == 'HA2']['site_index'].min()

# Create the line chart
chart = alt.Chart(window_df).mark_line(
    size=2,
    opacity=1
).encode(
    x=alt.X(
        'site:N',
        title='Site',
        sort=site_order,
        axis=alt.Axis(
            labelAngle=0,
            values=['100', '200', '300', '400', '500'],
            tickCount=5
        )
    ),
    y=alt.Y(
        'percent_conserved:Q',
        title=['Amino Acid Identity (%)'],
        scale=alt.Scale(domain=[0, 100])
    ),
    color=alt.Color(
        'comparison:N',
        title='Comparison',
        scale=alt.Scale(
            domain=list(comparison_colors.keys()),
            range=list(comparison_colors.values())
        ),
        legend=alt.Legend(
            orient='top',
            direction='horizontal',
        )
    ),
    tooltip=['window_size']
).properties(
    width=400,
    height=150,
)

# Add a vertical line to separate HA1 and HA2 using site_index
if boundary_idx is not None and not pd.isna(boundary_idx):
    boundary_site = site_order[int(boundary_idx)]
    separator = alt.Chart(pd.DataFrame({'site': [boundary_site]})).mark_rule(
        strokeDash=[5, 5],
        color='gray',
        size=1
    ).encode(
        x=alt.X('site:N', sort=site_order)
    )

    # Combine all layers
    final_chart = (chart + separator)
else:
    final_chart = chart

final_chart
Out[8]:
InĀ [9]:
# Calculate pairwise amino acid identity between HA domains

columns = ['h3_wt_aa', 'h5_wt_aa', 'h7_wt_aa']
results = []

for region in ['HA1', 'HA2']:
    # Filter to current region
    region_df = aln_df[aln_df['ha_region'] == region]
    
    # Calculate identity for all pairwise combinations in both directions
    for col1 in columns:
        for col2 in columns:
            if col1 == col2:
                # Self-comparison is 100% identity
                ha1 = col1.replace('_wt_aa', '').upper()
                ha2 = col2.replace('_wt_aa', '').upper()
                results.append({
                    'ha_region': region,
                    'ha_x': ha1,
                    'ha_y': ha2,
                    'matches': np.nan,
                    'alignable_residues': np.nan,
                    'percent_identity': 100.0
                })
            else:
                # Get positions where both sequences have amino acids (alignable residues)
                alignable = region_df[[col1, col2]].dropna()
                
                # Count matches
                matches = (alignable[col1] == alignable[col2]).sum()
                total = len(alignable)
                
                # Calculate percent identity
                pct_identity = (matches / total * 100) if total > 0 else 0
                
                ha1 = col1.replace('_wt_aa', '').upper()
                ha2 = col2.replace('_wt_aa', '').upper()
                
                results.append({
                    'ha_region': region,
                    'ha_x': ha1,
                    'ha_y': ha2,
                    'matches': matches,
                    'alignable_residues': total,
                    'percent_identity': pct_identity
                })

domain_identity_df = pd.DataFrame(results)
domain_identity_df
Out[9]:
ha_region ha_x ha_y matches alignable_residues percent_identity
0 HA1 H3 H3 NaN NaN 100.000000
1 HA1 H3 H5 108.0 315.0 34.285714
2 HA1 H3 H7 115.0 313.0 36.741214
3 HA1 H5 H3 108.0 315.0 34.285714
4 HA1 H5 H5 NaN NaN 100.000000
5 HA1 H5 H7 115.0 312.0 36.858974
6 HA1 H7 H3 115.0 313.0 36.741214
7 HA1 H7 H5 115.0 312.0 36.858974
8 HA1 H7 H7 NaN NaN 100.000000
9 HA2 H3 H3 NaN NaN 100.000000
10 HA2 H3 H5 84.0 164.0 51.219512
11 HA2 H3 H7 114.0 170.0 67.058824
12 HA2 H5 H3 84.0 164.0 51.219512
13 HA2 H5 H5 NaN NaN 100.000000
14 HA2 H5 H7 87.0 161.0 54.037267
15 HA2 H7 H3 114.0 170.0 67.058824
16 HA2 H7 H5 87.0 161.0 54.037267
17 HA2 H7 H7 NaN NaN 100.000000