Skip to content

Commit 5995d9a

Browse files
committed
logsoftmax bug fix
1 parent 95b1fd5 commit 5995d9a

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

neunet/nn/activations.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ class Softmax(Module): # Dynamic Softmax computation
427427
def __init__(self, axis=1):
428428
self.axis = axis
429429

430-
def forward(self, x):
430+
def forward(self, x: Tensor):
431431
e_x = x.sub(x.max(axis=self.axis, keepdims=True)).exp()
432432
return e_x.div(e_x.sum(axis=self.axis, keepdims=True))
433433

@@ -479,14 +479,10 @@ def __init__(self, data, args, op, device):
479479
super().__init__(data, args, op, device=device)
480480

481481
def grad_fn(t: Tensor, f_x, axis, grad):
482-
x = t.data
483-
batch_size = x.shape[0]
484-
softmax = f_x
482+
softmax = t.xp.exp(f_x) # e^(loge_softmax) = softmax
485483

486484
grad_x = grad - softmax * grad.sum(axis = axis, keepdims=True)
487485

488-
grad_x = grad_x / batch_size
489-
490486
t._apply_grad(grad_x)
491487

492488
self.grad_fn = grad_fn

0 commit comments

Comments
 (0)