@@ -8,10 +8,10 @@ class _SigmoidTensor(Tensor): # Static sigmoid tensor for backpropagation
88 def __init__ (self , data , args , op , device ):
99 super ().__init__ (data , args , op , device = device )
1010
11- def _backward (x : Tensor , f_x , grad ):
11+ def grad_fn (x : Tensor , f_x , grad ):
1212 x ._apply_grad (grad * f_x * (1 - f_x ))
1313
14- self ._backward = _backward
14+ self .grad_fn = grad_fn
1515
1616
1717class Sigmoid (Module ): # Static sigmoid computation
@@ -42,10 +42,10 @@ class _ReLUTensor(Tensor): # Static ReLU tensor for backpropagation
4242 def __init__ (self , data , args , op , device ):
4343 super ().__init__ (data , args , op , device = device )
4444
45- def _backward (t : Tensor , f_x , grad ):
45+ def grad_fn (t : Tensor , f_x , grad ):
4646 t ._apply_grad (grad * (f_x > 0 ))
4747
48- self ._backward = _backward
48+ self .grad_fn = grad_fn
4949
5050
5151class ReLU (Module ): # Static ReLU computation
@@ -75,12 +75,12 @@ class _LeakyReLUTensor(Tensor): # Static LeakyReLU tensor for backpropagation
7575 def __init__ (self , data , args , op , device ):
7676 super ().__init__ (data , args , op , device = device )
7777
78- def _backward (t : Tensor , f_x , alpha , grad ):
78+ def grad_fn (t : Tensor , f_x , alpha , grad ):
7979 t ._apply_grad (
8080 grad * t .xp .where (f_x <= 0 , alpha , 1 ).astype (grad .dtype )
8181 )
8282
83- self ._backward = _backward
83+ self .grad_fn = grad_fn
8484
8585class LeakyReLU (Module ): # Static LeakyReLU computation
8686 def __init__ (self , alpha = 0.01 ):
@@ -109,10 +109,10 @@ class _TanhTensor(Tensor): # Static Tanh tensor for backpropagation
109109 def __init__ (self , data , args , op , device ):
110110 super ().__init__ (data , args , op , device = device )
111111
112- def _backward (t : Tensor , f_x , grad ):
112+ def grad_fn (t : Tensor , f_x , grad ):
113113 t ._apply_grad (grad * (1 - f_x ** 2 ))
114114
115- self ._backward = _backward
115+ self .grad_fn = grad_fn
116116
117117
118118class Tanh (Module ): # Static Tanh computation
@@ -142,11 +142,11 @@ class _SoftplusTensor(Tensor): # Static Softplus tensor for backpropagation
142142 def __init__ (self , data , args , op , device ):
143143 super ().__init__ (data , args , op , device = device )
144144
145- def _backward (t : Tensor , grad ):
145+ def grad_fn (t : Tensor , grad ):
146146 x = t .data
147147 t ._apply_grad (grad * (1 / (1 + t .xp .exp (- x ))))
148148
149- self ._backward = _backward
149+ self .grad_fn = grad_fn
150150
151151
152152class Softplus (Module ): # Static Softplus computation
@@ -176,11 +176,11 @@ class _SoftsignTensor(Tensor): # Static Softsign tensor for backpropagation
176176 def __init__ (self , data , args , op , device ):
177177 super ().__init__ (data , args , op , device = device )
178178
179- def _backward (t : Tensor , grad ):
179+ def grad_fn (t : Tensor , grad ):
180180 x = t .data
181181 t ._apply_grad (grad * (1 / (1 + t .xp .abs (x )) ** 2 ))
182182
183- self ._backward = _backward
183+ self .grad_fn = grad_fn
184184
185185
186186class Softsign (Module ): # Static Softsign computation
@@ -210,13 +210,13 @@ class _SwishTensorTensor(Tensor): # Static Swish tensor for backpropagation
210210 def __init__ (self , data , args , op , device ):
211211 super ().__init__ (data , args , op , device = device )
212212
213- def _backward (t : Tensor , f_x , beta , grad ):
213+ def grad_fn (t : Tensor , f_x , beta , grad ):
214214 x = t .data
215215 sigmoid = lambda x : 1 / (1 + t .xp .exp (- x ))
216216
217217 t ._apply_grad (grad * (beta * f_x + sigmoid (beta * x ) * (1 - beta * f_x )))
218218
219- self ._backward = _backward
219+ self .grad_fn = grad_fn
220220
221221
222222class Swish (Module ): # Static Swish computation
@@ -252,7 +252,7 @@ class _MishTensor(Tensor): # Static Mish tensor for backpropagation
252252 def __init__ (self , data , args , op , device ):
253253 super ().__init__ (data , args , op , device = device )
254254
255- def _backward (t : Tensor , grad ):
255+ def grad_fn (t : Tensor , grad ):
256256 xp = t .xp
257257 x = t .data
258258
@@ -264,7 +264,7 @@ def _backward(t: Tensor, grad):
264264
265265 t ._apply_grad (grad_x )
266266
267- self ._backward = _backward
267+ self .grad_fn = grad_fn
268268
269269
270270class Mish (Module ): # Static Mish computation
@@ -295,15 +295,15 @@ class _TanhExpTensor(Tensor): # Static TanhExp tensor for backpropagation
295295 def __init__ (self , data , args , op , device ):
296296 super ().__init__ (data , args , op , device = device )
297297
298- def _backward (t : Tensor , grad ):
298+ def grad_fn (t : Tensor , grad ):
299299 xp = t .xp
300300 x = t .data
301301
302302 grad_x = grad * (xp .tanh (xp .exp (x )) - x * xp .exp (x ) * (xp .power (xp .tanh (xp .exp (x )), 2 ) - 1 ))
303303
304304 t ._apply_grad (grad_x )
305305
306- self ._backward = _backward
306+ self .grad_fn = grad_fn
307307
308308class TanhExp (Module ): # Static TanhExp computation
309309 def __init__ (self ):
@@ -333,13 +333,13 @@ class _ELUTensor(Tensor): # Static ELU tensor for backpropagation
333333 def __init__ (self , data , args , op , device ):
334334 super ().__init__ (data , args , op , device = device )
335335
336- def _backward (t : Tensor , f_x , alpha , grad ):
336+ def grad_fn (t : Tensor , f_x , alpha , grad ):
337337 x = t .data
338338 grad_x = grad * (t .xp .where (x <= 0 , alpha + f_x , 1 ).astype (grad .dtype ))
339339
340340 t ._apply_grad (grad_x )
341341
342- self ._backward = _backward
342+ self .grad_fn = grad_fn
343343
344344
345345class ELU (Module ): # Static ELU computation
@@ -359,13 +359,13 @@ class _SELUTensor(Tensor): # Static SELU tensor for backpropagation
359359 def __init__ (self , data , args , op , device ):
360360 super ().__init__ (data , args , op , device = device )
361361
362- def _backward (t : Tensor , alpha , lmbda , grad ):
362+ def grad_fn (t : Tensor , alpha , lmbda , grad ):
363363 x = t .data
364364 grad_x = grad * (lmbda * t .xp .where (x > 0 , 1 , alpha * t .xp .exp (x )).astype (grad .dtype ))
365365
366366 t ._apply_grad (grad_x )
367367
368- self ._backward = _backward
368+ self .grad_fn = grad_fn
369369
370370
371371class SELU (Module ): # Static SELU computation
@@ -388,7 +388,7 @@ class _GELUTensor(Tensor): # Static GELU tensor for backpropagation
388388 def __init__ (self , data , args , op , device ):
389389 super ().__init__ (data , args , op , device = device )
390390
391- def _backward (t : Tensor , grad ):
391+ def grad_fn (t : Tensor , grad ):
392392 xp = t .xp
393393 x = t .data
394394 # sech = lambda z: 2 / (np.exp(z) + np.exp(-z))
@@ -403,7 +403,7 @@ def _backward(t: Tensor, grad):
403403
404404 t ._apply_grad (grad_x )
405405
406- self ._backward = _backward
406+ self .grad_fn = grad_fn
407407
408408
409409class GELU (Module ): # Static GELU computation
@@ -439,9 +439,9 @@ def __call__(self, x):
439439# def __init__(self, data, args, op, device):
440440# super().__init__(data, args, op, device=device)
441441
442- # self._backward = self.__backward
442+ # self.grad_fn = self._grad_fn
443443
444- # def __backward (self):
444+ # def _grad_fn (self):
445445# x = self.args[0].data
446446# # f_x = self.args[1]
447447# f_x = self.data
@@ -478,7 +478,7 @@ class _LogSoftmax(Tensor): # Static LogSoftmax tensor for backpropagation
478478 def __init__ (self , data , args , op , device ):
479479 super ().__init__ (data , args , op , device = device )
480480
481- def _backward (t : Tensor , f_x , axis , grad ):
481+ def grad_fn (t : Tensor , f_x , axis , grad ):
482482 x = t .data
483483 batch_size = x .shape [0 ]
484484 softmax = f_x
@@ -489,7 +489,7 @@ def _backward(t: Tensor, f_x, axis, grad):
489489
490490 t ._apply_grad (grad_x )
491491
492- self ._backward = _backward
492+ self .grad_fn = grad_fn
493493
494494
495495class LogSoftmax (Module ):
0 commit comments