Skip to content

Commit 0734554

Browse files
committed
new, port zygote
1 parent 343fdcc commit 0734554

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using YaoExtensions, Yao
2+
using Test, Random
3+
using QuAlgorithmZoo: Adam, update!
4+
5+
include("zygote_patch.jl")
6+
7+
function loss(u, ansatz)
8+
m = Matrix(ansatz)
9+
sum(abs.(u .- m))
10+
end
11+
12+
function learn_su4(u::AbstractMatrix; optimizer=Adam(lr=0.1), niter=100)
13+
ansatz = general_U4() * put(2, 1=>phase(0.0)) # initial values are 0, here, we attach a global phase.
14+
params = parameters(ansatz)
15+
for i=1:1000
16+
println("Step = $i, loss = $(loss(u,ansatz))")
17+
grad = gradient(ansatz->loss(u, ansatz), ansatz)[1]
18+
update!(params, grad, optimizer)
19+
dispatch!(ansatz, params)
20+
end
21+
return ansatz
22+
end
23+
24+
using Random
25+
Random.seed!(2)
26+
u = rand_unitary(4)
27+
using LinearAlgebra
28+
#u[:,1] .*= -conj(det(u))
29+
#@show det(u)
30+
c = learn_su4(u; optimizer=Adam(lr=0.005))
31+
det(mat(c))
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
include("zygote_patch.jl")
2+
3+
import YaoExtensions, Random
4+
5+
c = YaoExtensions.variational_circuit(5)
6+
dispatch!(c, :random)
7+
8+
function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where N
9+
#copy(reg) |> circuit
10+
reg = apply!(copy(reg), circuit)
11+
st = state(reg)
12+
sum(real(st.*st))
13+
end
14+
15+
reg0 = zero_state(5)
16+
paramsδ = gradient(c->loss(reg0, c), c)[1]
17+
regδ = gradient(reg->loss(reg, c), reg0)[1]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using Zygote
2+
using Zygote: @adjoint
3+
using Yao, Yao.AD
4+
5+
@adjoint function apply!(reg::ArrayReg, block::AbstractBlock)
6+
out = apply!(reg, block)
7+
out, function (outδ)
8+
(in, inδ), paramsδ = apply_back((out, outδ), block)
9+
return (inδ, paramsδ)
10+
end
11+
end
12+
13+
@adjoint function Matrix(block::AbstractBlock)
14+
out = Matrix(block)
15+
out, function (outδ)
16+
paramsδ = mat_back(block, outδ)
17+
return (paramsδ,)
18+
end
19+
end
20+
21+
@adjoint function ArrayReg{B}(raw::AbstractArray) where B
22+
ArrayReg{B}(raw), adjy->(reshape(adjy.state, size(raw)),)
23+
end
24+
25+
@adjoint function ArrayReg{B}(raw::ArrayReg) where B
26+
ArrayReg{B}(raw), adjy->(adjy,)
27+
end
28+
29+
@adjoint function ArrayReg(raw::AbstractArray)
30+
ArrayReg(raw), adjy->(reshape(adjy.state, size(raw)),)
31+
end
32+
33+
@adjoint function copy(reg::ArrayReg) where B
34+
copy(reg), adjy->(adjy,)
35+
end
36+
37+
@adjoint state(reg::ArrayReg) = state(reg), adjy->(ArrayReg(adjy),)
38+
@adjoint statevec(reg::ArrayReg) = statevec(reg), adjy->(ArrayReg(adjy),)
39+
@adjoint state(reg::AdjointArrayReg) = state(reg), adjy->(ArrayReg(adjy')',)
40+
@adjoint statevec(reg::AdjointArrayReg) = statevec(reg), adjy->(ArrayReg(adjy')',)
41+
@adjoint parent(reg::AdjointArrayReg) = parent(reg), adjy->(adjy',)
42+
@adjoint Base.adjoint(reg::ArrayReg) = Base.adjoint(reg), adjy->(parent(adjy),)

0 commit comments

Comments
 (0)