@@ -27,7 +27,8 @@ class CVMatrix:
2727 ----------
2828 folds : Iterable of Hashable with N elements
2929 An iterable defining cross-validation splits. Each unique value in
30- `folds` corresponds to a different fold.
30+ `folds` corresponds to a different fold. The validation indices for each fold
31+ can be accessed using the `get_validation_indices` method.
3132
3233 center_X : bool, optional, default=True
3334 Whether to center `X` before computation of
@@ -348,7 +349,7 @@ def training_statistics(self, fold: Hashable) -> Tuple[
348349 If `fold` was not provided as a cross-validation split in the
349350 `folds` parameter of the constructor.
350351 """
351- val_indices = self ._get_val_indices (fold )
352+ val_indices = self .get_validation_indices (fold )
352353 X_val , X_val_unweighted , Y_val , Y_val_unweighted = self ._get_val_matrices (
353354 val_indices = val_indices , return_XTY = self .Y_total is not None
354355 )
@@ -366,6 +367,32 @@ def training_statistics(self, fold: Hashable) -> Tuple[
366367 :- 1
367368 ] # Exclude the sum of training weights from the return tuple
368369
370+ def get_validation_indices (self , fold : Hashable ) -> npt .NDArray [np .int_ ]:
371+ """
372+ Returns the indices of the validation set samples for a given fold.
373+
374+ Parameters
375+ ----------
376+ fold : Hashable
377+ The fold for which to return the validation set indices.
378+
379+ Returns
380+ -------
381+ Array of shape (N_val,)
382+ The indices of the validation set samples for the given fold.
383+
384+ Raises
385+ ------
386+ ValueError
387+ If `fold` was not provided as a cross-validation split in the
388+ `folds` parameter of the constructor.
389+ """
390+ try :
391+ val_indices = self .folds_dict [fold ]
392+ except KeyError as e :
393+ raise ValueError (f"Fold { fold } not found." ) from e
394+ return val_indices
395+
369396 def _get_sum_w_train_and_num_nonzero_w_train (
370397 self , val_indices : npt .NDArray [np .int_ ]
371398 ) -> Tuple [float , float ]:
@@ -589,7 +616,7 @@ def _training_matrices(
589616 )
590617 if return_XTY and self .Y_total is None :
591618 raise ValueError ("Response variables `Y` are not provided." )
592- val_indices = self ._get_val_indices (fold )
619+ val_indices = self .get_validation_indices (fold )
593620 X_val , X_val_unweighted , Y_val , Y_val_unweighted = self ._get_val_matrices (
594621 val_indices = val_indices , return_XTY = return_XTY
595622 )
@@ -680,24 +707,6 @@ def _training_matrices(
680707 stats_tuple ,
681708 )
682709
683- def _get_val_indices (self , fold : Hashable ) -> npt .NDArray [np .int_ ]:
684- """
685- Returns the indices of the validation set samples for a given fold.
686- Parameters
687- ----------
688- fold : Hashable
689- The fold for which to return the validation set indices.
690- Returns
691- -------
692- Array of shape (N_val,)
693- The indices of the validation set samples for the given fold.
694- """
695- try :
696- val_indices = self .folds_dict [fold ]
697- except KeyError as e :
698- raise ValueError (f"Fold { fold } not found." ) from e
699- return val_indices
700-
701710 def _get_val_matrices (
702711 self , val_indices : npt .NDArray [np .int_ ], return_XTY : bool
703712 ) -> Tuple [
0 commit comments