Skip to content

Commit b14253f

Browse files
committed
implement marginals
1 parent 0350a17 commit b14253f

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/belief.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,8 @@ end
110110
# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
111111
function contraction_results(state::BPState{T}) where T
112112
return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in]
113+
end
114+
115+
function marginals(state::BPState{T}) where T
116+
return Dict([v] => normalize!(reduce((x, y) -> x .* y, mi), 1) for (v, mi) in enumerate(state.message_in))
113117
end

test/belief.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,9 @@ end
4949
tnet = TensorNetworkModel(mps_uai)
5050
expected_result = probability(tnet)[]
5151
@test all(r -> isapprox(r, expected_result), contraction_res)
52+
mars = marginals(state)
53+
mars_tnet = marginals(tnet)
54+
for v in 1:TensorInference.num_variables(bp)
55+
@test mars[[v]] mars_tnet[[v]]
56+
end
5257
end

0 commit comments

Comments
 (0)