Skip to content

Commit bd11371

Browse files
committed
autograd, layernorm bug fix
1 parent bb09024 commit bd11371

File tree

4 files changed

+41
-42
lines changed

4 files changed

+41
-42
lines changed

neunet/autograd.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def _reverse_broadcast(self, grad):
10001000

10011001
def backward(
10021002
self, grad=None
1003-
): # grad=self.xp.array(1) # TODO: ASSERT GRAD SHAPE == DATA SHAPE, assert grad.device == self.device
1003+
):
10041004
if not self.requires_grad:
10051005
return
10061006

@@ -1012,36 +1012,28 @@ def backward(
10121012

10131013
self._apply_grad(grad)
10141014
# Perform a topological sort to ensure gradients are calculated in the correct order
1015-
def toposort(v):
1016-
tape = []
1017-
visited_ids = set()
1018-
stack = [v]
1019-
1020-
while stack:
1021-
node: Tensor = stack.pop()
1022-
node_id = id(node)
1023-
1024-
if node_id in visited_ids:
1025-
continue
1026-
1027-
visited_ids.add(node_id)
1028-
1029-
if node.args is not None:
1030-
for child in node.args:
1031-
if not isinstance(child, Tensor):
1032-
continue
1033-
if child.requires_grad is False:
1034-
continue
1035-
stack.append(child)
1036-
1037-
tape.append(node)
1015+
tape = []
1016+
visited_ids = set()
1017+
1018+
def toposort(v, tape: list, visited_ids: set):
1019+
# Topological Sort Using DFS
1020+
if id(v) not in visited_ids:
1021+
visited_ids.add(id(v))
1022+
if v.args is None:
1023+
return
1024+
for child in v.args:
1025+
if not isinstance(child, Tensor):
1026+
continue
1027+
if child.requires_grad is False:
1028+
continue
1029+
1030+
toposort(child, tape, visited_ids)
1031+
tape.append(v)
10381032

10391033
return tape
1040-
1034+
1035+
tape = toposort(self, tape, visited_ids)
10411036
# Apply the backward function in reverse order
1042-
for v in toposort(self):
1037+
for v in reversed(tape):
10431038
v.grad_fn(*v.args, grad=v.grad)
10441039

1045-
# BUGS:
1046-
# grad X - mean not correct with pytorch; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)
1047-
# softmax not equals grads with pytorch; place: div; maybe NOT BUG becase small numbers manipulation (Numerical stability issues)????

neunet/nn/layers/batchnorm1d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def __init__(self, data, args, op, device):
1313
super().__init__(data, args, op, device=device)
1414

1515
def grad_fn(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, affine, grad):
16+
# The method of calculating the derivative is similar to BatchNorm.
17+
# https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html
1618
X_hat = X_centered * stddev_inv
1719
batch_size = X.data.shape[0]
1820

neunet/nn/layers/batchnorm2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def __init__(self, data, args, op, device):
1313
super().__init__(data, args, op, device=device)
1414

1515
def grad_fn(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, affine, grad):
16+
# https://math.stackexchange.com/questions/2359981/batch-normalization-equation-derivation
17+
# https://arxiv.org/pdf/1502.03167
1618
batch_size = X.data.shape[0] * X.data.shape[2] * X.data.shape[3]
1719

1820
axis = (0, 2, 3)

neunet/nn/layers/layernorm.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
from neunet.nn.modules import Module
88
from neunet.nn.parameter import Parameter
99

10-
# class LayerNorm(): #layer with dynamic backpropagation
10+
# class LayerNorm(Module): #layer with dynamic backpropagation
1111
# def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True):
1212
# self.normalized_shape = (normalized_shape, ) if isinstance(normalized_shape, int) else normalized_shape
1313
# self.eps = eps
1414
# self.elementwise_affine = elementwise_affine
1515

1616
# if elementwise_affine:
17-
# self.weight = Tensor(self.xp.ones((normalized_shape)))
18-
# self.bias = Tensor(self.xp.zeros((normalized_shape)))
17+
# self.weight: Union[Tensor, None] = Parameter(neunet.tensor(np.ones((normalized_shape)), dtype=np.float32))
18+
# self.bias: Union[Tensor, None] = Parameter(neunet.tensor(np.zeros((normalized_shape)), dtype=np.float32))
1919
# else:
2020
# self.weight = None
2121
# self.bias = None
2222

23-
# def forward(self, X):
23+
# def forward(self, X: Tensor):
2424
# axis = tuple(range(-len(self.normalized_shape), 0))
2525

2626
# mean = X.mean(axis = axis, keepdims=True)
@@ -46,11 +46,14 @@ def __init__(self, data, args, op, device):
4646
super().__init__(data, args, op, device=device)
4747

4848
def grad_fn(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, axis, elementwise_affine, grad):
49-
# _axis = list(axis) if isinstance(axis, tuple) else axis
49+
# The method of calculating the derivative is similar to BatchNorm.
50+
_axis = list(axis) if isinstance(axis, tuple) else axis
5051
X_hat = X_centered * stddev_inv
5152

5253
weight_data = weight.data if elementwise_affine else 1
53-
weight_size = weight.size if elementwise_affine else 1
54+
# N = X.xp.prod(X.xp.array(X.shape)[_axis]) # Takes up a lot of GPU memory
55+
N = np.prod(np.array(X.shape)[_axis])
56+
5457

5558
dX_hat = weight_data * grad
5659
dstddev_inv = (
@@ -59,26 +62,26 @@ def grad_fn(X: Tensor, weight: Tensor, bias: Tensor, X_centered, stddev_inv, axi
5962
* X.xp.sum(dX_hat * X_centered, axis=axis, keepdims=True)
6063
)
6164
dvar = (
62-
X.xp.ones_like(X.data) * dstddev_inv * 2 * X_centered / weight_size
65+
X.xp.ones_like(X.data) * dstddev_inv * 2 * X_centered / N
6366
) # X.xp.prod(X.xp.array(X.shape)[_axis])
6467
dmean = (
6568
X.xp.ones_like(X.data)
6669
* X.xp.sum(dX_hat * stddev_inv, axis=axis, keepdims=True)
6770
* (-1)
68-
/ weight_size
71+
/ N
6972
) # X.xp.prod(X.xp.array(X.shape)[_axis])
7073
grad_X = dX_hat * stddev_inv + dvar + dmean
7174

72-
# grad_X = (1 / weight_size) * weight_data * stddev_inv * (
73-
# weight_size * grad
75+
# grad_X = (1 / N) * weight_data * stddev_inv * (
76+
# N * grad
7477
# - X.xp.sum(grad, axis = axis, keepdims = True)
7578
# - X_centered * X.xp.power(stddev_inv, 2) * X.xp.sum(grad * X_centered, axis = axis, keepdims = True)
7679
# )
7780

7881
# dX_hat = weight_data * grad
79-
# dvar = X.xp.sum(dX_hat * X_centered, axis = axis, keepdims = True) * (-0.5) * X.xp.power(stddev_inv, 3) * 2 * X_centered / weight_size
80-
# dmean = (X.xp.sum(dX_hat * (-stddev_inv), axis = axis, keepdims = True) + dvar * X.xp.mean(-2.0 * X_centered, axis = axis, keepdims = True)) * X.xp.ones_like(X.data) / weight_size
81-
# grad_X = dX_hat * stddev_inv + dvar + dmean
82+
# dvar = X.xp.sum(dX_hat * X_centered, axis = axis, keepdims = True) * (-0.5) * X.xp.power(stddev_inv, 3)
83+
# dmean = (X.xp.sum(dX_hat * (-stddev_inv), axis = axis, keepdims = True) + dvar * X.xp.mean(-2.0 * X_centered, axis = axis, keepdims = True)) * X.xp.ones_like(X.data) / N
84+
# grad_X = dX_hat * stddev_inv + dvar * 2 * X_centered / N + dmean / N
8285

8386
if elementwise_affine:
8487
grad_weight = X.xp.sum(grad * X_hat, axis=0)

0 commit comments

Comments
 (0)