Skip to content

Commit ec2e5e9

Browse files
committed
MDS plotting code
1 parent c17bc36 commit ec2e5e9

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

bin/mds_structures.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
Run MDS on structures to create an embedding visualization
3+
4+
Coloring options:
5+
* training TM similarity
6+
* scTM
7+
* helix/beta strand annotations
8+
* length
9+
"""
10+
11+
import os
12+
import json
13+
import logging
14+
from glob import glob
15+
import argparse
16+
17+
18+
import pandas as pd
19+
from sklearn.manifold import MDS
20+
from matplotlib import pyplot as plt
21+
22+
from hclust_structures import get_pairwise_tmscores, int_getter
23+
from annot_secondary_structures import count_structures_in_pdb
24+
25+
# :)
26+
SEED = int(
27+
float.fromhex("2254616977616e2069732061206672656520636f756e74727922") % 10000
28+
)
29+
30+
31+
def len_pdb_structure(fname: str) -> int:
32+
"""Return the integer length of the PDB structure"""
33+
with open(fname) as source:
34+
atom_lines = [l.strip() for l in source if l.startswith("ATOM")]
35+
last_line_tokens = atom_lines[-1].split()
36+
last_line_l = int(last_line_tokens[5])
37+
assert int(len(atom_lines) / 3) == last_line_l
38+
return last_line_l
39+
40+
41+
def build_parser():
42+
parser = argparse.ArgumentParser(
43+
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
44+
)
45+
parser.add_argument("dirname", type=str, help="Directory containing PDB files")
46+
parser.add_argument("--sctm", type=str, default="", help="scTM scores JSON file")
47+
parser.add_argument(
48+
"--trainingtm", type=str, default="", help="Training TM score JSON"
49+
)
50+
parser.add_argument(
51+
"-o",
52+
"--output",
53+
type=str,
54+
default="tmscore_mds",
55+
help="PDF file prefix to write output to",
56+
)
57+
return parser
58+
59+
60+
def main():
61+
"""Run script"""
62+
parser = build_parser()
63+
args = parser.parse_args()
64+
65+
# Get files
66+
fnames = sorted(
67+
glob(os.path.join(args.dirname, "*.pdb")),
68+
key=lambda x: int_getter(os.path.basename(x)),
69+
)
70+
logging.info(f"Computing TMscore on {len(fnames)} structures")
71+
72+
pdist_df = get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm)
73+
mds = MDS(n_components=2, dissimilarity="precomputed", n_jobs=-1, random_state=SEED)
74+
embedding = pd.DataFrame(mds.fit_transform(pdist_df.values), index=pdist_df.index)
75+
76+
format_strings = {
77+
"Number helices": "{x:.1f}",
78+
}
79+
# For a variety of coloring keys, compute/read the scores and color scatter
80+
# plot by the scores.
81+
for k, v in {
82+
"null": None,
83+
"Max training TM": args.trainingtm,
84+
"scTM": args.sctm,
85+
"length": lambda x: len_pdb_structure(x),
86+
"Number helices": lambda x: count_structures_in_pdb(x, "psea")[0],
87+
"Number sheets": lambda x: count_structures_in_pdb(x, "psea")[1],
88+
}.items():
89+
if v is None or v:
90+
logging.info(f"Coloring by {k} scores")
91+
if v is None:
92+
scores = None
93+
elif callable(v):
94+
fname_to_key = lambda f: os.path.basename(f).split(".")[0]
95+
scores = {
96+
fname_to_key(f): v(f)
97+
for f in fnames
98+
if fname_to_key(f) in embedding.index
99+
}
100+
scores = embedding.index.map(scores)
101+
elif os.path.isfile(v):
102+
with open(v) as source:
103+
scores = embedding.index.map(json.load(source))
104+
else:
105+
raise ValueError(f"Invalid value for {k}: {v}")
106+
107+
fig, ax = plt.subplots(dpi=300)
108+
points = ax.scatter(
109+
embedding.iloc[:, 0],
110+
embedding.iloc[:, 1],
111+
s=8,
112+
c=scores,
113+
cmap="RdYlBu",
114+
alpha=0.9,
115+
)
116+
ax.set(
117+
xlabel="MDS 1",
118+
ylabel="MDS 2",
119+
)
120+
if not k == "null":
121+
ax.set(
122+
xticks=[],
123+
yticks=[],
124+
title=k,
125+
)
126+
if scores is not None:
127+
cbar = plt.colorbar(
128+
points,
129+
ax=ax,
130+
fraction=0.08,
131+
pad=0.04,
132+
location="right",
133+
format=format_strings.get(k, None),
134+
)
135+
cbar.ax.set_ylabel(k, fontsize=12)
136+
137+
fig.savefig(f"{args.output}_mds_{k}.pdf", bbox_inches="tight")
138+
139+
140+
if __name__ == "__main__":
141+
logging.basicConfig(level=logging.INFO)
142+
main()

0 commit comments

Comments
 (0)