Skip to content

Commit 3c877db

Browse files
committed
First version of BP.
#Many todos # To add, evidence / interventions
1 parent 0761c95 commit 3c877db

File tree

2 files changed

+19
-66
lines changed

2 files changed

+19
-66
lines changed

examples/utilization/1_pgm/2_concept_bottleneck_model_bp/2_concept_bottleneck_model_bp.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def main():
2727
torch.nn.Sigmoid()))
2828
b_cpd = ParametricCPD("b",
2929
parametrization=torch.nn.Sequential(torch.nn.Linear(emb_size, b.size),
30-
torch.nn.Sigmoid()))
30+
torch.nn.Softmax(dim=-1)))
3131
c_cpd = ParametricCPD("c",
3232
parametrization=torch.nn.Sequential(torch.nn.Linear(a.size + b.size, c.size),
3333
torch.nn.Sigmoid()))
@@ -48,58 +48,5 @@ def main():
4848
print(results)
4949
exit()
5050

51-
print("Genotype Predictions (first 5 samples):")
52-
print(results[:, 0][:5])
53-
print("Smoking Predictions (first 5 samples):")
54-
print(results[:, 1][:5])
55-
print("Tar Predictions (first 5 samples):")
56-
print(results[:, 2][:5])
57-
print("Cancer Predictions (first 5 samples):")
58-
print(results[:, 3][:5])
59-
60-
# Original predictions (observational)
61-
original_results = inference_engine.query(
62-
query_concepts=["genotype", "smoking", "tar", "cancer"],
63-
evidence=initial_input
64-
)
65-
66-
# Intervention: Force smoking to 0 (prevent smoking)
67-
smoking_strategy_0 = DoIntervention(
68-
model=concept_model.parametric_cpds,
69-
constants=0.0
70-
)
71-
with intervention(
72-
policies=UniformPolicy(out_features=1),
73-
strategies=smoking_strategy_0,
74-
target_concepts=["smoking"]
75-
):
76-
intervened_results = inference_engine.query(
77-
query_concepts=["genotype", "smoking", "tar", "cancer"],
78-
evidence=initial_input
79-
)
80-
cancer_do_smoking_0 = intervened_results[:, 3]
81-
82-
# Intervention: Force smoking to 1 (promote smoking)
83-
smoking_strategy_1 = DoIntervention(
84-
model=concept_model.parametric_cpds,
85-
constants=1.0
86-
)
87-
with intervention(
88-
policies=UniformPolicy(out_features=1),
89-
strategies=smoking_strategy_1,
90-
target_concepts=["smoking"]
91-
):
92-
intervened_results = inference_engine.query(
93-
query_concepts=["genotype", "smoking", "tar", "cancer"],
94-
evidence=initial_input
95-
)
96-
cancer_do_smoking_1 = intervened_results[:, 3]
97-
98-
ace_cancer_do_smoking = cace_score(cancer_do_smoking_0, cancer_do_smoking_1)
99-
print(f"ACE of smoking on cancer: {ace_cancer_do_smoking:.3f}")
100-
101-
return
102-
103-
10451
if __name__ == "__main__":
10552
main()

examples/utilization/1_pgm/2_concept_bottleneck_model_bp/bp_with_conditional.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,10 @@ def compute_exact_marginals_bruteforce(
484484

485485
class BPInference(BaseInference):
486486

487-
def __init__(self, model):
487+
def __init__(self, model, iters = 5):
488488
super().__init__()
489489
self.model : ProbabilisticModel = model
490+
self.iters = iters
490491

491492

492493
variables = {}
@@ -565,7 +566,7 @@ def query(self, query, evidence):
565566
factor_eval_list.append(factor_eval)
566567
continue
567568
else:
568-
for i, p in enumerate(cpd.parents):
569+
for i, p in enumerate(cpd.variable.parents):
569570

570571
if p.distribution is Delta:
571572
emb = embeddings_dict[p.concepts[0]] # [B, emb_dim]
@@ -595,7 +596,7 @@ def query(self, query, evidence):
595596

596597
# turn into bidimentional tensor: [B * num_assignments, input_dim]
597598
input = input.view(batch_size * num_assignments, -1)
598-
evaluation = cpd.parameterization(input)
599+
evaluation = cpd.parametrization(input)
599600

600601
# reshape back to [B, num_assignments, output_dim]
601602
evaluation = evaluation.view(batch_size, num_assignments, -1)
@@ -604,40 +605,45 @@ def query(self, query, evidence):
604605
# TODO: We need to turn them into factor evaluations. In each factor, the target variable of the CPD is the first variable in the scope so we can do a simple reshape
605606
# TODO: check that this is the case
606607

607-
if cpd.distribution is RelaxedOneHotCategorical:
608+
if cpd.variable.distribution is RelaxedOneHotCategorical:
608609
#TODO: Check that it is concatenating the third dimension into the num_assignments dimension
609-
factor_eval = evaluation.view(batch_size, -1)
610610

611-
elif cpd.distribution is RelaxedBernoulli:
611+
# this is the tensorial equivalent to torch.cat([evaluation[:, :, i] for i in range(evaluation.shape[2])], dim=1)
612+
factor_eval = evaluation.permute(0, 2, 1).reshape(batch_size, -1)
613+
614+
elif cpd.variable.distribution is RelaxedBernoulli:
612615
# Bernoulli output: need to create a factor eval of size 2
613616
prob_1 = evaluation.view(batch_size, -1)
614617
prob_0 = 1.0 - prob_1
615618
factor_eval = torch.cat([prob_0, prob_1], dim=1)
616-
elif cpd.distribution is Delta:
619+
elif cpd.variable.distribution is Delta:
617620
factor_eval = torch.ones([batch_size,1], device=evaluation.device)
618621
else:
619622
raise NotImplementedError("Unknown CPD distribution in CPD2FactorWrapper.")
620623

621624
factor_eval_list.append(factor_eval)
622625

626+
B = batch_size
627+
S = self.metadata["total_edge_states"]
628+
E = self.metadata["E"]
623629
messages_f2v_init = torch.rand(B, S)
624630

625-
edge_id = md["edge_id_per_state"] # [S]
631+
edge_id = self.metadata["edge_id_per_state"] # [S]
626632
edge_id_b = edge_id.unsqueeze(0).expand(B, -1) # [B, S]
627633
sum_per_edge = torch.zeros(B, E)
628634
sum_per_edge.scatter_add_(1, edge_id_b, messages_f2v_init)
629635
messages_f2v_init = messages_f2v_init / (sum_per_edge.gather(1, edge_id_b) + 1e-20)
630636

631637
messages_f2v_uncond = messages_f2v_init.clone()
632-
for it in range(num_iters):
638+
for it in range(self.iters):
633639
messages_v2f_uncond = update_var_to_factor(
634-
messages_f2v_uncond, md, evidence_logmask_vs=None
640+
messages_f2v_uncond, self.metadata, evidence_logmask_vs=None
635641
)
636642
messages_f2v_uncond = update_factor_to_var(
637-
messages_v2f_uncond, factor_eval_list, md
643+
messages_v2f_uncond, factor_eval_list, self.metadata
638644
)
639645
bp_marginals_uncond = compute_var_marginals(
640-
messages_f2v_uncond, md, evidence_logmask_vs=None
646+
messages_f2v_uncond, self.metadata, evidence_logmask_vs=None
641647
)
642648

643649
return bp_marginals_uncond

0 commit comments

Comments
 (0)