11from .quadpotential import quad_potential
2- from .arraystep import ArrayStepShared , SamplerHist , Competence
2+ from .arraystep import ArrayStepShared , ArrayStep , SamplerHist , Competence
33from ..model import modelcontext , Point
44from ..vartypes import continuous_types
5- from .hmc import leapfrog , Hamiltonian , energy , bern
5+ from numpy import exp , log , array
6+ from numpy .random import uniform
7+ from .hmc import leapfrog , Hamiltonian , bern , energy
68from ..tuning import guess_scaling
7- import numpy as np
8- import numpy .random as nr
99import theano
1010from ..theanof import (make_shared_replacements , join_nonshared_inputs , CallableTensor ,
1111 gradient , inputvars )
@@ -26,7 +26,7 @@ class NUTS(ArrayStepShared):
2626 default_blocked = True
2727
2828 def __init__ (self , vars = None , scaling = None , step_scale = 0.25 , is_cov = False , state = None ,
29- max_energy = 1000 ,
29+ Emax = 1000 ,
3030 target_accept = 0.8 ,
3131 gamma = 0.05 ,
3232 k = 0.75 ,
@@ -42,11 +42,10 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
4242 step_scale : float, default=.25
4343 Size of steps to take, automatically scaled down by 1/n**(1/4)
4444 is_cov : bool, default=False
45- Treat C as a covariance matrix/vector if True, else treat it as a
46- precision matrix/vector
45+ Treat C as a covariance matrix/vector if True, else treat it as a precision matrix/vector
4746 state
4847 state to start from
49- max_energy : float, default 1000
48+ Emax : float, default 1000
5049 maximum energy
5150 target_accept : float (0,1) default .8
5251 target for avg accept probability between final branch and initial position
@@ -69,19 +68,27 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
6968 scaling = model .test_point
7069
7170 if isinstance (scaling , dict ):
72- scaling = guess_scaling (Point (scaling , model = model ), model = model , vars = vars )
73- self .step_size = step_scale / scaling .shape [0 ]** 0.25
71+ scaling = guess_scaling (
72+ Point (scaling , model = model ), model = model , vars = vars )
73+
74+ n = scaling .shape [0 ]
75+
76+ self .step_size = step_scale / n ** (1 / 4. )
77+
7478 self .potential = quad_potential (scaling , is_cov , as_cov = False )
79+
7580 if state is None :
7681 state = SamplerHist ()
7782 self .state = state
78- self .max_energy = max_energy
83+ self .Emax = Emax
84+
7985 self .target_accept = target_accept
8086 self .gamma = gamma
8187 self .t0 = t0
8288 self .k = k
83- self .h_bar = 0
84- self .u = np .log (self .step_size * 10 )
89+
90+ self .Hbar = 0
91+ self .u = log (self .step_size * 10 )
8592 self .m = 1
8693
8794 shared = make_shared_replacements (vars , model )
@@ -90,80 +97,97 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
9097
9198 super (NUTS , self ).__init__ (vars , shared , ** kwargs )
9299
93- @ staticmethod
94- def competence ( var ):
95- if var . dtype in continuous_types :
96- return Competence . IDEAL
97- return Competence . INCOMPATIBLE
100+ def astep ( self , q0 ):
101+ # Hamiltonian(self.logp, self.dlogp, self.potential)
102+ H = self . leapfrog1_dE
103+ Emax = self . Emax
104+ e = self . step_size
98105
99- def astep (self , initial_position ):
100- log_slice_var = np .log (nr .uniform ())
101- initial_momentum = self .potential .random ()
102- position = back_position = forward_position = initial_position
103- back_momentum = forward_momentum = initial_momentum
104- should_continue = True
105- trials = 1
106- depth = 0
107- while should_continue :
108- direction = nr .choice ((- 1 , 1 ))
109- step = np .array (direction * self .step_size )
110- new_trials = 0
111- metropolis_acceptance = 0
112- steps = 0
113- for _ in range (2 ** depth ):
114- if not should_continue :
115- break
116- if direction == 1 :
117- forward_position , forward_momentum , energy_change = self .leapfrog1_dE (
118- forward_position , forward_momentum , step ,
119- initial_position , initial_momentum )
120- else :
121- back_position , back_momentum , energy_change = self .leapfrog1_dE (
122- back_position , back_momentum , step , initial_position , initial_momentum )
123- new_trials += int (log_slice_var + energy_change <= 0 )
124- if should_update_position (new_trials , trials ):
125- if direction == 1 :
126- position = forward_position
127- else :
128- position = back_position
129-
130- should_continue = (self ._energy_is_bounded (log_slice_var , energy_change ) and
131- no_u_turns (forward_position , forward_momentum ,
132- back_position , back_momentum ))
133- metropolis_acceptance += min (1. , np .exp (- energy_change ))
134- steps += 1
135- trials += new_trials
136- depth += 1
137- w = 1. / (self .m + self .t0 )
138- self .h_bar = (1 - w ) * self .h_bar + w * (self .target_accept - metropolis_acceptance / steps )
139- self .step_size = np .exp (self .u - (self .m ** 0.5 / self .gamma ) * self .h_bar )
140- self .m += 1
141- return position
106+ p0 = self .potential .random ()
107+ u = uniform ()
108+ q = qn = qp = q0
109+ p = pn = pp = p0
142110
143- def _energy_is_bounded (self , log_slice_var , energy_change ):
144- return log_slice_var + energy_change < self .max_energy
111+ n , s , j = 1 , 1 , 0
145112
113+ while s == 1 :
114+ v = bern (.5 ) * 2 - 1
146115
147- def no_u_turns (forward_position , forward_momentum , back_position , back_momentum ):
148- span = forward_position - back_position
149- return span .dot (back_momentum ) >= 0 and span .dot (forward_momentum ) >= 0
116+ if v == - 1 :
117+ qn , pn , _ , _ , q1 , n1 , s1 , a , na = buildtree (
118+ H , qn , pn , u , v , j , e , Emax , q0 , p0 )
119+ else :
120+ _ , _ , qp , pp , q1 , n1 , s1 , a , na = buildtree (
121+ H , qp , pp , u , v , j , e , Emax , q0 , p0 )
150122
123+ if s1 == 1 and bern (min (1 , n1 * 1. / n )):
124+ q = q1
151125
152- def should_update_position (new_trials , trials ):
153- return bern (float (new_trials ) / max (trials , 1. ))
126+ n = n + n1
154127
128+ span = qp - qn
129+ s = s1 * (span .dot (pn ) >= 0 ) * (span .dot (pp ) >= 0 )
130+ j = j + 1
131+
132+ p = - p
133+
134+ w = 1. / (self .m + self .t0 )
135+ self .Hbar = (1 - w ) * self .Hbar + w * \
136+ (self .target_accept - a * 1. / na )
137+
138+ self .step_size = exp (self .u - (self .m ** .5 / self .gamma ) * self .Hbar )
139+ self .m += 1
140+
141+ return q
142+
143+ @staticmethod
144+ def competence (var ):
145+ if var .dtype in continuous_types :
146+ return Competence .IDEAL
147+ return Competence .INCOMPATIBLE
155148
156- def leapfrog1_dE (logp , vars , shared , quad_potential , profile ):
157- """Computes a theano function that computes one leapfrog step and the energy
158- difference between the beginning and end of the trajectory.
159149
150+ def buildtree (H , q , p , u , v , j , e , Emax , q0 , p0 ):
151+ if j == 0 :
152+ leapfrog1_dE = H
153+ q1 , p1 , dE = leapfrog1_dE (q , p , array (v * e ), q0 , p0 )
154+
155+ n1 = int (log (u ) + dE <= 0 )
156+ s1 = int (log (u ) + dE < Emax )
157+ return q1 , p1 , q1 , p1 , q1 , n1 , s1 , min (1 , exp (- dE )), 1
158+ else :
159+ qn , pn , qp , pp , q1 , n1 , s1 , a1 , na1 = buildtree (
160+ H , q , p , u , v , j - 1 , e , Emax , q0 , p0 )
161+ if s1 == 1 :
162+ if v == - 1 :
163+ qn , pn , _ , _ , q11 , n11 , s11 , a11 , na11 = buildtree (
164+ H , qn , pn , u , v , j - 1 , e , Emax , q0 , p0 )
165+ else :
166+ _ , _ , qp , pp , q11 , n11 , s11 , a11 , na11 = buildtree (
167+ H , qp , pp , u , v , j - 1 , e , Emax , q0 , p0 )
168+
169+ if bern (n11 * 1. / (max (n1 + n11 , 1 ))):
170+ q1 = q11
171+
172+ a1 = a1 + a11
173+ na1 = na1 + na11
174+
175+ span = qp - qn
176+ s1 = s11 * (span .dot (pn ) >= 0 ) * (span .dot (pp ) >= 0 )
177+ n1 = n1 + n11
178+ return qn , pn , qp , pp , q1 , n1 , s1 , a1 , na1
179+ return
180+
181+
182+ def leapfrog1_dE (logp , vars , shared , pot , profile ):
183+ """Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
160184 Parameters
161185 ----------
162186 logp : TensorVariable
163187 vars : list of tensor variables
164188 shared : list of shared variables not to compute leapfrog over
165- quad_potential : quadpotential
166- profile : Boolean
189+ pot : quadpotential
190+ porifle : Boolean
167191
168192 Returns
169193 -------
@@ -175,7 +199,7 @@ def leapfrog1_dE(logp, vars, shared, quad_potential, profile):
175199 logp = CallableTensor (logp )
176200 dlogp = CallableTensor (dlogp )
177201
178- hamiltonian = Hamiltonian (logp , dlogp , quad_potential )
202+ H = Hamiltonian (logp , dlogp , pot )
179203
180204 p = tt .dvector ('p' )
181205 p .tag .test_value = q .tag .test_value
@@ -188,9 +212,11 @@ def leapfrog1_dE(logp, vars, shared, quad_potential, profile):
188212 e = tt .dscalar ('e' )
189213 e .tag .test_value = 1
190214
191- q1 , p1 = leapfrog (hamiltonian , q , p , 1 , e )
192- energy_change = energy (hamiltonian , q1 , p1 ) - energy (hamiltonian , q0 , p0 )
215+ q1 , p1 = leapfrog (H , q , p , 1 , e )
216+ E = energy (H , q1 , p1 )
217+ E0 = energy (H , q0 , p0 )
218+ dE = E - E0
193219
194- f = theano .function ([q , p , e , q0 , p0 ], [q1 , p1 , energy_change ], profile = profile )
220+ f = theano .function ([q , p , e , q0 , p0 ], [q1 , p1 , dE ], profile = profile )
195221 f .trust_input = True
196222 return f
0 commit comments