File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed
pytorch_tabular/models/node Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -97,8 +97,14 @@ def _threshold_and_support(input, dim=-1):
9797 return tau , support_size
9898
9999
100- sparsemax = lambda input , dim = - 1 : SparsemaxFunction .apply (input , dim )
101- sparsemoid = lambda input : (0.5 * input + 0.5 ).clamp_ (0 , 1 )
100+ def sparsemax (input , dim = - 1 ):
101+ return SparsemaxFunction .apply (input , dim )
102+
103+
104+ def sparsemoid (input ):
105+ return (0.5 * input + 0.5 ).clamp_ (0 , 1 )
106+ # sparsemax = lambda input, dim=-1: SparsemaxFunction.apply(input, dim)
107+ # sparsemoid = lambda input: (0.5 * input + 0.5).clamp_(0, 1)
102108
103109
104110class Entmax15Function (Function ):
@@ -184,7 +190,9 @@ def _backward(output, grad_output):
184190 return grad_input
185191
186192
187- entmax15 = lambda input , dim = - 1 : Entmax15Function .apply (input , dim )
193+ def entmax15 (input , dim = - 1 ):
194+ return Entmax15Function .apply (input , dim )
195+ # entmax15 = lambda input, dim=-1: Entmax15Function.apply(input, dim)
188196entmoid15 = Entmoid15 .apply
189197
190198
You can’t perform that action at this time.
0 commit comments