Skip to content

Commit 6b892b6

Browse files
committed
-- removed lambda functions in Node Utils
1 parent a41541d commit 6b892b6

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

pytorch_tabular/models/node/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,14 @@ def _threshold_and_support(input, dim=-1):
9797
return tau, support_size
9898

9999

100-
sparsemax = lambda input, dim=-1: SparsemaxFunction.apply(input, dim)
101-
sparsemoid = lambda input: (0.5 * input + 0.5).clamp_(0, 1)
100+
def sparsemax(input, dim=-1):
101+
return SparsemaxFunction.apply(input, dim)
102+
103+
104+
def sparsemoid(input):
105+
return (0.5 * input + 0.5).clamp_(0, 1)
106+
# sparsemax = lambda input, dim=-1: SparsemaxFunction.apply(input, dim)
107+
# sparsemoid = lambda input: (0.5 * input + 0.5).clamp_(0, 1)
102108

103109

104110
class Entmax15Function(Function):
@@ -184,7 +190,9 @@ def _backward(output, grad_output):
184190
return grad_input
185191

186192

187-
entmax15 = lambda input, dim=-1: Entmax15Function.apply(input, dim)
193+
def entmax15(input, dim=-1):
194+
return Entmax15Function.apply(input, dim)
195+
# entmax15 = lambda input, dim=-1: Entmax15Function.apply(input, dim)
188196
entmoid15 = Entmoid15.apply
189197

190198

0 commit comments

Comments
 (0)