MBP Protein NMR Example

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import maxent

sns.set_context("paper")
sns.set_style(
    "whitegrid",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#333333",
        "ytick.color": "#333333",
    },
)
# plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
colors = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e"]
import pynmrstar
from functools import partialmethod
from tqdm import tqdm

tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
# load data from brmrb
bmrb = pynmrstar.Entry.from_database(20062, convert_data_types=True)
cs_result_sets = []
for chemical_shift_loop in bmrb.get_loops_by_category("Atom_chem_shift"):
    cs_result_sets.append(
        chemical_shift_loop.get_tag(
            ["Comp_index_ID", "Comp_ID", "Atom_ID", "Atom_type", "Val", "Val_err"]
        )
    )
ref_data = pd.DataFrame(
    cs_result_sets[0], columns=["id", "res", "atom", "type", "shift", "error"]
)

ref_resids = ref_data[ref_data.atom == "H"].id.values
ref_data[ref_data.atom == "H"].head(25)
id res atom type shift error
0 1 GLU H H 8.131 0.001
6 2 ASN H H 8.763 0.001
19 4 VAL H H 7.801 0.003
28 5 VAL H H 7.584 0.001
37 6 HIS H H 8.130 0.004
43 7 PHE H H 7.922 0.001
50 8 PHE H H 8.286 0.002
57 9 LYS H H 8.160 0.001
70 10 ASN H H 8.264 0.001
76 11 ILE H H 7.735 0.001
87 12 VAL H H 7.807 0.004
96 13 THR H H 7.951 0.002
109 15 ARG H H 8.013 0.001
118 16 THR H H 7.888 0.001
HAVE_MD_FILE = False

ref_hdata = ref_data[ref_data.atom == "H"]
# cut GLU because proton type mismatch
ref_hdata = ref_hdata["shift"].values[1:].astype(float)
resnames = ref_data[ref_data.atom == "H"].res[1:]
if HAVE_MD_FILE:
    data = pd.read_csv("./cs.csv")
    data.head(10)
    # only need weights, so we extract only shifts that will be biased
    hdata_df = data[data.names == "HN"]
    hdata_df = hdata_df[hdata_df["resids"].isin(ref_resids)]
    hdata_c = hdata_df.confident.values.reshape(len(data.frame.unique()), -1)
    hdata = hdata_df.peaks.values.reshape(len(data.frame.unique()), -1)
    assert hdata.shape[-1] == ref_hdata.shape[0]
    np.savez("mbp_files/mbp_cs.npz", hdata=hdata, hdata_c=hdata_c)
data = np.load("mbp_files/mbp_cs.npz")
hdata, hdata_c = data["hdata"], data["hdata_c"]
plt.plot(np.mean(hdata, axis=0), "o-")
plt.plot(ref_hdata, "o-")
plt.show()
../_images/MBP_4_0.png
# fill in unconfident peaks with mean
hdata_m = np.sum(hdata * hdata_c, axis=0) / np.sum(hdata_c, axis=0)
total_fill = 0
for i in range(hdata.shape[1]):
    hdata[:, i][~hdata_c[:, i]] = hdata_m[i]
    total_fill += np.sum(~hdata_c[:, i])
print("Filled", total_fill)
Filled 0
plt.plot(np.mean(hdata, axis=0), "o-")
plt.plot(ref_hdata, "o-")
plt.show()
../_images/MBP_6_0.png
# make restraints
restraints = []
do_restrain = range(len(ref_hdata) // 2)
for i in do_restrain:
    restraints.append(
        maxent.Restraint(lambda h, i=i: h[i], ref_hdata[i], prior=maxent.Laplace(0.05))
    )
model = maxent.MaxentModel(restraints)
model.compile(tf.keras.optimizers.Adam(0.1), "mean_squared_error")
history = model.fit(hdata, epochs=500, verbose=0)
plt.plot(history.history["loss"])
print(history.history["loss"][-1])
3.694961427347073e-15
../_images/MBP_9_1.png
np.mean(np.abs(np.sum(hdata * model.traj_weights[..., np.newaxis], axis=0) - ref_hdata))
0.1292930309589092
model.lambdas
<tf.Variable 'maxent-lambda:0' shape=(6,) dtype=float32, numpy=
array([-3.4798522 ,  4.554573  ,  0.15842113,  1.4714801 , -0.72251976,
        0.86921126], dtype=float32)>
plt.plot(model.traj_weights)
[<matplotlib.lines.Line2D at 0x7f8b9888d3a0>]
../_images/MBP_12_1.png
plt.figure(figsize=(3, 2), dpi=300)
seq_dict = {
    "CYS": "C",
    "ASP": "D",
    "SER": "S",
    "GLN": "Q",
    "LYS": "K",
    "ILE": "I",
    "PRO": "P",
    "THR": "T",
    "PHE": "F",
    "ASN": "N",
    "GLY": "G",
    "HIS": "H",
    "LEU": "L",
    "ARG": "R",
    "TRP": "W",
    "ALA": "A",
    "VAL": "V",
    "GLU": "E",
    "TYR": "Y",
    "MET": "M",
}
plt.plot(
    np.sum(hdata * model.traj_weights[..., np.newaxis], axis=0), "o-", label="Posterior"
)
plt.plot(np.mean(hdata, axis=0), "o-", label="Prior")
plt.plot(ref_hdata, "*", label="Experiment")
plt.axvline(x=len(ref_hdata) // 2 - 0.5, color="gray", linestyle="--")
plt.xticks(range(len(ref_hdata)), [seq_dict[r] for r in resnames])
plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.8))
plt.text(len(ref_hdata) // 5, 8.55, "Biased")
plt.text(len(ref_hdata) // 2, 8.55, "Unbiased")
plt.xlabel("Sequence")
plt.ylabel("Chemical Shift [ppm]")
plt.savefig("protein.pdf")
../_images/MBP_13_0.png
print("most favored clusters", np.argsort(model.traj_weights)[-3:])
most favored clusters [2158 2221 2147]