Skip to content

Commit 62918c1

Browse files
committed
Minor fixes in fake_classifiers nb
1 parent 37d5512 commit 62918c1

File tree

1 file changed

+28
-35
lines changed

1 file changed

+28
-35
lines changed

notebooks/fake_classifiers.ipynb

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9+
"import numpy as np\n",
10+
"\n",
911
"%load_ext autoreload\n",
1012
"%autoreload 2"
1113
]
@@ -18,13 +20,10 @@
1820
"source": [
1921
"import logging\n",
2022
"\n",
21-
"from kyle.evaluation import (\n",
22-
" EvalStats,\n",
23-
" compute_accuracy,\n",
24-
" compute_ECE,\n",
25-
")\n",
23+
"from kyle.evaluation import EvalStats\n",
2624
"from kyle.sampling.fake_clf import DirichletFC\n",
2725
"from kyle.transformations import *\n",
26+
"import matplotlib.pyplot as plt\n",
2827
"\n",
2928
"logging.basicConfig(level=logging.INFO)"
3029
]
@@ -71,22 +70,27 @@
7170
"outputs": [],
7271
"source": [
7372
"print(\n",
74-
" \"mostly underestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n",
73+
" \"mostly overestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n",
7574
")\n",
7675
"transform = PowerLawSimplexAut(np.array([2, 2, 2]))\n",
7776
"dirichlet_fc.set_simplex_automorphism(transform)\n",
77+
"\n",
78+
"\n",
7879
"eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n",
7980
"\n",
8081
"print(f\"Accuracy is {eval_stats.accuracy()}\")\n",
8182
"print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n",
82-
"ece_approx = -eval_stats.expected_confidence() + eval_stats.accuracy()\n",
83+
"ece_approx = eval_stats.expected_confidence() - eval_stats.accuracy()\n",
8384
"print(f\"{ece_approx=}\")\n",
84-
"eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n",
85+
"eval_stats.plot_reliability_curves(\n",
86+
" [0, 1, \"top_class\"], display_weights=True, n_bins=200\n",
87+
")\n",
88+
"plt.show()\n",
8589
"\n",
8690
"\n",
87-
"theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n",
88-
"theoretical_ece = compute_ECE(dirichlet_fc)[0]\n",
89-
"print(f\"{theoretical_acc=} , {theoretical_ece=}\")"
91+
"# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n",
92+
"# theoretical_ece = compute_ECE(dirichlet_fc)[0]\n",
93+
"# print(f\"{theoretical_acc=} , {theoretical_ece=}\")"
9094
]
9195
},
9296
{
@@ -96,7 +100,7 @@
96100
"outputs": [],
97101
"source": [
98102
"print(\n",
99-
" \"mostly overestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n",
103+
" \"mostly underestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n",
100104
")\n",
101105
"print(\"Note the variance and the resulting sensitivity to binning\")\n",
102106
"\n",
@@ -106,9 +110,10 @@
106110
"\n",
107111
"print(f\"Accuracy is {eval_stats.accuracy()}\")\n",
108112
"print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n",
109-
"ece_approx = eval_stats.expected_confidence() - eval_stats.accuracy()\n",
113+
"ece_approx = -eval_stats.expected_confidence() + eval_stats.accuracy()\n",
110114
"print(f\"{ece_approx=}\")\n",
111115
"eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n",
116+
"plt.show()\n",
112117
"\n",
113118
"\n",
114119
"# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n",
@@ -122,7 +127,7 @@
122127
"metadata": {},
123128
"outputs": [],
124129
"source": [
125-
"print(\"Overestimating predictions with MaxComponent\")\n",
130+
"print(\"Underestimating predictions with MaxComponent\")\n",
126131
"\n",
127132
"\n",
128133
"def overestimating_max(x: np.ndarray):\n",
@@ -139,6 +144,7 @@
139144
"print(f\"Accuracy is {eval_stats.accuracy()}\")\n",
140145
"print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n",
141146
"eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n",
147+
"plt.show()\n",
142148
"\n",
143149
"# Integrals converge pretty slowly, this takes time\n",
144150
"# theoretical_acc = compute_accuracy(dirichlet_fc, opts={\"limit\": 75})[0]\n",
@@ -188,7 +194,7 @@
188194
"metadata": {},
189195
"outputs": [],
190196
"source": [
191-
"print(\"mostly underestimating first two classes with RestrictedPowerSimplexAut\")\n",
197+
"print(\"mostly overestimating first two classes with RestrictedPowerSimplexAut\")\n",
192198
"\n",
193199
"transform = RestrictedPowerSimplexAut(np.array([2, 4]))\n",
194200
"dirichlet_fc.set_simplex_automorphism(transform)\n",
@@ -199,30 +205,14 @@
199205
"print(\"Theoretical approximation of ECE\")\n",
200206
"print(eval_stats.expected_confidence() - eval_stats.accuracy())\n",
201207
"eval_stats.plot_reliability_curves([0, 1, 2, \"top_class\"], display_weights=True)\n",
208+
"plt.show()\n",
202209
"\n",
203210
"\n",
204211
"# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n",
205212
"# theoretical_ece = compute_ECE(dirichlet_fc)[0]\n",
206213
"# print(f\"{theoretical_acc=} , {theoretical_ece=}\")"
207214
]
208215
},
209-
{
210-
"cell_type": "code",
211-
"execution_count": null,
212-
"metadata": {},
213-
"outputs": [],
214-
"source": [
215-
"print(\n",
216-
" f\"\"\"\n",
217-
"NOTE: here the ECE completely fails to converge to it's true, continuous value.\n",
218-
"This is probably due to the binning-variance, see plots below with 500 bins.\n",
219-
"The sharp peak in weights at the end certainly does not help convergence either.\n",
220-
"\"\"\"\n",
221-
")\n",
222-
"\n",
223-
"eval_stats.plot_reliability_curves([\"top_class\"], n_bins=500, display_weights=True)"
224-
]
225-
},
226216
{
227217
"cell_type": "markdown",
228218
"metadata": {},
@@ -265,7 +255,8 @@
265255
"\n",
266256
"print(f\"Accuracy is {eval_stats.accuracy()}\")\n",
267257
"print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n",
268-
"eval_stats.plot_reliability_curves([0, \"top_class\"], display_weights=True)"
258+
"eval_stats.plot_reliability_curves([0, \"top_class\"], display_weights=True)\n",
259+
"plt.show()"
269260
]
270261
},
271262
{
@@ -290,7 +281,8 @@
290281
"\n",
291282
"print(f\"Accuracy is {eval_stats.accuracy()}\")\n",
292283
"print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n",
293-
"eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)"
284+
"eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)\n",
285+
"plt.show()"
294286
]
295287
},
296288
{
@@ -313,7 +305,8 @@
313305
"\n",
314306
"print(f\"Accuracy is {eval_stats.accuracy()}\")\n",
315307
"print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n",
316-
"eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)"
308+
"eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)\n",
309+
"plt.show()"
317310
]
318311
}
319312
],

0 commit comments

Comments
 (0)