Skip to content

Commit 8f1bcd4

Browse files
authored
new, port zygote (#14)
* new, port zygote * switch to LBFGS
1 parent 343fdcc commit 8f1bcd4

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using YaoExtensions, Yao
2+
using Test, Random
3+
using Optim: LBFGS, optimize
4+
5+
# port the `Matrix` function to Yao's AD.
6+
include("zygote_patch.jl")
7+
8+
function loss(u, ansatz)
9+
m = Matrix(ansatz)
10+
sum(abs.(u .- m))
11+
end
12+
13+
"""
14+
learn_u4(u::AbstractMatrix; niter=100)
15+
16+
Learn a general U4 gate. The optimizer is LBFGS.
17+
"""
18+
function learn_u4(u::AbstractMatrix; niter=100)
19+
ansatz = general_U4() * put(2, 1=>phase(0.0)) # initial values are 0, here, we attach a global phase.
20+
params = parameters(ansatz)
21+
g!(G, x) = (dispatch!(ansatz, x); G .= gradient(ansatz->loss(u, ansatz), ansatz)[1])
22+
optimize(x->(dispatch!(ansatz, x); loss(u, ansatz)), g!, parameters(ansatz),
23+
LBFGS(), Optim.Options(iterations=niter))
24+
println("final loss = $(loss(u,ansatz))")
25+
return ansatz
26+
end
27+
28+
using Random
29+
Random.seed!(2)
30+
u = rand_unitary(4)
31+
c = learn_u4(u)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
include("zygote_patch.jl")
2+
3+
import YaoExtensions, Random
4+
5+
c = YaoExtensions.variational_circuit(5)
6+
dispatch!(c, :random)
7+
8+
function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where N
9+
#copy(reg) |> circuit
10+
reg = apply!(copy(reg), circuit)
11+
st = state(reg)
12+
sum(real(st.*st))
13+
end
14+
15+
reg0 = zero_state(5)
16+
paramsδ = gradient(c->loss(reg0, c), c)[1]
17+
regδ = gradient(reg->loss(reg, c), reg0)[1]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using Zygote
2+
using Zygote: @adjoint
3+
using Yao, Yao.AD
4+
5+
@adjoint function apply!(reg::ArrayReg, block::AbstractBlock)
6+
out = apply!(reg, block)
7+
out, function (outδ)
8+
(in, inδ), paramsδ = apply_back((out, outδ), block)
9+
return (inδ, paramsδ)
10+
end
11+
end
12+
13+
@adjoint function Matrix(block::AbstractBlock)
14+
out = Matrix(block)
15+
out, function (outδ)
16+
paramsδ = mat_back(block, outδ)
17+
return (paramsδ,)
18+
end
19+
end
20+
21+
@adjoint function ArrayReg{B}(raw::AbstractArray) where B
22+
ArrayReg{B}(raw), adjy->(reshape(adjy.state, size(raw)),)
23+
end
24+
25+
@adjoint function ArrayReg{B}(raw::ArrayReg) where B
26+
ArrayReg{B}(raw), adjy->(adjy,)
27+
end
28+
29+
@adjoint function ArrayReg(raw::AbstractArray)
30+
ArrayReg(raw), adjy->(reshape(adjy.state, size(raw)),)
31+
end
32+
33+
@adjoint function copy(reg::ArrayReg) where B
34+
copy(reg), adjy->(adjy,)
35+
end
36+
37+
@adjoint state(reg::ArrayReg) = state(reg), adjy->(ArrayReg(adjy),)
38+
@adjoint statevec(reg::ArrayReg) = statevec(reg), adjy->(ArrayReg(adjy),)
39+
@adjoint state(reg::AdjointArrayReg) = state(reg), adjy->(ArrayReg(adjy')',)
40+
@adjoint statevec(reg::AdjointArrayReg) = statevec(reg), adjy->(ArrayReg(adjy')',)
41+
@adjoint parent(reg::AdjointArrayReg) = parent(reg), adjy->(adjy',)
42+
@adjoint Base.adjoint(reg::ArrayReg) = Base.adjoint(reg), adjy->(parent(adjy),)

0 commit comments

Comments
 (0)