Skip to content

Commit 1506fc7

Browse files
committed
blacked
1 parent 953c652 commit 1506fc7

File tree

27 files changed

+89
-48
lines changed

27 files changed

+89
-48
lines changed

.github/workflows/lint.yml

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ on:
88

99
jobs:
1010
linter-black:
11-
name: Check code formatting with Black
12-
runs-on: ubuntu-latest
13-
steps:
14-
- name: Checkout
15-
uses: actions/checkout@v2
16-
- name: Set up Python 3.8
17-
uses: actions/setup-python@v2
18-
with:
19-
python-version: 3.8
20-
- name: Install Black
21-
run: pip install black[jupyter]
22-
- name: Run Black
23-
run: black --check .
11+
name: Check code formatting with Black
12+
runs-on: ubuntu-latest
13+
steps:
14+
- name: Checkout
15+
uses: actions/checkout@v2
16+
- name: Set up Python 3.8
17+
uses: actions/setup-python@v2
18+
with:
19+
python-version: 3.8
20+
- name: Install Black
21+
run: pip install black[jupyter]
22+
- name: Run Black
23+
run: black --check .
2424

2525
imports-check-isort:
2626
name: Check valid import formatting with isort
@@ -53,7 +53,6 @@ jobs:
5353
- name: Run checks
5454
run: flake8
5555

56-
5756
pre-commit-hooks:
5857
name: Check that pre-commit hooks pass
5958
runs-on: ubuntu-latest

pytorch_forecasting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
PyTorch Forecasting package for timeseries forecasting with PyTorch.
33
"""
4+
45
from pytorch_forecasting.data import (
56
EncoderNormalizer,
67
GroupNormalizer,

pytorch_forecasting/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Handling timeseries data is not trivial. It requires special treatment. This sub-package provides the necessary tools
55
to abstracts the necessary work.
66
"""
7+
78
from pytorch_forecasting.data.encoders import (
89
EncoderNormalizer,
910
GroupNormalizer,

pytorch_forecasting/data/encoders.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class SoftplusTransform(Transform):
6161
Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
6262
The implementation reverts to the linear function when :math:`x > 20`.
6363
"""
64+
6465
domain = constraints.real
6566
codomain = constraints.positive
6667
bijective = True
@@ -93,6 +94,7 @@ class MinusOneTransform(Transform):
9394
r"""
9495
Transform x -> x - 1.
9596
"""
97+
9698
domain = constraints.real
9799
codomain = constraints.real
98100
sign: int = 1
@@ -112,6 +114,7 @@ class ReLuTransform(Transform):
112114
r"""
113115
Transform x -> max(0, x).
114116
"""
117+
115118
domain = constraints.real
116119
codomain = constraints.nonnegative
117120
sign: int = 1
@@ -364,7 +367,7 @@ def inverse_transform(self, y: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
364367
decoded = self.classes_vector_[y]
365368
return decoded
366369

367-
def __call__(self, data: (Dict[str, torch.Tensor])) -> torch.Tensor:
370+
def __call__(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
368371
"""
369372
Extract prediction from network output. Does not map back to input
370373
categories as this would require a numpy tensor without grad-abilities.
@@ -1189,17 +1192,23 @@ def func(*args, **kwargs):
11891192
results = []
11901193
for idx, norm in enumerate(self.normalizers):
11911194
new_args = [
1192-
arg[idx]
1193-
if isinstance(arg, (list, tuple))
1194-
and not isinstance(arg, rnn.PackedSequence)
1195-
and len(arg) == n
1196-
else arg
1195+
(
1196+
arg[idx]
1197+
if isinstance(arg, (list, tuple))
1198+
and not isinstance(arg, rnn.PackedSequence)
1199+
and len(arg) == n
1200+
else arg
1201+
)
11971202
for arg in args
11981203
]
11991204
new_kwargs = {
1200-
key: val[idx]
1201-
if isinstance(val, list) and not isinstance(val, rnn.PackedSequence) and len(val) == n
1202-
else val
1205+
key: (
1206+
val[idx]
1207+
if isinstance(val, list)
1208+
and not isinstance(val, rnn.PackedSequence)
1209+
and len(val) == n
1210+
else val
1211+
)
12031212
for key, val in kwargs.items()
12041213
}
12051214
results.append(getattr(norm, name)(*new_args, **new_kwargs))

pytorch_forecasting/data/examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Example datasets for tutorials and testing.
33
"""
4+
45
from pathlib import Path
56

67
import numpy as np

pytorch_forecasting/data/samplers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Samplers for sampling time series from the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`
33
"""
4+
45
import warnings
56

67
import numpy as np

pytorch_forecasting/data/timeseries.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Timeseries data is special and has to be processed and fed to algorithms in a special way. This module
55
defines a class that is able to handle a wide variety of timeseries data problems.
66
"""
7+
78
from copy import copy as _copy, deepcopy
89
from functools import lru_cache
910
import inspect

pytorch_forecasting/metrics/_mqf2_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Classes and functions for the MQF2 metric."""
2+
23
from typing import List, Optional, Tuple
34

45
from cpflows.flows import DeepConvexFlow, SequentialFlow

pytorch_forecasting/metrics/base_metrics.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Base classes for metrics - only for inheritance.
33
"""
4+
45
import inspect
56
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
67
import warnings
@@ -454,17 +455,23 @@ def func(*args, **kwargs):
454455
results = []
455456
for idx, m in enumerate(self.metrics):
456457
new_args = [
457-
arg[idx]
458-
if isinstance(arg, (list, tuple))
459-
and not isinstance(arg, rnn.PackedSequence)
460-
and len(arg) == n
461-
else arg
458+
(
459+
arg[idx]
460+
if isinstance(arg, (list, tuple))
461+
and not isinstance(arg, rnn.PackedSequence)
462+
and len(arg) == n
463+
else arg
464+
)
462465
for arg in args
463466
]
464467
new_kwargs = {
465-
key: val[idx]
466-
if isinstance(val, list) and not isinstance(val, rnn.PackedSequence) and len(val) == n
467-
else val
468+
key: (
469+
val[idx]
470+
if isinstance(val, list)
471+
and not isinstance(val, rnn.PackedSequence)
472+
and len(val) == n
473+
else val
474+
)
468475
for key, val in kwargs.items()
469476
}
470477
results.append(getattr(m, name)(*new_args, **new_kwargs))

pytorch_forecasting/metrics/distributions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Metrics that allow the parametric forecast of parameters of uni- and multivariate distributions."""
2+
23
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
34

45
import numpy as np

0 commit comments

Comments
 (0)