Efficient Filtering and Fitting of Models Derived from Integro-Difference Equations

Author

Evan Tate Paterson Hughes

: : : {.content-visible unless-format = “pdf”} Index : : :

1 Fitting IDEM using jax_idem

The primary use of the jax_idem package is to fit Integro-difference equation models to data.

Currently, the only supported way to do this is through maximum-likelihood estimation with the kalman/information filter and OPTAX.

2 Simple example; synthetic simple data

We will start by simulating from a simple IDEM with only three time steps. We can quickly make a model using gen_example_idem

Code
import jax
import jax.random as rand
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)
import matplotlib.pyplot as plt

import jaxidem.idem as idem
import jaxidem.utils as utils
import jaxidem.filters as filts
Code
key = jax.random.PRNGKey(2)
keys = rand.split(key, 3)

process_basis = utils.place_basis(nres=2, min_knot_num=3)
#process_basis = utils.place_cosine_basis(N=9)
nbasis = process_basis.nbasis
truemodel = idem.gen_example_idem(
    keys[0], k_spat_inv=True,
    process_basis=process_basis,
    sigma2_eta=0.02**2,
    sigma2_eps=0.05**2,
    beta = jnp.array([2.0,2.0,2.0]),
)

alpha_0 = jnp.zeros(nbasis).at[81].set(10)

process_data, obs_data = truemodel.simulate(nobs=200, T=9, key=keys[1], alpha_0 = alpha_0)

process_data.save_plot("true_process.png")
obs_data.save_plot("synthetic_data.png")

Process

Observations
Figure 1: An example target simulation, with the underlying process (left), and noisy observations (right).

Note that there is one missing process time point that is not plotted here; \(t=0\). In the version of the model used, data is only taken at \(t=1,2,3\), while it is assumed that the process exists from time 0.

2.1 Fitting

Now, after we initialise with a ‘guess’ baseline model, we can use idem.IDEM_model.fit_kalman_filter (recomended for fixed data observation locations) or fit_information_filter to fit the model to the synthetic data.

Code
K_basis = truemodel.kernel.basis

k = (
            jnp.array([100]),
            jnp.array([0.001]),
            jnp.zeros(truemodel.kernel.basis[2].nbasis),
            jnp.zeros(truemodel.kernel.basis[2].nbasis),
)

# This is the kind of kernel used by ```gen_example_idem```
kernel = idem.param_exp_kernel(K_basis, k)

process_basis0 = utils.place_cosine_basis(N=9)

model0 = idem.Model(process_basis = process_basis0,
                   kernel=kernel,
                   process_grid = utils.create_grid(jnp.array([[0, 1], [0, 1]]),
                                                        jnp.array([41, 41])),
                   sigma2_eta = 0.01**2,
                   sigma2_eps = 0.01**2,
                   beta = jnp.array([0.0, 0.0, 0.0]),)

For context, the true values of the kernel parameters are

Code
print(truemodel.kernel.params)

So we’ve chosen a model with high prior variance and no flow, with inaccurate guesses for the spread, diffusion, and variances.

The fitting functions output new IDEM_Model objects, generated using OPTAX to optimise for the likelihood.

Code
# OUT OF DATE

obs_data_wide = obs_data.as_wide()
X_obs = jnp.column_stack([jnp.ones(obs_data_wide['x'].shape[0]), obs_data_wide['x'], obs_data_wide['y']])
X_obs_tuple = [X_obs for _ in range(len(obs_data.z))]

import optax 

model1, params = model0.fit_kalman_filter(obs_data=obs_data,
                                       X_obs=X_obs, 
                                       optimizer=optax.adamax(1e-1),
                                       debug=False,
                                       max_its=200,
                                       eps = 1e-5)
print(model1.kernel.params)
print(truemodel.kernel.params)
Code
ll, ms, Ps, _, _ = model1.kalman_filter(obs_data, X_obs=X_obs)
#ms = jnp.linalg.solve(Qs, nus[..., None]).squeeze(-1)

data = idem.basis_params_to_st_data(ms, model0.process_basis, model0.process_grid)

data.show_plot()
print(ll)
print(model1.beta)