77
88
99@njit
10- def primal (alpha , y , X , w ):
10+ def primal (alpha , y , X , w , weights ):
1111 r = y - X @ w
1212 p_obj = (r @ r ) / (2 * len (y ))
13- return p_obj + alpha * np .sum (np .abs (w ))
13+ return p_obj + alpha * np .sum (np .abs (w * weights ))
1414
1515
1616@njit
17- def primal_grp (alpha , y , X , w , grp_ptr , grp_indices ):
17+ def primal_grp (alpha , y , X , w , grp_ptr , grp_indices , weights ):
1818 r = y - X @ w
1919 p_obj = (r @ r ) / (2 * len (y ))
2020 for g in range (len (grp_ptr ) - 1 ):
2121 w_g = w [grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]]
22- p_obj += alpha * norm (w_g , ord = 2 )
22+ p_obj += alpha * norm (w_g * weights [ g ] , ord = 2 )
2323 return p_obj
2424
2525
26- def gram_lasso (X , y , alpha , max_iter , tol , check_freq = 10 ):
26+ def gram_lasso (X , y , alpha , max_iter , tol , w_init = None , weights = None , check_freq = 10 ):
2727 p_obj_prev = np .inf
2828 n_features = X .shape [1 ]
2929 grads = X .T @ y / len (y )
3030 G = X .T @ X
3131 lipschitz = np .zeros (n_features , dtype = X .dtype )
3232 for j in range (n_features ):
3333 lipschitz [j ] = (X [:, j ] ** 2 ).sum () / len (y )
34- w = np .zeros (n_features )
34+ w = w_init if w_init is not None else np .zeros (n_features )
35+ weights = weights if weights is not None else np .ones (n_features )
3536 # CD
3637 for n_iter in range (max_iter ):
37- cd_epoch (X , G , grads , w , alpha , lipschitz )
38+ cd_epoch (X , G , grads , w , alpha , lipschitz , weights )
3839 if n_iter % check_freq == 0 :
39- p_obj = primal (alpha , y , X , w )
40+ p_obj = primal (alpha , y , X , w , weights )
4041 if p_obj_prev - p_obj < tol :
4142 print ("Convergence reached!" )
4243 break
@@ -45,7 +46,8 @@ def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10):
4546 return w
4647
4748
48- def gram_group_lasso (X , y , alpha , groups , max_iter , tol , check_freq = 50 ):
49+ def gram_group_lasso (X , y , alpha , groups , max_iter , tol , w_init = None , weights = None ,
50+ check_freq = 50 ):
4951 p_obj_prev = np .inf
5052 n_features = X .shape [1 ]
5153 grp_ptr , grp_indices = _grp_converter (groups , X .shape [1 ])
@@ -56,12 +58,13 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
5658 for g in range (n_groups ):
5759 X_g = X [:, grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]]
5860 lipschitz [g ] = norm (X_g , ord = 2 ) ** 2 / len (y )
59- w = np .zeros (n_features )
61+ w = w_init if w_init is not None else np .zeros (n_features )
62+ weights = weights if weights is not None else np .ones (n_groups )
6063 # BCD
6164 for n_iter in range (max_iter ):
62- bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr )
65+ bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr , weights )
6366 if n_iter % check_freq == 0 :
64- p_obj = primal_grp (alpha , y , X , w , grp_ptr , grp_indices )
67+ p_obj = primal_grp (alpha , y , X , w , grp_ptr , grp_indices , weights )
6568 if p_obj_prev - p_obj < tol :
6669 print ("Convergence reached!" )
6770 break
@@ -71,26 +74,27 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
7174
7275
7376@njit
74- def cd_epoch (X , G , grads , w , alpha , lipschitz ):
77+ def cd_epoch (X , G , grads , w , alpha , lipschitz , weights ):
7578 n_features = X .shape [1 ]
7679 for j in range (n_features ):
77- if lipschitz [j ] == 0. :
80+ if lipschitz [j ] == 0. or weights [ j ] == np . inf :
7881 continue
7982 old_w_j = w [j ]
80- w [j ] = ST (w [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ])
83+ w [j ] = ST (w [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ] * weights [ j ] )
8184 if old_w_j != w [j ]:
8285 grads += G [j , :] * (old_w_j - w [j ]) / len (X )
8386
8487
8588@njit
86- def bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr ):
89+ def bcd_epoch (X , G , grads , w , alpha , lipschitz , grp_indices , grp_ptr , weights ):
8790 n_groups = len (grp_ptr ) - 1
8891 for g in range (n_groups ):
89- if lipschitz [g ] == 0. :
92+ if lipschitz [g ] == 0. and weights [ g ] == np . inf :
9093 continue
9194 idx = grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]
9295 old_w_g = w [idx ].copy ()
93- w [idx ] = BST (w [idx ] + grads [idx ] / lipschitz [g ], alpha / lipschitz [g ])
96+ w [idx ] = BST (w [idx ] + grads [idx ] / lipschitz [g ], alpha / lipschitz [g ]
97+ * weights [g ])
9498 diff = old_w_g - w [idx ]
9599 if np .any (diff != 0. ):
96100 grads += diff @ G [idx , :] / len (X )
0 commit comments