Skip to content

Commit 95b1fd5

Browse files
committed
code refactoring
1 parent 622fe48 commit 95b1fd5

File tree

16 files changed

+148
-152
lines changed

16 files changed

+148
-152
lines changed

neunet/autograd.py

Lines changed: 92 additions & 96 deletions
Large diffs are not rendered by default.

neunet/nn/activations.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ class _SigmoidTensor(Tensor): # Static sigmoid tensor for backpropagation
88
def __init__(self, data, args, op, device):
99
super().__init__(data, args, op, device=device)
1010

11-
def _backward(x: Tensor, f_x, grad):
11+
def grad_fn(x: Tensor, f_x, grad):
1212
x._apply_grad(grad * f_x * (1 - f_x))
1313

14-
self._backward = _backward
14+
self.grad_fn = grad_fn
1515

1616

1717
class Sigmoid(Module): # Static sigmoid computation
@@ -42,10 +42,10 @@ class _ReLUTensor(Tensor): # Static ReLU tensor for backpropagation
4242
def __init__(self, data, args, op, device):
4343
super().__init__(data, args, op, device=device)
4444

45-
def _backward(t: Tensor, f_x, grad):
45+
def grad_fn(t: Tensor, f_x, grad):
4646
t._apply_grad(grad * (f_x > 0))
4747

48-
self._backward = _backward
48+
self.grad_fn = grad_fn
4949

5050

5151
class ReLU(Module): # Static ReLU computation
@@ -75,12 +75,12 @@ class _LeakyReLUTensor(Tensor): # Static LeakyReLU tensor for backpropagation
7575
def __init__(self, data, args, op, device):
7676
super().__init__(data, args, op, device=device)
7777

78-
def _backward(t: Tensor, f_x, alpha, grad):
78+
def grad_fn(t: Tensor, f_x, alpha, grad):
7979
t._apply_grad(
8080
grad * t.xp.where(f_x <= 0, alpha, 1).astype(grad.dtype)
8181
)
8282

83-
self._backward = _backward
83+
self.grad_fn = grad_fn
8484

8585
class LeakyReLU(Module): # Static LeakyReLU computation
8686
def __init__(self, alpha=0.01):
@@ -109,10 +109,10 @@ class _TanhTensor(Tensor): # Static Tanh tensor for backpropagation
109109
def __init__(self, data, args, op, device):
110110
super().__init__(data, args, op, device=device)
111111

112-
def _backward(t: Tensor, f_x, grad):
112+
def grad_fn(t: Tensor, f_x, grad):
113113
t._apply_grad(grad * (1 - f_x ** 2))
114114

115-
self._backward = _backward
115+
self.grad_fn = grad_fn
116116

117117

118118
class Tanh(Module): # Static Tanh computation
@@ -142,11 +142,11 @@ class _SoftplusTensor(Tensor): # Static Softplus tensor for backpropagation
142142
def __init__(self, data, args, op, device):
143143
super().__init__(data, args, op, device=device)
144144

145-
def _backward(t: Tensor, grad):
145+
def grad_fn(t: Tensor, grad):
146146
x = t.data
147147
t._apply_grad(grad * (1 / (1 + t.xp.exp(-x))))
148148

149-
self._backward = _backward
149+
self.grad_fn = grad_fn
150150

151151

152152
class Softplus(Module): # Static Softplus computation
@@ -176,11 +176,11 @@ class _SoftsignTensor(Tensor): # Static Softsign tensor for backpropagation
176176
def __init__(self, data, args, op, device):
177177
super().__init__(data, args, op, device=device)
178178

179-
def _backward(t: Tensor, grad):
179+
def grad_fn(t: Tensor, grad):
180180
x = t.data
181181
t._apply_grad(grad * (1 / (1 + t.xp.abs(x)) ** 2))
182182

183-
self._backward = _backward
183+
self.grad_fn = grad_fn
184184

185185

186186
class Softsign(Module): # Static Softsign computation
@@ -210,13 +210,13 @@ class _SwishTensorTensor(Tensor): # Static Swish tensor for backpropagation
210210
def __init__(self, data, args, op, device):
211211
super().__init__(data, args, op, device=device)
212212

213-
def _backward(t: Tensor, f_x, beta, grad):
213+
def grad_fn(t: Tensor, f_x, beta, grad):
214214
x = t.data
215215
sigmoid = lambda x: 1 / (1 + t.xp.exp(-x))
216216

217217
t._apply_grad(grad * (beta * f_x + sigmoid(beta * x) * (1 - beta * f_x)))
218218

219-
self._backward = _backward
219+
self.grad_fn = grad_fn
220220

221221

222222
class Swish(Module): # Static Swish computation
@@ -252,7 +252,7 @@ class _MishTensor(Tensor): # Static Mish tensor for backpropagation
252252
def __init__(self, data, args, op, device):
253253
super().__init__(data, args, op, device=device)
254254

255-
def _backward(t: Tensor, grad):
255+
def grad_fn(t: Tensor, grad):
256256
xp = t.xp
257257
x = t.data
258258

@@ -264,7 +264,7 @@ def _backward(t: Tensor, grad):
264264

265265
t._apply_grad(grad_x)
266266

267-
self._backward = _backward
267+
self.grad_fn = grad_fn
268268

269269

270270
class Mish(Module): # Static Mish computation
@@ -295,15 +295,15 @@ class _TanhExpTensor(Tensor): # Static TanhExp tensor for backpropagation
295295
def __init__(self, data, args, op, device):
296296
super().__init__(data, args, op, device=device)
297297

298-
def _backward(t: Tensor, grad):
298+
def grad_fn(t: Tensor, grad):
299299
xp = t.xp
300300
x = t.data
301301

302302
grad_x = grad * (xp.tanh(xp.exp(x)) - x * xp.exp(x) * (xp.power(xp.tanh(xp.exp(x)), 2) - 1))
303303

304304
t._apply_grad(grad_x)
305305

306-
self._backward = _backward
306+
self.grad_fn = grad_fn
307307

308308
class TanhExp(Module): # Static TanhExp computation
309309
def __init__(self):
@@ -333,13 +333,13 @@ class _ELUTensor(Tensor): # Static ELU tensor for backpropagation
333333
def __init__(self, data, args, op, device):
334334
super().__init__(data, args, op, device=device)
335335

336-
def _backward(t: Tensor, f_x, alpha, grad):
336+
def grad_fn(t: Tensor, f_x, alpha, grad):
337337
x = t.data
338338
grad_x = grad * (t.xp.where(x <= 0, alpha + f_x, 1).astype(grad.dtype))
339339

340340
t._apply_grad(grad_x)
341341

342-
self._backward = _backward
342+
self.grad_fn = grad_fn
343343

344344

345345
class ELU(Module): # Static ELU computation
@@ -359,13 +359,13 @@ class _SELUTensor(Tensor): # Static SELU tensor for backpropagation
359359
def __init__(self, data, args, op, device):
360360
super().__init__(data, args, op, device=device)
361361

362-
def _backward(t: Tensor, alpha, lmbda, grad):
362+
def grad_fn(t: Tensor, alpha, lmbda, grad):
363363
x = t.data
364364
grad_x = grad * (lmbda * t.xp.where(x > 0, 1, alpha * t.xp.exp(x)).astype(grad.dtype))
365365

366366
t._apply_grad(grad_x)
367367

368-
self._backward = _backward
368+
self.grad_fn = grad_fn
369369

370370

371371
class SELU(Module): # Static SELU computation
@@ -388,7 +388,7 @@ class _GELUTensor(Tensor): # Static GELU tensor for backpropagation
388388
def __init__(self, data, args, op, device):
389389
super().__init__(data, args, op, device=device)
390390

391-
def _backward(t: Tensor, grad):
391+
def grad_fn(t: Tensor, grad):
392392
xp = t.xp
393393
x = t.data
394394
# sech = lambda z: 2 / (np.exp(z) + np.exp(-z))
@@ -403,7 +403,7 @@ def _backward(t: Tensor, grad):
403403

404404
t._apply_grad(grad_x)
405405

406-
self._backward = _backward
406+
self.grad_fn = grad_fn
407407

408408

409409
class GELU(Module): # Static GELU computation
@@ -439,9 +439,9 @@ def __call__(self, x):
439439
# def __init__(self, data, args, op, device):
440440
# super().__init__(data, args, op, device=device)
441441

442-
# self._backward = self.__backward
442+
# self.grad_fn = self._grad_fn
443443

444-
# def __backward(self):
444+
# def _grad_fn(self):
445445
# x = self.args[0].data
446446
# # f_x = self.args[1]
447447
# f_x = self.data
@@ -478,7 +478,7 @@ class _LogSoftmax(Tensor): # Static LogSoftmax tensor for backpropagation
478478
def __init__(self, data, args, op, device):
479479
super().__init__(data, args, op, device=device)
480480

481-
def _backward(t: Tensor, f_x, axis, grad):
481+
def grad_fn(t: Tensor, f_x, axis, grad):
482482
x = t.data
483483
batch_size = x.shape[0]
484484
softmax = f_x
@@ -489,7 +489,7 @@ def _backward(t: Tensor, f_x, axis, grad):
489489

490490
t._apply_grad(grad_x)
491491

492-
self._backward = _backward
492+
self.grad_fn = grad_fn
493493

494494

495495
class LogSoftmax(Module):

neunet/nn/layers/avgpool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class _AvgPool2dTensor(Tensor):
1111
def __init__(self, data, args, op, device):
1212
super().__init__(data, args, op, device=device)
1313

14-
def _backward(
14+
def grad_fn(
1515
X: Tensor,
1616
kernel_size,
1717
stride,
@@ -45,7 +45,7 @@ def _backward(
4545

4646
X._apply_grad(grad_X)
4747

48-
self._backward = _backward
48+
self.grad_fn = grad_fn
4949

5050
class AvgPool2d(Module):
5151
def __init__(self, kernel_size: Union[int, tuple[int, int]], stride: Optional[Union[int, tuple[int, int]]] = None, padding: Union[int, tuple[int, int]] = 0):

neunet/nn/layers/batchnorm1d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class _BatchNorm1dTensor(Tensor): # tensor for static backpropagation
1212
def __init__(self, data, args, op, device):
1313
super().__init__(data, args, op, device=device)
1414

15-
def _backward(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, affine, grad):
15+
def grad_fn(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, affine, grad):
1616
X_hat = X_centered * stddev_inv
1717
batch_size = X.data.shape[0]
1818

@@ -38,7 +38,7 @@ def _backward(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, a
3838
weight._apply_grad(grad_weight)
3939
bias._apply_grad(grad_bias)
4040

41-
self._backward = _backward
41+
self.grad_fn = grad_fn
4242

4343

4444
class BatchNorm1d(Module): # layer with static backpropagation

neunet/nn/layers/batchnorm2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class _BatchNorm2dTensor(Tensor): # tensor for static backpropagation
1212
def __init__(self, data, args, op, device):
1313
super().__init__(data, args, op, device=device)
1414

15-
def _backward(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, affine, grad):
15+
def grad_fn(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, affine, grad):
1616
batch_size = X.data.shape[0] * X.data.shape[2] * X.data.shape[3]
1717

1818
axis = (0, 2, 3)
@@ -49,7 +49,7 @@ def _backward(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, a
4949
weight._apply_grad(grad_weight)
5050
bias._apply_grad(grad_bias)
5151

52-
self._backward = _backward
52+
self.grad_fn = grad_fn
5353

5454

5555
class BatchNorm2d(Module): # layer with static backpropagation

neunet/nn/layers/bidirectional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class _BidirectionalTensor(Tensor):
1111
def __init__(self, data, args, op, device):
1212
super().__init__(data, args, op, device=device)
1313

14-
def _backward(D_O: Tensor, R_O: Tensor, merge_mode, grad):
14+
def grad_fn(D_O: Tensor, R_O: Tensor, merge_mode, grad):
1515

1616
if merge_mode == "concat":
1717
direct_grad, reverse_grad = D_O.xp.split(grad, 2, axis=-1)
@@ -25,7 +25,7 @@ def _backward(D_O: Tensor, R_O: Tensor, merge_mode, grad):
2525
D_O._apply_grad(direct_grad)
2626
R_O._apply_grad(reverse_grad)
2727

28-
self._backward = _backward
28+
self.grad_fn = grad_fn
2929

3030

3131
class Bidirectional(Module):

neunet/nn/layers/conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class _Conv2dTensor(Tensor): # tensor for static backpropagation
1313
def __init__(self, data, args, op, device):
1414
super().__init__(data, args, op, device=device)
1515

16-
def _backward(
16+
def grad_fn(
1717
X: Tensor,
1818
weight: Tensor,
1919
bias: Tensor,
@@ -114,7 +114,7 @@ def _backward(
114114
if bias is not None:
115115
bias._apply_grad(grad_bias)
116116

117-
self._backward = _backward
117+
self.grad_fn = grad_fn
118118

119119
class Conv2d(Module): # layer with static backpropagation
120120
"""

neunet/nn/layers/convtranspose2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def prepare_grad(grad, padding, stride, dilated_kernel_size, output_padding):
3333

3434
return unstrided_grad
3535

36-
def _backward(
36+
def grad_fn(
3737
X: Tensor,
3838
weight: Tensor,
3939
bias: Tensor,
@@ -107,7 +107,7 @@ def _backward(
107107
if bias is not None:
108108
bias._apply_grad(grad_bias)
109109

110-
self._backward = _backward
110+
self.grad_fn = grad_fn
111111

112112

113113

neunet/nn/layers/dropout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ class _DropoutTensor(Tensor): # tensor for static backpropagation
66
def __init__(self, data, args, op, device):
77
super().__init__(data, args, op, device=device)
88

9-
def _backward(X: Tensor, mask, grad):
9+
def grad_fn(X: Tensor, mask, grad):
1010
X._apply_grad(grad * mask)
1111

12-
self._backward = _backward
12+
self.grad_fn = grad_fn
1313

1414

1515

neunet/nn/layers/embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ class _EmbeddingTensor(Tensor): # tensor for static backpropagation
1010
def __init__(self, data, args, op, device):
1111
super().__init__(data, args, op, device=device)
1212

13-
def _backward(X: np.ndarray, weight: Tensor, grad):
13+
def grad_fn(X: np.ndarray, weight: Tensor, grad):
1414
axis = list(range(len(X.shape)))
1515
axis[-1], axis[-2] = axis[-2], axis[-1]
1616

1717
weight_grad = weight.xp.matmul(X.transpose(*axis), grad)
1818
weight._apply_grad(weight_grad)
1919

20-
self._backward = _backward
20+
self.grad_fn = grad_fn
2121

2222

2323
class Embedding(Module):

0 commit comments

Comments
 (0)