In [1]:
import pandas as pd
import numpy as np
import re
import string
from pathlib import Path

from Bio import SeqIO
from Bio.PDB.MMCIF2Dict import MMCIF2Dict
from natsort import natsort_keygen

import altair as alt
import theme
alt.data_transformers.disable_max_rows()
Out[1]:
DataTransformerRegistry.enable('default')

Format structural alignment¶

In [2]:
def alignment_to_df(fasta_path, ref_id, start_site=1):
    records = []
    for rec in SeqIO.parse(fasta_path, "fasta"):
        clean_id = rec.id.split("_")[0]
        records.append((clean_id, str(rec.seq)))

    seq_dict = dict(records)
    if ref_id not in seq_dict:
        raise ValueError(f"ref_id '{ref_id}' not found. Available IDs: {list(seq_dict.keys())}")

    other_ids = [rid for rid in seq_dict if rid != ref_id]
    ref_seq = seq_dict[ref_id]

    letters = list(string.ascii_lowercase)
    def insertion_label(base_num, k):
        if k < len(letters):
            return f"{base_num}{letters[k]}"
        else:
            raise ValueError(f"Insertion index {k} exceeds available letters ({len(letters)})")

    final_df = None

    for pdb in other_ids:
        seq = seq_dict[pdb]
        if len(ref_seq) != len(seq):
            raise ValueError("Aligned sequences must be the same length.")
        
        rows = []
        last_numeric = start_site - 1
        ins_count = {}

        for ref_aa, aa in zip(ref_seq, seq):
            if ref_aa == "-" and aa == "-":
                continue

            if ref_aa != "-" and aa != "-":
                last_numeric += 1
                site = str(last_numeric)
                rows.append((site, ref_aa, aa))
                ins_count[last_numeric] = 0

            elif ref_aa == "-" and aa != "-":
                base = last_numeric
                k = ins_count.get(base, 0)
                site = insertion_label(base, k)
                ins_count[base] = k + 1
                rows.append((site, "-", aa))

            else:  # ref_aa != "-" and aa == "-"
                last_numeric += 1
                site = str(last_numeric)
                rows.append((site, ref_aa, "-"))
                ins_count[last_numeric] = 0

            df = pd.DataFrame(rows, columns=["struct_site", f"{ref_id}_aa", f"{pdb}_aa"])

        if final_df is None:
            final_df = df
        else:
            final_df = pd.merge(final_df, df, on=["struct_site", f"{ref_id}_aa"], how="outer")
        
    return final_df.sort_values("struct_site", key=natsort_keygen()).reset_index(drop=True)

ha1_aln = alignment_to_df('../results/foldmason_alignment/chain_A/result_aa.fa', ref_id='4o5n', start_site=9)
ha2_aln = alignment_to_df('../results/foldmason_alignment/chain_B/result_aa.fa', ref_id='4o5n', start_site=330)

ha1_aln.head()
Out[2]:
struct_site 4o5n_aa 6ii9_aa 4kwm_aa
0 9 P - P
1 10 G - G
2 11 A D D
3 12 T K Q
4 13 L I I

Add RSA¶

In [3]:
AA3_TO_1 = {
    'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C','GLN':'Q','GLU':'E','GLY':'G',
    'HIS':'H','ILE':'I','LEU':'L','LYS':'K','MET':'M','MSE':'M','PHE':'F','PRO':'P',
    'SER':'S','THR':'T','TRP':'W','TYR':'Y','VAL':'V','SEC':'U','PYL':'O'
}

MAX_ASA_TIEN = {
    'A': 129.0, 'C': 167.0, 'D': 193.0, 'E': 223.0, 'F': 240.0, 'G': 104.0, 
    'H': 224.0, 'I': 197.0, 'K': 236.0, 'L': 201.0, 'M': 224.0, 'N': 195.0, 
    'P': 159.0, 'Q': 225.0, 'R': 274.0, 'S': 155.0, 'T': 172.0, 'V': 174.0, 
    'W': 285.0, 'Y': 263.0
}


def process_dssp_mmcif_output(mmcif_path, chain=None, max_asa=MAX_ASA_TIEN, prefer='auth'):
    """
    Parse a DSSP *mmCIF output* file from mkdssp and return a DataFrame like processDSSP.

    Parameters
    ----------
    mmcif_path : str
        Path to DSSP-written mmCIF.
    chain : str or None
        Chain ID to keep. Interpreted as 'auth_asym_id' if prefer='auth', else 'label_asym_id'.
        If None and multiple chains exist, raises an assertion (mimics your original).
    max_asa : dict
        Map of 1-letter AA -> max ASA (e.g., Tien et al. 2013) for RSA.
    prefer : {'auth','label'}
        Which chain/numbering scheme to expose in output (and to filter by).

    Returns
    -------
    pd.DataFrame with columns: pdb_site, amino_acid, ASA, RSA, SS, SS_class
    """
    mm = MMCIF2Dict(mmcif_path)

    # 1) Pull per-residue DSSP summary from DSSP's mmCIF category
    #    (this is written by mkdssp in mmCIF mode)
    lab_asym = mm['_dssp_struct_summary.label_asym_id']
    lab_seq  = mm['_dssp_struct_summary.label_seq_id']
    comp_id  = mm['_dssp_struct_summary.label_comp_id']
    ss_list  = mm['_dssp_struct_summary.secondary_structure']
    asa_list = mm.get('_dssp_struct_summary.accessibility', None)  # ASA may be omitted in some runs

    # 2) Build mapping from (label_asym_id,label_seq_id) -> (auth_asym_id, auth_seq_id, ins_code)
    #    using the atom_site table so we can expose author numbering & insertion codes.
    lat_asym = mm['_atom_site.label_asym_id']
    lat_seq  = mm['_atom_site.label_seq_id']
    auth_asym = mm.get('_atom_site.auth_asym_id', [])
    auth_seq  = mm.get('_atom_site.auth_seq_id', [])
    ins_code  = mm.get('_atom_site.pdbx_PDB_ins_code', [])

    # normalize to lists
    def _norm(v): return v if isinstance(v, list) else [v]
    lat_asym = _norm(lat_asym); lat_seq = _norm(lat_seq)
    auth_asym = _norm(auth_asym); auth_seq = _norm(auth_seq); ins_code = _norm(ins_code)

    label2auth = {}
    seen = set()
    for la, ls, aa, asq, ic in zip(lat_asym, lat_seq, auth_asym, auth_seq, ins_code):
        key = (la, ls)
        if key in seen:
            continue
        seen.add(key)
        label2auth[key] = (aa, asq, ic)

    # 3) Collect rows, mapping to preferred chain scheme and computing RSA/SS_class
    dssp_cys = re.compile('[a-z]')  # DSSP uses lowercase for half-cystines
    rows = []
    chains_present = set()

    for la, ls, res3, ss, asa in zip(lab_asym, lab_seq, comp_id, ss_list, (asa_list or ['.']*len(ss_list))):
        # derive author identifiers if possible
        aa_asym, aa_seq, aa_ins = label2auth.get((la, ls), (None, None, None))

        # choose chain id used to filter (and to report numbering)
        chosen_chain = aa_asym if (prefer == 'auth' and aa_asym is not None) else la
        chains_present.add(chosen_chain)

        if chain is not None and chosen_chain != chain:
            continue

        # amino acid one-letter
        aa1 = AA3_TO_1.get(res3.upper(), 'X')
        # match your Cys normalization if DSSP marked half-cystine in lowercase (rare in this table, but safe)
        if dssp_cys.match(aa1):
            aa1 = 'C'

        # SS: DSSP mmCIF uses '.' for coil; map to '-' like your original
        ss = '-' if (ss is None or ss == '.' or ss.strip() == '') else ss

        # ASA
        try:
            asa_val = float(asa) if asa not in (None, '.', '?') else float('nan')
        except ValueError:
            asa_val = float('nan')

        # RSA
        rsa = asa_val / float(max_asa.get(aa1, max_asa.get('A'))) if pd.notna(asa_val) else float('nan')

        # PDB-style site: prefer author seq + insertion code if available
        if prefer == 'auth' and (aa_seq is not None):
            if aa_ins and aa_ins not in ('?', ' '):
                pdb_site = f"{aa_seq}{str(aa_ins).strip()}"
            else:
                pdb_site = str(aa_seq) #int(aa_seq) if str(aa_seq).isdigit() else aa_seq
        else:
            # fall back to label numbering
            pdb_site = str(ls) #int(ls) if str(ls).isdigit() else ls

        # 3-class SS
        if ss in ['G', 'H', 'I', 'P']:       # includes polyproline helix 'P'
            ss_class = 'helix'
        elif ss in ['B', 'E']:
            ss_class = 'strand'
        elif ss in ['T', 'S', '-']:
            ss_class = 'loop'
        else:
            raise ValueError(f"invalid SS of {ss}")

        rows.append({
            'pdb_site': pdb_site,
            'amino_acid': aa1,
            'ASA': asa_val,
            'RSA': rsa,
            'SS': ss,
            'SS_class': ss_class,
        })

    if chain is None:
        assert len(chains_present) == 1, "chain is None, but multiple chains"

    return pd.DataFrame(rows).assign(chain=chain)


h3_ha1_ss = process_dssp_mmcif_output("../results/dssp/4o5n_dssp.mmcif", chain='A')[['pdb_site', 'amino_acid', 'RSA', 'SS', 'chain']]
h3_ha2_ss = process_dssp_mmcif_output("../results/dssp/4o5n_dssp.mmcif", chain='B')[['pdb_site', 'amino_acid', 'RSA', 'SS', 'chain']]

h5_ha1_ss = process_dssp_mmcif_output("../results/dssp/4kwm_dssp.mmcif", chain='A')[['pdb_site', 'amino_acid', 'RSA', 'SS', 'chain']]
h5_ha2_ss = process_dssp_mmcif_output("../results/dssp/4kwm_dssp.mmcif", chain='B')[['pdb_site', 'amino_acid', 'RSA', 'SS', 'chain']]

h7_ha1_ss = process_dssp_mmcif_output("../results/dssp/6ii9_dssp.mmcif", chain='A')[['pdb_site', 'amino_acid', 'RSA', 'SS', 'chain']]
h7_ha2_ss = process_dssp_mmcif_output("../results/dssp/6ii9_dssp.mmcif", chain='B')[['pdb_site', 'amino_acid', 'RSA', 'SS', 'chain']]
In [4]:
def merge_rsa_by_alignment(aln_df, rsa_df, aa_col, prefix=True):
    """
    Merge RSA/SS data into an alignment DataFrame by matching amino acids.

    Parameters
    ----------
    aln_df : DataFrame with an alignment column (e.g., '4o5n_aa') containing letters or '-'.
    rsa_df : DataFrame with columns ['amino_acid', ...other RSA/SS fields...] in order.
    aa_col : Column in aln_df to traverse (default '4o5n_aa').
    prefix : If True, prefix copied columns with f"{aa_col}_" to avoid collisions.
    """
    if 'amino_acid' not in rsa_df.columns:
        raise ValueError("rsa_df must contain 'amino_acid' column.")

    aa = aln_df[aa_col].astype(str)
    is_letter = aa.str.fullmatch(r"[A-Za-z]")

    # How many letters we’ll consume from rsa_df
    n_letters = int(is_letter.sum())
    if len(rsa_df) < n_letters:
        raise ValueError(
            f"rsa_df has {len(rsa_df)} rows but df[{aa_col}] has {n_letters} letters."
        )

    # The rsa rows we’ll actually use (first n_letters, in order)
    rsa_used = rsa_df.iloc[:n_letters].reset_index(drop=True)

    # Validate amino-acid matching (case-insensitive)
    aa_letters = aa[is_letter].str.upper().to_numpy()
    rsa_letters = rsa_used['amino_acid'].astype(str).str.upper().to_numpy()
    mism = aa_letters != rsa_letters
    if mism.any():
        # Report first few mismatches with df row indices
        bad_df_idx = aa[is_letter].index[mism]
        examples = []
        for k, i in enumerate(bad_df_idx[:5]):
            j = np.flatnonzero(mism)[k]
            examples.append(
                f"{i}: df[{aa_col}]={aln_df.loc[i, aa_col]!r} vs rsa={rsa_used.loc[j, 'amino_acid']!r}"
            )
        raise ValueError("Amino-acid mismatch at rows [" + ", ".join(examples) + "].")

    # Columns to copy from rsa_df (exclude 'amino_acid' itself)
    cols_to_copy = [c for c in rsa_df.columns if c != 'amino_acid']

    # Prepare an output frame filled with NaN for all df rows
    to_add = pd.DataFrame(index=aln_df.index, columns=cols_to_copy, dtype='object')

    # Fill only the letter rows from the corresponding rsa_used rows (in order)
    to_add.loc[is_letter, cols_to_copy] = rsa_used[cols_to_copy].to_numpy()

    if prefix:
        to_add = to_add.add_prefix(f"{aa_col}_")

    return aln_df.join(to_add)

ha1_aln_ss = merge_rsa_by_alignment(ha1_aln, h3_ha1_ss, '4o5n_aa')
ha1_aln_ss = merge_rsa_by_alignment(ha1_aln_ss, h5_ha1_ss, '4kwm_aa')
ha1_aln_ss = merge_rsa_by_alignment(ha1_aln_ss, h7_ha1_ss, '6ii9_aa')

ha2_aln_ss = merge_rsa_by_alignment(ha2_aln, h3_ha2_ss, '4o5n_aa')
ha2_aln_ss = merge_rsa_by_alignment(ha2_aln_ss, h5_ha2_ss, '4kwm_aa')
ha2_aln_ss = merge_rsa_by_alignment(ha2_aln_ss, h7_ha2_ss, '6ii9_aa')

ha1_aln_ss.head()
Out[4]:
struct_site 4o5n_aa 6ii9_aa 4kwm_aa 4o5n_aa_pdb_site 4o5n_aa_RSA 4o5n_aa_SS 4o5n_aa_chain 4kwm_aa_pdb_site 4kwm_aa_RSA 4kwm_aa_SS 4kwm_aa_chain 6ii9_aa_pdb_site 6ii9_aa_RSA 6ii9_aa_SS 6ii9_aa_chain
0 9 P - P 9 1.084277 - A -1 1.140252 - A NaN NaN NaN NaN
1 10 G - G 10 0.150962 - A 0 0.175962 - A NaN NaN NaN NaN
2 11 A D D 11 0.050388 E A 1 0.097927 - A 1 0.624352 - A
3 12 T K Q 12 0.268605 E A 2 0.216889 E A 2 0.368644 E A
4 13 L I I 13 0.0 E A 3 0.0 E A 3 0.0 E A

Add RMSD¶

In [5]:
def parse_rmsd_txt(path, id=None):
    lines = Path(path).read_text().splitlines()
    out = []
    for ln in lines[1:]:
        m = re.match(r"\s*(\d+):\s*(.*)\s*$", ln)
        if not m:
            continue
        i = int(m.group(1))
        v = m.group(2).strip()
        v = None if v == "None" else float(v)
        out.append((i, v))
    return pd.DataFrame(out, columns=["aln_idx", f"rmsd_{id}"])

h3_h5_ha1_rmsd = parse_rmsd_txt('../data/rmsd/h3_h5_ha1_rmsd.txt', id='h3h5')
h3_h5_ha2_rmsd = parse_rmsd_txt('../data/rmsd/h3_h5_ha2_rmsd.txt', id='h3h5')

h3_h7_ha1_rmsd = parse_rmsd_txt('../data/rmsd/h3_h7_ha1_rmsd.txt', id='h3h7')
h3_h7_ha2_rmsd = parse_rmsd_txt('../data/rmsd/h3_h7_ha2_rmsd.txt', id='h3h7')

h5_h7_ha1_rmsd = parse_rmsd_txt('../data/rmsd/h5_h7_ha1_rmsd.txt', id='h5h7')
h5_h7_ha2_rmsd = parse_rmsd_txt('../data/rmsd/h5_h7_ha2_rmsd.txt', id='h5h7')
In [6]:
def merge_aln_rmsd(aln_df, rmsd_df):
    assert len(aln_df) == len(rmsd_df)
    return pd.concat(
        [aln_df, rmsd_df], axis=1
    ).drop(columns=['aln_idx'])

struct_align_df = (
    pd.concat(
        [merge_aln_rmsd(
            merge_aln_rmsd(
                merge_aln_rmsd(ha1_aln_ss, h3_h5_ha1_rmsd),
                h3_h7_ha1_rmsd
            ), h5_h7_ha1_rmsd
        ),
        merge_aln_rmsd(
            merge_aln_rmsd(
                merge_aln_rmsd(ha2_aln_ss, h3_h5_ha2_rmsd),
                h3_h7_ha2_rmsd
            ), h5_h7_ha2_rmsd
        )], ignore_index=True
    )
    .sort_values("struct_site", key=natsort_keygen())
    .reset_index(drop=True)
)

struct_align_df.head()
Out[6]:
struct_site 4o5n_aa 6ii9_aa 4kwm_aa 4o5n_aa_pdb_site 4o5n_aa_RSA 4o5n_aa_SS 4o5n_aa_chain 4kwm_aa_pdb_site 4kwm_aa_RSA 4kwm_aa_SS 4kwm_aa_chain 6ii9_aa_pdb_site 6ii9_aa_RSA 6ii9_aa_SS 6ii9_aa_chain rmsd_h3h5 rmsd_h3h7 rmsd_h5h7
0 9 P - P 9 1.084277 - A -1 1.140252 - A NaN NaN NaN NaN 9.167400 NaN NaN
1 10 G - G 10 0.150962 - A 0 0.175962 - A NaN NaN NaN NaN 8.157247 NaN NaN
2 11 A D D 11 0.050388 E A 1 0.097927 - A 1 0.624352 - A 5.040040 2.984626 2.886615
3 12 T K Q 12 0.268605 E A 2 0.216889 E A 2 0.368644 E A 3.937602 1.626754 3.384350
4 13 L I I 13 0.0 E A 3 0.0 E A 3 0.0 E A 3.687798 1.734039 2.549524

Add DMS background residues and numbering¶

In [7]:
# remove sites that are missing in the structural alignment

h3_missing = [*range(1, 9), *range(326, 330), *range(503, 505)]
h3_wt = pd.read_csv(
    '../data/cell_entry_effects/MDCKSIAT1_entry_func_effects.csv'
)[['site', 'wildtype']].drop_duplicates().reset_index(drop=True).query(
    "site not in @h3_missing"
)

h5_missing = [*map(str, range(1, 9)),
              *map(str, range(325, 339)),
              *map(str, range(503, 552)),
              "328a", "328b", "328c", "510a"]
h5_wt = pd.read_csv(
    '../data/cell_entry_effects/293T_entry_func_effects.csv'
)[['site', 'wildtype']].drop_duplicates().reset_index(drop=True).query(
    "site not in @h5_missing and ~site.str.contains('-', na=False)",
    engine="python"
)

h7_missing = [*map(str, range(1, 11)),
              *map(str, range(326, 330)),
              *map(str, range(500, 515))]
h7_wt = pd.read_csv(
    '../data/cell_entry_effects/293_mix_entry_func_effects.csv'
)[['site', 'wildtype']].drop_duplicates().reset_index(drop=True).query(
    "site not in @h7_missing"
)
In [8]:
def add_wt_cols(
    aln_df, wt_df,
    ref_col="4o5n_aa",      # col to align against with letters or '-'
    site_col="site",        # col in wt_df with site numbering
    aa_col="wildtype",      # column in wt_df with wt residue
    out_site_col="wt_site",
    out_aa_col="wt_aa",
):
    wt = wt_df.reset_index(drop=True).sort_values("site", key=natsort_keygen())
    is_letter = aln_df[ref_col].astype(str).str.fullmatch(r"[A-Za-z]")

    idx = is_letter.cumsum() - 1
    take = is_letter & (idx < len(wt))

    # prefill outputs as NA
    out_site = pd.Series(np.nan, index=aln_df.index, dtype="object")
    out_aa   = pd.Series(np.nan, index=aln_df.index, dtype="object")

    # fill where we have letters and WT left
    pos = idx[take].to_numpy()
    out_site.loc[take] = wt[site_col].to_numpy()[pos]
    out_aa.loc[take]   = wt[aa_col].to_numpy()[pos]

    # summary stats
    comparable = take
    comp_idx = comparable[comparable].index
    n_compared = len(comp_idx)
    if n_compared:
        ref_up = aln_df.loc[comp_idx, ref_col].astype("string").str.upper().to_numpy()
        wt_up  = pd.Series(out_aa, dtype="string").loc[comp_idx].str.upper().to_numpy()
        n_match = int((ref_up == wt_up).sum())
        pct_match = float(np.round(100.0 * n_match / n_compared, 2))

    else:
        n_match = 0
        pct_match = np.nan

    out_df = aln_df.assign(**{out_site_col: out_site, out_aa_col: out_aa})
    return out_df, {"n_compared": n_compared, "n_match": n_match, "pct_match": pct_match}

aln_out, stats = add_wt_cols(
    struct_align_df, 
    h3_wt, 
    ref_col="4o5n_aa",
    out_aa_col="h3_wt_aa",
    out_site_col="h3_site"
)
print(stats)

aln_out, stats = add_wt_cols(
    aln_out, 
    h5_wt, 
    ref_col="4kwm_aa",
    out_aa_col="h5_wt_aa",
    out_site_col="h5_site"
)
print(stats)

aln_out, stats = add_wt_cols(
    aln_out, 
    h7_wt, 
    ref_col="6ii9_aa",
    out_aa_col="h7_wt_aa",
    out_site_col="h7_site"
)
print(stats)
{'n_compared': 490, 'n_match': 457, 'pct_match': 93.27}
{'n_compared': 487, 'n_match': 453, 'pct_match': 93.02}
{'n_compared': 487, 'n_match': 487, 'pct_match': 100.0}
In [9]:
aln_out.to_csv('../results/structural_alignment/structural_alignment_detailed.csv', index=False)

aln_out[[
    'struct_site', 'h3_site', 'h5_site', 'h7_site', 'h3_wt_aa', 'h5_wt_aa', 'h7_wt_aa',
    'rmsd_h3h5', 'rmsd_h3h7', 'rmsd_h5h7', '4o5n_aa_RSA', '4kwm_aa_RSA', '6ii9_aa_RSA'
]].to_csv('../results/structural_alignment/structural_alignment.csv', index=False)
aln_out.head()
Out[9]:
struct_site 4o5n_aa 6ii9_aa 4kwm_aa 4o5n_aa_pdb_site 4o5n_aa_RSA 4o5n_aa_SS 4o5n_aa_chain 4kwm_aa_pdb_site 4kwm_aa_RSA ... 6ii9_aa_chain rmsd_h3h5 rmsd_h3h7 rmsd_h5h7 h3_site h3_wt_aa h5_site h5_wt_aa h7_site h7_wt_aa
0 9 P - P 9 1.084277 - A -1 1.140252 ... NaN 9.167400 NaN NaN 9 S 9 K NaN NaN
1 10 G - G 10 0.150962 - A 0 0.175962 ... NaN 8.157247 NaN NaN 10 T 10 S NaN NaN
2 11 A D D 11 0.050388 E A 1 0.097927 ... A 5.040040 2.984626 2.886615 11 A 11 D 11 D
3 12 T K Q 12 0.268605 E A 2 0.216889 ... A 3.937602 1.626754 3.384350 12 T 12 Q 12 K
4 13 L I I 13 0.0 E A 3 0.0 ... A 3.687798 1.734039 2.549524 13 L 13 I 13 I

5 rows × 25 columns

Sanity check: RSA correlation¶

In [10]:
def plot_rsa_correlation(aln_df, pdb_x, pdb_y, color_col, colors, color_title=None):
    r_value = aln_df[f'{pdb_x}_aa_RSA'].corr(aln_df[f'{pdb_y}_aa_RSA'])
    r_text = f"r = {r_value:.2f}"

    aln_df = aln_df.assign(
        same_wildtype= lambda x: np.where(
            x[f'{pdb_x}_aa'] == x[f'{pdb_y}_aa'],
            'Amino acid conserved',
            'Amino acid changed'
        ),
    ).drop_duplicates()

    chart = alt.Chart(aln_df).mark_circle(
        size=35, opacity=1, stroke='black', strokeWidth=0.5
    ).encode(
        y=alt.Y(f'{pdb_x}_aa_RSA', title=f'RSA in PDB {pdb_x.upper()}'),
        x=alt.X(f'{pdb_y}_aa_RSA', title=f'RSA in PDB {pdb_y.upper()}'),
        color=alt.Color(
            color_col,
            title=color_title, 
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
        tooltip=['struct_site', f'{pdb_x}_aa', f'{pdb_y}_aa', f'{pdb_x}_aa_RSA', f'{pdb_y}_aa_RSA', color_col]
    ).properties(
        width=175,
        height=175
    )

    r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
        align='left',
        baseline='top',
        fontSize=16,
        fontWeight='normal',
        color='black'
    ).encode(
        text='text:N',
        x=alt.value(5), 
        y=alt.value(5)
    )

    return chart + r_label

colors = {
    'Amino acid changed' : '#5484AF',
    'Amino acid conserved' : '#E04948'
}

(
    plot_rsa_correlation(aln_out, '4o5n', '4kwm', 'same_wildtype', colors) |
    plot_rsa_correlation(aln_out, '4o5n', '6ii9', 'same_wildtype', colors) |
    plot_rsa_correlation(aln_out, '4kwm', '6ii9', 'same_wildtype', colors)
)
Out[10]: