Skip to content

Commit a8a4431

Browse files
committed
edit CrossEntropyLoss
1 parent 7e85688 commit a8a4431

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

neunet/nn/losses.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import neunet as nnet
44
from neunet.autograd import Tensor
5-
from neunet.nn.activations import Softmax
5+
from neunet.nn.activations import LogSoftmax
66
from neunet.nn.modules import Module
77

88

@@ -62,7 +62,8 @@ def __init__(self, weight=None, ignore_index=-100, reduction="mean"):
6262
self.ignore_index = ignore_index
6363
self.reduction = reduction
6464

65-
self.softmax = Softmax(axis=1)
65+
# self.softmax = Softmax(axis=1)
66+
self.log_softmax = LogSoftmax(axis=1)
6667
self.nll_loss = NLLLoss(weight, ignore_index, reduction)
6768

6869
def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
@@ -71,7 +72,8 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
7172
if y_pred.device != y_true.device:
7273
raise ValueError("Tensors must be on the same device")
7374

74-
y_pred = self.softmax(y_pred).log()
75+
# y_pred = self.softmax(y_pred).log()
76+
y_pred = self.log_softmax(y_pred)
7577
return self.nll_loss(y_pred, y_true)
7678

7779
def __call__(self, y_pred, y_true):
@@ -91,7 +93,7 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
9193
raise ValueError("Tensors must be on the same device")
9294

9395
if self.weight is None:
94-
self.weight = y_pred.xp.ones((y_pred.data.shape[1]))
96+
self.weight = y_pred.xp.ones((y_pred.data.shape[1]), dtype=y_pred.dtype)
9597

9698
if self.weight.shape != (y_pred.data.shape[1],):
9799
raise ValueError("Weight shape must be equal to number of classes")
@@ -103,10 +105,13 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
103105
if y_true.data.ndim == 1:
104106
y_true = y_true[..., None]
105107

106-
ignore_mask = y_true.data != self.ignore_index
108+
# TODO: if neg value in y_true != ignore_index, raise error, fix, negative ids in weight
109+
110+
ignore_mask = (y_true.data != self.ignore_index).astype(y_pred.dtype)
107111

108112
idx = np.indices(y_true.data.shape, sparse=True)
109113
criterion = (idx[0], y_true.data, *idx[1:])
114+
# criterion = (self.xp.arange(y_true.data.shape[0]), y_true.data.flatten())
110115
loss = -y_pred[criterion] * self.weight[y_true.data] * ignore_mask
111116

112117
if self.reduction == "mean":

0 commit comments

Comments
 (0)