11{
22 "cells" : [
3+ {
4+ "cell_type" : " code" ,
5+ "execution_count" : null ,
6+ "metadata" : {},
7+ "outputs" : [],
8+ "source" : [
9+ " %load_ext autoreload\n " ,
10+ " %autoreload 2"
11+ ]
12+ },
313 {
414 "cell_type" : " code" ,
515 "execution_count" : null ,
5060 "outputs" : [],
5161 "source" : [
5262 " n_samples = 2000\n " ,
53- " n_classes = 2 "
63+ " n_classes = 3 "
5464 ]
5565 },
5666 {
6878 " random_state=42,\n " ,
6979 " )\n " ,
7080 " X_train, X_test, y_train, y_test = train_test_split(\n " ,
71- " X, y, test_size=0.2 , random_state=42\n " ,
81+ " X, y, test_size=0.5 , random_state=42\n " ,
7282 " )"
7383 ]
7484 },
8595 "metadata" : {},
8696 "outputs" : [],
8797 "source" : [
88- " model = MLPClassifier(hidden_layer_sizes=(50, 50, 50 ))\n " ,
98+ " model = MLPClassifier(hidden_layer_sizes=(20, 20, 10 ))\n " ,
8999 " model.fit(X_train, y_train)"
90100 ]
91101 },
181191 "metadata" : {},
182192 "outputs" : [],
183193 "source" : [
184- " ece = ECE(bins=10 )"
194+ " ece = ECE(bins=12 )"
185195 ]
186196 },
187197 {
191201 "outputs" : [],
192202 "source" : [
193203 " # Evaluate uncalibrated predictions\n " ,
194- " uncalibrated_confidences = model.predict_proba(X_test)\n " ,
204+ " y_pred = model.predict_proba(X_test)\n " ,
195205 " \n " ,
196- " pre_calibration_ece = ece.compute(uncalibrated_confidences , y_test)\n " ,
206+ " pre_calibration_ece = ece.compute(y_pred , y_test)\n " ,
197207 " \n " ,
198208 " f\" ECE before calibration: {pre_calibration_ece}\" "
199209 ]
212222 "metadata" : {},
213223 "outputs" : [],
214224 "source" : [
215- " eval_stats = EvalStats(y_test, uncalibrated_confidences)\n " ,
216- " class_labels = [i for i in range(n_classes)]\n " ,
217- " \n " ,
218- " eval_stats.plot_reliability_curves(class_labels)"
225+ " eval_stats = EvalStats(y_test, y_pred)\n " ,
226+ " class_labels = range(n_classes)"
227+ ]
228+ },
229+ {
230+ "cell_type" : " code" ,
231+ "execution_count" : null ,
232+ "metadata" : {},
233+ "outputs" : [],
234+ "source" : [
235+ " fig = eval_stats.plot_reliability_curves(\n " ,
236+ " [\" top_class\" , 0], display_weights=True, strategy=\" uniform\" , n_bins=8\n " ,
237+ " )"
238+ ]
239+ },
240+ {
241+ "cell_type" : " markdown" ,
242+ "metadata" : {},
243+ "source" : [
244+ " The density of predictions is distributed highly inhomogeneously on the unit interval, some bins have\n " ,
245+ " few members and the estimate of the reliability has high variance. This can be helped by employing\n " ,
246+ " the \" quantile\" binning strategy, also called adaptive binning"
247+ ]
248+ },
249+ {
250+ "cell_type" : " code" ,
251+ "execution_count" : null ,
252+ "metadata" : {},
253+ "outputs" : [],
254+ "source" : [
255+ " fig = eval_stats.plot_reliability_curves(\n " ,
256+ " [0, \" top_class\" ], display_weights=True, n_bins=8, strategy=\" quantile\"\n " ,
257+ " )"
258+ ]
259+ },
260+ {
261+ "cell_type" : " markdown" ,
262+ "metadata" : {},
263+ "source" : [
264+ " Now all bins have the same weight but different width. The pointwise reliability estimates\n " ,
265+ " have lower variance but there are wide gaps, thus requiring more interpolation.\n " ,
266+ " Both binning strategies have their advantages and disadvantages."
219267 ]
220268 },
221269 {
455503 "source" : [
456504 " ece.compute(confidences, ground_truth)"
457505 ]
458- },
459- {
460- "cell_type" : " markdown" ,
461- "metadata" : {},
462- "source" : [
463- " Once again, to verify that miscalibration will indeed increase with more samples, let's sample *5x* as many samples as\n " ,
464- " before and measure $ECE$ again:"
465- ]
466- },
467- {
468- "cell_type" : " code" ,
469- "execution_count" : null ,
470- "metadata" : {},
471- "outputs" : [],
472- "source" : [
473- " uncalibrated_samples = shifted_sampler.get_sample_arrays(1000)\n " ,
474- " ground_truth, confidences = uncalibrated_samples\n " ,
475- " \n " ,
476- " ece.compute(confidences, ground_truth)"
477- ]
478- },
479- {
480- "cell_type" : " markdown" ,
481- "metadata" : {},
482- "source" : [
483- " Great! Calibration error goes up as we sample more instances."
484- ]
485506 }
486507 ],
487508 "metadata" : {
488509 "kernelspec" : {
489- "display_name" : " Python 3" ,
510+ "display_name" : " Python 3 (ipykernel) " ,
490511 "language" : " python" ,
491512 "name" : " python3"
492513 },
500521 "name" : " python" ,
501522 "nbconvert_exporter" : " python" ,
502523 "pygments_lexer" : " ipython3" ,
503- "version" : " 3.8.5 "
524+ "version" : " 3.8.13 "
504525 }
505526 },
506527 "nbformat" : 4 ,
507528 "nbformat_minor" : 1
508- }
529+ }
0 commit comments