diff --git a/doc/model_selection.rst b/doc/model_selection.rst new file mode 100644 index 000000000..31fad3ebb --- /dev/null +++ b/doc/model_selection.rst @@ -0,0 +1,128 @@ +.. _cross_validation: + +================ +Cross validation +================ + +.. currentmodule:: imblearn.model_selection + + +.. _instance_hardness_threshold_cv: + +The term instance hardness is used in literature to express the difficulty to correctly +classify an instance. An instance for which the predicted probability of the true class +is low, has large instance hardness. The way these hard-to-classify instances are +distributed over train and test sets in cross validation, has significant effect on the +test set performance metrics. The :class:`~imblearn.model_selection.InstanceHardnessCV` +splitter distributes samples with large instance hardness equally over the folds, +resulting in more robust cross validation. + +We will discuss instance hardness in this document and explain how to use the +:class:`~imblearn.model_selection.InstanceHardnessCV` splitter. + +Instance hardness and average precision +======================================= + +Instance hardness is defined as 1 minus the probability of the most probable class: + +.. math:: + + H(x) = 1 - P(\hat{y}|x) + +In this equation :math:`H(x)` is the instance hardness for a sample with features +:math:`x` and :math:`P(\hat{y}|x)` the probability of predicted label :math:`\hat{y}` +given the features. If the model predicts label 0 and gives a `predict_proba` output +of [0.9, 0.1], the probability of the most probable class (0) is 0.9 and the +instance hardness is `1-0.9=0.1`. + +Samples with large instance hardness have significant effect on the area under +precision-recall curve, or average precision. Especially samples with label 0 +with large instance hardness (so the model predicts label 1) reduce the average +precision a lot as these points affect the precision-recall curve in the left +where the area is largest; the precision is lowered in the range of low recall +and high thresholds. When doing cross validation, e.g. in case of hyperparameter +tuning or recursive feature elimination, random gathering of these points in +some folds introduce variance in CV results that deteriorates robustness of the +cross validation task. The :class:`~imblearn.model_selection.InstanceHardnessCV` +splitter aims to distribute the samples with large instance hardness over the +folds in order to reduce undesired variance. Note that one should use this +splitter to make model *selection* tasks robust like hyperparameter tuning and +feature selection but not for model *performance estimation* for which you also +want to know the variance of performance to be expected in production. + + +Create imbalanced dataset with samples with large instance hardness +=================================================================== + +Let's start by creating a dataset to work with. We create a dataset with 5% class +imbalance using scikit-learn's :func:`~sklearn.datasets.make_blobs` function. + + >>> import numpy as np + >>> from matplotlib import pyplot as plt + >>> from sklearn.datasets import make_blobs + >>> from imblearn.datasets import make_imbalance + >>> random_state = 10 + >>> X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)), + ... random_state=random_state) + >>> plt.scatter(X[:, 0], X[:, 1], c=y) + >>> plt.show() + +.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_001.png + :target: ./auto_examples/model_selection/plot_instance_hardness_cv.html + :align: center + +Now we add some samples with large instance hardness + + >>> X_hard, y_hard = make_blobs(n_samples=10, centers=((3, 0), (-3, 0)), + ... cluster_std=1, + ... random_state=random_state) + >>> X = np.vstack((X, X_hard)) + >>> y = np.hstack((y, y_hard)) + >>> plt.scatter(X[:, 0], X[:, 1], c=y) + >>> plt.show() + +.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_002.png + :target: ./auto_examples/model_selection/plot_instance_hardness_cv.html + :align: center + +Assess cross validation performance variance using `InstanceHardnessCV` splitter +================================================================================ + +Then we take a :class:`~sklearn.linear_model.LogisticRegression` and assess the +cross validation performance using a :class:`~sklearn.model_selection.StratifiedKFold` +cv splitter and the :func:`~sklearn.model_selection.cross_validate` function. + + >>> from sklearn.ensemble import LogisticRegressionClassifier + >>> clf = LogisticRegressionClassifier(random_state=random_state) + >>> skf_cv = StratifiedKFold(n_splits=5, shuffle=True, + ... random_state=random_state) + >>> skf_result = cross_validate(clf, X, y, cv=skf_cv, scoring="average_precision") + +Now, we do the same using an :class:`~imblearn.model_selection.InstanceHardnessCV` +splitter. We use provide our classifier to the splitter to calculate instance hardness +and distribute samples with large instance hardness equally over the folds. + + >>> ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5, + ... random_state=random_state) + >>> ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision") + +When we plot the test scores for both cv splitters, we see that the variance using the +:class:`~imblearn.model_selection.InstanceHardnessCV` splitter is lower than for the +:class:`~sklearn.model_selection.StratifiedKFold` splitter. + + >>> plt.boxplot([skf_result['test_score'], ih_result['test_score']], + ... tick_labels=["StratifiedKFold", "InstanceHardnessCV"], + ... vert=False) + >>> plt.xlabel('Average precision') + >>> plt.tight_layout() + +.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_003.png + :target: ./auto_examples/model_selection/plot_instance_hardness_cv.html + :align: center + +Be aware that the most important part of cross-validation splitters is to simulate the +conditions that one will encounter in production. Therefore, if it is likely to get +difficult samples in production, one should use a cross-validation splitter that +emulates this situation. In our case, the +:class:`~sklearn.model_selection.StratifiedKFold` splitter did not allow to distribute +the difficult samples over the folds and thus it was likely a problem for our use case. diff --git a/doc/references/index.rst b/doc/references/index.rst index f5fe3bf53..6b22f63a9 100644 --- a/doc/references/index.rst +++ b/doc/references/index.rst @@ -18,5 +18,6 @@ This is the full API documentation of the `imbalanced-learn` toolbox. miscellaneous pipeline metrics + model_selection datasets utils diff --git a/doc/references/model_selection.rst b/doc/references/model_selection.rst new file mode 100644 index 000000000..713781dd7 --- /dev/null +++ b/doc/references/model_selection.rst @@ -0,0 +1,23 @@ +.. _model_selection_ref: + +Model selection methods +======================= + +.. automodule:: imblearn.model_selection + :no-members: + :no-inherited-members: + +Cross-validation splitters +-------------------------- + +.. automodule:: imblearn.model_selection._split + :no-members: + :no-inherited-members: + +.. currentmodule:: imblearn.model_selection + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + InstanceHardnessCV diff --git a/doc/user_guide.rst b/doc/user_guide.rst index bfa8c00f9..9db06ca22 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -19,6 +19,7 @@ User Guide ensemble.rst miscellaneous.rst metrics.rst + model_selection.rst common_pitfalls.rst Dataset loading utilities developers_utils.rst diff --git a/doc/whats_new/0.14.rst b/doc/whats_new/0.14.rst index 2db60138e..2afdbf9bb 100644 --- a/doc/whats_new/0.14.rst +++ b/doc/whats_new/0.14.rst @@ -14,6 +14,10 @@ Bug fixes Enhancements ............ +- Add :class:`~imblearn.model_selection.InstanceHardnessCV` to split data and ensure + that samples are distributed in folds based on their instance hardness. + :pr:`1125` by :user:`Frits Hermans `. + Compatibility ............. diff --git a/examples/model_selection/plot_instance_hardness_cv.py b/examples/model_selection/plot_instance_hardness_cv.py new file mode 100644 index 000000000..2990bd6b2 --- /dev/null +++ b/examples/model_selection/plot_instance_hardness_cv.py @@ -0,0 +1,97 @@ +""" +==================================================== +Distribute hard-to-classify datapoints over CV folds +==================================================== + +'Instance hardness' refers to the difficulty to classify an instance. The way +hard-to-classify instances are distributed over train and test sets has +significant effect on the test set performance metrics. In this example we +show how to deal with this problem. We are making the comparison with normal +:class:`~sklearn.model_selection.StratifiedKFold` cross-validation splitter. +""" + +# Authors: Frits Hermans, https://fritshermans.github.io +# License: MIT + +# %% +print(__doc__) + +# %% +# Create an imbalanced dataset with instance hardness +# --------------------------------------------------- +# +# We create an imbalanced dataset with using scikit-learn's +# :func:`~sklearn.datasets.make_blobs` function and set the class imbalance ratio to +# 5%. +import numpy as np +from matplotlib import pyplot as plt +from sklearn.datasets import make_blobs + +X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)), random_state=10) +plt.scatter(X[:, 0], X[:, 1], c=y) + +# %% +# To introduce instance hardness in our dataset, we add some hard to classify samples: +X_hard, y_hard = make_blobs( + n_samples=10, centers=((3, 0), (-3, 0)), cluster_std=1, random_state=10 +) +X, y = np.vstack((X, X_hard)), np.hstack((y, y_hard)) +plt.scatter(X[:, 0], X[:, 1], c=y) + +# %% +# Compare cross validation scores using `StratifiedKFold` and `InstanceHardnessCV` +# -------------------------------------------------------------------------------- +# +# Now, we want to assess a linear predictive model. Therefore, we should use +# cross-validation. The most important concept with cross-validation is to create +# training and test splits that are representative of the the data in production to have +# statistical results that one can expect in production. +# +# By applying a standard :class:`~sklearn.model_selection.StratifiedKFold` +# cross-validation splitter, we do not control in which fold the hard-to-classify +# samples will be. +# +# The :class:`~imblearn.model_selection.InstanceHardnessCV` splitter allows to +# control the distribution of the hard-to-classify samples over the folds. +# +# Let's make an experiment to compare the results that we get with both splitters. +# We use a :class:`~sklearn.linear_model.LogisticRegression` classifier and +# :func:`~sklearn.model_selection.cross_validate` to calculate the cross validation +# scores. We use average precision for scoring. +import pandas as pd +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import StratifiedKFold, cross_validate + +from imblearn.model_selection import InstanceHardnessCV + +logistic_regression = LogisticRegression() + +results = {} +for cv in ( + StratifiedKFold(n_splits=5, shuffle=True, random_state=10), + InstanceHardnessCV(estimator=LogisticRegression(), n_splits=5, random_state=10), +): + result = cross_validate( + logistic_regression, + X, + y, + cv=cv, + scoring="average_precision", + ) + results[cv.__class__.__name__] = result["test_score"] +results = pd.DataFrame(results) + +# %% +ax = results.plot.box(vert=False, whis=[0, 100]) +ax.set( + xlabel="Average precision", + title="Cross validation scores with different splitters", + xlim=(0, 1), +) + +# %% +# The boxplot shows that the :class:`~imblearn.model_selection.InstanceHardnessCV` +# splitter results in less variation of average precision than +# :class:`~sklearn.model_selection.StratifiedKFold` splitter. When doing +# hyperparameter tuning or feature selection using a wrapper method (like +# :class:`~sklearn.feature_selection.RFECV`) this will give more stable results. diff --git a/imblearn/__init__.py b/imblearn/__init__.py index 8a8e7ee2d..df05d4b34 100644 --- a/imblearn/__init__.py +++ b/imblearn/__init__.py @@ -11,7 +11,7 @@ Module which provides methods generating an ensemble of under-sampled subsets. exceptions - Module including custom warnings and error clases used across + Module including custom warnings and error classes used across imbalanced-learn. keras Module which provides custom generator, layers for deep learning using @@ -19,6 +19,8 @@ metrics Module which provides metrics to quantified the classification performance with imbalanced dataset. +model_selection + Module which provides methods to split the dataset into training and test sets. over_sampling Module which provides methods to over-sample a dataset. tensorflow @@ -54,6 +56,7 @@ ensemble, exceptions, metrics, + model_selection, over_sampling, pipeline, tensorflow, @@ -113,6 +116,7 @@ def __dir__(self): "exceptions", "keras", "metrics", + "model_selection", "over_sampling", "tensorflow", "under_sampling", diff --git a/imblearn/model_selection/__init__.py b/imblearn/model_selection/__init__.py new file mode 100644 index 000000000..aa47b21a3 --- /dev/null +++ b/imblearn/model_selection/__init__.py @@ -0,0 +1,8 @@ +""" +The :mod:`imblearn.model_selection` provides methods to split the dataset into +training and test sets. +""" + +from ._split import InstanceHardnessCV + +__all__ = ["InstanceHardnessCV"] diff --git a/imblearn/model_selection/_split.py b/imblearn/model_selection/_split.py new file mode 100644 index 000000000..fc28cbcb4 --- /dev/null +++ b/imblearn/model_selection/_split.py @@ -0,0 +1,122 @@ +import warnings + +import numpy as np +from sklearn.base import clone +from sklearn.model_selection import LeaveOneGroupOut, cross_val_predict +from sklearn.model_selection._split import BaseCrossValidator +from sklearn.utils.multiclass import type_of_target +from sklearn.utils.validation import _num_samples + + +class InstanceHardnessCV(BaseCrossValidator): + """Instance-hardness cross-validation splitter. + + Cross-validation splitter that distributes samples with large instance hardness + equally over the folds. The instance hardness is internally estimated by using + `estimator` and stratified cross-validation. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object + Classifier to be used to estimate instance hardness of the samples. + This classifier should implement `predict_proba`. + + n_splits : int, default=5 + Number of folds. Must be at least 2. + + pos_label : int, float, bool or str, default=None + The class considered the positive class when selecting the probability + representing the instance hardness. If None, the positive class is + automatically inferred by the estimator as `estimator.classes_[1]`. + + Examples + -------- + >>> from imblearn.model_selection import InstanceHardnessCV + >>> from sklearn.datasets import make_classification + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(weights=[0.9, 0.1], class_sep=2, + ... n_informative=3, n_redundant=1, flip_y=0.05, n_samples=1000, random_state=10) + >>> estimator = LogisticRegression() + >>> ih_cv = InstanceHardnessCV(estimator) + >>> cv_result = cross_validate(estimator, X, y, cv=ih_cv) + >>> print(f"Standard deviation of test_scores: {cv_result['test_score'].std():.3f}") + Standard deviation of test_scores: 0.00... + """ + + def __init__(self, estimator, *, n_splits=5, pos_label=None): + self.estimator = estimator + self.n_splits = n_splits + self.pos_label = pos_label + + def split(self, X, y, groups=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data, where `n_samples` is the number of samples + and `n_features` is the number of features. + + y : array-like of shape (n_samples,) + The target variable for supervised learning problems. + + groups : object + Always ignored, exists for compatibility. + + Yields + ------ + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + if groups is not None: + warnings.warn( + f"The groups parameter is ignored by {self.__class__.__name__}", + UserWarning, + ) + + classes = np.unique(y) + y_type = type_of_target(y) + if y_type != "binary": + raise ValueError("InstanceHardnessCV only supports binary classification.") + if self.pos_label is None: + pos_label = 1 + else: + pos_label = np.flatnonzero(classes == self.pos_label)[0] + + y_proba = cross_val_predict( + clone(self.estimator), X, y, cv=self.n_splits, method="predict_proba" + ) + # sorting first on y and then by the instance hardness + sorted_indices = np.lexsort((y_proba[:, pos_label], y)) + groups = np.empty(_num_samples(X), dtype=int) + groups[sorted_indices] = np.arange(_num_samples(X)) % self.n_splits + cv = LeaveOneGroupOut() + for train_index, test_index in cv.split(X, y, groups): + yield train_index, test_index + + def get_n_splits(self, X=None, y=None, groups=None): + """Returns the number of splitting iterations in the cross-validator. + + Parameters + ---------- + X: object + Always ignored, exists for compatibility. + + y: object + Always ignored, exists for compatibility. + + groups: object + Always ignored, exists for compatibility. + + Returns + ------- + n_splits: int + Returns the number of splitting iterations in the cross-validator. + """ + return self.n_splits diff --git a/imblearn/model_selection/tests/__init__.py b/imblearn/model_selection/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/imblearn/model_selection/tests/test_split.py b/imblearn/model_selection/tests/test_split.py new file mode 100644 index 000000000..1b1d94e91 --- /dev/null +++ b/imblearn/model_selection/tests/test_split.py @@ -0,0 +1,99 @@ +import numpy as np +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import make_scorer, precision_score +from sklearn.model_selection import cross_validate +from sklearn.utils._testing import assert_allclose + +from imblearn.model_selection import InstanceHardnessCV + + +@pytest.fixture +def data(): + return make_classification( + weights=[0.5, 0.5], + class_sep=0.5, + n_informative=3, + n_redundant=1, + flip_y=0.05, + n_samples=50, + random_state=10, + ) + + +def test_groups_parameter_warning(data): + """Test that a warning is raised when groups parameter is provided.""" + X, y = data + ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), n_splits=3) + + warning_msg = "The groups parameter is ignored by InstanceHardnessCV" + with pytest.warns(UserWarning, match=warning_msg): + list(ih_cv.split(X, y, groups=np.ones_like(y))) + + +def test_error_on_multiclass(): + """Test that an error is raised when the target is not binary.""" + X, y = make_classification(n_classes=3, n_clusters_per_class=1) + err_msg = "InstanceHardnessCV only supports binary classification." + with pytest.raises(ValueError, match=err_msg): + next(InstanceHardnessCV(estimator=LogisticRegression()).split(X, y)) + + +def test_default_params(data): + """Test that the default parameters are used.""" + X, y = data + ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), n_splits=3) + cv_result = cross_validate( + LogisticRegression(), X, y, cv=ih_cv, scoring="precision" + ) + assert_allclose(cv_result["test_score"], [0.625, 0.6, 0.625], atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize("dtype_target", [None, object]) +def test_target_string_labels(data, dtype_target): + """Test that the target can be a string array.""" + X, y = data + labels = np.array(["a", "b"], dtype=dtype_target) + y = labels[y] + ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), n_splits=3) + cv_result = cross_validate( + LogisticRegression(), + X, + y, + cv=ih_cv, + scoring=make_scorer(precision_score, pos_label="b"), + ) + assert_allclose(cv_result["test_score"], [0.625, 0.6, 0.625], atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize("dtype_target", [None, object]) +def test_target_string_pos_label(data, dtype_target): + """Test that the `pos_label` parameter can be used to select the positive class. + + Here, changing the `pos_label` will change the instance hardness and thus the + `cv_result`. + """ + X, y = data + labels = np.array(["a", "b"], dtype=dtype_target) + y = labels[y] + ih_cv = InstanceHardnessCV( + estimator=LogisticRegression(), pos_label="a", n_splits=3 + ) + cv_result = cross_validate( + LogisticRegression(), + X, + y, + cv=ih_cv, + scoring=make_scorer(precision_score, pos_label="a"), + ) + assert_allclose( + cv_result["test_score"], [0.666667, 0.666667, 0.4], atol=1e-6, rtol=1e-6 + ) + + +@pytest.mark.parametrize("n_splits", [2, 3, 4]) +def test_n_splits(n_splits): + """Test that the number of splits is correctly set.""" + ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), n_splits=n_splits) + assert ih_cv.get_n_splits() == n_splits