Skip to content

Commit 5858392

Browse files
authored
Exposed get_validation_indices and added an LOOCV test case (#8)
Co-authored-by: Ole-Christian Galbo Engstrøm <ocge@foss.dk>
1 parent 9aae748 commit 5858392

File tree

4 files changed

+75
-23
lines changed

4 files changed

+75
-23
lines changed

cvmatrix/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.1.0.post1"
1+
__version__ = "2.1.1"

cvmatrix/cvmatrix.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class CVMatrix:
2727
----------
2828
folds : Iterable of Hashable with N elements
2929
An iterable defining cross-validation splits. Each unique value in
30-
`folds` corresponds to a different fold.
30+
`folds` corresponds to a different fold. The validation indices for each fold
31+
can be accessed using the `get_validation_indices` method.
3132
3233
center_X : bool, optional, default=True
3334
Whether to center `X` before computation of
@@ -348,7 +349,7 @@ def training_statistics(self, fold: Hashable) -> Tuple[
348349
If `fold` was not provided as a cross-validation split in the
349350
`folds` parameter of the constructor.
350351
"""
351-
val_indices = self._get_val_indices(fold)
352+
val_indices = self.get_validation_indices(fold)
352353
X_val, X_val_unweighted, Y_val, Y_val_unweighted = self._get_val_matrices(
353354
val_indices=val_indices, return_XTY=self.Y_total is not None
354355
)
@@ -366,6 +367,32 @@ def training_statistics(self, fold: Hashable) -> Tuple[
366367
:-1
367368
] # Exclude the sum of training weights from the return tuple
368369

370+
def get_validation_indices(self, fold: Hashable) -> npt.NDArray[np.int_]:
371+
"""
372+
Returns the indices of the validation set samples for a given fold.
373+
374+
Parameters
375+
----------
376+
fold : Hashable
377+
The fold for which to return the validation set indices.
378+
379+
Returns
380+
-------
381+
Array of shape (N_val,)
382+
The indices of the validation set samples for the given fold.
383+
384+
Raises
385+
------
386+
ValueError
387+
If `fold` was not provided as a cross-validation split in the
388+
`folds` parameter of the constructor.
389+
"""
390+
try:
391+
val_indices = self.folds_dict[fold]
392+
except KeyError as e:
393+
raise ValueError(f"Fold {fold} not found.") from e
394+
return val_indices
395+
369396
def _get_sum_w_train_and_num_nonzero_w_train(
370397
self, val_indices: npt.NDArray[np.int_]
371398
) -> Tuple[float, float]:
@@ -589,7 +616,7 @@ def _training_matrices(
589616
)
590617
if return_XTY and self.Y_total is None:
591618
raise ValueError("Response variables `Y` are not provided.")
592-
val_indices = self._get_val_indices(fold)
619+
val_indices = self.get_validation_indices(fold)
593620
X_val, X_val_unweighted, Y_val, Y_val_unweighted = self._get_val_matrices(
594621
val_indices=val_indices, return_XTY=return_XTY
595622
)
@@ -680,24 +707,6 @@ def _training_matrices(
680707
stats_tuple,
681708
)
682709

683-
def _get_val_indices(self, fold: Hashable) -> npt.NDArray[np.int_]:
684-
"""
685-
Returns the indices of the validation set samples for a given fold.
686-
Parameters
687-
----------
688-
fold : Hashable
689-
The fold for which to return the validation set indices.
690-
Returns
691-
-------
692-
Array of shape (N_val,)
693-
The indices of the validation set samples for the given fold.
694-
"""
695-
try:
696-
val_indices = self.folds_dict[fold]
697-
except KeyError as e:
698-
raise ValueError(f"Fold {fold} not found.") from e
699-
return val_indices
700-
701710
def _get_val_matrices(
702711
self, val_indices: npt.NDArray[np.int_], return_XTY: bool
703712
) -> Tuple[

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "cvmatrix"
3-
version = "2.1.0.post1"
3+
version = "2.1.1"
44
description = "Fast computation of possibly weighted and possibly centered/scaled training set kernel matrices in a cross-validation setting."
55
authors = ["Sm00thix <oleemail@icloud.com>"]
66
maintainers = ["Sm00thix <oleemail@icloud.com>"]

tests/test_cvmatrix.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,3 +1139,46 @@ def test_statistics_cvmatrix_methods(self):
11391139
err_msg="Statistics from training_statistics and "
11401140
"training_XTX methods are not equivalent." + diagnostic_msg,
11411141
)
1142+
1143+
def test_loocv(self):
1144+
"""
1145+
Tests if the matrices computed by the NaiveCVMatrix and CVMatrix models are
1146+
equivalent when using Leave-One-Out Cross-Validation (LOOCV).
1147+
"""
1148+
X = self.load_X()[:, :5]
1149+
Ys = [None, self.load_Y(["Protein", "Moisture"])]
1150+
folds = np.arange(X.shape[0])
1151+
center_Xs = [False, True]
1152+
center_Ys = [False, True]
1153+
scale_Xs = [False, True]
1154+
scale_Ys = [False, True]
1155+
ddofs = [0, 1]
1156+
use_weights = [False, True]
1157+
for center_X, center_Y, scale_X, scale_Y, use_w, ddof, Y in product(
1158+
center_Xs, center_Ys, scale_Xs, scale_Ys, use_weights, ddofs, Ys
1159+
):
1160+
diagnostic_msg = (
1161+
f"center_X: {center_X}, center_Y: {center_Y}, "
1162+
f"scale_X: {scale_X}, scale_Y: {scale_Y}, "
1163+
f"ddof: {ddof}, use_weights: {use_w}, use_Y: {Y is not None}"
1164+
)
1165+
if use_w:
1166+
weights = self.randomly_zero_weights(self.load_weights(random=True))
1167+
else:
1168+
weights = None
1169+
naive, fast = self.fit_models(
1170+
X,
1171+
Y,
1172+
weights,
1173+
folds,
1174+
center_X,
1175+
center_Y,
1176+
scale_X,
1177+
scale_Y,
1178+
ddof,
1179+
np.float64,
1180+
)
1181+
print(diagnostic_msg)
1182+
# Extract 20 unique folds from the folds array.
1183+
subset_folds = np.random.choice(np.unique(folds), size=20, replace=False)
1184+
self.check_equivalent_matrices(naive, fast, subset_folds)

0 commit comments

Comments
 (0)