Skip to content

Commit a41541d

Browse files
committed
-- moved nested function out for pickling
1 parent 10dbb12 commit a41541d

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

pytorch_tabular/models/mixture_density/mdn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ def unpack_input(self, x: Dict):
378378
class NODEMDN(BaseMDN):
379379
def __init__(self, config: DictConfig, **kwargs):
380380
super().__init__(config, **kwargs)
381+
382+
def subset(self, x):
383+
return x[..., :].mean(dim=-2)
381384

382385
def _build_network(self):
383386
self.hparams.node_input_dim = (
@@ -387,10 +390,7 @@ def _build_network(self):
387390
# average first n channels of every tree, where n is the number of output targets for regression
388391
# and number of classes for classification
389392

390-
def subset(x):
391-
return x[..., :].mean(dim=-2)
392-
393-
output_response = utils.Lambda(subset)
393+
output_response = utils.Lambda(self.subset)
394394
self.backbone = nn.Sequential(backbone, output_response)
395395
# Adding the last layer
396396
self.hparams.mdn_config.input_dim = backbone.output_dim

pytorch_tabular/models/node/node_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __init__(self, config: DictConfig, **kwargs):
6060
if config.embed_categorical:
6161
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
6262
super().__init__(config, **kwargs)
63+
64+
def subset(self, x):
65+
return x[..., : self.hparams.output_dim].mean(dim=-2)
6366

6467
def _build_network(self):
6568
if self.hparams.embed_categorical:
@@ -79,10 +82,7 @@ def _build_network(self):
7982
# average first n channels of every tree, where n is the number of output targets for regression
8083
# and number of classes for classification
8184

82-
def subset(x):
83-
return x[..., : self.hparams.output_dim].mean(dim=-2)
84-
85-
self.output_response = utils.Lambda(subset)
85+
self.output_response = utils.Lambda(self.subset)
8686

8787
def unpack_input(self, x: Dict):
8888
if self.hparams.embed_categorical:

0 commit comments

Comments
 (0)