Code
import jax
import jax.random as rand
import jax.numpy as jnp
'jax_enable_x64', True)
jax.config.update(import matplotlib.pyplot as plt
import jaxidem.idem as idem
import jaxidem.utils as utils
import jaxidem.filters as filts
Evan Tate Paterson Hughes
: : : {.content-visible unless-format = “pdf”} Index : : :
jax_idem
The primary use of the jax_idem package is to fit Integro-difference equation models to data.
Currently, the only supported way to do this is through maximum-likelihood estimation with the kalman/information filter and OPTAX.
We will start by simulating from a simple IDEM with only three time steps. We can quickly make a model using gen_example_idem
key = jax.random.PRNGKey(2)
keys = rand.split(key, 3)
process_basis = utils.place_basis(nres=2, min_knot_num=3)
#process_basis = utils.place_cosine_basis(N=9)
nbasis = process_basis.nbasis
truemodel = idem.gen_example_idem(
keys[0], k_spat_inv=True,
process_basis=process_basis,
sigma2_eta=0.02**2,
sigma2_eps=0.05**2,
beta = jnp.array([2.0,2.0,2.0]),
)
alpha_0 = jnp.zeros(nbasis).at[81].set(10)
process_data, obs_data = truemodel.simulate(nobs=200, T=9, key=keys[1], alpha_0 = alpha_0)
process_data.save_plot("true_process.png")
obs_data.save_plot("synthetic_data.png")
Note that there is one missing process time point that is not plotted here; \(t=0\). In the version of the model used, data is only taken at \(t=1,2,3\), while it is assumed that the process exists from time 0.
Now, after we initialise with a ‘guess’ baseline model, we can use idem.IDEM_model.fit_kalman_filter
(recomended for fixed data observation locations) or fit_information_filter
to fit the model to the synthetic data.
K_basis = truemodel.kernel.basis
k = (
jnp.array([100]),
jnp.array([0.001]),
jnp.zeros(truemodel.kernel.basis[2].nbasis),
jnp.zeros(truemodel.kernel.basis[2].nbasis),
)
# This is the kind of kernel used by ```gen_example_idem```
kernel = idem.param_exp_kernel(K_basis, k)
process_basis0 = utils.place_cosine_basis(N=9)
model0 = idem.Model(process_basis = process_basis0,
kernel=kernel,
process_grid = utils.create_grid(jnp.array([[0, 1], [0, 1]]),
jnp.array([41, 41])),
sigma2_eta = 0.01**2,
sigma2_eps = 0.01**2,
beta = jnp.array([0.0, 0.0, 0.0]),)
For context, the true values of the kernel parameters are
So we’ve chosen a model with high prior variance and no flow, with inaccurate guesses for the spread, diffusion, and variances.
The fitting functions output new IDEM_Model
objects, generated using OPTAX to optimise for the likelihood.
# OUT OF DATE
obs_data_wide = obs_data.as_wide()
X_obs = jnp.column_stack([jnp.ones(obs_data_wide['x'].shape[0]), obs_data_wide['x'], obs_data_wide['y']])
X_obs_tuple = [X_obs for _ in range(len(obs_data.z))]
import optax
model1, params = model0.fit_kalman_filter(obs_data=obs_data,
X_obs=X_obs,
optimizer=optax.adamax(1e-1),
debug=False,
max_its=200,
eps = 1e-5)
print(model1.kernel.params)
print(truemodel.kernel.params)