From c8a60f3ae23232700ee2e834688c931c35b747fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=94=A4=E6=B5=B7?= Date: Wed, 2 Jul 2025 10:15:28 +0800 Subject: [PATCH 1/2] Add conj with cost_and_gradient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 周唤海 --- src/belief.jl | 2 +- src/mar.jl | 1 + test/belief.jl | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/belief.jl b/src/belief.jl index eede996..be23506 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -80,7 +80,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve # TODO: speed up if needed! code = star_code(length(vectors_in)) cost, gradient = cost_and_gradient(code, (t, vectors_in...)) - for (o, g) in zip(vectors_out, gradient[2:end]) + for (o, g) in zip(vectors_out, conj.(gradient[2:end])) o .= g end return cost[] diff --git a/src/mar.jl b/src/mar.jl index 3e399b4..0eaa154 100644 --- a/src/mar.jl +++ b/src/mar.jl @@ -78,6 +78,7 @@ probabilities of the queried variables, represented by tensors. function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}} # sometimes, the cost can overflow, then we need to rescale the tensors during contraction. cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,)) + grads = conj.(grads) @debug "cost = $cost" ixs = OMEinsum.getixsv(tn.code) queryvars = ixs[tn.unity_tensors_idx] diff --git a/test/belief.jl b/test/belief.jl index 150c302..ca20fdb 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -46,7 +46,7 @@ end @testset "belief propagation" begin n = 5 chi = 3 - mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi) + mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi) bp = BeliefPropgation(mps_uai) @test TensorInference.initial_state(bp) isa TensorInference.BPState state, info = belief_propagate(bp) @@ -63,7 +63,7 @@ end @testset "belief propagation on circle" begin n = 10 chi = 3 - mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true) + mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge bp = BeliefPropgation(mps_uai) @test TensorInference.initial_state(bp) isa TensorInference.BPState state, info = belief_propagate(bp; max_iter=100, tol=1e-6) From 04b72156abaa01b99ca8570872ffac7de2d6254c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=94=A4=E6=B5=B7?= Date: Wed, 2 Jul 2025 10:55:45 +0800 Subject: [PATCH 2/2] Fix test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 周唤海 --- test/belief.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/belief.jl b/test/belief.jl index ca20fdb..96ad974 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -74,7 +74,9 @@ end mars = marginals(state) mars_tnet = marginals(tnet) for v in 1:TensorInference.num_variables(bp) - @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-4 + gauge = mars[[v]] ./ mars_tnet[[v]] + @test all(gauge .≈ gauge[1]) + @test mars[[v]] ≈ gauge[1] .* mars_tnet[[v]] atol=1e-4 end end