@@ -61,6 +61,7 @@ class SoftplusTransform(Transform):
6161 Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
6262 The implementation reverts to the linear function when :math:`x > 20`.
6363 """
64+
6465 domain = constraints .real
6566 codomain = constraints .positive
6667 bijective = True
@@ -93,6 +94,7 @@ class MinusOneTransform(Transform):
9394 r"""
9495 Transform x -> x - 1.
9596 """
97+
9698 domain = constraints .real
9799 codomain = constraints .real
98100 sign : int = 1
@@ -112,6 +114,7 @@ class ReLuTransform(Transform):
112114 r"""
113115 Transform x -> max(0, x).
114116 """
117+
115118 domain = constraints .real
116119 codomain = constraints .nonnegative
117120 sign : int = 1
@@ -364,7 +367,7 @@ def inverse_transform(self, y: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
364367 decoded = self .classes_vector_ [y ]
365368 return decoded
366369
367- def __call__ (self , data : ( Dict [str , torch .Tensor ]) ) -> torch .Tensor :
370+ def __call__ (self , data : Dict [str , torch .Tensor ]) -> torch .Tensor :
368371 """
369372 Extract prediction from network output. Does not map back to input
370373 categories as this would require a numpy tensor without grad-abilities.
@@ -1189,17 +1192,23 @@ def func(*args, **kwargs):
11891192 results = []
11901193 for idx , norm in enumerate (self .normalizers ):
11911194 new_args = [
1192- arg [idx ]
1193- if isinstance (arg , (list , tuple ))
1194- and not isinstance (arg , rnn .PackedSequence )
1195- and len (arg ) == n
1196- else arg
1195+ (
1196+ arg [idx ]
1197+ if isinstance (arg , (list , tuple ))
1198+ and not isinstance (arg , rnn .PackedSequence )
1199+ and len (arg ) == n
1200+ else arg
1201+ )
11971202 for arg in args
11981203 ]
11991204 new_kwargs = {
1200- key : val [idx ]
1201- if isinstance (val , list ) and not isinstance (val , rnn .PackedSequence ) and len (val ) == n
1202- else val
1205+ key : (
1206+ val [idx ]
1207+ if isinstance (val , list )
1208+ and not isinstance (val , rnn .PackedSequence )
1209+ and len (val ) == n
1210+ else val
1211+ )
12031212 for key , val in kwargs .items ()
12041213 }
12051214 results .append (getattr (norm , name )(* new_args , ** new_kwargs ))
0 commit comments