Skip to content

Commit 1fd57f7

Browse files
add keras3 example
1 parent d8c5d62 commit 1fd57f7

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
- Add benchmark example showcasing new way of implementing matrix product using vmap
1212

13+
- Add keras3 example showcasing integration with tc
14+
1315
## 0.10.0
1416

1517
### Added

examples/keras3_tc_integration.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
keras3 is excellent to use together with tc, we will have unique features including:
3+
1. turn OO paradigm to functional paradigm, i.e. reuse keras layer function in functional programming
4+
2. batch on neural network weights
5+
"""
6+
7+
import os
8+
9+
os.environ["KERAS_BACKEND"] = "jax"
10+
import keras_core as keras
11+
import numpy as np
12+
import optax
13+
import tensorcircuit as tc
14+
15+
K = tc.set_backend("jax")
16+
17+
batch = 8
18+
n = 6
19+
layer = keras.layers.Dense(1, activation="sigmoid")
20+
layer.build([batch, n])
21+
22+
data_x = np.random.choice([0, 1], size=batch * n).reshape([batch, n])
23+
# data_y = np.sum(data_x, axis=-1) % 2
24+
data_y = data_x[:, 0]
25+
data_y = data_y.reshape([batch, 1])
26+
data_x = data_x.astype(np.float32)
27+
data_y = data_y.astype(np.float32)
28+
29+
30+
print("data", data_x, data_y)
31+
32+
33+
def loss(xs, ys, params, weights):
34+
c = tc.Circuit(n)
35+
c.rx(range(n), theta=xs)
36+
c.cx(range(n - 1), range(1, n))
37+
c.rz(range(n), theta=params)
38+
outputs = K.stack([K.real(c.expectation_ps(z=[i])) for i in range(n)])
39+
ypred, _ = layer.stateless_call(weights, [], outputs)
40+
return keras.losses.binary_crossentropy(ypred, ys), ypred
41+
42+
43+
# common data batch practice
44+
vgf = K.jit(
45+
K.vectorized_value_and_grad(
46+
loss, argnums=(2, 3), vectorized_argnums=(0, 1), has_aux=True
47+
)
48+
)
49+
50+
params = K.implicit_randn(shape=[n])
51+
w = K.implicit_randn(shape=[n, 1])
52+
b = K.implicit_randn(shape=[1])
53+
opt = K.optimizer(optax.adam(1e-2))
54+
# seems that currently keras3'optimizer doesn't support nested list of variables
55+
56+
for i in range(100):
57+
(v, yp), gs = vgf(data_x, data_y, params, [w, b])
58+
params, [w, b] = opt.update(gs, (params, [w, b]))
59+
if i % 10 == 0:
60+
print(K.mean(v))
61+
62+
m = keras.metrics.BinaryAccuracy()
63+
m.update_state(data_y, yp[:, None])
64+
print("acc", m.result())
65+
66+
67+
# data batch with batched and quantum neural weights
68+
69+
vgf2 = K.jit(
70+
K.vmap(
71+
K.vectorized_value_and_grad(
72+
loss, argnums=(2, 3), vectorized_argnums=(0, 1), has_aux=True
73+
),
74+
vectorized_argnums=(2, 3),
75+
)
76+
)
77+
78+
wbatch = 4
79+
params = K.implicit_randn(shape=[wbatch, n])
80+
w = K.implicit_randn(shape=[wbatch, n, 1])
81+
b = K.implicit_randn(shape=[wbatch, 1])
82+
opt = K.optimizer(optax.adam(1e-2))
83+
# seems that currently keras3'optimizer doesn't support nested list of variables
84+
85+
for i in range(100):
86+
(v, yp), gs = vgf2(data_x, data_y, params, [w, b])
87+
params, [w, b] = opt.update(gs, (params, [w, b]))
88+
if i % 10 == 0:
89+
print(K.mean(v, axis=-1))
90+
91+
for i in range(wbatch):
92+
m = keras.metrics.BinaryAccuracy()
93+
m.update_state(data_y, yp[0, :, None])
94+
print("acc", m.result())
95+
m.reset_state()

0 commit comments

Comments
 (0)