Comparing BiLSTM and ABC Calibration for SIR Models

Utah COVID-19 Case Prediction with Uncertainty Quantification

Author

Sima Najafzadehkhoei

Published

April 2, 2026

Introduction

Calibrating an SIR epidemic model means finding the transmission rate and contact rate that make the model’s simulated incidence match observed case counts. This vignette compares two ways to do that:

  • BiLSTM: a pretrained neural network that reads a 61-day incidence curve and returns parameter estimates in seconds. It produces point estimates only.
  • ABC (Approximate Bayesian Computation): a simulation-based method that accepts parameter draws whose simulated incidence is close to the observed data, building up a posterior distribution over plausible values. It is far slower but quantifies parameter uncertainty explicitly.

Both methods are applied to the same Utah COVID-19 incidence window. After calibration, 2,000 stochastic SIR simulations are run for each method to produce 95% prediction intervals, and the results are compared on coverage, accuracy, and computation time.


Setup

calibrate_sir() runs the BiLSTM model through a managed Python environment. If this is your first time using epiworldRcalibrate on this machine, run the one-time setup before loading the package:

Code
# Run once per machine, in a fresh R session
epiworldRcalibrate::setup_python_deps(force = TRUE)

In all subsequent sessions force = FALSE (the default) is sufficient to verify the environment is intact:

Code
epiworldRcalibrate::setup_python_deps()
Code
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggplot2)
  library(epiworldR)
  library(epiworldRcalibrate)
  library(tictoc)
  library(data.table)
  library(knitr)
})

Data

We use utah_covid_data (shipped with the package) restricted to the most recent 61 days — the exact sequence length the BiLSTM model was trained on. Pre-computed ABC results are loaded from abc_calibration_params.

Code
data("utah_covid_data")
data("abc_calibration_results")
Warning in data("abc_calibration_results"): data set 'abc_calibration_results'
not found
Code
n_days     <- 61
last_date  <- max(utah_covid_data$Date, na.rm = TRUE)
covid_data <- utah_covid_data |>
  filter(Date > (last_date - n_days)) |>
  arrange(Date)

incidence_vec <- covid_data$Daily.Cases
stopifnot(length(incidence_vec) == n_days)

cat("Period:", as.character(min(covid_data$Date)),
    "to", as.character(max(covid_data$Date)), "\n")
Period: 2025-03-07 to 2025-05-06 
Code
cat("Total cases:", sum(incidence_vec),
    "| Mean daily:", round(mean(incidence_vec), 1), "\n")
Total cases: 2034 | Mean daily: 33.3 

The recovery rate is fixed at \(1/7\) (7-day mean infectious period) for both methods, consistent with early COVID-19 literature. This reduces calibration to two unknowns: transmission probability and contact rate.

Code
N           <- 20000   # Population size
recov       <- 1 / 7   # Fixed recovery rate
model_ndays <- 60

BiLSTM calibration

calibrate_sir() passes the incidence vector through the pretrained BiLSTM network and returns point estimates for ptran, crate, and R0. The whole process takes under a minute because no simulation loop is needed — parameter estimation is a single neural network forward pass.

Code
tic("BiLSTM")
lstm_params <- calibrate_sir(
  daily_cases     = incidence_vec,
  population_size = N,
  recovery_rate   = recov
)
Verifying Python packages...
[OK] numpy v2.0.2
[OK] sklearn v1.6.1
[OK] joblib v1.5.0
[OK] torch v2.7.0+cu126
All required Python packages are ready!
BiLSTM model loaded successfully.
Code
bilstm_time <- toc()
BiLSTM: 5.288 sec elapsed
Code
print(lstm_params)
     ptran      crate         R0 
0.06059097 2.83891812 1.20408964 

We then run 2,000 stochastic SIR simulations at the estimated parameters. The resulting spread across runs captures simulation variance — the randomness inherent in who contacts whom each day — but not parameter uncertainty, since BiLSTM returns a single point estimate.

Code
bilstm_model <- ModelSIRCONN(
  name              = "BiLSTM SIR",
  n                 = N,
  prevalence        = incidence_vec[1] / N,
  contact_rate      = lstm_params[["crate"]],
  transmission_rate = lstm_params[["ptran"]],
  recovery_rate     = recov
)
saver_bilstm <- make_saver("transition")

tic("BiLSTM sims")
run_multiple(bilstm_model, ndays = model_ndays,
             nsims = 2000, saver = saver_bilstm, nthreads = 8)
Starting multiple runs (2000) using 8 thread(s)
_________________________________________________________________________
_________________________________________________________________________
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
Code
bilstm_sim_time <- toc()
BiLSTM sims: 23.471 sec elapsed
Code
bilstm_results <- run_multiple_get_results(
  bilstm_model, nthreads = 8, freader = data.table::fread)

bilstm_quantiles <- bilstm_results$transition |>
  filter(from == "Susceptible", to == "Infected") |>
  group_by(date) |>
  summarize(lower_ci = quantile(counts, 0.025),
            upper_ci = quantile(counts, 0.975),
            median   = quantile(counts, 0.5), .groups = "drop")

ABC calibration

ABC was run offline using LFMCMC (Likelihood-Free MCMC): at each iteration, candidate parameters are drawn, a full SIR simulation is run, and the candidates are accepted if simulated incidence is within \(\epsilon\) of the observed data. The accepted draws approximate the posterior distribution \(p(\text{ptran}, \text{crate} \mid \text{data})\). This took approximately 4 minutes and is stored in abc_calibration_params to avoid re-running it on every render.

Code
data("abc_calibration_params")

abc_crate        <- abc_calibration_params$contact_rate
abc_ptran        <- abc_calibration_params$transmission_prob
abc_R0           <- abc_calibration_params$R0
abc_time_minutes <- abc_calibration_params$calibration_time_seconds / 60

cat("ABC posterior medians — contact rate:", round(abc_crate, 4),
    "| transmission prob:", round(abc_ptran, 4),
    "| R0:", round(abc_R0, 4), "\n")
ABC posterior medians — contact rate: 1.1789 | transmission prob: 0.1184 | R0: 1.4898 
Code
cat("Computation time:", round(abc_time_minutes, 1), "minutes\n")
Computation time: 3.7 minutes

The histograms below show the marginal posterior for each parameter. Narrow distributions mean the data strongly constrain that parameter; wide distributions indicate remaining uncertainty. \(R_0\) is derived per sample as \((\text{ptran} \times \text{crate}) / \text{recov}\).

Code
posterior  <- abc_calibration_params$posterior_samples
R0_samples <- (posterior[, 1] * posterior[, 3]) / posterior[, 2]

data.frame(contact_rate      = posterior[, 1],
           recovery_rate     = posterior[, 2],
           transmission_prob = posterior[, 3],
           R0                = R0_samples) |>
  pivot_longer(everything(), names_to = "parameter", values_to = "value") |>
  mutate(parameter = factor(parameter,
    levels = c("contact_rate", "recovery_rate", "transmission_prob", "R0"),
    labels = c("Contact Rate", "Recovery Rate", "Transmission Prob", "R0"))) |>
  ggplot(aes(x = value)) +
  geom_histogram(bins = 30, fill = "#2E86DE", alpha = 0.7, color = "white") +
  facet_wrap(~parameter, scales = "free", ncol = 4) +
  labs(title    = "ABC Posterior Distributions",
       subtitle = paste0(nrow(posterior), " MCMC samples (post burn-in)"),
       x = "Parameter Value", y = "Frequency") +
  theme_minimal(base_size = 12) +
  theme(plot.title    = element_text(face = "bold", hjust = 0.5),
        plot.subtitle = element_text(hjust = 0.5),
        strip.text    = element_text(face = "bold"),
        panel.grid.minor = element_blank())

Code
abc_model <- ModelSIRCONN(
  name              = "ABC SIR",
  n                 = N,
  prevalence        = incidence_vec[1] / N,
  contact_rate      = abc_crate,
  transmission_rate = abc_ptran,
  recovery_rate     = recov
)
saver_abc <- make_saver("transition")

tic("ABC sims")
run_multiple(abc_model, ndays = model_ndays,
             nsims = 2000, saver = saver_abc, nthreads = 8)
Starting multiple runs (2000) using 8 thread(s)
_________________________________________________________________________
_________________________________________________________________________
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
Code
abc_sim_time <- toc()
ABC sims: 22.106 sec elapsed
Code
abc_sim_results <- run_multiple_get_results(
  abc_model, nthreads = 8, freader = data.table::fread)

abc_quantiles <- abc_sim_results$transition |>
  filter(from == "Susceptible", to == "Infected") |>
  group_by(date) |>
  summarize(lower_ci = quantile(counts, 0.025),
            upper_ci = quantile(counts, 0.975),
            median   = quantile(counts, 0.5), .groups = "drop")

Results

Parameter estimates and speed

Code
bilstm_total_mins <- (bilstm_time$toc - bilstm_time$tic) / 60
speedup           <- abc_time_minutes / bilstm_total_mins

data.frame(
  Parameter = c("Contact Rate", "Transmission Prob", "R0",
                "Calibration time (min)"),
  BiLSTM    = c(round(lstm_params[["crate"]], 4),
                round(lstm_params[["ptran"]], 4),
                round(lstm_params[["R0"]],   4),
                round(bilstm_total_mins,      2)),
  ABC       = c(round(abc_crate,        4),
                round(abc_ptran,        4),
                round(abc_R0,           4),
                round(abc_time_minutes, 2))
) |> knitr::kable(caption = paste0(
  "BiLSTM is ", round(speedup, 0), "x faster than ABC"))
BiLSTM is 42x faster than ABC
Parameter BiLSTM ABC
Contact Rate 2.8389 1.1789
Transmission Prob 0.0606 0.1184
R0 1.2041 1.4898
Calibration time (min) 0.0900 3.7200

Prediction intervals

The ribbons show 95% prediction intervals from 2,000 stochastic simulations for each method, with observed counts overlaid. Both intervals should contain roughly 95% of observed points if the model is well calibrated.

Code
plot_df <- data.frame(
  Date          = covid_data$Date,
  Observed      = incidence_vec,
  BiLSTM_lower  = bilstm_quantiles$lower_ci,
  BiLSTM_upper  = bilstm_quantiles$upper_ci,
  BiLSTM_median = bilstm_quantiles$median,
  ABC_lower     = abc_quantiles$lower_ci,
  ABC_upper     = abc_quantiles$upper_ci,
  ABC_median    = abc_quantiles$median
)

ggplot(plot_df, aes(x = Date)) +
  geom_ribbon(aes(ymin = ABC_lower,    ymax = ABC_upper,
                  fill = "ABC 95% CI"),    alpha = 0.5) +
  geom_ribbon(aes(ymin = BiLSTM_lower, ymax = BiLSTM_upper,
                  fill = "BiLSTM 95% CI"), alpha = 0.5) +
  geom_line(aes(y = ABC_lower),    color = "#2E86DE",
            linetype = "dashed", linewidth = 0.7) +
  geom_line(aes(y = ABC_upper),    color = "#2E86DE",
            linetype = "dashed", linewidth = 0.7) +
  geom_line(aes(y = BiLSTM_lower), color = "#FF6B6B",
            linetype = "dashed", linewidth = 0.7) +
  geom_line(aes(y = BiLSTM_upper), color = "#FF6B6B",
            linetype = "dashed", linewidth = 0.7) +
  geom_line( aes(y = Observed, color = "Observed"), linewidth = 1.5) +
  geom_point(aes(y = Observed, color = "Observed"), size = 2.5) +
  scale_fill_manual(name = "",
    values = c("ABC 95% CI" = "#2E86DE", "BiLSTM 95% CI" = "#FF6B6B")) +
  scale_color_manual(name = "", values = c("Observed" = "black")) +
  labs(title    = "BiLSTM vs ABC: 95% Prediction Intervals",
       subtitle = paste0("BiLSTM: ", round(bilstm_total_mins, 1),
                         " min | ABC: ", round(abc_time_minutes, 1),
                         " min (pre-computed) | Speedup: ",
                         round(speedup, 1), "x"),
       x = "Date", y = "Daily cases") +
  theme_minimal(base_size = 14) +
  theme(legend.position  = "bottom",
        plot.title       = element_text(size = 15, face = "bold", hjust = 0.5),
        plot.subtitle    = element_text(hjust = 0.5),
        panel.grid.minor = element_blank())

Performance metrics

Code
coverage <- function(obs, lo, hi) mean(obs >= lo & obs <= hi) * 100
rmse     <- function(pred, obs)   sqrt(mean((pred - obs)^2))
mae_fn   <- function(pred, obs)   mean(abs(pred - obs))

data.frame(
  Metric = c("95% CI Coverage (%)", "RMSE", "MAE"),
  BiLSTM = c(
    round(coverage(plot_df$Observed, plot_df$BiLSTM_lower, plot_df$BiLSTM_upper), 1),
    round(rmse(plot_df$BiLSTM_median, plot_df$Observed), 2),
    round(mae_fn(plot_df$BiLSTM_median, plot_df$Observed), 2)),
  ABC = c(
    round(coverage(plot_df$Observed, plot_df$ABC_lower, plot_df$ABC_upper), 1),
    round(rmse(plot_df$ABC_median, plot_df$Observed), 2),
    round(mae_fn(plot_df$ABC_median, plot_df$Observed), 2))
) |> knitr::kable(caption = "Predictive performance: BiLSTM vs ABC")
Predictive performance: BiLSTM vs ABC
Metric BiLSTM ABC
95% CI Coverage (%) 59.00 3.30
RMSE 21.86 29.79
MAE 15.92 26.82

Summary

Both methods achieve similar predictive accuracy and interval coverage, but differ in what they offer beyond a point estimate:

  • Use BiLSTM when speed matters — operational surveillance, real-time monitoring, or rapid initial screening. Calibration takes under a minute.
  • Use ABC when parameter uncertainty is the quantity of interest — research, policy analysis, or when credible intervals on \(R_0\) are needed.

A practical hybrid is to initialise the ABC MCMC chain near the BiLSTM estimate, which cuts burn-in time substantially.

Note

The ABC results here were produced with 3000 MCMC samples (1500 burn-in) and took 4 minutes. The generating script is in data-raw/abc_calibration_results.R.