@@ -337,12 +337,11 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
337337def create_online_function (
338338 estimator_instance , method_instance , data_args , num_batches , batch_size
339339):
340- n_batches = data_args [0 ].shape [0 ] // batch_size
341340
342341 if "y" in list (inspect .signature (method_instance ).parameters ):
343342
344343 def ndarray_function (x , y ):
345- for i in range (n_batches ):
344+ for i in range (num_batches ):
346345 method_instance (
347346 x [i * batch_size : (i + 1 ) * batch_size ],
348347 y [i * batch_size : (i + 1 ) * batch_size ],
@@ -351,7 +350,7 @@ def ndarray_function(x, y):
351350 estimator_instance ._onedal_finalize_fit ()
352351
353352 def dataframe_function (x , y ):
354- for i in range (n_batches ):
353+ for i in range (num_batches ):
355354 method_instance (
356355 x .iloc [i * batch_size : (i + 1 ) * batch_size ],
357356 y .iloc [i * batch_size : (i + 1 ) * batch_size ],
@@ -362,13 +361,13 @@ def dataframe_function(x, y):
362361 else :
363362
364363 def ndarray_function (x ):
365- for i in range (n_batches ):
364+ for i in range (num_batches ):
366365 method_instance (x [i * batch_size : (i + 1 ) * batch_size ])
367366 if hasattr (estimator_instance , "_onedal_finalize_fit" ):
368367 estimator_instance ._onedal_finalize_fit ()
369368
370369 def dataframe_function (x ):
371- for i in range (n_batches ):
370+ for i in range (num_batches ):
372371 method_instance (x .iloc [i * batch_size : (i + 1 ) * batch_size ])
373372 if hasattr (estimator_instance , "_onedal_finalize_fit" ):
374373 estimator_instance ._onedal_finalize_fit ()
0 commit comments