@@ -484,9 +484,10 @@ def compute_exact_marginals_bruteforce(
484484
485485class BPInference (BaseInference ):
486486
487- def __init__ (self , model ):
487+ def __init__ (self , model , iters = 5 ):
488488 super ().__init__ ()
489489 self .model : ProbabilisticModel = model
490+ self .iters = iters
490491
491492
492493 variables = {}
@@ -565,7 +566,7 @@ def query(self, query, evidence):
565566 factor_eval_list .append (factor_eval )
566567 continue
567568 else :
568- for i , p in enumerate (cpd .parents ):
569+ for i , p in enumerate (cpd .variable . parents ):
569570
570571 if p .distribution is Delta :
571572 emb = embeddings_dict [p .concepts [0 ]] # [B, emb_dim]
@@ -595,7 +596,7 @@ def query(self, query, evidence):
595596
596597 # turn into bidimentional tensor: [B * num_assignments, input_dim]
597598 input = input .view (batch_size * num_assignments , - 1 )
598- evaluation = cpd .parameterization (input )
599+ evaluation = cpd .parametrization (input )
599600
600601 # reshape back to [B, num_assignments, output_dim]
601602 evaluation = evaluation .view (batch_size , num_assignments , - 1 )
@@ -604,40 +605,45 @@ def query(self, query, evidence):
604605 # TODO: We need to turn them into factor evaluations. In each factor, the target variable of the CPD is the first variable in the scope so we can do a simple reshape
605606 # TODO: check that this is the case
606607
607- if cpd .distribution is RelaxedOneHotCategorical :
608+ if cpd .variable . distribution is RelaxedOneHotCategorical :
608609 #TODO: Check that it is concatenating the third dimension into the num_assignments dimension
609- factor_eval = evaluation .view (batch_size , - 1 )
610610
611- elif cpd .distribution is RelaxedBernoulli :
611+ # this is the tensorial equivalent to torch.cat([evaluation[:, :, i] for i in range(evaluation.shape[2])], dim=1)
612+ factor_eval = evaluation .permute (0 , 2 , 1 ).reshape (batch_size , - 1 )
613+
614+ elif cpd .variable .distribution is RelaxedBernoulli :
612615 # Bernoulli output: need to create a factor eval of size 2
613616 prob_1 = evaluation .view (batch_size , - 1 )
614617 prob_0 = 1.0 - prob_1
615618 factor_eval = torch .cat ([prob_0 , prob_1 ], dim = 1 )
616- elif cpd .distribution is Delta :
619+ elif cpd .variable . distribution is Delta :
617620 factor_eval = torch .ones ([batch_size ,1 ], device = evaluation .device )
618621 else :
619622 raise NotImplementedError ("Unknown CPD distribution in CPD2FactorWrapper." )
620623
621624 factor_eval_list .append (factor_eval )
622625
626+ B = batch_size
627+ S = self .metadata ["total_edge_states" ]
628+ E = self .metadata ["E" ]
623629 messages_f2v_init = torch .rand (B , S )
624630
625- edge_id = md ["edge_id_per_state" ] # [S]
631+ edge_id = self . metadata ["edge_id_per_state" ] # [S]
626632 edge_id_b = edge_id .unsqueeze (0 ).expand (B , - 1 ) # [B, S]
627633 sum_per_edge = torch .zeros (B , E )
628634 sum_per_edge .scatter_add_ (1 , edge_id_b , messages_f2v_init )
629635 messages_f2v_init = messages_f2v_init / (sum_per_edge .gather (1 , edge_id_b ) + 1e-20 )
630636
631637 messages_f2v_uncond = messages_f2v_init .clone ()
632- for it in range (num_iters ):
638+ for it in range (self . iters ):
633639 messages_v2f_uncond = update_var_to_factor (
634- messages_f2v_uncond , md , evidence_logmask_vs = None
640+ messages_f2v_uncond , self . metadata , evidence_logmask_vs = None
635641 )
636642 messages_f2v_uncond = update_factor_to_var (
637- messages_v2f_uncond , factor_eval_list , md
643+ messages_v2f_uncond , factor_eval_list , self . metadata
638644 )
639645 bp_marginals_uncond = compute_var_marginals (
640- messages_f2v_uncond , md , evidence_logmask_vs = None
646+ messages_f2v_uncond , self . metadata , evidence_logmask_vs = None
641647 )
642648
643649 return bp_marginals_uncond
0 commit comments