Skip to content

Commit 3090498

Browse files
update nested_vmap example
1 parent 1ad13b1 commit 3090498

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

examples/nested_vmap_grad.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
with tc.runtime_backend(backend) as K:
1111
L = 2
1212
inputs = K.cast(K.ones([3, 2]), tc.rdtypestr)
13-
weights = K.cast(K.ones([2]), tc.rdtypestr)
13+
weights = K.cast(K.ones([3, 2]), tc.rdtypestr)
1414

1515
def ansatz(thetas, alpha):
1616
c = tc.Circuit(L)
@@ -27,24 +27,42 @@ def f(thetas, alpha):
2727
observables = K.stack([K.real(c.expectation_ps(z=[i])) for i in range(L)])
2828
return K.mean(observables)
2929

30-
# f_vmap = K.vmap(f, vectorized_argnums=0)
31-
32-
print("grad", K.grad(f)(inputs[0], weights))
33-
print("vmap", K.vmap(f)(inputs, weights))
34-
print("vmap over grad", K.vmap(K.grad(f))(inputs, weights))
30+
print("grad_0", K.grad(f)(inputs[0], weights[0]))
31+
print("grad_1", K.grad(f, argnums=1)(inputs[0], weights[0]))
32+
print("vmap_0", K.vmap(f)(inputs, weights[0]))
33+
print("vmap_1", K.vmap(f, vectorized_argnums=1)(inputs[0], weights))
34+
print("vmap over grad_0", K.vmap(K.grad(f))(inputs, weights[0]))
3535
# wrong in tf due to https://github.com/google/TensorNetwork/issues/940
3636
# https://github.com/tensorflow/tensorflow/issues/52148
37-
print("vmap over jacfwd", K.vmap(K.jacfwd(f))(inputs, weights))
38-
print("jacfwd over vmap", K.jacfwd(K.vmap(f))(inputs, weights))
39-
r = K.vmap(K.jacrev(f))(inputs, weights)
40-
print("vmap over jacrev", r)
37+
print("vmap over grad_1", K.vmap(K.grad(f, argnums=1))(inputs, weights[0]))
38+
# wrong in tf
39+
print("vmap over jacfwd_0", K.vmap(K.jacfwd(f))(inputs, weights[0]))
40+
print("jacfwd_0 over vmap", K.jacfwd(K.vmap(f))(inputs, weights[0]))
41+
print("vmap over jacfwd_1", K.vmap(K.jacfwd(f, argnums=1))(inputs, weights[0]))
42+
print("jacfwd_1 over vmap", K.jacfwd(K.vmap(f), argnums=1)(inputs, weights[0]))
43+
r = K.vmap(K.jacrev(f))(inputs, weights[0])
44+
print("vmap over jacrev0", r)
45+
# wrong in tf
46+
r = K.jacrev(K.vmap(f))(inputs, weights[0])
47+
print("jacrev0 over vmap", r)
48+
r = K.vmap(K.jacrev(f, argnums=1))(inputs, weights[0])
49+
print("vmap over jacrev1", r)
50+
# wrong in tf
51+
r = K.jacrev(K.vmap(f), argnums=1)(inputs, weights[0])
52+
print("jacrev1 over vmap", r)
53+
r = K.vmap(K.jacrev(f, argnums=1), vectorized_argnums=1)(inputs[0], weights)
54+
print("vmap1 over jacrev1", r)
55+
r = K.jacrev(K.vmap(f, vectorized_argnums=1), argnums=1)(inputs[0], weights)
56+
print("jacrev1 over vmap1", r)
57+
r = K.vmap(K.hessian(f))(inputs, weights[0])
58+
print("vmap over hess0", r)
4159
# wrong in tf
42-
r = K.jacrev(K.vmap(f))(inputs, weights)
43-
print("jacrev over vmap", r)
44-
r = K.vmap(K.hessian(f))(inputs, weights)
45-
print("vmap over hess", r)
60+
r = K.hessian(K.vmap(f))(inputs, weights[0])
61+
print("hess0 over vmap", r)
62+
r = K.vmap(K.hessian(f, argnums=1))(inputs, weights[0])
63+
print("vmap over hess1", r)
4664
# wrong in tf
47-
r = K.hessian(K.vmap(f))(inputs, weights)
48-
print("hess over vmap", r)
65+
r = K.hessian(K.vmap(f), argnums=1)(inputs, weights[0])
66+
print("hess1 over vmap", r)
4967

5068
# lessons: never put vmap outside gradient function in tf

0 commit comments

Comments
 (0)