5959sklearn_version = parse_version (sklearn .__version__ )
6060
6161
62+ def sample_dataset_generator ():
63+ X , y = make_classification (
64+ n_samples = 1000 ,
65+ n_classes = 3 ,
66+ n_informative = 4 ,
67+ weights = [0.2 , 0.3 , 0.5 ],
68+ random_state = 0 ,
69+ )
70+ return X , y
71+
72+
73+ @pytest .fixture (name = "sample_dataset_generator" )
74+ def sample_dataset_generator_fixture ():
75+ return sample_dataset_generator ()
76+
77+
6278def _set_checking_parameters (estimator ):
6379 params = estimator .get_params ()
6480 name = estimator .__class__ .__name__
@@ -233,13 +249,7 @@ def check_samplers_fit(name, sampler_orig):
233249
234250def check_samplers_fit_resample (name , sampler_orig ):
235251 sampler = clone (sampler_orig )
236- X , y = make_classification (
237- n_samples = 1000 ,
238- n_classes = 3 ,
239- n_informative = 4 ,
240- weights = [0.2 , 0.3 , 0.5 ],
241- random_state = 0 ,
242- )
252+ X , y = sample_dataset_generator ()
243253 target_stats = Counter (y )
244254 X_res , y_res = sampler .fit_resample (X , y )
245255 if isinstance (sampler , BaseOverSampler ):
@@ -269,13 +279,7 @@ def check_samplers_fit_resample(name, sampler_orig):
269279def check_samplers_sampling_strategy_fit_resample (name , sampler_orig ):
270280 sampler = clone (sampler_orig )
271281 # in this test we will force all samplers to not change the class 1
272- X , y = make_classification (
273- n_samples = 1000 ,
274- n_classes = 3 ,
275- n_informative = 4 ,
276- weights = [0.2 , 0.3 , 0.5 ],
277- random_state = 0 ,
278- )
282+ X , y = sample_dataset_generator ()
279283 expected_stat = Counter (y )[1 ]
280284 if isinstance (sampler , BaseOverSampler ):
281285 sampling_strategy = {2 : 498 , 0 : 498 }
@@ -298,13 +302,7 @@ def check_samplers_sparse(name, sampler_orig):
298302 sampler = clone (sampler_orig )
299303 # check that sparse matrices can be passed through the sampler leading to
300304 # the same results than dense
301- X , y = make_classification (
302- n_samples = 1000 ,
303- n_classes = 3 ,
304- n_informative = 4 ,
305- weights = [0.2 , 0.3 , 0.5 ],
306- random_state = 0 ,
307- )
305+ X , y = sample_dataset_generator ()
308306 X_sparse = sparse .csr_matrix (X )
309307 X_res_sparse , y_res_sparse = sampler .fit_resample (X_sparse , y )
310308 sampler = clone (sampler )
@@ -318,13 +316,7 @@ def check_samplers_pandas(name, sampler_orig):
318316 pd = pytest .importorskip ("pandas" )
319317 sampler = clone (sampler_orig )
320318 # Check that the samplers handle pandas dataframe and pandas series
321- X , y = make_classification (
322- n_samples = 1000 ,
323- n_classes = 3 ,
324- n_informative = 4 ,
325- weights = [0.2 , 0.3 , 0.5 ],
326- random_state = 0 ,
327- )
319+ X , y = sample_dataset_generator ()
328320 X_df = pd .DataFrame (X , columns = [str (i ) for i in range (X .shape [1 ])])
329321 y_df = pd .DataFrame (y )
330322 y_s = pd .Series (y , name = "class" )
@@ -351,13 +343,7 @@ def check_samplers_pandas(name, sampler_orig):
351343def check_samplers_list (name , sampler_orig ):
352344 sampler = clone (sampler_orig )
353345 # Check that the can samplers handle simple lists
354- X , y = make_classification (
355- n_samples = 1000 ,
356- n_classes = 3 ,
357- n_informative = 4 ,
358- weights = [0.2 , 0.3 , 0.5 ],
359- random_state = 0 ,
360- )
346+ X , y = sample_dataset_generator ()
361347 X_list = X .tolist ()
362348 y_list = y .tolist ()
363349
@@ -374,13 +360,7 @@ def check_samplers_list(name, sampler_orig):
374360def check_samplers_multiclass_ova (name , sampler_orig ):
375361 sampler = clone (sampler_orig )
376362 # Check that multiclass target lead to the same results than OVA encoding
377- X , y = make_classification (
378- n_samples = 1000 ,
379- n_classes = 3 ,
380- n_informative = 4 ,
381- weights = [0.2 , 0.3 , 0.5 ],
382- random_state = 0 ,
383- )
363+ X , y = sample_dataset_generator ()
384364 y_ova = label_binarize (y , classes = np .unique (y ))
385365 X_res , y_res = sampler .fit_resample (X , y )
386366 X_res_ova , y_res_ova = sampler .fit_resample (X , y_ova )
@@ -391,27 +371,15 @@ def check_samplers_multiclass_ova(name, sampler_orig):
391371
392372def check_samplers_2d_target (name , sampler_orig ):
393373 sampler = clone (sampler_orig )
394- X , y = make_classification (
395- n_samples = 100 ,
396- n_classes = 3 ,
397- n_informative = 4 ,
398- weights = [0.2 , 0.3 , 0.5 ],
399- random_state = 0 ,
400- )
374+ X , y = sample_dataset_generator ()
401375
402376 y = y .reshape (- 1 , 1 ) # Make the target 2d
403377 sampler .fit_resample (X , y )
404378
405379
406380def check_samplers_preserve_dtype (name , sampler_orig ):
407381 sampler = clone (sampler_orig )
408- X , y = make_classification (
409- n_samples = 1000 ,
410- n_classes = 3 ,
411- n_informative = 4 ,
412- weights = [0.2 , 0.3 , 0.5 ],
413- random_state = 0 ,
414- )
382+ X , y = sample_dataset_generator ()
415383 # Cast X and y to not default dtype
416384 X = X .astype (np .float32 )
417385 y = y .astype (np .int32 )
@@ -422,13 +390,7 @@ def check_samplers_preserve_dtype(name, sampler_orig):
422390
423391def check_samplers_sample_indices (name , sampler_orig ):
424392 sampler = clone (sampler_orig )
425- X , y = make_classification (
426- n_samples = 1000 ,
427- n_classes = 3 ,
428- n_informative = 4 ,
429- weights = [0.2 , 0.3 , 0.5 ],
430- random_state = 0 ,
431- )
393+ X , y = sample_dataset_generator ()
432394 sampler .fit_resample (X , y )
433395 sample_indices = sampler ._get_tags ().get ("sample_indices" , None )
434396 if sample_indices :
0 commit comments