@@ -120,51 +120,52 @@ def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
120120 return w
121121
122122
123- n_samples , n_features = 1_000_000 , 300
124- X , y , w_star = make_correlated_data (
125- n_samples = n_samples , n_features = n_features , random_state = 0 )
126- alpha_max = norm (X .T @ y , ord = np .inf )
127-
128- # Hyperparameters
129- max_iter = 1000
130- tol = 1e-8
131- reg = 0.1
132- group_size = 3
133-
134- alpha = alpha_max * reg / n_samples
135-
136- # Lasso
137- print ("#" * 15 )
138- print ("Lasso" )
139- print ("#" * 15 )
140- start = time ()
141- w = lasso (X , y , alpha , max_iter , tol )
142- gram_lasso_time = time () - start
143- clf_sk = Lasso (alpha , tol = tol , fit_intercept = False )
144- start = time ()
145- clf_sk .fit (X , y )
146- celer_lasso_time = time () - start
147- np .testing .assert_allclose (w , clf_sk .coef_ , rtol = 1e-5 )
148-
149- print ("\n " )
150- print ("Celer: %.2f" % celer_lasso_time )
151- print ("Gram: %.2f" % gram_lasso_time )
152- print ("\n " )
153-
154- # Group Lasso
155- print ("#" * 15 )
156- print ("Group Lasso" )
157- print ("#" * 15 )
158- start = time ()
159- w = group_lasso (X , y , alpha , group_size , max_iter , tol )
160- gram_group_lasso_time = time () - start
161- clf_celer = GroupLasso (group_size , alpha , tol = tol , fit_intercept = False )
162- start = time ()
163- clf_celer .fit (X , y )
164- celer_group_lasso_time = time () - start
165- np .testing .assert_allclose (w , clf_celer .coef_ , rtol = 1e-1 )
166-
167- print ("\n " )
168- print ("Celer: %.2f" % celer_group_lasso_time )
169- print ("Gram: %.2f" % gram_group_lasso_time )
170- print ("\n " )
123+ if __name__ == "__main__" :
124+ n_samples , n_features = 1_000_000 , 300
125+ X , y , w_star = make_correlated_data (
126+ n_samples = n_samples , n_features = n_features , random_state = 0 )
127+ alpha_max = norm (X .T @ y , ord = np .inf )
128+
129+ # Hyperparameters
130+ max_iter = 1000
131+ tol = 1e-8
132+ reg = 0.1
133+ group_size = 3
134+
135+ alpha = alpha_max * reg / n_samples
136+
137+ # Lasso
138+ print ("#" * 15 )
139+ print ("Lasso" )
140+ print ("#" * 15 )
141+ start = time ()
142+ w = lasso (X , y , alpha , max_iter , tol )
143+ gram_lasso_time = time () - start
144+ clf_sk = Lasso (alpha , tol = tol , fit_intercept = False )
145+ start = time ()
146+ clf_sk .fit (X , y )
147+ celer_lasso_time = time () - start
148+ np .testing .assert_allclose (w , clf_sk .coef_ , rtol = 1e-5 )
149+
150+ print ("\n " )
151+ print ("Celer: %.2f" % celer_lasso_time )
152+ print ("Gram: %.2f" % gram_lasso_time )
153+ print ("\n " )
154+
155+ # Group Lasso
156+ print ("#" * 15 )
157+ print ("Group Lasso" )
158+ print ("#" * 15 )
159+ start = time ()
160+ w = group_lasso (X , y , alpha , group_size , max_iter , tol )
161+ gram_group_lasso_time = time () - start
162+ clf_celer = GroupLasso (group_size , alpha , tol = tol , fit_intercept = False )
163+ start = time ()
164+ clf_celer .fit (X , y )
165+ celer_group_lasso_time = time () - start
166+ np .testing .assert_allclose (w , clf_celer .coef_ , rtol = 1e-1 )
167+
168+ print ("\n " )
169+ print ("Celer: %.2f" % celer_group_lasso_time )
170+ print ("Gram: %.2f" % gram_group_lasso_time )
171+ print ("\n " )
0 commit comments