@@ -12,8 +12,15 @@ kernelspec:
1212
1313+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "slide"}}
1414
15+ (Euler-Maruyama_and_SDEs)=
1516# Inferring parameters of SDEs using a Euler-Maruyama scheme
1617
18+ :::{post} July 2016
19+ :tags: time series
20+ :category: advanced, reference
21+ :author: @maedoc
22+ :::
23+
1724_ This notebook is derived from a presentation prepared for the Theoretical Neuroscience Group, Institute of Systems Neuroscience at Aix-Marseile University._
1825
1926``` {code-cell} ipython3
@@ -25,12 +32,20 @@ run_control:
2532slideshow:
2633 slide_type: '-'
2734---
35+ import warnings
36+
2837import arviz as az
2938import matplotlib.pyplot as plt
3039import numpy as np
3140import pymc as pm
3241import pytensor.tensor as pt
3342import scipy as sp
43+
44+ # Ignore UserWarnings
45+ warnings.filterwarnings("ignore", category=UserWarning)
46+
47+ RANDOM_SEED = 8927
48+ np.random.seed(RANDOM_SEED)
3449```
3550
3651``` {code-cell} ipython3
@@ -40,13 +55,15 @@ az.style.use("arviz-darkgrid")
4055
4156+++ {"button": false, "nbpresent": {"id": "2325c7f9-37bd-4a65-aade-86bee1bff5e3"}, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "slide"}}
4257
43- ## Toy model 1
58+ ## Example Model
4459
4560Here's a scalar linear SDE in symbolic form
4661
4762$ dX_t = \lambda X_t + \sigma^2 dW_t $
4863
49- discretized with the Euler-Maruyama scheme
64+ discretized with the Euler-Maruyama scheme.
65+
66+ We can simulate data from this process and then attempt to recover the parameters.
5067
5168``` {code-cell} ipython3
5269---
@@ -87,16 +104,19 @@ run_control:
87104slideshow:
88105 slide_type: subslide
89106---
90- plt.figure(figsize=(10, 3))
91- plt.plot(x_t[:30], "k", label="$x(t)$", alpha=0.5)
92- plt.plot(z_t[:30], "r", label="$z(t)$", alpha=0.5)
93- plt.title("Transient")
94- plt.legend()
95- plt.subplot(122)
96- plt.plot(x_t[30:], "k", label="$x(t)$", alpha=0.5)
97- plt.plot(z_t[30:], "r", label="$z(t)$", alpha=0.5)
98- plt.title("All time")
99- plt.legend();
107+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))
108+
109+ ax1.plot(x_t[:30], "k", label="$x(t)$", alpha=0.5)
110+ ax1.plot(z_t[:30], "r", label="$z(t)$", alpha=0.5)
111+ ax1.set_title("Transient")
112+ ax1.legend()
113+
114+ ax2.plot(x_t[30:], "k", label="$x(t)$", alpha=0.5)
115+ ax2.plot(z_t[30:], "r", label="$z(t)$", alpha=0.5)
116+ ax2.set_title("All time")
117+ ax2.legend()
118+
119+ plt.tight_layout()
100120```
101121
102122+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
@@ -105,7 +125,7 @@ What is the inference we want to make? Since we've made a noisy observation of t
105125
106126+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
107127
108- First, we rewrite our SDE as a function returning a tuple of the drift and diffusion coefficients
128+ We need to provide an SDE function that returns the drift and diffusion coefficients.
109129
110130``` {code-cell} ipython3
111131---
@@ -114,13 +134,13 @@ new_sheet: false
114134run_control:
115135 read_only: false
116136---
117- def lin_sde(x, lam):
137+ def lin_sde(x, lam, s2 ):
118138 return lam * x, s2
119139```
120140
121141+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
122142
123- Next, we describe the probability model as a set of three stochastic variables, ` lam ` , ` xh ` , and ` zh ` :
143+ The probability model is comprised of a prior on the drift parameter ` lam ` , the diffusion coefficient ` s ` , the latent Euler-Maruyama process ` xh ` and the likelihood describing the noisy observations ` zh ` . We will assume that we know the observation noise.
124144
125145``` {code-cell} ipython3
126146---
@@ -135,19 +155,20 @@ slideshow:
135155---
136156with pm.Model() as model:
137157 # uniform prior, but we know it must be negative
138- l = pm.Flat("l")
158+ l = pm.HalfCauchy("l", beta=1)
159+ s = pm.Uniform("s", 0.005, 0.5)
139160
140161 # "hidden states" following a linear SDE distribution
141162 # parametrized by time step (det. variable) and lam (random variable)
142- xh = pm.EulerMaruyama("xh", dt=dt, sde_fn=lin_sde, sde_pars=(l, ), shape=N)
163+ xh = pm.EulerMaruyama("xh", dt=dt, sde_fn=lin_sde, sde_pars=(-l, s**2 ), shape=N, initval=x_t )
143164
144165 # predicted observation
145166 zh = pm.Normal("zh", mu=xh, sigma=5e-3, observed=z_t)
146167```
147168
148169+++ {"button": false, "nbpresent": {"id": "287d10b5-0193-4ffe-92a7-362993c4b72e"}, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
149170
150- Once the model is constructed, we perform inference, i.e. sample from the posterior distribution, in the following steps:
171+ Once the model is constructed, we perform inference, which here is via the NUTS algorithm as implemented in ` nutpie ` , which will be extremely fast.
151172
152173``` {code-cell} ipython3
153174---
@@ -157,7 +178,7 @@ run_control:
157178 read_only: false
158179---
159180with model:
160- trace = pm.sample()
181+ trace = pm.sample(nuts_sampler="nutpie", random_seed=RANDOM_SEED, target_accept=0.99 )
161182```
162183
163184+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
@@ -176,7 +197,7 @@ plt.plot(x_t, "r", label="$x(t)$")
176197plt.legend()
177198
178199plt.subplot(122)
179- plt.hist(az.extract(trace.posterior)["l"], 30, label=r"$\hat{\lambda}$", alpha=0.5)
200+ plt.hist(-1 * az.extract(trace.posterior)["l"], 30, label=r"$\hat{\lambda}$", alpha=0.5)
180201plt.axvline(lam, color="r", label=r"$\lambda$", alpha=0.5)
181202plt.legend();
182203```
@@ -209,148 +230,24 @@ plt.legend();
209230
210231+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
211232
212- Note that
213-
214- - inference also estimates the initial conditions
215- - the observed data $z(t)$ lies fully within the 95% interval of the PPC.
216- - there are many other ways of evaluating fit
217-
218- +++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "slide"}}
219-
220- ### Toy model 2
221-
222- As the next model, let's use a 2D deterministic oscillator,
223- \begin{align}
224- \dot{x} &= \tau (x - x^3/3 + y) \\
225- \dot{y} &= \frac{1}{\tau} (a - x)
226- \end{align}
233+ Note that the initial conditions are also estimated, and that most of the observed data $z(t)$ lies within the 95% interval of the PPC.
227234
228- with noisy observation $z(t) = m x + (1 - m) y + N(0, 0.05)$ .
235+ Another approach is to look at draws from the sampling distribution of the data relative to the observed data. This too shows a good fit across the range of observations -- the posterior predictive mean almost perfectly tracks the data .
229236
230237``` {code-cell} ipython3
231- N, tau, a, m, s2 = 200, 3.0, 1.05, 0.2, 1e-1
232- xs, ys = [0.0], [1.0]
233- for i in range(N):
234- x, y = xs[-1], ys[-1]
235- dx = tau * (x - x**3.0 / 3.0 + y)
236- dy = (1.0 / tau) * (a - x)
237- xs.append(x + dt * dx + np.sqrt(dt) * s2 * np.random.randn())
238- ys.append(y + dt * dy + np.sqrt(dt) * s2 * np.random.randn())
239- xs, ys = np.array(xs), np.array(ys)
240- zs = m * xs + (1 - m) * ys + np.random.randn(xs.size) * 0.1
241-
242- plt.figure(figsize=(10, 2))
243- plt.plot(xs, label="$x(t)$")
244- plt.plot(ys, label="$y(t)$")
245- plt.plot(zs, label="$z(t)$")
246- plt.legend()
238+ az.plot_ppc(trace)
247239```
248240
249- +++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
241+ ## Authors
242+ - Authored by @maedoc in July 2016
243+ - Updated to PyMC v5 by @fonnesbeck in September 2024
250244
251- Now, estimate the hidden states $x(t)$ and $y(t)$, as well as parameters $\tau$, $a$ and $m$.
245+ +++
252246
253- As before, we rewrite our SDE as a function returned drift & diffusion coefficients:
254-
255- ``` {code-cell} ipython3
256- ---
257- button: false
258- new_sheet: false
259- run_control:
260- read_only: false
261- ---
262- def osc_sde(xy, tau, a):
263- x, y = xy[:, 0], xy[:, 1]
264- dx = tau * (x - x**3.0 / 3.0 + y)
265- dy = (1.0 / tau) * (a - x)
266- dxy = pt.stack([dx, dy], axis=0).T
267- return dxy, s2
268- ```
269-
270- +++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
271-
272- As before, the Euler-Maruyama discretization of the SDE is written as a prediction of the state at step $i+1$ based on the state at step $i$.
273-
274- +++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
275-
276- We can now write our statistical model as before, with uninformative priors on $\tau$, $a$ and $m$:
277-
278- ``` {code-cell} ipython3
279- ---
280- button: false
281- new_sheet: false
282- run_control:
283- read_only: false
284- ---
285- xys = np.c_[xs, ys]
286-
287- with pm.Model() as model:
288- tau_h = pm.Uniform("tau_h", lower=0.1, upper=5.0)
289- a_h = pm.Uniform("a_h", lower=0.5, upper=1.5)
290- m_h = pm.Uniform("m_h", lower=0.0, upper=1.0)
291- xy_h = pm.EulerMaruyama(
292- "xy_h", dt=dt, sde_fn=osc_sde, sde_pars=(tau_h, a_h), shape=xys.shape, initval=xys
293- )
294- zh = pm.Normal("zh", mu=m_h * xy_h[:, 0] + (1 - m_h) * xy_h[:, 1], sigma=0.1, observed=zs)
295- ```
296-
297- ``` {code-cell} ipython3
298- pm.__version__
299- ```
300-
301- ``` {code-cell} ipython3
302- ---
303- button: false
304- new_sheet: false
305- run_control:
306- read_only: false
307- ---
308- with model:
309- trace = pm.sample(2000, tune=1000)
310- ```
311-
312- +++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
313-
314- Again, the result is a set of samples from the posterior, including our parameters of interest but also the hidden states
315-
316- ``` {code-cell} ipython3
317- ---
318- button: false
319- new_sheet: false
320- run_control:
321- read_only: false
322- ---
323- figure(figsize=(10, 6))
324- subplot(211)
325- plot(percentile(trace[xyh][..., 0], [2.5, 97.5], axis=0).T, "k", label=r"$\hat{x}_{95\%}(t)$")
326- plot(xs, "r", label="$x(t)$")
327- legend(loc=0)
328- subplot(234), hist(trace["τh"]), axvline(τ), xlim([1.0, 4.0]), title("τ")
329- subplot(235), hist(trace["ah"]), axvline(a), xlim([0, 2.0]), title("a")
330- subplot(236), hist(trace["mh"]), axvline(m), xlim([0, 1]), title("m")
331- tight_layout()
332- ```
333-
334- +++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
335-
336- Again, we can perform a posterior predictive check, that our data are likely given the fit model
337-
338- ``` {code-cell} ipython3
339- ---
340- button: false
341- new_sheet: false
342- run_control:
343- read_only: false
344- ---
345- # generate trace from posterior
346- ppc_trace = pm.sample_posterior_predictive(trace, model=model)
347-
348- # plot with data
349- figure(figsize=(10, 3))
350- plot(percentile(ppc_trace["zh"], [2.5, 97.5], axis=0).T, "k", label=r"$z_{95\% PP}(t)$")
351- plot(zs, "r", label="$z(t)$")
352- legend()
353- ```
247+ ## References
248+ :::{bibliography}
249+ :filter: docname in docnames
250+ :::
354251
355252``` {code-cell} ipython3
356253%load_ext watermark
0 commit comments