77from neunet .nn .modules import Module
88from neunet .nn .parameter import Parameter
99
10- # class LayerNorm(): #layer with dynamic backpropagation
10+ # class LayerNorm(Module ): #layer with dynamic backpropagation
1111# def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True):
1212# self.normalized_shape = (normalized_shape, ) if isinstance(normalized_shape, int) else normalized_shape
1313# self.eps = eps
1414# self.elementwise_affine = elementwise_affine
1515
1616# if elementwise_affine:
17- # self.weight = Tensor(self.xp .ones((normalized_shape)))
18- # self.bias = Tensor(self.xp .zeros((normalized_shape)))
17+ # self.weight: Union[Tensor, None] = Parameter(neunet.tensor(np .ones((normalized_shape)), dtype=np.float32 ))
18+ # self.bias: Union[Tensor, None] = Parameter(neunet.tensor(np .zeros((normalized_shape)), dtype=np.float32 ))
1919# else:
2020# self.weight = None
2121# self.bias = None
2222
23- # def forward(self, X):
23+ # def forward(self, X: Tensor ):
2424# axis = tuple(range(-len(self.normalized_shape), 0))
2525
2626# mean = X.mean(axis = axis, keepdims=True)
@@ -46,11 +46,14 @@ def __init__(self, data, args, op, device):
4646 super ().__init__ (data , args , op , device = device )
4747
4848 def grad_fn (X : Tensor , weight : Tensor , bias : Tensor , X_centered , stddev_inv , axis , elementwise_affine , grad ):
49- # _axis = list(axis) if isinstance(axis, tuple) else axis
49+ # The method of calculating the derivative is similar to BatchNorm.
50+ _axis = list (axis ) if isinstance (axis , tuple ) else axis
5051 X_hat = X_centered * stddev_inv
5152
5253 weight_data = weight .data if elementwise_affine else 1
53- weight_size = weight .size if elementwise_affine else 1
54+ # N = X.xp.prod(X.xp.array(X.shape)[_axis]) # Takes up a lot of GPU memory
55+ N = np .prod (np .array (X .shape )[_axis ])
56+
5457
5558 dX_hat = weight_data * grad
5659 dstddev_inv = (
@@ -59,26 +62,26 @@ def grad_fn(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, axi
5962 * X .xp .sum (dX_hat * X_centered , axis = axis , keepdims = True )
6063 )
6164 dvar = (
62- X .xp .ones_like (X .data ) * dstddev_inv * 2 * X_centered / weight_size
65+ X .xp .ones_like (X .data ) * dstddev_inv * 2 * X_centered / N
6366 ) # X.xp.prod(X.xp.array(X.shape)[_axis])
6467 dmean = (
6568 X .xp .ones_like (X .data )
6669 * X .xp .sum (dX_hat * stddev_inv , axis = axis , keepdims = True )
6770 * (- 1 )
68- / weight_size
71+ / N
6972 ) # X.xp.prod(X.xp.array(X.shape)[_axis])
7073 grad_X = dX_hat * stddev_inv + dvar + dmean
7174
72- # grad_X = (1 / weight_size ) * weight_data * stddev_inv * (
73- # weight_size * grad
75+ # grad_X = (1 / N ) * weight_data * stddev_inv * (
76+ # N * grad
7477 # - X.xp.sum(grad, axis = axis, keepdims = True)
7578 # - X_centered * X.xp.power(stddev_inv, 2) * X.xp.sum(grad * X_centered, axis = axis, keepdims = True)
7679 # )
7780
7881 # dX_hat = weight_data * grad
79- # dvar = X.xp.sum(dX_hat * X_centered, axis = axis, keepdims = True) * (-0.5) * X.xp.power(stddev_inv, 3) * 2 * X_centered / weight_size
80- # dmean = (X.xp.sum(dX_hat * (-stddev_inv), axis = axis, keepdims = True) + dvar * X.xp.mean(-2.0 * X_centered, axis = axis, keepdims = True)) * X.xp.ones_like(X.data) / weight_size
81- # grad_X = dX_hat * stddev_inv + dvar + dmean
82+ # dvar = X.xp.sum(dX_hat * X_centered, axis = axis, keepdims = True) * (-0.5) * X.xp.power(stddev_inv, 3)
83+ # dmean = (X.xp.sum(dX_hat * (-stddev_inv), axis = axis, keepdims = True) + dvar * X.xp.mean(-2.0 * X_centered, axis = axis, keepdims = True)) * X.xp.ones_like(X.data) / N
84+ # grad_X = dX_hat * stddev_inv + dvar * 2 * X_centered / N + dmean / N
8285
8386 if elementwise_affine :
8487 grad_weight = X .xp .sum (grad * X_hat , axis = 0 )
0 commit comments