Skip to content

Commit 0350a17

Browse files
committed
update
1 parent 0806a5f commit 0350a17

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/belief.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343

4444
function collect_message!(bp::BeliefPropgation, state::BPState)
4545
for it in 1:num_tensors(bp)
46-
_collect_message!(vectors_on_tensor(state.message_out, bp, it), bp.tensors[it], vectors_on_tensor(state.message_in, bp, it))
46+
_collect_message!(vectors_on_tensor(state.message_in, bp, it), bp.tensors[it], vectors_on_tensor(state.message_out, bp, it))
4747
end
4848
end
4949
# collect the vectors associated with the target tensor
@@ -94,13 +94,15 @@ struct BPInfo
9494
iterations::Int
9595
end
9696
function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T
97+
pre_message_in = deepcopy(state.message_in)
9798
for i in 1:max_iter
98-
process_message!(state)
9999
collect_message!(bp, state)
100+
process_message!(state)
100101
# check convergence
101-
if all(iv -> all(it -> isapprox(state.message_out[iv][it], state.message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
102+
if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
102103
return BPInfo(true, i)
103104
end
105+
pre_message_in = deepcopy(state.message_in)
104106
end
105107
return BPInfo(false, max_iter)
106108
end

test/belief.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ end
4343
bp = BeliefPropgation(mps_uai)
4444
@test TensorInference.initial_state(bp) isa TensorInference.BPState
4545
state, info = belief_propagate(bp)
46-
@show TensorInference.contraction_results(state)
4746
@test info.converged
47+
@test info.iterations < 10
48+
contraction_res = TensorInference.contraction_results(state)
4849
tnet = TensorNetworkModel(mps_uai)
49-
@show probability(tnet)[]
50+
expected_result = probability(tnet)[]
51+
@test all(r -> isapprox(r, expected_result), contraction_res)
5052
end

0 commit comments

Comments
 (0)