@@ -126,11 +126,11 @@ class CVMatrix:
126126 `weights` are provided and otherwise the row of column-wise sums of
127127 :math:`\mathbf{Y}\odot\mathbf{Y}`.
128128
129- Xw : np.ndarray or None
129+ WX : np.ndarray or None
130130 The total weighted predictor matrix `X` for the entire dataset. This is
131131 :math:`\mathbf{W}\mathbf{X}`.
132132
133- Yw : np.ndarray or None
133+ WY : np.ndarray or None
134134 The total weighted response matrix `Y` for the entire dataset. This is
135135 :math:`\mathbf{W}\mathbf{Y}`. This is computed only if `Y` is not `None`.
136136
@@ -176,8 +176,8 @@ def __init__(
176176 self .sum_Y = None
177177 self .sum_sq_X = None
178178 self .sum_sq_Y = None
179- self .Xw = None
180- self .Yw = None
179+ self .WX = None
180+ self .WY = None
181181 self .weights = None
182182 self .sum_w = None
183183 self .num_nonzero_w = None
@@ -593,39 +593,37 @@ def _compute_training_stats(
593593 )
594594 if return_X_mean or return_X_std :
595595 sum_X_val = np .sum (X_val , axis = 0 , keepdims = True )
596+ sum_X_train = self ._compute_train_mat_sum (sum_X_val , self .sum_X )
596597 X_train_mean = self ._compute_training_mat_mean (
597- sum_X_val ,
598- self .sum_X ,
598+ sum_X_train ,
599599 sum_w_train ,
600600 )
601601 if return_Y_mean or return_Y_std :
602602 sum_Y_val = np .sum (Y_val , axis = 0 , keepdims = True )
603+ sum_Y_train = self ._compute_train_mat_sum (sum_Y_val , self .sum_Y )
603604 Y_train_mean = self ._compute_training_mat_mean (
604- sum_Y_val ,
605- self .sum_Y ,
605+ sum_Y_train ,
606606 sum_w_train ,
607607 )
608608 if return_X_std or return_Y_std :
609609 divisor = self ._compute_std_divisor (sum_w_train , num_nonzero_w_train )
610610 if return_X_std :
611611 X_train_std = self ._compute_training_mat_std (
612- sum_X_val ,
613612 X_val ,
614613 X_val_unweighted ,
615614 X_train_mean ,
616- self .sum_X ,
617615 self .sum_sq_X ,
616+ sum_X_train ,
618617 sum_w_train ,
619618 divisor ,
620619 )
621620 if return_Y_std :
622621 Y_train_std = self ._compute_training_mat_std (
623- sum_Y_val ,
624622 Y_val ,
625623 Y_val_unweighted ,
626624 Y_train_mean ,
627- self .sum_Y ,
628625 self .sum_sq_Y ,
626+ sum_Y_train ,
629627 sum_w_train ,
630628 divisor ,
631629 )
@@ -812,7 +810,7 @@ def _get_val_matrices(
812810 variables `Y_unweighted`. If `return_XTY` is `False`, `Y` and
813811 `Y_unweighted` will be `None`.
814812 """
815- X_val = self .Xw [val_indices ]
813+ X_val = self .WX [val_indices ]
816814 if self .weights is None :
817815 X_val_unweighted = X_val
818816 else :
@@ -824,7 +822,7 @@ def _get_val_matrices(
824822 Y_val = self .Y [val_indices ]
825823 Y_val_unweighted = Y_val
826824 else :
827- Y_val = self .Yw [val_indices ]
825+ Y_val = self .WY [val_indices ]
828826 Y_val_unweighted = self .Y [val_indices ]
829827 else :
830828 Y_val = None
@@ -900,22 +898,28 @@ def _training_kernel_matrix(
900898 return XTmat2_train / mat2_train_std
901899 return XTmat2_train
902900
903- def _compute_training_mat_mean (
901+ def _compute_train_mat_sum (
904902 self ,
905903 sum_mat_val : np .ndarray ,
906904 sum_mat : np .ndarray ,
905+ ) -> np .ndarray :
906+ """
907+ Computes the row vector of column-wise sums of a matrix for a given fold.
908+ """
909+ return sum_mat - sum_mat_val
910+
911+ def _compute_training_mat_mean (
912+ self ,
913+ sum_mat_train : np .ndarray ,
907914 sum_w_train : float ,
908915 ) -> np .ndarray :
909916 """
910917 Computes the row of column-wise means of a matrix for a given fold.
911918
912919 Parameters
913920 ----------
914- sum_mat_val : Array of shape (1, K) or (1, M)
915- The row of column-wise sums of validation set of `Xw` or `Yw`.
916-
917- sum_mat : Array of shape (1, K) or (1, M)
918- The row of column-wise sums of the total `Xw` or `Yw`.
921+ sum_mat_train : Array of shape (1, K) or (1, M)
922+ The row of column-wise sums of the training set of `WX` or `WY`.
919923
920924 sum_w_train : float
921925 The sum of weights in the training set.
@@ -925,7 +929,7 @@ def _compute_training_mat_mean(
925929 Array of shape (1, K) or (1, M)
926930 The row of column-wise means of the training set matrix.
927931 """
928- return ( sum_mat - sum_mat_val ) / sum_w_train
932+ return sum_mat_train / sum_w_train
929933
930934 def _compute_std_divisor (
931935 self , sum_w_train : float , num_nonzero_w_train : int
@@ -956,12 +960,11 @@ def _compute_std_divisor(
956960
957961 def _compute_training_mat_std (
958962 self ,
959- sum_mat_val : np .ndarray ,
960963 mat_val : np .ndarray ,
961964 mat_val_unweighted : np .ndarray ,
962965 mat_train_mean : np .ndarray ,
963- sum_mat : np .ndarray ,
964966 sum_sq_mat : np .ndarray ,
967+ sum_mat_train : np .ndarray ,
965968 sum_w_train : float ,
966969 divisor : float ,
967970 ) -> np .ndarray :
@@ -971,25 +974,22 @@ def _compute_training_mat_std(
971974
972975 Parameters
973976 ----------
974- sum_mat_val : Array of shape (1, K) or (1, M)
975- The row of column-wise sums of validation set of `Xw` or `Yw`.
976-
977977 mat_val : Array of shape (N_val, K) or (N_val, M)
978- The validation set of `Xw ` or `Yw `.
978+ The validation set of `WX ` or `WY `.
979979
980980 mat_val_unweighted : Array of shape (N_val, K) or (N_val, M)
981981 The validation set of `X` or `Y`.
982982
983983 mat_train_mean : Array of shape (1, K) or (1, M)
984984 The row of column-wise weighted means of the training matrix.
985985
986- sum_mat : Array of shape (1, K) or (1, M)
987- The row of column-wise sums of the total weighted matrix.
988-
989986 sum_sq_mat : Array of shape (1, K) or (1, M)
990987 The row of column-wise sums of products between the total weighted matrix
991988 and the total unweighted matrix.
992989
990+ sum_mat_val : Array of shape (1, K) or (1, M)
991+ The row of column-wise sums of validation set of `WX` or `WY`.
992+
993993 sum_w_train : float
994994 The size of the training set.
995995
@@ -1002,12 +1002,11 @@ def _compute_training_mat_std(
10021002 Array of shape (1, K) or (1, M)
10031003 The row of column-wise standard deviations of the training set matrix.
10041004 """
1005- train_sum_mat = sum_mat - sum_mat_val
10061005 train_sum_sq_mat = sum_sq_mat - np .sum (
10071006 mat_val * mat_val_unweighted , axis = 0 , keepdims = True
10081007 )
10091008 mat_train_var = (
1010- - 2 * mat_train_mean * train_sum_mat
1009+ - 2 * mat_train_mean * sum_mat_train
10111010 + sum_w_train * mat_train_mean ** 2
10121011 + train_sum_sq_mat
10131012 ) / divisor
@@ -1047,7 +1046,7 @@ def _init_mats(
10471046 """
10481047 Initializes the matrices `X`, `Y`, and `weights` with the provided
10491048 data. If `Y` is `None`, then `Y` is not initialized. If `weights` is
1050- provided, it initializes the weighted matrices `Xw ` and `Yw `.
1049+ provided, it initializes the weighted matrices `WX ` and `WY `.
10511050
10521051 Parameters
10531052 ----------
@@ -1080,29 +1079,29 @@ def _init_mats(
10801079
10811080 def _init_weighted_mats (self ):
10821081 """
1083- Initializes the weighted matrices `Xw ` and `Yw ` if weights are
1082+ Initializes the weighted matrices `WX ` and `WY ` if weights are
10841083 provided. These matrices are computed as the product of the original matrices
1085- `X` and `Y` with the `weights`. If `Y` is `None`, then `Yw ` is not initialized.
1084+ `X` and `Y` with the `weights`. If `Y` is `None`, then `WY ` is not initialized.
10861085 If `w` is `None`, then this method does nothing.
10871086 """
10881087 if self .weights is None :
1089- self .Xw = self .X
1088+ self .WX = self .X
10901089 if self .Y is not None :
1091- self .Yw = self .Y
1090+ self .WY = self .Y
10921091 else :
1093- self .Xw = self .X * self .weights
1092+ self .WX = self .X * self .weights
10941093 if self .Y is not None and (self .center_X or self .center_Y or self .scale_Y ):
1095- self .Yw = self .Y * self .weights
1094+ self .WY = self .Y * self .weights
10961095
10971096 def _init_matrix_products (self ) -> None :
10981097 """
10991098 Initializes the global matrix products `XTX` and `XTY` for the
11001099 entire dataset. These are :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{X}`
11011100 and :math:`\mathbf{X}^{\mathbf{T}}\mathbf{W}\mathbf{Y}`, respectively.
11021101 """
1103- self .XTX = self .Xw .T @ self .X
1102+ self .XTX = self .WX .T @ self .X
11041103 if self .Y is not None :
1105- self .XTY = self .Xw .T @ self .Y
1104+ self .XTY = self .WX .T @ self .Y
11061105
11071106 def _init_stats (self ) -> None :
11081107 """
@@ -1116,14 +1115,14 @@ def _init_stats(self) -> None:
11161115 self .sum_w = self .N
11171116 self .num_nonzero_w = self .N
11181117 if self .center_X or self .center_Y or self .scale_X :
1119- self .sum_X = np .sum (self .Xw , axis = 0 , keepdims = True )
1118+ self .sum_X = np .sum (self .WX , axis = 0 , keepdims = True )
11201119 if (self .center_X or self .center_Y or self .scale_Y ) and self .Y is not None :
1121- self .sum_Y = np .sum (self .Yw , axis = 0 , keepdims = True )
1120+ self .sum_Y = np .sum (self .WY , axis = 0 , keepdims = True )
11221121 if self .scale_X :
1123- self .sum_sq_X = np .sum (self .Xw * self .X , axis = 0 , keepdims = True )
1122+ self .sum_sq_X = np .sum (self .WX * self .X , axis = 0 , keepdims = True )
11241123 else :
11251124 self .sum_sq_X = None
11261125 if self .scale_Y and self .Y is not None :
1127- self .sum_sq_Y = np .sum (self .Yw * self .Y , axis = 0 , keepdims = True )
1126+ self .sum_sq_Y = np .sum (self .WY * self .Y , axis = 0 , keepdims = True )
11281127 else :
11291128 self .sum_sq_Y = None
0 commit comments