Filtering in JAX-IDEM

Author

Evan Tate Paterson Hughes

Index

1 A Simple Example

Although the project is primarily for Integro-Difference equation models, the file filter_and_smoother_functions provides functions applicable for any discrete dynamical system. This page provides a very simple example to test the kalman filter, the information filter, and fitting with the corresponding likelihoods.

1.1 The simple model

Consider the simple system, for \(t=1,\dots,T\)

\[\begin{split} \vec\alpha_{t+1} &= M\vec\alpha_t + \vec\eta_t,\\ \end{split} \tag{1}\]

where \(\alpha_0 = (1,1)^\intercal\) and

\[\begin{split} M = \left[\begin{matrix} \cos(0.3) & -\sin(0.3)\\ \sin(0.3) & \sin(0.3) \end{matrix}\right]. \end{split} \tag{2}\]

The error terms are mutually independant and have variances \(\sigma^{2}_\epsilon=0.02\) and \(\sigma^{2}_{\eta}=0.03\) and \(\vec z_t\) are transformed linear ‘observations’ of \(\vec\alpha\)

\[\begin{split} \vec z_t &= \Phi \vec\alpha_t + \vec\epsilon_t,\\ \Phi &= \left[\begin{matrix} 1 & 0 \\ 0.6 & 0.4\\ 0.4 & 0.6 \end{matrix}\right] \end{split}. \tag{3}\]

The process, \(\alpha\), simply spins in a circle with some noise. Lets simulate from this system;

Code
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64",False)

import jaxidem.filters as filt

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio
import plotly.graph_objs as go

import pandas as pd
import pickle
Code
key = jax.random.PRNGKey(1)

alpha_0 = jnp.ones(2)  # 2D, easily plottable
M = jnp.array([[jnp.cos(0.3), -jnp.sin(0.3)],
              [jnp.sin(0.3), jnp.cos(0.3)]])  # spinny

alphas = [alpha_0]
zs = []

T = 50
keys = jax.random.split(key, T*2)

sigma2_eta = jnp.array(0.001)
sigma2_eps = jnp.array(0.01)

PHI = jnp.array([[1, 0], [0.6, 0.4], [0.4, 0.6]])
Code
for i in range(T):
    alphas.append(M @ alphas[i] + jnp.sqrt(sigma2_eta)*jax.random.normal(keys[2*i], shape=(2,)))
    zs.append(PHI @ alphas[i+1] + jnp.sqrt(sigma2_eps)*jax.random.normal(keys[2*i+1], shape=(3,)))


alphas = jnp.array(alphas)
zs_tree = zs
    
alphas_df = pd.DataFrame(alphas, columns = ["x", "y"])
zs_df = pd.DataFrame(zs, columns = ["x", "y", "z"])

with open('./pickles/alphas.pkl', 'wb') as file:
    pickle.dump((alphas, zs_tree, alphas_df, zs_df), file)
Code
with open('./pickles/alphas.pkl', 'rb') as file: 
    alphas, zs_tree, alphas_df, zs_df = pickle.load(file)
Code
fig1 = px.line(alphas_df, x='x', y='y', height=200)
fig1.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
(a)
(b)
Figure 1
Code
fig2 = go.Figure(data=[go.Scatter3d(
    x=zs_df['x'],
    y=zs_df['y'],
    z=zs_df['z'],
    mode='markers',
    marker=dict(
        symbol='cross',  # Change marker to cross
        size=5           # Adjust marker size
    )
)])

fig2.update_layout(height=200)
    
fig2.show()
Figure 2

We can see how the process is an odd random spiral, and the observations are skewed noisy observations of this in 3D space

With filtering, we aim to recover the process {fig-truth} from the observations {fig-obs}. We do this with two ‘forms’ of the filter, which should be equivalent.

2 Kalman filter

Code
m_0 = jnp.zeros(2)
P_0 = 100*jnp.eye(2)

# We need to give the dimensions of both sigma2_eps and sigma2_eta.
# Since our errors are all iid, this is 0
sigma2_eps_dim = sigma2_eps.ndim
sigma2_eta_dim = sigma2_eta.ndim


filt_results = filt.kalman_filter(m_0,
                                  P_0,
                                  M,
                                  PHI,
                                  sigma2_eta,
                                  sigma2_eps,
                                  zs_tree,
                                  sigma2_eps_dim = sigma2_eps_dim,
                                  sigma2_eta_dim = sigma2_eta_dim,
                                  forecast=0,
                                  likelihood='full'
    )
    
ms_df = pd.DataFrame(list(filt_results['ms']), columns = ["x", "y"])

combined_df = pd.concat([alphas_df.assign(line='True Process'), ms_df.assign(line='Filtered Process Means')])

# Creating the line plot with custom colors
fig3 = px.line(combined_df, x='x', y='y', color='line',
               color_discrete_sequence=['blue', 'red'], height=200)  # Specify colors here

fig3.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
# Show the plot
fig3.show()
Figure 3

2.1 Pseudo-information filter?

This is a test.

Code
i = PHI.T @ zs / sigma2_eps
I = PHI.T @ PHI / sigma2_eps

filt_results_alt = filt.kalman_filter(m_0,
                                  P_0,
                                  M,
                                  I,
                                  sigma2_eta,
                                  I,
                                  i,
                                  sigma2_eps_dim = 2,
                                  sigma2_eta_dim = sigma2_eta_dim,
                                  forecast=0,
                                  likelihood='full'
    )
ms_df_alt = pd.DataFrame(list(filt_results_alt['ms']), columns = ["x", "y"])

combined_df = pd.concat([alphas_df.assign(line='True Process'), ms_df_alt.assign(line='Filtered Process Means')])

# Creating the line plot with custom colors
fig_alt = px.line(combined_df, x='x', y='y', color='line',
               color_discrete_sequence=['blue', 'red'], height=200)  # Specify colors here

fig_alt.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
# Show the plot
fig_alt.show()
print(filt_results_alt["ll"])
Figure 4

Let’s check if they really are the same;

Code
#print(jnp.allclose(filt_results['ms'], filt_results_alt['ms'], atol=1e-04))
#print(jnp.allclose(filt_results['Ps'], filt_results_alt['Ps'], atol=1e-04))
print(filt_results['ms'][0:5])
#print(filt_results_alt['ms'][0:5])
[[ 0.6675799   1.157503  ]
 [ 0.2237205   1.3271735 ]
 [-0.21794711  1.3617302 ]
 [-0.49495745  1.2157357 ]
 [-0.8105154   1.0334284 ]]

3 Information Filter

Similarily, we can apply the information filter. Since the information filter function has support for time varying observation shapes, there is a little more work to do.

Code
Q_0 = 0.01 * jnp.eye(2) 
nu_0 = jnp.zeros(2)

PHI_tuple = tuple([PHI for _ in range(T)])
zs_tuple = zs_tree

filt_results2  = filt.information_filter(nu_0,
                                         Q_0,
                                         M,
                                         PHI_tuple,
                                         sigma2_eta,
                                         [sigma2_eps for _ in range(T)],
                                         zs_tuple,
                                         sigma2_eps_dim = sigma2_eps_dim,
                                         sigma2_eta_dim = sigma2_eta_dim,
                                         forecast = 0,
                                         likelihood='full'
                                         )

ms2 = jnp.linalg.solve(filt_results2['Qs'], filt_results2['nus'][..., None]).squeeze(-1)

ms2_df = pd.DataFrame(list(ms2), columns = ["x", "y"])

combined_df_2 = pd.concat([alphas_df.assign(line='True Process'), ms2_df.assign(line='Filtered Process Means')])

# Creating the line plot with custom colors
fig4 = px.line(combined_df_2, x='x', y='y', color='line',
               color_discrete_sequence=['blue', 'green'], height=200)  # Specify colors here

fig4.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
# Show the plot
fig4.show()
Figure 5

We can clearly see that the green and red lines are the same; if plotted over eachother, they overlap. They only differ slightly due to numerical error.

We still need to look at the log-likelihoods that were outputted;

Code
print(filt_results['ll'])
print(filt_results2['ll'])
97.97394
97.974464

4 Square Root Kalman filter

Code
m_0 = jnp.zeros(2)
U_0 = 10*jnp.eye(2)

# Since we have independant errors, we can use the faster sqrt_filter_indep.
    
filt_results3 = filt.sqrt_filter(m_0,
                                         U_0,
                                         M,
                                         PHI,
                                         sigma2_eta,
                                         sigma2_eps,
                                         zs_tree,
                                         sigma2_eps_dim = sigma2_eps_dim,
                                         sigma2_eta_dim = sigma2_eta_dim,
                                         forecast=0,
                                         likelihood='full')
    
ms3_df = pd.DataFrame(list(filt_results3['ms']), columns = ["x", "y"])

combined3_df = pd.concat([alphas_df.assign(line='True Process'), ms3_df.assign(line='Filtered Process Means')])

# Creating the line plot with custom colors
fig5 = px.line(combined3_df, x='x', y='y', color='line',
               color_discrete_sequence=['blue', 'red'], height=200)  # Specify colors here

fig5.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
# Show the plot
fig5.show()
Figure 6

5 Square Root Information filter

Code
R_0 = 0.1*jnp.eye(2)
    
filt_results4  = filt.sqrt_information_filter(nu_0,
                                              R_0,
                                              M,
                                              PHI_tuple,
                                              sigma2_eta,
                                              [sigma2_eps for _ in range(T)],
                                              zs_tuple,
                                              sigma2_eps_dim = sigma2_eps_dim,
                                              sigma2_eta_dim = sigma2_eta_dim,
                                              forecast=0,
                                              likelihood='full'
                                              )

Qs2 = jnp.matmul(jnp.transpose(filt_results4['Rs'], (0, 2, 1)), filt_results4['Rs'])

ms4 = jnp.linalg.solve(Qs2, filt_results4['nus'][..., None]).squeeze(-1)
 
ms4_df = pd.DataFrame(list(ms4), columns = ["x", "y"])

combined_df_4 = pd.concat([alphas_df.assign(line='True Process'), ms4_df.assign(line='Filtered Process Means')])

# Creating the line plot with custom colors
fig4 = px.line(combined_df_4, x='x', y='y', color='line',
               color_discrete_sequence=['blue', 'green'], height=200)  # Specify colors here

fig4.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
# Show the plot
fig4.show()
Figure 7
Code
print(filt_results['ll'])
print(filt_results2['ll'])
print(filt_results3['ll'])
print(filt_results4['ll'])
97.97394
97.974464
97.97452
97.97438

6 Fitting

Let’s ‘forget’ some of the details of the original model; the propegation matrix \(M\) and the variances \(\sigma^2_\epsilon\) and \(\sigma^2_\eta\) (though we keep the assumption that everything is independant)

Code
# initial guesses

def objective_kalman(params):
    M = params['M']
    sigma2_eps = jnp.exp(params['log_sigma2_eps'])
    sigma2_eta = jnp.exp(params['log_sigma2_eta'])

    Sigma_eta = sigma2_eta * jnp.eye(2)
    Sigma_eps = sigma2_eps * jnp.eye(3)

    ll, ms, Ps, _, _, _ = fsf.kalman_filter(m_0,
                                            P_0,
                                            M,
                                            PHI,
                                            Sigma_eta,
                                            Sigma_eps,
                                            zs.T,
                                            likelihood="full")
    return -ll

objk_grad = jax.grad(objective_kalman)
             
import optax

# initial guesses
param0 = {"M": jnp.eye(2),
          "log_sigma2_eps": jnp.log(jnp.array(0.01)),
          "log_sigma2_eta": jnp.log(jnp.array(0.01))}
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)
param_ad = param0
opt_state = optimizer.init(param_ad)
import time
start_time = time.time()
                                   
for i in range(50):
    grad = objk_grad(param_ad)
    updates, opt_state = optimizer.update(grad, opt_state)
    param_ad = optax.apply_updates(param_ad, updates)
    nll = objective_kalman(param_ad)
#    print(nll)

ll, ms3, Ps, _, _, _ = fsf.kalman_filter(m_0,
                                         P_0,
                                         param_ad['M'],
                                         PHI,
                                         jnp.exp(param_ad['log_sigma2_eta'])*jnp.eye(2),
                                         jnp.exp(param_ad['log_sigma2_eps'])*jnp.eye(3),
                                         zs.T,)
end_time = time.time()

print(f"Time Elapsed is {end_time - start_time}")
print(param_ad)
print("true M is ", M)

ms3_df = pd.DataFrame(list(ms3), columns = ["x", "y"])

combined_df = pd.concat([alphas_df.assign(line='True Process'), ms3_df.assign(line='Filtered Process Means')])

# Creating the line plot with custom colors
fig4 = px.line(combined_df, x='x', y='y', color='line',
               color_discrete_sequence=['blue', 'red'], height=200)  # Specify colors here

fig4.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
# Show the plot
fig4.show()
Figure 8

These results are great. Lets confirm that the information filter is also working, using the same method

Code
@jax.jit
def objective_info(params):
    M = params['M']
    sigma2_eps = jnp.exp(params['log_sigma2_eps'])
    sigma2_eta = jnp.exp(params['log_sigma2_eta'])

    ll, _, _, _, _ = fsf.information_filter_iid(nu_0,
                                      Q_0,
                                      M,
                                      PHI_tuple,
                                      sigma2_eta,
                                      sigma2_eps,
                                      zs_tuple,
                                      likelihood="full")
    return -ll

obji_grad = jax.grad(objective_info)

# initial guesses
param0 = {"M": jnp.eye(2),
          "log_sigma2_eps": jnp.log(jnp.array(0.01)),
          "log_sigma2_eta": jnp.log(jnp.array(0.01))}
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)
param_ad = param0
opt_state = optimizer.init(param_ad)

start_time = time.time()

for i in range(5):
    grad = obji_grad(param_ad)
    updates, opt_state = optimizer.update(grad, opt_state)
    param_ad = optax.apply_updates(param_ad, updates)
    nll = objective_info(param_ad)
    #print(nll)

ll, nus, Qs, _, _ = fsf.information_filter_iid(nu_0,
                                     Q_0,
                                     param_ad['M'],
                                     PHI_tuple,
                                     jnp.exp(param_ad['log_sigma2_eta']),
                                     jnp.exp(param_ad['log_sigma2_eps']),
                                     zs_tuple,)

ms4 = jnp.linalg.solve(Qs, nus[..., None]).squeeze(-1)
                                                                     
#ll, nus, Qs = fsf.information_filter_iid(nu_0,
#                                           P_0,
#                                           param_ad['M'],
#                                           PHI_tuple,
#                                           param_ad['sigma2_eta'],
#                                           param_ad['sigma2_eps'],
#                                           zs_tuple,)

                                                                
end_time = time.time()
                                                                                         
print(f"Time Elapsed is {end_time - start_time}")
print(param_ad)
print("true M is ", M)

ms4_df = pd.DataFrame(list(ms4), columns = ["x", "y"])

combined_df = pd.concat([alphas_df.assign(line='True Process'), ms4_df.assign(line='Filtered Process Means')])

# Creating the line plot with custom colors
fig5 = px.line(combined_df, x='x', y='y', color='line',
               color_discrete_sequence=['blue', 'green'], height=200)  # Specify colors here

fig5.update_layout(xaxis=dict(scaleanchor="y", scaleratio=1, title='X Axis'), yaxis=dict(scaleanchor="x", scaleratio=1, title='Y Axis'))
# Show the plot
fig5.show()
Figure 9