1010E-mail: ocge@foss.dk
1111"""
1212
13- from typing import Tuple , Union
13+ from typing import Optional , Tuple , Union
1414
1515import numpy as np
1616from numpy import typing as npt
@@ -185,8 +185,8 @@ def __init__(
185185 def fit (
186186 self ,
187187 X : npt .ArrayLike ,
188- Y : Union [ None , npt .ArrayLike ] = None ,
189- weights : Union [ None , npt .ArrayLike ] = None ,
188+ Y : Optional [ npt .ArrayLike ] = None ,
189+ weights : Optional [ npt .ArrayLike ] = None ,
190190 ) -> None :
191191 """
192192 Loads and stores `X`, `Y`, and "weights", for cross-validation. Computes
@@ -208,6 +208,10 @@ def fit(
208208 Weights for each sample in `X` and `Y`. If `None`, no weights are used in
209209 the computations. If provided, the weights must be non-negative.
210210
211+ Returns
212+ -------
213+ None
214+
211215 Raises
212216 ------
213217 ValueError
@@ -221,7 +225,7 @@ def fit(
221225 def training_XTX (
222226 self , validation_indices : npt .NDArray [np .int_ ]
223227 ) -> Tuple [
224- np .ndarray , Tuple [Union [ None , np .ndarray ], Union [ None , np .ndarray ], None , None ]
228+ np .ndarray , Tuple [Optional [ np .ndarray ], Optional [ np .ndarray ], None , None ]
225229 ]:
226230 """
227231 Computes the training set :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}`
@@ -240,11 +244,11 @@ def training_XTX(
240244
241245 Returns
242246 -------
243- Tuple of two elements. The first element is an array of shape (K, K)
244- corresponding to the training set
245- :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}`. The second element is
246- a tuple containing the row of column-wise weighted means for `X`, the row
247- of column-wise weighted standard deviations for `X`, and two `None`
247+ Tuple of two elements.
248+ The first element is an array of shape (K, K) corresponding to the training
249+ set :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}`. The second element
250+ is a tuple containing the row of column-wise weighted means for `X`, the
251+ row of column-wise weighted standard deviations for `X`, and two `None`
248252 corresponding to the non-computed rows of column-wise weighted means and
249253 standard deviations for `Y`. If a statistic is not computed, it is `None`.
250254
@@ -280,10 +284,10 @@ def training_XTX(
280284 def training_XTY (self , validation_indices : npt .NDArray [np .int_ ]) -> Tuple [
281285 np .ndarray ,
282286 Tuple [
283- Union [ None , np .ndarray ],
284- Union [ None , np .ndarray ],
285- Union [ None , np .ndarray ],
286- Union [ None , np .ndarray ],
287+ Optional [ np .ndarray ],
288+ Optional [ np .ndarray ],
289+ Optional [ np .ndarray ],
290+ Optional [ np .ndarray ],
287291 ],
288292 ]:
289293 """
@@ -305,9 +309,9 @@ def training_XTY(self, validation_indices: npt.NDArray[np.int_]) -> Tuple[
305309
306310 Returns
307311 -------
308- Tuple of two elements. The first element is an array of shape (K, M)
309- corresponding to the training set
310- :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{Y}`. The second element
312+ Tuple of two elements.
313+ The first element is an array of shape (K, M) corresponding to the training
314+ set :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{Y}`. The second element
311315 is a tuple containing the row of column-wise weighted means for `X`, the
312316 row of column-wise weighted standard deviations for `X`, the row of
313317 column-wise weighted means for `Y`, and the row of column-wise weighted
@@ -344,10 +348,10 @@ def training_XTY(self, validation_indices: npt.NDArray[np.int_]) -> Tuple[
344348 def training_XTX_XTY (self , validation_indices : npt .NDArray [np .int_ ]) -> Tuple [
345349 Tuple [np .ndarray , np .ndarray ],
346350 Tuple [
347- Union [ None , np .ndarray ],
348- Union [ None , np .ndarray ],
349- Union [ None , np .ndarray ],
350- Union [ None , np .ndarray ],
351+ Optional [ np .ndarray ],
352+ Optional [ np .ndarray ],
353+ Optional [ np .ndarray ],
354+ Optional [ np .ndarray ],
351355 ],
352356 ]:
353357 """
@@ -372,9 +376,9 @@ def training_XTX_XTY(self, validation_indices: npt.NDArray[np.int_]) -> Tuple[
372376
373377 Returns
374378 -------
375- Tuple of two tuples. The first tuple contains arrays of shapes (K, K) and
376- (K, M). These are the training set
377- :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}` and
379+ Tuple of two tuples.
380+ The first tuple contains arrays of shapes (K, K) and (K, M). These are the
381+ training set :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}` and
378382 :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{Y}`. The second tuple
379383 contains the row of column-wise weighted means for `X`, the row of
380384 column-wise weighted standard deviations for `X`, the row of column-wise
@@ -408,10 +412,10 @@ def training_XTX_XTY(self, validation_indices: npt.NDArray[np.int_]) -> Tuple[
408412 return self ._training_matrices (True , True , validation_indices )
409413
410414 def training_statistics (self , validation_indices : npt .NDArray [np .int_ ]) -> Tuple [
411- Union [ None , np .ndarray ],
412- Union [ None , np .ndarray ],
413- Union [ None , np .ndarray ],
414- Union [ None , np .ndarray ],
415+ Optional [ np .ndarray ],
416+ Optional [ np .ndarray ],
417+ Optional [ np .ndarray ],
418+ Optional [ np .ndarray ],
415419 ]:
416420 """
417421 Computes the row of column-wise weighted means and standard deviations for `X`
@@ -430,7 +434,7 @@ def training_statistics(self, validation_indices: npt.NDArray[np.int_]) -> Tuple
430434
431435 Returns
432436 -------
433- Tuple of four elements of Union[None, np.ndarray]
437+ Tuple of four elements of Optional[ np.ndarray]
434438 A tuple containing the row of column-wise weighted means for `X`, the row
435439 of column-wise weighted standard deviations for `X`, the row of column-wise
436440 weighted means for `Y`, and the row of column-wise weighted standard
@@ -509,20 +513,20 @@ def _get_sum_w_train_and_num_nonzero_w_train(
509513 def _compute_training_stats (
510514 self ,
511515 val_indices : npt .NDArray [np .int_ ],
512- X_val : Union [ None , np .ndarray ],
513- X_val_unweighted : Union [ None , np .ndarray ],
514- Y_val : Union [ None , np .ndarray ],
515- Y_val_unweighted : Union [ None , np .ndarray ],
516+ X_val : Optional [ np .ndarray ],
517+ X_val_unweighted : Optional [ np .ndarray ],
518+ Y_val : Optional [ np .ndarray ],
519+ Y_val_unweighted : Optional [ np .ndarray ],
516520 return_X_mean : bool ,
517521 return_X_std : bool ,
518522 return_Y_mean : bool ,
519523 return_Y_std : bool ,
520524 ) -> Tuple [
521- Union [ None , np .ndarray ],
522- Union [ None , np .ndarray ],
523- Union [ None , np .ndarray ],
524- Union [ None , np .ndarray ],
525- Union [ None , float ],
525+ Optional [ np .ndarray ],
526+ Optional [ np .ndarray ],
527+ Optional [ np .ndarray ],
528+ Optional [ np .ndarray ],
529+ Optional [ float ],
526530 ]:
527531 """
528532 Computes the training set statistics. The training set corresponds
@@ -570,7 +574,7 @@ def _compute_training_stats(
570574
571575 Returns
572576 -------
573- Tuple of Union[None, np.ndarray]
577+ Tuple of Optional[ np.ndarray]
574578 A tuple containing the row of column-wise weighted means for `X`, the row
575579 of column-wise weighted standard deviations for `X`, the row of column-wise
576580 weighted means for `Y`, the row of column-wise weighted standard deviations
@@ -638,10 +642,10 @@ def _training_matrices(
638642 ) -> Tuple [
639643 Union [np .ndarray , Tuple [np .ndarray , np .ndarray ]],
640644 Tuple [
641- Union [ None , np .ndarray ],
642- Union [ None , np .ndarray ],
643- Union [ None , np .ndarray ],
644- Union [ None , np .ndarray ],
645+ Optional [ np .ndarray ],
646+ Optional [ np .ndarray ],
647+ Optional [ np .ndarray ],
648+ Optional [ np .ndarray ],
645649 ],
646650 ]:
647651 """
@@ -669,15 +673,15 @@ def _training_matrices(
669673
670674 Returns
671675 -------
672- Tuple of two elements. The first element is an array of shape (K, K) or (K, M)
673- or a tuple of arrays of shapes (K, K) and (K, M). These are the training
674- set :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}` and/or
675- training set :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{Y}`. The
676- second element is a tuple containing the row of column-wise weighted means
677- for `X`, the row of column-wise weighted standard deviations for `X`, the
678- row of column-wise weighted means for `Y`, and the row of column-wise
679- weighted standard deviations for `Y`. If a statistic is not computed, it is
680- `None`.
676+ Tuple of two elements.
677+ The first element is an array of shape (K, K) or (K, M) or a tuple of
678+ arrays of shapes (K, K) and (K, M). These are the training set
679+ :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}` and/or training set
680+ :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{Y}`. The second element is
681+ a tuple containing the row of column-wise weighted means for `X`, the row
682+ of column-wise weighted standard deviations for `X`, the row of column-wise
683+ weighted means for `Y`, and the row of column-wise weighted standard
684+ deviations for `Y`. If a statistic is not computed, it is `None`.
681685
682686 Raises
683687 ------
@@ -787,8 +791,8 @@ def _get_val_matrices(
787791 np .ndarray ,
788792 np .ndarray ,
789793 np .ndarray ,
790- Union [ None , np .ndarray ],
791- Union [ None , np .ndarray ],
794+ Optional [ np .ndarray ],
795+ Optional [ np .ndarray ],
792796 ]:
793797 """
794798 Returns the validation set matrices for a given fold.
@@ -832,11 +836,11 @@ def _training_kernel_matrix(
832836 total_kernel_mat : np .ndarray ,
833837 X_val : np .ndarray ,
834838 mat2_val : np .ndarray ,
835- X_train_mean : Union [ None , np .ndarray ] = None ,
836- mat2_train_mean : Union [ None , np .ndarray ] = None ,
837- X_train_std : Union [ None , np .ndarray ] = None ,
838- mat2_train_std : Union [ None , np .ndarray ] = None ,
839- sum_w_train : Union [ None , float ] = None ,
839+ X_train_mean : Optional [ np .ndarray ] = None ,
840+ mat2_train_mean : Optional [ np .ndarray ] = None ,
841+ X_train_std : Optional [ np .ndarray ] = None ,
842+ mat2_train_std : Optional [ np .ndarray ] = None ,
843+ sum_w_train : Optional [ float ] = None ,
840844 center : bool = False ,
841845 ) -> np .ndarray :
842846 """
@@ -1037,8 +1041,8 @@ def _init_mat(self, mat: np.ndarray) -> np.ndarray:
10371041 def _init_mats (
10381042 self ,
10391043 X : npt .ArrayLike ,
1040- Y : Union [ None , npt .ArrayLike ],
1041- weights : Union [ None , npt .ArrayLike ],
1044+ Y : Optional [ npt .ArrayLike ],
1045+ weights : Optional [ npt .ArrayLike ],
10421046 ) -> None :
10431047 """
10441048 Initializes the matrices `X`, `Y`, and `weights` with the provided
0 commit comments