import jax
import os
import jax.numpy as jnp
import jaxidem.utils as utils
import jaxidem.idem as idem
import matplotlib.pyplot as plt
from PIL import Image, ImageSequence
seed = 4
key = jax.random.PRNGKey(seed)
keys = jax.random.split(key, 10 )
process_basis = utils.place_cosine_basis(N = 10 )
#process_basis = utils.place_basis()
sigma2_eta = jnp.diag((0.01 * jnp.ones(process_basis.nbasis)).at[1 ].set (40.0 ).at[30 ].set (80.0 ).at[31 ].set (60.0 ))
#sigma2_eta = 0.01
covariate_labels= ['Intercept' , 'x' , 'y' ]
model = idem.gen_example_idem(keys[0 ], k_spat_inv= False , ngrid= jnp.array([40 , 40 ]), process_basis = process_basis, sigma2_eta = sigma2_eta, covariate_labels= covariate_labels)
# Simulation
T = 35
nobs = 50
coords = jax.random.uniform(
keys[0 ],
shape= (nobs, 2 ),
minval= 0 ,
maxval= 1 ,
)
times = jnp.repeat(jnp.arange(1 , T + 1 ), coords.shape[0 ])
rep_coords = jnp.tile(coords, (T, 1 ))
x = rep_coords[:,0 ]
y = rep_coords[:,1 ]
process_data, obs_data = model.simulate(keys[1 ], x, y, times,
covariates = jnp.column_stack([x,y]))
dpi = 200
width = 576 / dpi
height = 480 / dpi
# plot the objects
utils.gif_st_grid(process_data, "site/figure/process.gif" , width= width, height= height)
utils.gif_st_pts(obs_data, "site/figure/obs.gif" , width= width, height= height)
model.kernel.save_plot("site/figure/kernel.png" , width= width, height= height)
gif1 = Image.open ('site/figure/process.gif' )
gif2 = Image.open ('site/figure/tardis.gif' )
width, height = gif1.size
frames = []
num_frames_gif1 = len (list (ImageSequence.Iterator(gif1)))
num_frames_gif2 = len (list (ImageSequence.Iterator(gif2)))
max_frames = max (num_frames_gif1, num_frames_gif2)
for i in range (max_frames):
frame1 = ImageSequence.Iterator(gif1)[i % num_frames_gif1].convert("RGBA" )
frame2 = ImageSequence.Iterator(gif2)[i % num_frames_gif2].convert("RGBA" )
frame2 = frame2.resize((width, height), Image.LANCZOS)
combined = Image.alpha_composite(frame1, frame2)
frames.append(combined)
frames[0 ].save('site/figure/process.gif' , save_all= True , append_images= frames[1 :], duration= gif1.info['duration' ], loop= 0 )