@@ -206,11 +206,12 @@ def warnings(self):
206206
207207
208208# A proposal for the next position
209- Proposal = namedtuple ("Proposal" , "q, q_grad, energy, log_p_accept , logp" )
209+ Proposal = namedtuple ("Proposal" , "q, q_grad, energy, log_p_accept_weighted , logp" )
210210
211211# A subtree of the binary tree built by nuts.
212212Subtree = namedtuple (
213- "Subtree" , "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals"
213+ "Subtree" ,
214+ "left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals" ,
214215)
215216
216217
@@ -243,7 +244,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
243244 )
244245 self .depth = 0
245246 self .log_size = 0
246- self .log_accept_sum = - np .inf
247+ self .log_weighted_accept_sum = - np .inf
247248 self .mean_tree_accept = 0.0
248249 self .n_proposals = 0
249250 self .p_sum = start .p .copy ()
@@ -291,7 +292,9 @@ def extend(self, direction):
291292 self .proposal = tree .proposal
292293
293294 self .log_size = np .logaddexp (self .log_size , tree .log_size )
294- self .log_accept_sum = np .logaddexp (self .log_accept_sum , tree .log_accept_sum )
295+ self .log_weighted_accept_sum = np .logaddexp (
296+ self .log_weighted_accept_sum , tree .log_weighted_accept_sum
297+ )
295298 self .p_sum [:] += tree .p_sum
296299
297300 # Additional turning check only when tree depth > 0 to avoid redundant work
@@ -331,13 +334,17 @@ def _single_step(self, left, epsilon):
331334 # e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
332335 # Saturated Metropolis accept probability with Boltzmann weight
333336 # if h - H0 < 0
334- log_p_accept = - energy_change + min (0.0 , - energy_change )
337+ log_p_accept_weighted = - energy_change + min (0.0 , - energy_change )
335338 log_size = - energy_change
336339 proposal = Proposal (
337- right .q , right .q_grad , right .energy , log_p_accept , right .model_logp
340+ right .q ,
341+ right .q_grad ,
342+ right .energy ,
343+ log_p_accept_weighted ,
344+ right .model_logp ,
338345 )
339346 tree = Subtree (
340- right , right , right .p , proposal , log_size , log_p_accept , 1
347+ right , right , right .p , proposal , log_size , log_p_accept_weighted , 1
341348 )
342349 return tree , None , False
343350 else :
@@ -377,21 +384,23 @@ def _build_subtree(self, left, depth, epsilon):
377384 turning = turning | turning1 | turning2
378385
379386 log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
380- log_accept_sum = np .logaddexp (tree1 .log_accept_sum , tree2 .log_accept_sum )
387+ log_weighted_accept_sum = np .logaddexp (
388+ tree1 .log_weighted_accept_sum , tree2 .log_weighted_accept_sum
389+ )
381390 if logbern (tree2 .log_size - log_size ):
382391 proposal = tree2 .proposal
383392 else :
384393 proposal = tree1 .proposal
385394 else :
386395 p_sum = tree1 .p_sum
387396 log_size = tree1 .log_size
388- log_accept_sum = tree1 .log_accept_sum
397+ log_weighted_accept_sum = tree1 .log_weighted_accept_sum
389398 proposal = tree1 .proposal
390399
391400 n_proposals = tree1 .n_proposals + tree2 .n_proposals
392401
393402 tree = Subtree (
394- left , right , p_sum , proposal , log_size , log_accept_sum , n_proposals
403+ left , right , p_sum , proposal , log_size , log_weighted_accept_sum , n_proposals
395404 )
396405 return tree , diverging , turning
397406
@@ -401,7 +410,9 @@ def stats(self):
401410 # Remove contribution from initial state which is always a perfect
402411 # accept
403412 log_sum_weight = logdiffexp_numpy (self .log_size , 0.0 )
404- self .mean_tree_accept = np .exp (self .log_accept_sum - log_sum_weight )
413+ self .mean_tree_accept = np .exp (
414+ self .log_weighted_accept_sum - log_sum_weight
415+ )
405416 return {
406417 "depth" : self .depth ,
407418 "mean_tree_accept" : self .mean_tree_accept ,
0 commit comments