22
33import neunet as nnet
44from neunet .autograd import Tensor
5- from neunet .nn .activations import Softmax
5+ from neunet .nn .activations import LogSoftmax
66from neunet .nn .modules import Module
77
88
@@ -62,7 +62,8 @@ def __init__(self, weight=None, ignore_index=-100, reduction="mean"):
6262 self .ignore_index = ignore_index
6363 self .reduction = reduction
6464
65- self .softmax = Softmax (axis = 1 )
65+ # self.softmax = Softmax(axis=1)
66+ self .log_softmax = LogSoftmax (axis = 1 )
6667 self .nll_loss = NLLLoss (weight , ignore_index , reduction )
6768
6869 def forward (self , y_pred : Tensor , y_true : Tensor ) -> Tensor :
@@ -71,7 +72,8 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
7172 if y_pred .device != y_true .device :
7273 raise ValueError ("Tensors must be on the same device" )
7374
74- y_pred = self .softmax (y_pred ).log ()
75+ # y_pred = self.softmax(y_pred).log()
76+ y_pred = self .log_softmax (y_pred )
7577 return self .nll_loss (y_pred , y_true )
7678
7779 def __call__ (self , y_pred , y_true ):
@@ -91,7 +93,7 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
9193 raise ValueError ("Tensors must be on the same device" )
9294
9395 if self .weight is None :
94- self .weight = y_pred .xp .ones ((y_pred .data .shape [1 ]))
96+ self .weight = y_pred .xp .ones ((y_pred .data .shape [1 ]), dtype = y_pred . dtype )
9597
9698 if self .weight .shape != (y_pred .data .shape [1 ],):
9799 raise ValueError ("Weight shape must be equal to number of classes" )
@@ -103,10 +105,13 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
103105 if y_true .data .ndim == 1 :
104106 y_true = y_true [..., None ]
105107
106- ignore_mask = y_true .data != self .ignore_index
108+ # TODO: if neg value in y_true != ignore_index, raise error, fix, negative ids in weight
109+
110+ ignore_mask = (y_true .data != self .ignore_index ).astype (y_pred .dtype )
107111
108112 idx = np .indices (y_true .data .shape , sparse = True )
109113 criterion = (idx [0 ], y_true .data , * idx [1 :])
114+ # criterion = (self.xp.arange(y_true.data.shape[0]), y_true.data.flatten())
110115 loss = - y_pred [criterion ] * self .weight [y_true .data ] * ignore_mask
111116
112117 if self .reduction == "mean" :
0 commit comments