|
1 | 1 | from time import time |
2 | 2 | import numpy as np |
3 | 3 | from numpy.linalg import norm |
4 | | -from numba import njit |
5 | 4 | from celer import Lasso, GroupLasso |
6 | 5 | from benchopt.datasets.simulated import make_correlated_data |
7 | | -from skglm.utils import BST, ST |
8 | | - |
9 | | - |
10 | | -def _grp_converter(groups, n_features): |
11 | | - if isinstance(groups, int): |
12 | | - grp_size = groups |
13 | | - if n_features % grp_size != 0: |
14 | | - raise ValueError("n_features (%d) is not a multiple of the desired" |
15 | | - " group size (%d)" % (n_features, grp_size)) |
16 | | - n_groups = n_features // grp_size |
17 | | - grp_ptr = grp_size * np.arange(n_groups + 1) |
18 | | - grp_indices = np.arange(n_features) |
19 | | - elif isinstance(groups, list) and isinstance(groups[0], int): |
20 | | - grp_indices = np.arange(n_features).astype(np.int32) |
21 | | - grp_ptr = np.cumsum(np.hstack([[0], groups])) |
22 | | - elif isinstance(groups, list) and isinstance(groups[0], list): |
23 | | - grp_sizes = np.array([len(ls) for ls in groups]) |
24 | | - grp_ptr = np.cumsum(np.hstack([[0], grp_sizes])) |
25 | | - grp_indices = np.array([idx for grp in groups for idx in grp]) |
26 | | - else: |
27 | | - raise ValueError("Unsupported group format.") |
28 | | - return grp_ptr.astype(np.int32), grp_indices.astype(np.int32) |
29 | | - |
30 | | - |
31 | | -@njit |
32 | | -def primal(alpha, y, X, w): |
33 | | - r = y - X @ w |
34 | | - p_obj = (r @ r) / (2 * len(y)) |
35 | | - return p_obj + alpha * np.sum(np.abs(w)) |
36 | | - |
37 | | - |
38 | | -@njit |
39 | | -def primal_grp(alpha, y, X, w, grp_ptr, grp_indices): |
40 | | - r = y - X @ w |
41 | | - p_obj = (r @ r) / (2 * len(y)) |
42 | | - for g in range(len(grp_ptr) - 1): |
43 | | - w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] |
44 | | - p_obj += alpha * norm(w_g, ord=2) |
45 | | - return p_obj |
46 | | - |
47 | | - |
48 | | -@njit |
49 | | -def cd_epoch(X, G, grads, w, alpha, lipschitz): |
50 | | - n_features = X.shape[1] |
51 | | - for j in range(n_features): |
52 | | - if lipschitz[j] == 0.: |
53 | | - continue |
54 | | - old_w_j = w[j] |
55 | | - w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j]) |
56 | | - if old_w_j != w[j]: |
57 | | - grads += G[j, :] * (old_w_j - w[j]) / len(X) |
58 | | - |
59 | | - |
60 | | -@njit |
61 | | -def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr): |
62 | | - n_groups = len(grp_ptr) - 1 |
63 | | - for g in range(n_groups): |
64 | | - if lipschitz[g] == 0.: |
65 | | - continue |
66 | | - idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] |
67 | | - old_w_g = w[idx].copy() |
68 | | - w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]) |
69 | | - diff = old_w_g - w[idx] |
70 | | - if np.any(diff != 0.): |
71 | | - grads += diff @ G[idx, :] / len(X) |
72 | | - |
73 | | - |
74 | | -def lasso(X, y, alpha, max_iter, tol, check_freq=10): |
75 | | - p_obj_prev = np.inf |
76 | | - n_features = X.shape[1] |
77 | | - # Initialization |
78 | | - grads = X.T @ y / len(y) |
79 | | - G = X.T @ X |
80 | | - lipschitz = np.zeros(n_features, dtype=X.dtype) |
81 | | - for j in range(n_features): |
82 | | - lipschitz[j] = (X[:, j] ** 2).sum() / len(y) |
83 | | - w = np.zeros(n_features) |
84 | | - # CD |
85 | | - for n_iter in range(max_iter): |
86 | | - cd_epoch(X, G, grads, w, alpha, lipschitz) |
87 | | - if n_iter % check_freq == 0: |
88 | | - p_obj = primal(alpha, y, X, w) |
89 | | - if p_obj_prev - p_obj < tol: |
90 | | - print("Convergence reached!") |
91 | | - break |
92 | | - print(f"iter {n_iter} :: p_obj {p_obj}") |
93 | | - p_obj_prev = p_obj |
94 | | - return w |
95 | | - |
96 | | - |
97 | | -def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): |
98 | | - p_obj_prev = np.inf |
99 | | - n_features = X.shape[1] |
100 | | - grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) |
101 | | - n_groups = len(grp_ptr) - 1 |
102 | | - # Initialization |
103 | | - grads = X.T @ y / len(y) |
104 | | - G = X.T @ X |
105 | | - lipschitz = np.zeros(n_groups, dtype=X.dtype) |
106 | | - for g in range(n_groups): |
107 | | - X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] |
108 | | - lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) |
109 | | - w = np.zeros(n_features) |
110 | | - # BCD |
111 | | - for n_iter in range(max_iter): |
112 | | - bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr) |
113 | | - if n_iter % check_freq == 0: |
114 | | - p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices) |
115 | | - if p_obj_prev - p_obj < tol: |
116 | | - print("Convergence reached!") |
117 | | - break |
118 | | - print(f"iter {n_iter} :: p_obj {p_obj}") |
119 | | - p_obj_prev = p_obj |
120 | | - return w |
121 | | - |
122 | | - |
123 | | -if __name__ == "__main__": |
124 | | - n_samples, n_features = 1_000_000, 300 |
125 | | - X, y, w_star = make_correlated_data( |
126 | | - n_samples=n_samples, n_features=n_features, random_state=0) |
127 | | - alpha_max = norm(X.T @ y, ord=np.inf) |
128 | | - |
129 | | - # Hyperparameters |
130 | | - max_iter = 1000 |
131 | | - tol = 1e-8 |
132 | | - reg = 0.1 |
133 | | - group_size = 3 |
134 | | - |
135 | | - alpha = alpha_max * reg / n_samples |
136 | | - |
137 | | - # Lasso |
138 | | - print("#" * 15) |
139 | | - print("Lasso") |
140 | | - print("#" * 15) |
141 | | - start = time() |
142 | | - w = lasso(X, y, alpha, max_iter, tol) |
143 | | - gram_lasso_time = time() - start |
144 | | - clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) |
145 | | - start = time() |
146 | | - clf_sk.fit(X, y) |
147 | | - celer_lasso_time = time() - start |
148 | | - np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) |
149 | | - |
150 | | - print("\n") |
151 | | - print("Celer: %.2f" % celer_lasso_time) |
152 | | - print("Gram: %.2f" % gram_lasso_time) |
153 | | - print("\n") |
154 | | - |
155 | | - # Group Lasso |
156 | | - print("#" * 15) |
157 | | - print("Group Lasso") |
158 | | - print("#" * 15) |
159 | | - start = time() |
160 | | - w = group_lasso(X, y, alpha, group_size, max_iter, tol) |
161 | | - gram_group_lasso_time = time() - start |
162 | | - clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) |
163 | | - start = time() |
164 | | - clf_celer.fit(X, y) |
165 | | - celer_group_lasso_time = time() - start |
166 | | - np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) |
167 | | - |
168 | | - print("\n") |
169 | | - print("Celer: %.2f" % celer_group_lasso_time) |
170 | | - print("Gram: %.2f" % gram_group_lasso_time) |
171 | | - print("\n") |
| 6 | +from skglm.solvers.gram import gram_lasso, gram_group_lasso |
| 7 | + |
| 8 | + |
| 9 | +n_samples, n_features = 1_000_000, 300 |
| 10 | +X, y, w_star = make_correlated_data( |
| 11 | + n_samples=n_samples, n_features=n_features, random_state=0) |
| 12 | +alpha_max = norm(X.T @ y, ord=np.inf) |
| 13 | + |
| 14 | +# Hyperparameters |
| 15 | +max_iter = 1000 |
| 16 | +tol = 1e-8 |
| 17 | +reg = 0.1 |
| 18 | +group_size = 3 |
| 19 | + |
| 20 | +alpha = alpha_max * reg / n_samples |
| 21 | + |
| 22 | +# Lasso |
| 23 | +print("#" * 15) |
| 24 | +print("Lasso") |
| 25 | +print("#" * 15) |
| 26 | +start = time() |
| 27 | +w = gram_lasso(X, y, alpha, max_iter, tol) |
| 28 | +gram_lasso_time = time() - start |
| 29 | +clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) |
| 30 | +start = time() |
| 31 | +clf_sk.fit(X, y) |
| 32 | +celer_lasso_time = time() - start |
| 33 | +np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) |
| 34 | + |
| 35 | +print("\n") |
| 36 | +print("Celer: %.2f" % celer_lasso_time) |
| 37 | +print("Gram: %.2f" % gram_lasso_time) |
| 38 | +print("\n") |
| 39 | + |
| 40 | +# Group Lasso |
| 41 | +print("#" * 15) |
| 42 | +print("Group Lasso") |
| 43 | +print("#" * 15) |
| 44 | +start = time() |
| 45 | +w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol) |
| 46 | +gram_group_lasso_time = time() - start |
| 47 | +clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) |
| 48 | +start = time() |
| 49 | +clf_celer.fit(X, y) |
| 50 | +celer_group_lasso_time = time() - start |
| 51 | +np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) |
| 52 | + |
| 53 | +print("\n") |
| 54 | +print("Celer: %.2f" % celer_group_lasso_time) |
| 55 | +print("Gram: %.2f" % gram_group_lasso_time) |
| 56 | +print("\n") |
0 commit comments