Skip to content

Commit 3918d97

Browse files
committed
use shape after preprocessing in base network backbone
1 parent 7f32ada commit 3918d97

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def __init__(self,
3030
self.add_fit_requirements([
3131
FitRequirement('X_train', (np.ndarray, pd.DataFrame, spmatrix), user_defined=True,
3232
dataset_property=False),
33-
FitRequirement('input_shape', (Iterable,), user_defined=True, dataset_property=True),
34-
FitRequirement('tabular_transformer', (BaseEstimator,), user_defined=False, dataset_property=False),
33+
FitRequirement('shape_after_preprocessing', (Iterable,), user_defined=False, dataset_property=False),
3534
FitRequirement('network_embedding', (nn.Module,), user_defined=False, dataset_property=False)
3635
])
3736
self.backbone: nn.Module = None
@@ -49,9 +48,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
4948
Self
5049
"""
5150
self.check_requirements(X, y)
52-
X_train = X['X_train']
5351

54-
input_shape = X_train.shape[1:]
52+
input_shape = X['shape_after_preprocessing']
5553

5654
input_shape = get_output_shape(X['network_embedding'], input_shape=input_shape)
5755
self.input_shape = input_shape

0 commit comments

Comments
 (0)