|
27 | 27 | "import numpy as np\n", |
28 | 28 | "import pandas as pd\n", |
29 | 29 | "import seaborn as sns\n", |
| 30 | + "from matplotlib.ticker import FuncFormatter\n", |
30 | 31 | "\n", |
31 | 32 | "from posebench.analysis.inference_analysis_casp import (\n", |
32 | 33 | " CASP_BUST_TEST_COLUMNS,\n", |
|
304 | 305 | " :param method: Method name.\n", |
305 | 306 | " :return: Method category.\n", |
306 | 307 | " \"\"\"\n", |
307 | | - " return method_category_mapping.get(method, \"DL-based blind\")" |
| 308 | + " return method_category_mapping.get(method, \"DL-based blind\")\n", |
| 309 | + "\n", |
| 310 | + "\n", |
| 311 | + "def percent_angstrom_formatter(x, pos):\n", |
| 312 | + " \"\"\"\n", |
| 313 | + " Format function for percent/angstrom axis.\n", |
| 314 | + "\n", |
| 315 | + " :param x: Value.\n", |
| 316 | + " :param pos: Position.\n", |
| 317 | + " :return: Formatted string.\n", |
| 318 | + " \"\"\"\n", |
| 319 | + " return f\"{x:.0f}% / Å\"" |
308 | 320 | ] |
309 | 321 | }, |
310 | 322 | { |
|
799 | 811 | " width=bar_width,\n", |
800 | 812 | " )\n", |
801 | 813 | "\n", |
| 814 | + " # extract raw RMSD values for each method and condition\n", |
| 815 | + " for method_idx, method in enumerate(method_mapping.values()):\n", |
| 816 | + " # get unrelaxed RMSD values grouped by target\n", |
| 817 | + " unrelaxed_rmsd_by_target = {}\n", |
| 818 | + " relaxed_rmsd_by_target = {}\n", |
| 819 | + "\n", |
| 820 | + " for repeat_index in range(1, max_num_repeats_per_method + 1):\n", |
| 821 | + " # unrelaxed data\n", |
| 822 | + " casp15_unrelaxed = (\n", |
| 823 | + " globals()[f\"scoring_results_table_{repeat_index}\"][\n", |
| 824 | + " (\n", |
| 825 | + " globals()[f\"scoring_results_table_{repeat_index}\"][\"dataset\"]\n", |
| 826 | + " == \"CASP15 set\"\n", |
| 827 | + " )\n", |
| 828 | + " & (\n", |
| 829 | + " globals()[f\"scoring_results_table_{repeat_index}\"][\"post-processing\"]\n", |
| 830 | + " == \"none\"\n", |
| 831 | + " )\n", |
| 832 | + " & (globals()[f\"scoring_results_table_{repeat_index}\"][\"method\"] == method)\n", |
| 833 | + " ]\n", |
| 834 | + " .groupby(\"target\")\n", |
| 835 | + " .agg({\"rmsd\": \"mean\"})\n", |
| 836 | + " )\n", |
| 837 | + "\n", |
| 838 | + " # relaxed data\n", |
| 839 | + " casp15_relaxed = (\n", |
| 840 | + " globals()[f\"scoring_results_table_{repeat_index}\"][\n", |
| 841 | + " (\n", |
| 842 | + " globals()[f\"scoring_results_table_{repeat_index}\"][\"dataset\"]\n", |
| 843 | + " == \"CASP15 set\"\n", |
| 844 | + " )\n", |
| 845 | + " & (\n", |
| 846 | + " globals()[f\"scoring_results_table_{repeat_index}\"][\"post-processing\"]\n", |
| 847 | + " == \"energy minimization\"\n", |
| 848 | + " )\n", |
| 849 | + " & (globals()[f\"scoring_results_table_{repeat_index}\"][\"method\"] == method)\n", |
| 850 | + " ]\n", |
| 851 | + " .groupby(\"target\")\n", |
| 852 | + " .agg({\"rmsd\": \"mean\"})\n", |
| 853 | + " )\n", |
| 854 | + "\n", |
| 855 | + " # accumulate values by target\n", |
| 856 | + " for target, rmsd_value in casp15_unrelaxed.iterrows():\n", |
| 857 | + " if target not in unrelaxed_rmsd_by_target:\n", |
| 858 | + " unrelaxed_rmsd_by_target[target] = []\n", |
| 859 | + " unrelaxed_rmsd_by_target[target].append(rmsd_value[\"rmsd\"])\n", |
| 860 | + "\n", |
| 861 | + " for target, rmsd_value in casp15_relaxed.iterrows():\n", |
| 862 | + " if target not in relaxed_rmsd_by_target:\n", |
| 863 | + " relaxed_rmsd_by_target[target] = []\n", |
| 864 | + " relaxed_rmsd_by_target[target].append(rmsd_value[\"rmsd\"])\n", |
| 865 | + "\n", |
| 866 | + " # calculate average RMSD across repeats for each target\n", |
| 867 | + " unrelaxed_rmsd_averages = [\n", |
| 868 | + " np.mean(values) for values in unrelaxed_rmsd_by_target.values()\n", |
| 869 | + " ]\n", |
| 870 | + " relaxed_rmsd_averages = [np.mean(values) for values in relaxed_rmsd_by_target.values()]\n", |
| 871 | + "\n", |
| 872 | + " # overlay unrelaxed RMSD points (averaged per target)\n", |
| 873 | + " if len(unrelaxed_rmsd_averages) > 0:\n", |
| 874 | + " # add small random jitter for better visibility when points overlap\n", |
| 875 | + " x_positions = np.random.normal(r1[method_idx], 0.05, len(unrelaxed_rmsd_averages))\n", |
| 876 | + " # clamp RMSD values at 100\n", |
| 877 | + " clamped_rmsd = [min(val, 100) for val in unrelaxed_rmsd_averages]\n", |
| 878 | + " axis.scatter(\n", |
| 879 | + " x_positions,\n", |
| 880 | + " clamped_rmsd,\n", |
| 881 | + " alpha=0.6,\n", |
| 882 | + " s=20,\n", |
| 883 | + " color=\"darkred\",\n", |
| 884 | + " edgecolors=\"black\",\n", |
| 885 | + " linewidth=0.5,\n", |
| 886 | + " zorder=10,\n", |
| 887 | + " ) # higher zorder to appear on top\n", |
| 888 | + "\n", |
| 889 | + " # overlay relaxed RMSD points (averaged per target)\n", |
| 890 | + " if len(relaxed_rmsd_averages) > 0:\n", |
| 891 | + " x_positions = np.random.normal(r2[method_idx], 0.05, len(relaxed_rmsd_averages))\n", |
| 892 | + " # clamp RMSD values at 100\n", |
| 893 | + " clamped_rmsd = [min(val, 100) for val in relaxed_rmsd_averages]\n", |
| 894 | + " axis.scatter(\n", |
| 895 | + " x_positions,\n", |
| 896 | + " clamped_rmsd,\n", |
| 897 | + " alpha=0.6,\n", |
| 898 | + " s=20,\n", |
| 899 | + " color=\"purple\",\n", |
| 900 | + " edgecolors=\"black\",\n", |
| 901 | + " linewidth=0.5,\n", |
| 902 | + " zorder=10,\n", |
| 903 | + " )\n", |
| 904 | + "\n", |
802 | 905 | " # add labels, titles, ticks, etc.\n", |
803 | 906 | " axis.set_xlabel(f\"{complex_type.title()}-ligand blind docking ({complex_license})\")\n", |
804 | | - " axis.set_ylabel(\"Percentage of predictions\")\n", |
| 907 | + " axis.set_ylabel(\"Percentage of predictions / RMSD (Å)\")\n", |
805 | 908 | " axis.set_xlim(1, 23 + 0.1)\n", |
806 | 909 | " axis.set_ylim(0, 125)\n", |
807 | 910 | "\n", |
808 | 911 | " axis.bar_label(casp15_rmsd_lt2_bar, fmt=\"{:,.1f}\", label_type=\"center\")\n", |
809 | 912 | " axis.bar_label(casp15_relaxed_rmsd_lt_2_bar, fmt=\"{:,.1f}\", label_type=\"center\")\n", |
810 | 913 | " axis.bar_label(casp15_plif_wm_bar, fmt=\"{:,.1f}\", label_type=\"center\")\n", |
811 | 914 | "\n", |
812 | | - " axis.yaxis.set_major_formatter(mtick.PercentFormatter())\n", |
| 915 | + " axis.yaxis.set_major_formatter(FuncFormatter(percent_angstrom_formatter))\n", |
813 | 916 | "\n", |
814 | 917 | " axis.set_yticks([0, 20, 40, 60, 80, 100])\n", |
815 | 918 | " axis.axhline(y=0, color=\"#EAEFF8\")\n", |
|
0 commit comments