Skip to content

Commit 39e0170

Browse files
authored
Minor optimization (#4)
1 parent 70b2a0e commit 39e0170

File tree

4 files changed

+21
-55
lines changed

4 files changed

+21
-55
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ Guidelines](https://github.com/Sm00thix/CVMatrix/blob/main/CONTRIBUTING.md).
9696
9797
1. [Engstrøm, O.-C. G. and Jensen, M. H. (2025). Fast partition-based cross-validation with centering and scaling for $\mathbf{X}^\mathbf{T}\mathbf{X}$ and $\mathbf{X}^\mathbf{T}\mathbf{Y}$. *Journal of Chemometrics*, 39(3).](https://doi.org/10.1002/cem.70008)
9898
2. [Dayal, B. S. and MacGregor, J. F. (1997). Improved PLS algorithms. *Journal of Chemometrics*, 11(1), 73-85.](https://doi.org/10.1002/(SICI)1099-128X(199701)11:1%3C73::AID-CEM435%3E3.0.CO;2-%23?)
99-
3. [Engstrøm, O.-C. G. and Dreier, E. S. and Jespersen, B. M. and Pedersen, K. S. IKPLS: Improved Kernel Partial Least Squares and Fast Cross-Validation Algorithms for Python with CPU and GPU Implementations Using NumPy and JAX. *Journal of Open Source Software*, 9(99).](https://doi.org/10.21105/joss.06533)
99+
3. [Engstrøm, O.-C. G. and Dreier, E. S. and Jespersen, B. M. and Pedersen, K. S. (2024). IKPLS: Improved Kernel Partial Least Squares and Fast Cross-Validation Algorithms for Python with CPU and GPU Implementations Using NumPy and JAX. *Journal of Open Source Software*, 9(99).](https://doi.org/10.21105/joss.06533)
100100
101101
## Funding
102102
- Up until May 31st 2025, this work has been carried out as part of an industrial Ph. D. project receiving funding from [FOSS Analytical A/S](https://www.fossanalytics.com/) and [The Innovation Fund Denmark](https://innovationsfonden.dk/en). Grant number 1044-00108B.

cvmatrix/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.0.0"
1+
__version__ = "2.0.1"

cvmatrix/cvmatrix.py

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ def __init__(
9999
self.N = None
100100
self.K = None
101101
self.M = None
102-
self.X_total_mean = None
103-
self.Y_total_mean = None
104102
self.XTX_total = None
105103
self.XTY_total = None
106104
self.sum_X_total = None
@@ -146,7 +144,6 @@ def fit(
146144
ValueError
147145
If `weights` is provided and contains negative values.
148146
"""
149-
150147
self._init_mats(X, Y, weights)
151148
self._init_weighted_mats()
152149
self._init_matrix_products()
@@ -353,26 +350,19 @@ def _training_matrices(
353350
"The number of non-zero weights in the training set must be "
354351
"greater than zero."
355352
)
356-
sum_w_total_over_sum_w_train = self.sum_w_total / sum_w_train
357-
sum_w_val_over_sum_w_train = sum_w_val / sum_w_train
358-
359353
if self.center_X or self.scale_X or (return_XTY and self.center_Y):
360354
sum_X_val = np.sum(X_val, axis=0, keepdims=True)
361355
X_train_mean = self._compute_training_mat_mean(
362356
sum_X_val,
363-
sum_w_val,
364-
self.X_total_mean,
365-
sum_w_total_over_sum_w_train,
366-
sum_w_val_over_sum_w_train,
357+
self.sum_X_total,
358+
sum_w_train,
367359
)
368360
if return_XTY and (self.center_X or self.center_Y or self.scale_Y):
369361
sum_Y_val = np.sum(Y_val, axis=0, keepdims=True)
370362
Y_train_mean = self._compute_training_mat_mean(
371363
sum_Y_val,
372-
sum_w_val,
373-
self.Y_total_mean,
374-
sum_w_total_over_sum_w_train,
375-
sum_w_val_over_sum_w_train,
364+
self.sum_Y_total,
365+
sum_w_train,
376366
)
377367
if self.scale_X or (self.scale_Y and return_XTY):
378368
divisor = self._compute_std_divisor(sum_w_train, num_nonzero_w_train)
@@ -519,10 +509,8 @@ def _training_kernel_matrix(
519509
def _compute_training_mat_mean(
520510
self,
521511
sum_mat_val: np.ndarray,
522-
sum_w_val: float,
523-
mat_total_mean: np.ndarray,
524-
sum_w_total_over_sum_w_train: float,
525-
sum_w_val_over_sum_w_train: float,
512+
sum_mat_total: np.ndarray,
513+
sum_w_train: float,
526514
) -> np.ndarray:
527515
"""
528516
Computes the row of column-wise means of a matrix for a given fold.
@@ -532,34 +520,18 @@ def _compute_training_mat_mean(
532520
sum_mat_val : Array of shape (1, K) or (1, M)
533521
The row of column-wise sums of validation set of `Xw` or `Yw`.
534522
535-
sum_w_val : float
536-
The sum of weights in the validation set.
537-
538-
mat_total_mean : Array of shape (1, K) or (1, M)
539-
The row of column-wise weighted means of the total matrix.
540-
541-
sum_w_total_over_sum_w_train : float
542-
The ratio of the sum of weights in the entire dataset to the sum of weights
543-
in the training set.
544-
545-
sum_w_val_over_sum_w_train : float
546-
The ratio of the sum of weights in the validation set to the sum of weights
547-
in the training set.
523+
sum_mat_total : Array of shape (1, K) or (1, M)
524+
The row of column-wise sums of the total `Xw` or `Yw`.
548525
549-
sum_w_val : float
550-
The sum of weights in the validation set.
526+
sum_w_train : float
527+
The sum of weights in the training set.
551528
552529
Returns
553530
-------
554531
Array of shape (1, K) or (1, M)
555532
The row of column-wise means of the training set matrix.
556533
"""
557-
train_part_contribution = sum_w_total_over_sum_w_train * mat_total_mean
558-
if sum_w_val <= self.eps:
559-
return train_part_contribution
560-
return train_part_contribution - sum_w_val_over_sum_w_train * (
561-
sum_mat_val / sum_w_val
562-
)
534+
return (sum_mat_total - sum_mat_val) / sum_w_train
563535

564536
def _compute_std_divisor(
565537
self, sum_w_train: float, num_nonzero_w_train: int
@@ -745,24 +717,19 @@ def _init_total_stats(self) -> None:
745717
"""
746718
Initializes the global statistics for `X` and `Y`.
747719
"""
748-
if self.w_total is not None:
749-
self.sum_w_total = np.sum(self.w_total)
750-
self.num_nonzero_w_total = np.count_nonzero(self.w_total)
751-
else:
752-
self.sum_w_total = self.N
753-
self.num_nonzero_w_total = self.N
720+
if self.center_X or self.center_Y or self.scale_X or self.scale_Y:
721+
if self.w_total is not None:
722+
self.sum_w_total = np.sum(self.w_total)
723+
self.num_nonzero_w_total = np.count_nonzero(self.w_total)
724+
else:
725+
self.sum_w_total = self.N
726+
self.num_nonzero_w_total = self.N
754727
if self.center_X or self.center_Y or self.scale_X:
755728
self.sum_X_total = np.sum(self.Xw_total, axis=0, keepdims=True)
756-
self.X_total_mean = self.sum_X_total / self.sum_w_total
757-
else:
758-
self.X_total_mean = None
759729
if (
760730
self.center_X or self.center_Y or self.scale_Y
761731
) and self.Y_total is not None:
762732
self.sum_Y_total = np.sum(self.Yw_total, axis=0, keepdims=True)
763-
self.Y_total_mean = self.sum_Y_total / self.sum_w_total
764-
else:
765-
self.Y_total_mean = None
766733
if self.scale_X:
767734
self.sum_sq_X_total = np.expand_dims(
768735
np.einsum("ij, ij -> j", self.Xw_total, self.X_total), axis=0
@@ -788,7 +755,6 @@ def _init_folds_dict(self, folds: Iterable[Hashable]) -> None:
788755
An iterable defining cross-validation splits. Each unique value in
789756
`folds` corresponds to a different fold.
790757
"""
791-
792758
folds_dict: "defaultdict[Hashable, list[int]]" = defaultdict(list)
793759
for i, num in enumerate(folds):
794760
folds_dict[num].append(i)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "cvmatrix"
3-
version = "2.0.0"
3+
version = "2.0.1"
44
description = "Fast computation of possibly weighted and possibly centered/scaled training set kernel matrices in a cross-validation setting."
55
authors = ["Sm00thix <oleemail@icloud.com>"]
66
maintainers = ["Sm00thix <oleemail@icloud.com>"]

0 commit comments

Comments
 (0)