Skip to the content.

Accelerating Spatiotemporal Neural Process Simulator with Deep Reinforcement Learning


In this paper, we come up with a deep reinforcement learning approach to improve the efficiency of COVID-19 pandemic simulation. The pandemic simulation, a crucial tool for evaluating the impact of policy measures, is modeled using the spatio(temporal) neural network (STNP) and the Neural Process (NP) family. To enhance the computational efficiency, Bayesian active learning is employed to actively query the most informative data. However, the brute-force search approach for evaluating the best parameter set is inefficient. Our proposed framework then replaces it with Deep-Q network(DQN), where the reward of a parameter set is estimated with Q-values. The model outputs the Q-values of choosing each parameter set, allowing the model to efficiently explore the parameter space and identify the parameter set with the best reward. To allow the DQN to adapt to varying distributions, we pre-train and fit the DQN with data points generated during the first epoch of the STNP model. Then, the relatively converged DQN model replaces STNP brute-force parameter search. With some experiments and analysis, we show our progress and limitations of our approach.


In the era of pandemic, researchers have attempted to construct COVID-19 pandemic simulations, as means to better understand the underlying factors of COVID-19 spread and determine preventive measures that lower the death rate, rate of disease spread, economic loss, etc. An accurate simulation is crucial to evaluate the impact of external factors, such as the cost of policy enforcements on the population. To model the propagation of disease transmission in general, several simulators such as GLEAM(Balcan et. al) and SEIR(Mwalili et. al) have been introduced, where the transmission population across infection states are modeled as stochastic processes. These simulators vary in settings, such as modeled equations, demographic data, parameters, and heuristics. The enhancement in simulator’s complexities better capture both fine-grained and macro-scale impacts, but comes with the tradeoff of expensive computational and time complexities, which is impractical under urgent and immediate circumstances. To speed up this issue while preserving the data generated by the simulator, deep surrogate models are implemented to imitate the dynamics of the simulator. The Neural Process family(NP)(​​Garnelo et al.) is then introduced to the task given its characteristics to model stochastic processes and uncertainty with a deep learning framework. To encode both spatial and temporal dynamics based on historical states, Wu proposes a spatiotemporal neural network to accelerate the simulations, where the hidden states are learned through the mapping of inputs of simulation parameters.
Compared to other NP models, the STNP model queries the simulator through Bayesian active learning to alleviate the computational cost of querying all parameters(Wu et al., 2022). With the initial training set as the simulator’s generated data and one simulator parameter, the STNP model determines the set of simulation parameters with the largest reward in assisting the model’s learning, augments the training set, then re-trains the model in an iterative manner. While active learning’s acquisition function provides a good estimation of the reward for a given parameter set, the current brute force search approach in exploring the best parameter set is comparatively inefficient, given the large domain of parameter combinations. Here, we propose to replace the objective function with DeepQ reinforcement learning network, where the reward of a parameter set is estimated with Q-value(Barto et. al). Then, by inputting a training set of data predictions, parameters, and latent variables, the deep learning network outputs the Q-values of choosing each parameter set. In one forward pass, the model identifies the parameter set with the best reward, which explores the parameter space in an efficient, elegant fashion. In this paper, we aim to integrate Deep Reinforcement Learning (DRL) into the STNP simulator to find optimal policies in COVID-19 simulations. By replacing brute-force search with a Deep-Q network on the acquisition function, the proposed framework efficiently and effectively explores the parameter space while considering different reward functions to evaluate the performance of the model. This research will contribute to the development of more accurate and effective COVID-19 simulations, helping to guide decision-makers in the implementation of effective preventive measures to reduce the impact of the pandemic.


The proposed idea aims to integrate Deep Reinforcement Learning (DRL) into a spatiotemporal neural process (STNP) simulator in order to find optimal training scenarios in COVID-19 simulations. We propose to replace the current brute-force search approach of the acquisition function with a Deep-Q Network (DQN) that outputs Q-values for each set of simulation parameters. The STNP model, where DQN is trained on, learns the behavior of the SEIR simulator with respect to different scenarios of COVID latent period and transmissibility rate. Our goal is to evaluate the performance of the DQN in exploring the parameter space efficiently and effectively.

Given the training data points, consisting of context and target points, the DQN model performs an action of choosing the most suitable parameter in the parameter pool, corresponding to the action with the largest estimated Q-value. Each action will be yielded with a reward from the environment, where the reward is equivalent to the acquisition function in STNP.

RL Environment : To apply RL framework onto the task, we design an environment as the vector consisting of concatenated the context points, target points, and latent variables. In a similar manner as the original neural process, we concatenate x with y, leading to a feature vector of dim(x_c)+dim(y_c), where x_c, y_c represent the parameter and the infection statistics respectively. Then, with the latent variable z, concatenate them to the last dimension, leading to a feature of size dim(x_c) + dim(y_c)+z_dim.

Deep-Q Network: A naive Deep-Q Network is based on the concept of Q-learning wherein an agent learns values of actions in a particular state called Q. Q-learning allows the agent to find an optimal policy using the Markov Decision Process to maximize the expected value of total reward over future steps. A naive DQN uses a neural network that receives action state pairs as input and outputs action values. The DQN is then trained using reinforcement learning, wherein the neural networks minimize loss between the predicted Q-values and the true Q-value. The new Q-value is computed by equation (3)

Qnew(st,at) (1-) Qcurr(st,at) + (rt + maxaQ(st+1, a)) (3)

In equation (3), Qnew represents the new Q value computed by taking the current Q value at state st and action at and subtracting it by that Q value times learning rate . Then, that value is added to reward rt plus the maximum Q value from the best possible action maxaQ(st+1, a) taken at the next step st+1 . This reward plus maximum Q value at the next step is also multiplied by the learning rate .

Algorithm 1 and 2 describe a typical DQN algorithm and STNP algorithm respectively. We improve the model by replacing the brute force method in Algorithm 1, shown in Algorithm 2. Instead of the brute force method finding the maximum reward, the DQN uses 1 forward pass to determine the action with the highest value(Q-value/reward), while internalizing past mistakes/success through experience replay to aggregate more insights on calculating the reward. The definitions of variables described in algorithms can be seen in table 1. This is described in Algorithm 0.

Algorithm 0: Deep Q-Learning with Experience replay Initialize replay memory D to capacity N Initialize action-value function Q with random weights For episode = 1, M: Initialize sequence s1 = {x1} and preprocessed sequences 1= (s1) For t=1,T: A random action is selected with probability otherwise action is selected by at=maxaQ*((st),a;) Execute action at in emulator and observe reward rt and image xt+1 Set st+1= st,at,xt+1 and preprocess t+1= (st+1) Store transition (t, at, rt, t+1) in D Sample random minibatch of transitions (t, at, rt, t+1) from D Set yj = rj for terminal j+1 or rj+ maxa’Q(j+1,a’; ) from non-terminal j+1 Perform a gradient descent step on (yj-Q(j,aj;))2

Algorithm 1: Interactive Neural Process
Input: Initial simulation dataset S1, DQN parameters
Train the model NP(1)(S1);
i=1,2, …;
Learn (z1, z2, … , zt) ~ qi(z1:T| x1:T, Si);
Predict (x1, x2, … , xT) ~ pi(x1:T | z1:T, Si);
Select a batch {(i+1)} arg max Ep(x1:T|z1:T, )[r(x1:T|z1:T, )] ;
Simulate {x(i+1)1:T}=F(i+1, u);
Augment training set Si+1=Si {i+1, x(i+1)1:T};
Update the model NP(i+1)(Si+1), DQN;

Algorithm 2: Interactive Neural Process with DQN
Input: Initial simulation dataset S1, DQN parameters
Train STNP model for 1 epoch, collect data points (x_c, y_c, x_t, y_t, y_t_pred) at every 2000 iterations. (collect 10 sets of tuples in total)
Using data points in step1, we generate 10 environments. We train DQN model for 150 epochs on each environment, where the tuples collected later are trained earlier
After pre-training the DQN model in step1 and 2, use the DQN to actively query the unselected dataset for STNP model (below is the repetition of algorithm 1)
Train the model NP(1)(S1);
Learn (z1, z2, … , zt) ~ qi(z1:T| x1:T, Si);
Predict (x1, x2, … , xT) ~ pi(x1:T | z1:T, Si);
Select the action with the DQN’s largest Q-value. Find corresponding i+1
Augment training set Si+1=Si {i+1, x(i+1)1:T};
Update the model NP(i+1)(Si+1),
The definitions of variables can be found from the original paper written by Wu et al.

Our reward function will be computed with the STNP reward function, namely the acquisition function of Latent Information Gain (LIG). Given the state information and parameter information, the model will ideally learn the mapping function of the acquisition function. Current implementation is constrained given that the convergence rate of DQN is significantly slower than the STNP model.

Data & Evaluation Metric: The SEIR simulator generates data points consisting of 300 parameter scenarios split into 270 for training, 15 for validation, and 15 for testing.

To compare the performance of the DQN models in selecting informative parameters, we will calculate with the groundtruth acquisition value for each parameter. The evaluation metric of DQN performance is the number of selected parameters located in the top-20 (or top-10) highest acquisition value. Given a certain number of actions for an agent to interact in an environment(Eg:25 actions), the agent should learn to prioritize parameters with more acquisition value, which these evaluation metrics adequately quantifies such requirements. From a reinforcement learning perspective, we will as well reference the total rewards per episode, which indicates the proficiency of the DQN model in its environment.

Experimental Setup: To collect data points to train on our proposed method, we will generate data points using the STNP model in the first epoch. In particular, for every 1000 iterations(each STNP model has 20000 iterations in one epoch), we retrieve the context points, target points, and the latent variables generated from both context and training points).

To train the DQN, we will first preprocess the simulation parameters and flatten the states into a feature vector(Given that RL states are usually represented as vectors). Then, within the DQN model, we reshape it back to N x d, where N is the number of context and target points and d is the feature vector, equal to dimension of x_c + x_t + z = 2 + 100 + 8 = 110. We then feed the vectors into 3 linear layers, each followed by an activation function(Eg: sigmoid or LeakyReLU). The output channels of each layer are 192, 256, 270 respectively, where 270 nodes corresponds to the number of possible parameters to be chosen. Then, we take the mean of the feature values across the data points, leading to a 270x1 vector. To avoid exploding gradient, we clip the norm of the gradient and apply softmax to scale down the magnitudes and output as q-value vectors.

To compare the performance of the DQN with the current brute-force search approach, we will run the SEIR simulator with the same set of parameters used in the DQN and compare the results. We will also compare the efficiency of the DQN in exploring the parameter space and finding optimal policies with the brute-force search approach.

Model setup Data Preparation: our data is generated based on the SEIR model, which is a widely used compartmental model for infectious diseases. The model uses following data: “Seir_data”: contains the beta and epsilon values for the SEIR model. “x_all”, “y_all”: contains the full dataset for all beta and epsilon values. “x_val”,”y_val” :validation set. “X_test”, “y_test”: test set Environment and temporal model training: a temporary Spatio-Temporal Neural(STNP) model is trained on the initial dataset. The model is trained for a specified number of epochs, and the results are saved in a file named “dcrcnn{}”. The environment setup is: State = (x_c, y_c, x_t, y_t, y_t_pred, z) -> dim = (n_c+n_t) x 110 (2+100+z_dim). X_c is the x variable for the context points, and y_c is the y variable for the context points. The x_t is the input feature vector for target points, which are the points the model wants to predict. The y_t is the true output value corresponding to the target point (x_t). The y_t_pred is the predicted output values corresponding to the target points (x_t) generated by the model. The z is the latent variable representation in the model, which captures the underlying structure of the data. The dimension of z is represented as z_dim. The dimensions of the states are given as (n_c + n_t) x 110, where n_c is the number of context points and n_t is the number of target points. The number 110 comes from the sum of dimensions of x_c, x_t, and z. Action: all possible parameters Reward = acquisition value of the selected parameter Agent(DQN): a DQN is trained using a game environment that simulates the spread of infectious diseases. The DQN learns to select the best data points for training the STNP model. The training process is iterative, and the DQN’s local Q-network is saved periodically in a file named “results/dqn_ckpt{}.pth”. The DQN structure is: Stack of linear layer + activation function(sigmoid/relu), followed by a softmax that inhibits Q-value with exploding gradients STNP model training with DQN-based Data Selection: the main STNP model is trained iteratively with the data selected by the DQN. In each iteration: The STNP model is trained on the current dataset. The model is evaluated on the test and full datasets. Scores are calculated for each data point using the current STNP model. The DQN selects a new set of data points based on the scores. The selected data points are added to the training dataset for the next iteration. Result Analysis: At each iteration, we evaluated the performance of the DQN model with respect to the groundtruth reward. Since the DQN model selects parameters that maximize reward, the performance is measured by the number of selected parameters in the top 20 of highest reward.


Graph 1 shows the total rewards over episodes. The rewards have been normalized through z-scoring(deducting rewards from the mean then dividing by standard deviation). For each training scenario/tuple, we train for 150 epochs.Through the first 300 episodes of training, total reward seems to be going up. There is sharp change after 300 episodes and the total reward does not reach the same peak. Moreover, the total reward of the same environment/iteration seems to have a similar magnitude. This further shows how the DQN model is unable to quickly adapt to a completely new scenario, possibly due to completely different latent variable distribution. At the peak of the training procedure(approximately at epoch 300), out of the selected 15 actions, the model is able to retrieve 6 parameters that were in the top-20 highest acquisition values and 3 of them being in the top-10. As compared to random selection, the expected values of choosing top-20 and top-10 rewards are approximately 1.11 and 0.55 actions. This shows how DQN and RL can be a feasible approach to choose suitable parameters/training data given a fixed scenario.

Discussion/Conclusion :

In this study, we proposed to integrate deep reinforcement learning into a spatiotemporal neural process simulator for finding optimal training scenarios in COVID-19 simulations. Our approach involved replacing the brute-force search approach with a Deep-Q Network that outputs Q-values for each set of simulation parameters. We trained the DQN on the STNP model, which learns the behavior of the SEIR simulator with respect to different scenarios of COVID latent period and transmissibility rate. Our goal was to evaluate the performance of the DQN in exploring the parameter space efficiently and effectively. To guarantee the DQN model performance with higher acquisition values, we can as well substitute 3(d) in algorithm 2 by selecting top-10 actions from the DQN model, then evaluate with acquisition function, and acquire the training data with max reward. In terms of runtime complexity, it is still constant time, given that we did narrow down the scope of parameter candidates.

The magnitude and distribution of latent variables varies across STNP training iterations. Therefore, to prevent the DQN from learning the magnitude instead of the representations, we have normalized the observation states. Our results showed that the DQN model was able to retrieve a significant number of parameters with the highest acquisition values, especially during the initial training epochs. The model’s total rewards increased in the initial training epochs, indicating that the DQN was able to explore the parameter space efficiently. However, the performance of the model decreased after approximately 300 episodes, which may be due to the varying magnitude and distribution of the latent variables across different iterations.

Overall, this study demonstrates the potential of using deep reinforcement learning in spatiotemporal neural process simulations for optimal data selection. However, further research is needed to improve the efficiency and robustness of the proposed method in handling the dynamic and complex nature of infectious disease simulations.

To further improve the performance of our proposed method, there are several areas that can be explored. First, adjusting the learning rate of the DQN could potentially enhance the model’s ability to converge faster and reach higher rewards. Additionally, reducing the buffer size could help to prevent the model from becoming overfit and improve its generalization ability. Another potential area for improvement is to adjust the number of actions per episode to prioritize actions with more rewards. By assigning higher weights to actions with higher rewards, the model can focus on most informative parameters and avoid wasting time exploring less useful ones. Besides, normalizing state and reward values could also help to reduce the variation in magnitude from different scenarios and prevent the DQN from overfitting on specific scenarios. Adding more randomness to the model’s decision-making process could also improve its exploration ability and prevent it from getting stuck in local minima. Furthermore, tuning the model architecture and adjusting the hyperparameters could lead to better performance. This includes experimenting with different activation functions, hidden layer sizes, and regularization techniques.