|
6 | 6 | "metadata": {}, |
7 | 7 | "outputs": [], |
8 | 8 | "source": [ |
| 9 | + "import numpy as np\n", |
| 10 | + "\n", |
9 | 11 | "%load_ext autoreload\n", |
10 | 12 | "%autoreload 2" |
11 | 13 | ] |
|
18 | 20 | "source": [ |
19 | 21 | "import logging\n", |
20 | 22 | "\n", |
21 | | - "from kyle.evaluation import (\n", |
22 | | - " EvalStats,\n", |
23 | | - " compute_accuracy,\n", |
24 | | - " compute_ECE,\n", |
25 | | - ")\n", |
| 23 | + "from kyle.evaluation import EvalStats\n", |
26 | 24 | "from kyle.sampling.fake_clf import DirichletFC\n", |
27 | 25 | "from kyle.transformations import *\n", |
| 26 | + "import matplotlib.pyplot as plt\n", |
28 | 27 | "\n", |
29 | 28 | "logging.basicConfig(level=logging.INFO)" |
30 | 29 | ] |
|
71 | 70 | "outputs": [], |
72 | 71 | "source": [ |
73 | 72 | "print(\n", |
74 | | - " \"mostly underestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n", |
| 73 | + " \"mostly overestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n", |
75 | 74 | ")\n", |
76 | 75 | "transform = PowerLawSimplexAut(np.array([2, 2, 2]))\n", |
77 | 76 | "dirichlet_fc.set_simplex_automorphism(transform)\n", |
| 77 | + "\n", |
| 78 | + "\n", |
78 | 79 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", |
79 | 80 | "\n", |
80 | 81 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", |
81 | 82 | "print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n", |
82 | | - "ece_approx = -eval_stats.expected_confidence() + eval_stats.accuracy()\n", |
| 83 | + "ece_approx = eval_stats.expected_confidence() - eval_stats.accuracy()\n", |
83 | 84 | "print(f\"{ece_approx=}\")\n", |
84 | | - "eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n", |
| 85 | + "eval_stats.plot_reliability_curves(\n", |
| 86 | + " [0, 1, \"top_class\"], display_weights=True, n_bins=200\n", |
| 87 | + ")\n", |
| 88 | + "plt.show()\n", |
85 | 89 | "\n", |
86 | 90 | "\n", |
87 | | - "theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n", |
88 | | - "theoretical_ece = compute_ECE(dirichlet_fc)[0]\n", |
89 | | - "print(f\"{theoretical_acc=} , {theoretical_ece=}\")" |
| 91 | + "# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n", |
| 92 | + "# theoretical_ece = compute_ECE(dirichlet_fc)[0]\n", |
| 93 | + "# print(f\"{theoretical_acc=} , {theoretical_ece=}\")" |
90 | 94 | ] |
91 | 95 | }, |
92 | 96 | { |
|
96 | 100 | "outputs": [], |
97 | 101 | "source": [ |
98 | 102 | "print(\n", |
99 | | - " \"mostly overestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n", |
| 103 | + " \"mostly underestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n", |
100 | 104 | ")\n", |
101 | 105 | "print(\"Note the variance and the resulting sensitivity to binning\")\n", |
102 | 106 | "\n", |
|
106 | 110 | "\n", |
107 | 111 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", |
108 | 112 | "print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n", |
109 | | - "ece_approx = eval_stats.expected_confidence() - eval_stats.accuracy()\n", |
| 113 | + "ece_approx = -eval_stats.expected_confidence() + eval_stats.accuracy()\n", |
110 | 114 | "print(f\"{ece_approx=}\")\n", |
111 | 115 | "eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n", |
| 116 | + "plt.show()\n", |
112 | 117 | "\n", |
113 | 118 | "\n", |
114 | 119 | "# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n", |
|
122 | 127 | "metadata": {}, |
123 | 128 | "outputs": [], |
124 | 129 | "source": [ |
125 | | - "print(\"Overestimating predictions with MaxComponent\")\n", |
| 130 | + "print(\"Underestimating predictions with MaxComponent\")\n", |
126 | 131 | "\n", |
127 | 132 | "\n", |
128 | 133 | "def overestimating_max(x: np.ndarray):\n", |
|
139 | 144 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", |
140 | 145 | "print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n", |
141 | 146 | "eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n", |
| 147 | + "plt.show()\n", |
142 | 148 | "\n", |
143 | 149 | "# Integrals converge pretty slowly, this takes time\n", |
144 | 150 | "# theoretical_acc = compute_accuracy(dirichlet_fc, opts={\"limit\": 75})[0]\n", |
|
188 | 194 | "metadata": {}, |
189 | 195 | "outputs": [], |
190 | 196 | "source": [ |
191 | | - "print(\"mostly underestimating first two classes with RestrictedPowerSimplexAut\")\n", |
| 197 | + "print(\"mostly overestimating first two classes with RestrictedPowerSimplexAut\")\n", |
192 | 198 | "\n", |
193 | 199 | "transform = RestrictedPowerSimplexAut(np.array([2, 4]))\n", |
194 | 200 | "dirichlet_fc.set_simplex_automorphism(transform)\n", |
|
199 | 205 | "print(\"Theoretical approximation of ECE\")\n", |
200 | 206 | "print(eval_stats.expected_confidence() - eval_stats.accuracy())\n", |
201 | 207 | "eval_stats.plot_reliability_curves([0, 1, 2, \"top_class\"], display_weights=True)\n", |
| 208 | + "plt.show()\n", |
202 | 209 | "\n", |
203 | 210 | "\n", |
204 | 211 | "# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n", |
205 | 212 | "# theoretical_ece = compute_ECE(dirichlet_fc)[0]\n", |
206 | 213 | "# print(f\"{theoretical_acc=} , {theoretical_ece=}\")" |
207 | 214 | ] |
208 | 215 | }, |
209 | | - { |
210 | | - "cell_type": "code", |
211 | | - "execution_count": null, |
212 | | - "metadata": {}, |
213 | | - "outputs": [], |
214 | | - "source": [ |
215 | | - "print(\n", |
216 | | - " f\"\"\"\n", |
217 | | - "NOTE: here the ECE completely fails to converge to it's true, continuous value.\n", |
218 | | - "This is probably due to the binning-variance, see plots below with 500 bins.\n", |
219 | | - "The sharp peak in weights at the end certainly does not help convergence either.\n", |
220 | | - "\"\"\"\n", |
221 | | - ")\n", |
222 | | - "\n", |
223 | | - "eval_stats.plot_reliability_curves([\"top_class\"], n_bins=500, display_weights=True)" |
224 | | - ] |
225 | | - }, |
226 | 216 | { |
227 | 217 | "cell_type": "markdown", |
228 | 218 | "metadata": {}, |
|
265 | 255 | "\n", |
266 | 256 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", |
267 | 257 | "print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n", |
268 | | - "eval_stats.plot_reliability_curves([0, \"top_class\"], display_weights=True)" |
| 258 | + "eval_stats.plot_reliability_curves([0, \"top_class\"], display_weights=True)\n", |
| 259 | + "plt.show()" |
269 | 260 | ] |
270 | 261 | }, |
271 | 262 | { |
|
290 | 281 | "\n", |
291 | 282 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", |
292 | 283 | "print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n", |
293 | | - "eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)" |
| 284 | + "eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)\n", |
| 285 | + "plt.show()" |
294 | 286 | ] |
295 | 287 | }, |
296 | 288 | { |
|
313 | 305 | "\n", |
314 | 306 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", |
315 | 307 | "print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n", |
316 | | - "eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)" |
| 308 | + "eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)\n", |
| 309 | + "plt.show()" |
317 | 310 | ] |
318 | 311 | } |
319 | 312 | ], |
|
0 commit comments