Skip to content

Commit 6da24a8

Browse files
Embedding layer (#91)
* work in progress * in progress * Working network embedding * ADD tests for network embedding * Removed ordinal encoder * Removed ordinal encoder * Add seed for test_losses for reproducibility * Addressed comments * fix flake * fix test import training * ADD_109 * No print allow * Fix tests and move to boston * Debug issue with python 3.6 * Debug for python3.6 * Run only debug file * work in progress * in progress * Working network embedding * ADD tests for network embedding * Removed ordinal encoder * Removed ordinal encoder * Addressed comments * fix flake * fix test import training * Fix tests and move to boston * Debug issue with python 3.6 * Run only debug file * Debug for python3.6 * print paths of parent dir * Trying to run examples * Trying to run examples * Add success model * Added parent directory for printing paths * Try no autouse * print log file to see if backend is saving num run * Setup logger in backend * handle nans in categorical columns (#118) * handle nans in categorical columns * Fixed error in self dtypes * Addressed comments from francisco * Forgot to commit * Fix flake * try without embeddings * work in progress * in progress * Working network embedding * ADD tests for network embedding * Removed ordinal encoder * Removed ordinal encoder * Addressed comments * fix flake * fix test import training * Fix tests and move to boston * Debug issue with python 3.6 * Run only debug file * Debug for python3.6 * work in progress * in progress * Working network embedding * ADD tests for network embedding * print paths of parent dir * Trying to run examples * Trying to run examples * Add success model * Added parent directory for printing paths * print log file to see if backend is saving num run * Setup logger in backend * try without embeddings * no embedding for python 3.6 * Deleted debug example * Fix test for evaluation * Deleted utils file Co-authored-by: chico <francisco.rivera.valverde@gmail.com>
1 parent 55ec853 commit 6da24a8

File tree

26 files changed

+608
-139
lines changed

26 files changed

+608
-139
lines changed

.github/workflows/examples.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ jobs:
3131
- name: Run tests
3232
run: |
3333
python examples/example_tabular_classification.py
34-
python examples/example_image_classification.py
34+
python examples/example_tabular_regression.py
35+
python examples/example_image_classification.py

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ def __init__(self, backend: Backend,
331331
name=logger_name,
332332
port=logger_port,
333333
)
334+
self.backend.setup_logger(name=logger_name, port=logger_port)
335+
334336
self.Y_optimization: Optional[np.ndarray] = None
335337
self.Y_actual_train: Optional[np.ndarray] = None
336338
self.pipelines: Optional[List[BaseEstimator]] = None
@@ -538,6 +540,7 @@ def file_output(
538540
else:
539541
pipeline = None
540542

543+
self.logger.debug("Saving directory {}, {}, {}".format(self.seed, self.num_run, self.budget))
541544
self.backend.save_numrun_to_dir(
542545
seed=int(self.seed),
543546
idx=int(self.num_run),

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from sklearn.compose import ColumnTransformer, make_column_transformer
5+
from sklearn.compose import ColumnTransformer
66
from sklearn.pipeline import make_pipeline
77

88
import torch
@@ -57,9 +57,9 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer":
5757
if len(X['dataset_properties']['categorical_columns']):
5858
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
5959

60-
self.preprocessor = make_column_transformer(
61-
(numerical_pipeline, X['dataset_properties']['numerical_columns']),
62-
(categorical_pipeline, X['dataset_properties']['categorical_columns']),
60+
self.preprocessor = ColumnTransformer([
61+
('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']),
62+
('categorical_pipeline', categorical_pipeline, X['dataset_properties']['categorical_columns'])],
6363
remainder='passthrough'
6464
)
6565

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OrdinalEncoder.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder_choice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_hyperparameter_search_space(self,
6565
raise ValueError("no encoders found, please add a encoder")
6666

6767
if default is None:
68-
defaults = ['OneHotEncoder', 'OrdinalEncoder', 'NoEncoder']
68+
defaults = ['OneHotEncoder', 'NoEncoder']
6969
for default_ in defaults:
7070
if default_ in available_preprocessors:
7171
if include is not None and default_ not in include:

autoPyTorch/pipeline/components/setup/network/base_network.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
self.add_fit_requirements([
3030
FitRequirement("network_head", (torch.nn.Module,), user_defined=False, dataset_property=False),
3131
FitRequirement("network_backbone", (torch.nn.Module,), user_defined=False, dataset_property=False),
32+
FitRequirement("network_embedding", (torch.nn.Module,), user_defined=False, dataset_property=False),
3233
])
3334
self.final_activation = None
3435

@@ -47,7 +48,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:
4748
# information to fit this stage
4849
self.check_requirements(X, y)
4950

50-
self.network = torch.nn.Sequential(X['network_backbone'], X['network_head'])
51+
self.network = torch.nn.Sequential(X['network_embedding'], X['network_backbone'], X['network_head'])
5152

5253
# Properly set the network training device
5354
if self.device is None:

autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from autoPyTorch.pipeline.components.base_component import (
1515
autoPyTorchComponent,
1616
)
17+
from autoPyTorch.pipeline.components.setup.network_backbone.utils import get_output_shape
1718
from autoPyTorch.utils.common import FitRequirement
1819

1920

@@ -31,7 +32,9 @@ def __init__(self,
3132
FitRequirement('X_train', (np.ndarray, pd.DataFrame, csr_matrix), user_defined=True,
3233
dataset_property=False),
3334
FitRequirement('input_shape', (Iterable,), user_defined=True, dataset_property=True),
34-
FitRequirement('tabular_transformer', (BaseEstimator,), user_defined=False, dataset_property=False)])
35+
FitRequirement('tabular_transformer', (BaseEstimator,), user_defined=False, dataset_property=False),
36+
FitRequirement('network_embedding', (nn.Module,), user_defined=False, dataset_property=False)
37+
])
3538
self.backbone: nn.Module = None
3639
self.config = kwargs
3740
self.input_shape: Optional[Iterable] = None
@@ -56,6 +59,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
5659
column_transformer = X['tabular_transformer'].preprocessor
5760
input_shape = column_transformer.transform(X_train[:1]).shape[1:]
5861

62+
input_shape = get_output_shape(X['network_embedding'], input_shape=input_shape)
5963
self.input_shape = input_shape
6064

6165
self.backbone = self.build_backbone(
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any, Dict, Optional, Tuple, Union
2+
3+
from ConfigSpace.configuration_space import ConfigurationSpace
4+
from ConfigSpace.hyperparameters import (
5+
UniformFloatHyperparameter,
6+
UniformIntegerHyperparameter
7+
)
8+
9+
import numpy as np
10+
11+
import torch
12+
from torch import nn
13+
14+
from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding import NetworkEmbeddingComponent
15+
16+
17+
class _LearnedEntityEmbedding(nn.Module):
18+
""" Learned entity embedding module for categorical features"""
19+
20+
def __init__(self, config: Dict[str, Any], num_input_features: np.ndarray, num_numerical_features: int):
21+
"""
22+
Arguments:
23+
config (Dict[str, Any]): The configuration sampled by the hyperparameter optimizer
24+
num_input_features (np.ndarray): column wise information of number of output columns after transformation
25+
for each categorical column and 0 for numerical columns
26+
num_numerical_features (int): number of numerical features in X
27+
"""
28+
super().__init__()
29+
self.config = config
30+
31+
self.num_numerical = num_numerical_features
32+
# list of number of categories of categorical data
33+
# or 0 for numerical data
34+
self.num_input_features = num_input_features
35+
categorical_features = self.num_input_features > 0
36+
37+
self.num_categorical_features = self.num_input_features[categorical_features]
38+
39+
self.embed_features = [num_in >= config["min_unique_values_for_embedding"] for num_in in
40+
self.num_input_features]
41+
self.num_output_dimensions = [0] * num_numerical_features
42+
self.num_output_dimensions.extend([config["dimension_reduction_" + str(i)] * num_in for i, num_in in
43+
enumerate(self.num_categorical_features)])
44+
self.num_output_dimensions = [int(np.clip(num_out, 1, num_in - 1)) for num_out, num_in in
45+
zip(self.num_output_dimensions, self.num_input_features)]
46+
self.num_output_dimensions = [num_out if embed else num_in for num_out, embed, num_in in
47+
zip(self.num_output_dimensions, self.embed_features,
48+
self.num_input_features)]
49+
self.num_out_feats = self.num_numerical + sum(self.num_output_dimensions)
50+
51+
self.ee_layers = self._create_ee_layers()
52+
53+
def forward(self, x: torch.Tensor) -> torch.Tensor:
54+
# pass the columns of each categorical feature through entity embedding layer
55+
# before passing it through the model
56+
concat_seq = []
57+
last_concat = 0
58+
x_pointer = 0
59+
layer_pointer = 0
60+
for num_in, embed in zip(self.num_input_features, self.embed_features):
61+
if not embed:
62+
x_pointer += 1
63+
continue
64+
if x_pointer > last_concat:
65+
concat_seq.append(x[:, last_concat: x_pointer])
66+
categorical_feature_slice = x[:, x_pointer: x_pointer + num_in]
67+
concat_seq.append(self.ee_layers[layer_pointer](categorical_feature_slice))
68+
layer_pointer += 1
69+
x_pointer += num_in
70+
last_concat = x_pointer
71+
72+
concat_seq.append(x[:, last_concat:])
73+
return torch.cat(concat_seq, dim=1)
74+
75+
def _create_ee_layers(self) -> nn.ModuleList:
76+
# entity embeding layers are Linear Layers
77+
layers = nn.ModuleList()
78+
for i, (num_in, embed, num_out) in enumerate(zip(self.num_input_features, self.embed_features,
79+
self.num_output_dimensions)):
80+
if not embed:
81+
continue
82+
layers.append(nn.Linear(num_in, num_out))
83+
return layers
84+
85+
86+
class LearnedEntityEmbedding(NetworkEmbeddingComponent):
87+
"""
88+
Class to learn an embedding for categorical hyperparameters.
89+
"""
90+
91+
def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None, **kwargs: Any):
92+
super().__init__(random_state=random_state)
93+
self.config = kwargs
94+
95+
def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module:
96+
return _LearnedEntityEmbedding(config=self.config,
97+
num_input_features=num_input_features,
98+
num_numerical_features=num_numerical_features)
99+
100+
@staticmethod
101+
def get_hyperparameter_search_space(
102+
dataset_properties: Optional[Dict[str, str]] = None,
103+
min_unique_values_for_embedding: Tuple[Tuple, int, bool] = ((3, 7), 5, True),
104+
dimension_reduction: Tuple[Tuple, float] = ((0, 1), 0.5),
105+
) -> ConfigurationSpace:
106+
cs = ConfigurationSpace()
107+
min_hp = UniformIntegerHyperparameter("min_unique_values_for_embedding",
108+
lower=min_unique_values_for_embedding[0][0],
109+
upper=min_unique_values_for_embedding[0][1],
110+
default_value=min_unique_values_for_embedding[1],
111+
log=min_unique_values_for_embedding[2]
112+
)
113+
cs.add_hyperparameter(min_hp)
114+
if dataset_properties is not None:
115+
for i in range(len(dataset_properties['categorical_columns'])):
116+
ee_dimensions_hp = UniformFloatHyperparameter("dimension_reduction_" + str(i),
117+
lower=dimension_reduction[0][0],
118+
upper=dimension_reduction[0][1],
119+
default_value=dimension_reduction[1]
120+
)
121+
cs.add_hyperparameter(ee_dimensions_hp)
122+
return cs
123+
124+
@staticmethod
125+
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]:
126+
return {
127+
'shortname': 'embedding',
128+
'name': 'LearnedEntityEmbedding',
129+
'handles_tabular': True,
130+
'handles_image': False,
131+
'handles_time_series': False,
132+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Dict, Optional, Union
2+
3+
from ConfigSpace.configuration_space import ConfigurationSpace
4+
5+
import numpy as np
6+
7+
import torch
8+
from torch import nn
9+
10+
from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding import NetworkEmbeddingComponent
11+
12+
13+
class _NoEmbedding(nn.Module):
14+
def forward(self, x: torch.Tensor) -> torch.Tensor:
15+
return x
16+
17+
18+
class NoEmbedding(NetworkEmbeddingComponent):
19+
"""
20+
Class to learn an embedding for categorical hyperparameters.
21+
"""
22+
23+
def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None):
24+
super().__init__(random_state=random_state)
25+
26+
def build_embedding(self, num_input_features: np.ndarray, num_numerical_features: int) -> nn.Module:
27+
return _NoEmbedding()
28+
29+
@staticmethod
30+
def get_hyperparameter_search_space(
31+
dataset_properties: Optional[Dict[str, str]] = None,
32+
) -> ConfigurationSpace:
33+
cs = ConfigurationSpace()
34+
return cs
35+
36+
@staticmethod
37+
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]:
38+
return {
39+
'shortname': 'no embedding',
40+
'name': 'NoEmbedding',
41+
'handles_tabular': True,
42+
'handles_image': False,
43+
'handles_time_series': False,
44+
}

autoPyTorch/pipeline/components/setup/network_embedding/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)