Skip to content

Commit 347f9dc

Browse files
committed
-- Added seed_everything
-- updated history
1 parent 5bc6f8e commit 347f9dc

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

docs/history.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,12 @@ History
2929
------------------
3030
- Added more documentation
3131
- Added Zenodo citation
32+
33+
0.6.0 (2021-06-21)
34+
------------------
35+
- Upgraded versions of PyTorch Lightning to 1.3.6
36+
- Changed the way `gpus` parameter is handled to avoid confusion. `None` is CPU, `-1` is all GPUs, `int` is number of GPUs
37+
- Added a few more Trainer Params like `deterministic`, `auto_select_gpus`
38+
- Some bug fixes and changes to docs
39+
- Added `seed_everything` to the fit method to ensure reproducibility
40+
- Refactored data_aware_initialization to be part of the BaseModel. Inherited Models can override the method to implement data aware initialization techniques

pytorch_tabular/config/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,12 @@ class ModelConfig:
576576
"help": "The range in which we should limit the output variable. Currently ignored for multi-target regression. Typically used for Regression problems. If left empty, will not apply any restrictions"
577577
},
578578
)
579+
seed: int = field(
580+
default=42,
581+
metadata={
582+
"help": "The seed for reproducibility. Defaults to 42"
583+
},
584+
)
579585

580586
def __post_init__(self):
581587
if self.task == "regression":

pytorch_tabular/tabular_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# For license information, see LICENSE.TXT
44
"""Tabular Model"""
55
from collections import defaultdict
6+
from pytorch_lightning.utilities.seed import seed_everything
67
import logging
78
import os
89
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
@@ -19,6 +20,7 @@
1920
from sklearn.base import TransformerMixin
2021
from torch import nn
2122
from tqdm.autonotebook import tqdm
23+
from pytorch_tabular import config
2224

2325
import pytorch_tabular.models as models
2426
from pytorch_tabular.config import (
@@ -335,7 +337,7 @@ def _pre_fit(
335337
target_transform: Optional[Union[TransformerMixin, Tuple]],
336338
max_epochs: int,
337339
min_epochs: int,
338-
reset: bool,
340+
reset: bool
339341
):
340342
"""Prepares the dataloaders, trainer, and model for the fit process"""
341343
if target_transform is not None:
@@ -381,6 +383,7 @@ def fit(
381383
max_epochs: Optional[int] = None,
382384
min_epochs: Optional[int] = None,
383385
reset: bool = False,
386+
seed: Optional[int] = None,
384387
) -> None:
385388
"""The fit method which takes in the data and triggers the training
386389
@@ -414,7 +417,10 @@ def fit(
414417
min_epochs (Optional[int]): Overwrite minimum number of epochs to be run
415418
416419
reset: (bool): Flag to reset the model and train again from scratch
420+
421+
seed: (int): If you have to override the default seed set as part of of ModelConfig
417422
"""
423+
seed_everything(seed if seed is not None else self.config.seed)
418424
train_loader, val_loader = self._pre_fit(
419425
train,
420426
validation,

0 commit comments

Comments
 (0)