@@ -30,8 +30,7 @@ def __init__(self,
3030 self .add_fit_requirements ([
3131 FitRequirement ('X_train' , (np .ndarray , pd .DataFrame , spmatrix ), user_defined = True ,
3232 dataset_property = False ),
33- FitRequirement ('input_shape' , (Iterable ,), user_defined = True , dataset_property = True ),
34- FitRequirement ('tabular_transformer' , (BaseEstimator ,), user_defined = False , dataset_property = False ),
33+ FitRequirement ('shape_after_preprocessing' , (Iterable ,), user_defined = False , dataset_property = False ),
3534 FitRequirement ('network_embedding' , (nn .Module ,), user_defined = False , dataset_property = False )
3635 ])
3736 self .backbone : nn .Module = None
@@ -49,9 +48,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
4948 Self
5049 """
5150 self .check_requirements (X , y )
52- X_train = X ['X_train' ]
5351
54- input_shape = X_train . shape [ 1 : ]
52+ input_shape = X [ 'shape_after_preprocessing' ]
5553
5654 input_shape = get_output_shape (X ['network_embedding' ], input_shape = input_shape )
5755 self .input_shape = input_shape
0 commit comments