Skip to content

Commit 2f3c512

Browse files
committed
Add dot plots to bar plots for CASP15 docking success rates
1 parent 683d7e7 commit 2f3c512

5 files changed

+106
-3
lines changed
267 KB
Loading
268 KB
Loading

notebooks/casp15_inference_results_plotting.ipynb

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"import numpy as np\n",
2828
"import pandas as pd\n",
2929
"import seaborn as sns\n",
30+
"from matplotlib.ticker import FuncFormatter\n",
3031
"\n",
3132
"from posebench.analysis.inference_analysis_casp import (\n",
3233
" CASP_BUST_TEST_COLUMNS,\n",
@@ -304,7 +305,18 @@
304305
" :param method: Method name.\n",
305306
" :return: Method category.\n",
306307
" \"\"\"\n",
307-
" return method_category_mapping.get(method, \"DL-based blind\")"
308+
" return method_category_mapping.get(method, \"DL-based blind\")\n",
309+
"\n",
310+
"\n",
311+
"def percent_angstrom_formatter(x, pos):\n",
312+
" \"\"\"\n",
313+
" Format function for percent/angstrom axis.\n",
314+
"\n",
315+
" :param x: Value.\n",
316+
" :param pos: Position.\n",
317+
" :return: Formatted string.\n",
318+
" \"\"\"\n",
319+
" return f\"{x:.0f}% / Å\""
308320
]
309321
},
310322
{
@@ -799,17 +811,108 @@
799811
" width=bar_width,\n",
800812
" )\n",
801813
"\n",
814+
" # extract raw RMSD values for each method and condition\n",
815+
" for method_idx, method in enumerate(method_mapping.values()):\n",
816+
" # get unrelaxed RMSD values grouped by target\n",
817+
" unrelaxed_rmsd_by_target = {}\n",
818+
" relaxed_rmsd_by_target = {}\n",
819+
"\n",
820+
" for repeat_index in range(1, max_num_repeats_per_method + 1):\n",
821+
" # unrelaxed data\n",
822+
" casp15_unrelaxed = (\n",
823+
" globals()[f\"scoring_results_table_{repeat_index}\"][\n",
824+
" (\n",
825+
" globals()[f\"scoring_results_table_{repeat_index}\"][\"dataset\"]\n",
826+
" == \"CASP15 set\"\n",
827+
" )\n",
828+
" & (\n",
829+
" globals()[f\"scoring_results_table_{repeat_index}\"][\"post-processing\"]\n",
830+
" == \"none\"\n",
831+
" )\n",
832+
" & (globals()[f\"scoring_results_table_{repeat_index}\"][\"method\"] == method)\n",
833+
" ]\n",
834+
" .groupby(\"target\")\n",
835+
" .agg({\"rmsd\": \"mean\"})\n",
836+
" )\n",
837+
"\n",
838+
" # relaxed data\n",
839+
" casp15_relaxed = (\n",
840+
" globals()[f\"scoring_results_table_{repeat_index}\"][\n",
841+
" (\n",
842+
" globals()[f\"scoring_results_table_{repeat_index}\"][\"dataset\"]\n",
843+
" == \"CASP15 set\"\n",
844+
" )\n",
845+
" & (\n",
846+
" globals()[f\"scoring_results_table_{repeat_index}\"][\"post-processing\"]\n",
847+
" == \"energy minimization\"\n",
848+
" )\n",
849+
" & (globals()[f\"scoring_results_table_{repeat_index}\"][\"method\"] == method)\n",
850+
" ]\n",
851+
" .groupby(\"target\")\n",
852+
" .agg({\"rmsd\": \"mean\"})\n",
853+
" )\n",
854+
"\n",
855+
" # accumulate values by target\n",
856+
" for target, rmsd_value in casp15_unrelaxed.iterrows():\n",
857+
" if target not in unrelaxed_rmsd_by_target:\n",
858+
" unrelaxed_rmsd_by_target[target] = []\n",
859+
" unrelaxed_rmsd_by_target[target].append(rmsd_value[\"rmsd\"])\n",
860+
"\n",
861+
" for target, rmsd_value in casp15_relaxed.iterrows():\n",
862+
" if target not in relaxed_rmsd_by_target:\n",
863+
" relaxed_rmsd_by_target[target] = []\n",
864+
" relaxed_rmsd_by_target[target].append(rmsd_value[\"rmsd\"])\n",
865+
"\n",
866+
" # calculate average RMSD across repeats for each target\n",
867+
" unrelaxed_rmsd_averages = [\n",
868+
" np.mean(values) for values in unrelaxed_rmsd_by_target.values()\n",
869+
" ]\n",
870+
" relaxed_rmsd_averages = [np.mean(values) for values in relaxed_rmsd_by_target.values()]\n",
871+
"\n",
872+
" # overlay unrelaxed RMSD points (averaged per target)\n",
873+
" if len(unrelaxed_rmsd_averages) > 0:\n",
874+
" # add small random jitter for better visibility when points overlap\n",
875+
" x_positions = np.random.normal(r1[method_idx], 0.05, len(unrelaxed_rmsd_averages))\n",
876+
" # clamp RMSD values at 100\n",
877+
" clamped_rmsd = [min(val, 100) for val in unrelaxed_rmsd_averages]\n",
878+
" axis.scatter(\n",
879+
" x_positions,\n",
880+
" clamped_rmsd,\n",
881+
" alpha=0.6,\n",
882+
" s=20,\n",
883+
" color=\"darkred\",\n",
884+
" edgecolors=\"black\",\n",
885+
" linewidth=0.5,\n",
886+
" zorder=10,\n",
887+
" ) # higher zorder to appear on top\n",
888+
"\n",
889+
" # overlay relaxed RMSD points (averaged per target)\n",
890+
" if len(relaxed_rmsd_averages) > 0:\n",
891+
" x_positions = np.random.normal(r2[method_idx], 0.05, len(relaxed_rmsd_averages))\n",
892+
" # clamp RMSD values at 100\n",
893+
" clamped_rmsd = [min(val, 100) for val in relaxed_rmsd_averages]\n",
894+
" axis.scatter(\n",
895+
" x_positions,\n",
896+
" clamped_rmsd,\n",
897+
" alpha=0.6,\n",
898+
" s=20,\n",
899+
" color=\"purple\",\n",
900+
" edgecolors=\"black\",\n",
901+
" linewidth=0.5,\n",
902+
" zorder=10,\n",
903+
" )\n",
904+
"\n",
802905
" # add labels, titles, ticks, etc.\n",
803906
" axis.set_xlabel(f\"{complex_type.title()}-ligand blind docking ({complex_license})\")\n",
804-
" axis.set_ylabel(\"Percentage of predictions\")\n",
907+
" axis.set_ylabel(\"Percentage of predictions / RMSD (Å)\")\n",
805908
" axis.set_xlim(1, 23 + 0.1)\n",
806909
" axis.set_ylim(0, 125)\n",
807910
"\n",
808911
" axis.bar_label(casp15_rmsd_lt2_bar, fmt=\"{:,.1f}\", label_type=\"center\")\n",
809912
" axis.bar_label(casp15_relaxed_rmsd_lt_2_bar, fmt=\"{:,.1f}\", label_type=\"center\")\n",
810913
" axis.bar_label(casp15_plif_wm_bar, fmt=\"{:,.1f}\", label_type=\"center\")\n",
811914
"\n",
812-
" axis.yaxis.set_major_formatter(mtick.PercentFormatter())\n",
915+
" axis.yaxis.set_major_formatter(FuncFormatter(percent_angstrom_formatter))\n",
813916
"\n",
814917
" axis.set_yticks([0, 20, 40, 60, 80, 100])\n",
815918
" axis.axhline(y=0, color=\"#EAEFF8\")\n",
264 KB
Loading
276 KB
Loading

0 commit comments

Comments
 (0)