LSTM-Calibrated SIR Model for COVID-19 Dynamics

Utah COVID-19 Case Prediction with Uncertainty Quantification

Author

Sima Najafzadehkhoei

Published

April 2, 2026

Introduction

Calibrating an SIR model from real surveillance data requires estimating transmission probability and contact rate from a noisy incidence curve — a task that is computationally expensive when done with simulation-based methods. This vignette shows how calibrate_sir() solves this in seconds using a pretrained BiLSTM network, then quantifies prediction uncertainty by running 2,000 stochastic SIR simulations with the calibrated parameters.


Setup

calibrate_sir() runs the BiLSTM model through a managed Python environment. If this is your first time using epiworldRcalibrate on this machine, run the one-time setup before loading the package:

Code
# Run once per machine, in a fresh R session
epiworldRcalibrate::setup_python_deps(force = TRUE)

In all subsequent sessions force = FALSE (the default) is sufficient to verify the environment is intact:

Code
epiworldRcalibrate::setup_python_deps()
Code
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggplot2)
  library(epiworldR)
  library(epiworldRcalibrate)
})

Data

We use utah_covid_data (shipped with the package), restricted to the most recent 61 days — the sequence length the BiLSTM was trained on. Using fewer or more days would produce incorrect predictions.

Code
data("utah_covid_data")

last_date  <- max(utah_covid_data$Date, na.rm = TRUE)
covid_data <- utah_covid_data |>
  filter(Date > (last_date - 61)) |>
  arrange(Date)

incidence_vec <- covid_data$Daily.Cases
stopifnot(length(incidence_vec) == 61)

cat("Period:", as.character(min(covid_data$Date)),
    "to", as.character(max(covid_data$Date)), "\n")
Period: 2025-03-07 to 2025-05-06 
Code
cat("Total cases:", sum(incidence_vec),
    "| Mean daily:", round(mean(incidence_vec), 1), "\n")
Total cases: 2034 | Mean daily: 33.3 

BiLSTM calibration

The recovery rate is fixed at \(1/7\) (7-day mean infectious period), which is well established for COVID-19 and reduces the calibration problem to two unknowns: ptran and crate. The contact rate returned by the network is further constrained by the identity \(\text{crate} = R_0 \times \text{recov} / \text{ptran}\), ensuring the three outputs are internally consistent.

Code
N     <- 5000
recov <- 1 / 7

lstm_predictions <- calibrate_sir(
  daily_cases     = incidence_vec,
  population_size = N,
  recovery_rate   = recov
)
Verifying Python packages...
[OK] numpy v2.0.2
[OK] sklearn v1.6.1
[OK] joblib v1.5.0
[OK] torch v2.7.0+cu126
All required Python packages are ready!
BiLSTM model loaded successfully.
Code
print(lstm_predictions)
     ptran      crate         R0 
0.08302574 2.39933270 1.39444458 

Stochastic SIR simulation

With parameters in hand, we run 2,000 independent stochastic realisations of the SIR model. Each run represents one possible epidemic trajectory under the estimated transmission dynamics. Variation across runs reflects the randomness inherent in disease spread — who contacts whom on any given day — rather than uncertainty about the parameters themselves.

Code
model <- ModelSIRCONN(
  name              = "Utah LSTM-calibrated SIR",
  n                 = N,
  prevalence        = incidence_vec[1] / N,
  contact_rate      = lstm_predictions[["crate"]],
  transmission_rate = lstm_predictions[["ptran"]],
  recovery_rate     = recov
)

saver <- make_saver("transition")
run_multiple(model, ndays = 60, nsims = 2000,
             saver = saver, nthreads = 12)
Starting multiple runs (2000) using 12 thread(s)
_________________________________________________________________________
_________________________________________________________________________
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
Code
sim_results <- run_multiple_get_results(
  model, nthreads = 12, freader = data.table::fread)

We extract Susceptible → Infected transitions (new daily cases) and summarise across simulations to get the median and 95% prediction interval at each time step.

Code
quantiles_df <- sim_results$transition |>
  filter(from == "Susceptible", to == "Infected") |>
  group_by(date) |>
  summarize(lower_ci = quantile(counts, 0.025),
            upper_ci = quantile(counts, 0.975),
            median   = quantile(counts, 0.5),
            .groups  = "drop")

Results

Predictions vs observed cases

Code
plot_df <- quantiles_df |>
  mutate(Date           = covid_data$Date,
         observed_cases = incidence_vec)

ggplot(plot_df, aes(x = Date)) +
  geom_ribbon(aes(ymin = lower_ci, ymax = upper_ci),
              fill = "red", alpha = 0.3) +
  geom_line(aes(y = median, color = "Model median"),
            linewidth = 1.2) +
  geom_line(aes(y = observed_cases, color = "Observed"),
            linewidth = 1.4) +
  geom_point(aes(y = observed_cases, color = "Observed"), size = 2) +
  scale_color_manual(values = c("Model median" = "red",
                                "Observed"     = "blue")) +
  labs(
    title    = "Daily Infections: Observed vs LSTM-Calibrated SIR (95% CI)",
    subtitle = paste0(
      "Contact rate: ", round(lstm_predictions[["crate"]], 3),
      " | Trans. prob: ", round(lstm_predictions[["ptran"]], 3),
      " | R\u2080: ", round(lstm_predictions[["R0"]], 2)
    ),
    x = "Date", y = "Daily cases", color = ""
  ) +
  theme_minimal(base_size = 14) +
  theme(legend.position  = "bottom",
        plot.title       = element_text(size = 15, face = "bold", hjust = 0.5),
        plot.subtitle    = element_text(hjust = 0.5))

Performance metrics

Three metrics assess how well the calibrated model fits the observed data. Coverage measures whether the prediction interval is appropriately wide — a well-calibrated 95% interval should contain roughly 95% of observed points. RMSE and MAE measure point prediction error on the median trajectory, with RMSE penalising large individual errors more heavily.

Code
coverage <- plot_df |>
  mutate(in_ci = observed_cases >= lower_ci & observed_cases <= upper_ci) |>
  summarize(rate = mean(in_ci) * 100, n_in = sum(in_ci), n = n())

rmse <- sqrt(mean((plot_df$median - plot_df$observed_cases)^2))
mae  <- mean(abs(plot_df$median - plot_df$observed_cases))

data.frame(
  Metric = c("95% CI Coverage (%)", "RMSE (cases/day)", "MAE (cases/day)"),
  Value  = c(round(coverage$rate, 1), round(rmse, 2), round(mae, 2))
) |> knitr::kable(caption = paste0(
  coverage$n_in, " of ", coverage$n, " observed days fall within the 95% CI"))
41 of 61 observed days fall within the 95% CI
Metric Value
95% CI Coverage (%) 67.20
RMSE (cases/day) 19.21
MAE (cases/day) 13.52

Discussion

The BiLSTM estimated a transmission probability of 0.083, a contact rate of 2.399 contacts per person per day, and \(R_0 =\) 1.39, with 67.2% of observed days falling inside the 95% prediction interval and a mean absolute error of 13.52 cases per day.

The prediction interval reflects simulation variance — stochastic variation in who infects whom — rather than parameter uncertainty, since BiLSTM produces point estimates. If parameter uncertainty is also needed, the ABC approach in vignette("bilstm-abc-comparison") propagates a full posterior distribution through the simulations.

Key assumptions of this model worth noting: the population is treated as homogeneously mixed, parameters are held constant over the 60-day window, and the recovery rate is fixed rather than estimated. Each of these could be relaxed in extensions — for example by fitting a time-varying \(\beta(t)\) or moving to an SEIR structure to account for a latent period.