|
1 | | -# # Quantum GAN |
2 | | -using Yao, YaoExtensions |
3 | | -using Yao.ConstGate: P0 |
4 | | -import QuAlgorithmZoo |
5 | | -using Test, Random |
| 1 | +using Yao |
| 2 | +using YaoExtensions: variational_circuit, Sequence, faithful_grad, numdiff |
| 3 | +using QuAlgorithmZoo: Adam, update! |
| 4 | +import Yao: tracedist |
6 | 5 |
|
7 | | -include("QuGANlib.jl") |
| 6 | +""" |
| 7 | +Quantum GAN. |
8 | 8 |
|
9 | | -# ## DATA: Target Wave Function |
10 | | -# here we learn a 3 qubit state |
11 | | -nbit = 3 |
12 | | -target_state = rand_state(nbit) |
| 9 | +Reference: |
| 10 | + Benedetti, M., Grant, E., Wossnig, L., & Severini, S. (2018). Adversarial quantum circuit learning for pure state approximation, 1–14. |
| 11 | +""" |
| 12 | +struct QuGAN{N} |
| 13 | + target::ArrayReg |
| 14 | + generator::AbstractBlock{N} |
| 15 | + discriminator::AbstractBlock |
| 16 | + reg0::ArrayReg |
| 17 | + witness_op::AbstractBlock |
| 18 | + circuit::AbstractBlock |
| 19 | + |
| 20 | + function QuGAN(target::ArrayReg, gen::AbstractBlock, dis::AbstractBlock) |
| 21 | + N = nqubits(target) |
| 22 | + c = Sequence([gen, addbits!(1), dis]) |
| 23 | + witness_op = put(N+1, (N+1)=>ConstGate.P0) |
| 24 | + new{N}(target, gen, dis, zero_state(N), witness_op, c) |
| 25 | + end |
| 26 | +end |
| 27 | + |
| 28 | +# INTERFACES |
| 29 | +circuit(qg::QuGAN) = qg.circuit |
| 30 | +loss(qg::QuGAN) = p0t(qg) - p0g(qg) |
| 31 | + |
| 32 | +function gradient(qg::QuGAN) |
| 33 | + grad_gen = faithful_grad(qg.witness_op, qg.reg0 => qg.circuit) |
| 34 | + grad_tar = faithful_grad(qg.witness_op, qg.target => qg.circuit[2:end]) |
| 35 | + ngen = nparameters(qg.generator) |
| 36 | + [-grad_gen[1:ngen]; grad_tar - grad_gen[ngen+1:end]] |
| 37 | +end |
| 38 | + |
| 39 | +"""probability to get evidense qubit 0 on generation set.""" |
| 40 | +p0g(qg::QuGAN) = expect(qg.witness_op, qg.reg0 => qg.circuit) |> real |
| 41 | +"""probability to get evidense qubit 0 on target set.""" |
| 42 | +p0t(qg::QuGAN) = expect(qg.witness_op, qg.target => qg.circuit[2:end]) |> real |
| 43 | +"""generated wave function""" |
| 44 | +outputψ(qg::QuGAN) = copy(qg.reg0) |> qg.generator |
13 | 45 |
|
14 | | -# ## MODEL: Quantum Circuit and Loss |
15 | | -# using a 4-layer random differential circuit for both generator and discriminator |
16 | | -# we build the qcgan setup. |
| 46 | +"""tracedistance between target and generated wave function""" |
| 47 | +tracedist(qg::QuGAN) = tracedist(qg.target, outputψ(qg))[] |
| 48 | + |
| 49 | +using Test, Random |
| 50 | +Random.seed!(2) |
| 51 | + |
| 52 | +nbit = 3 |
17 | 53 | depth_gen = 4 |
18 | | -generator = dispatch!(variational_circuit(nbit, depth_gen, pair_ring(nbit)), :random) |> autodiff(:QC); |
19 | | - |
20 | | -#------------------------------ |
21 | | -depth_disc = 4 |
22 | | -discriminator = dispatch!(variational_circuit(nbit+1, depth_disc, pair_ring(nbit+1)), :random) |> autodiff(:QC) |
23 | | -qg = QuGAN(target_state, generator, discriminator); |
24 | | - |
25 | | -# ## TRAINING: Gradient Descent |
26 | | -# using a proper learning parameters, we perform 1000 steps of training |
27 | | -g_learning_rate=0.2 |
28 | | -d_learning_rate=0.5 |
29 | | -niter=1000 |
30 | | -for info in QuGANGo!(qg, g_learning_rate, d_learning_rate, niter) |
31 | | - i = info["step"] |
32 | | - (i*20)%niter==0 && println("Step = $i, Trace Distance = $(tracedist(qg)), loss = $(qg |> loss)") |
| 54 | +depth_dis = 4 |
| 55 | + |
| 56 | +# define a QuGAN |
| 57 | +target = rand_state(nbit) |
| 58 | +generator = dispatch!(variational_circuit(nbit, depth_gen), :random) |
| 59 | +discriminator = dispatch!(variational_circuit(nbit+1, depth_dis), :random) |
| 60 | +qg = QuGAN(target, generator, discriminator) |
| 61 | + |
| 62 | +# check the gradient |
| 63 | +grad = gradient(qg) |
| 64 | +numgrad = numdiff(c->loss(qg), qg.circuit) |
| 65 | +@test isapprox(grad, numgrad, atol=1e-4) |
| 66 | + |
| 67 | +# learning rates for the generator and discriminator |
| 68 | +g_lr = 0.2 |
| 69 | +d_lr = 0.5 |
| 70 | +for i=1:300 |
| 71 | + ng = nparameters(qg.generator) |
| 72 | + grad = gradient(qg) |
| 73 | + dispatch!(-, qg.generator, grad[1:ng]*g_lr) |
| 74 | + dispatch!(-, qg.discriminator, -grad[ng+1:end]*d_lr) |
| 75 | + println("Step $i, trace distance = $(tracedist(qg))") |
33 | 76 | end |
| 77 | + |
| 78 | +@test qg |> loss < 0.1 |
0 commit comments