Comparing BiLSTM and ABC Calibration for SIR Models

Utah COVID-19 Case Prediction with Uncertainty Quantification

Author

Sima Najafzadehkhoei

Published

January 27, 2026

Introduction

This analysis compares two approaches to epidemic model calibration:

  1. BiLSTM (Deep Learning): Fast neural network-based parameter estimation
  2. ABC (Approximate Bayesian Computation): Simulation-based Bayesian inference

Both methods calibrate SIR model parameters from observed COVID-19 incidence data, but differ dramatically in computational cost and uncertainty quantification.

Workflow Overview

  1. Download recent COVID-19 case data from Utah (61 days)
  2. BiLSTM calibration: Estimate parameters using trained neural network
  3. ABC calibration: Load pre-computed parameters from package data
  4. Run 2,000 stochastic SIR simulations for each method
  5. Compare predictions, uncertainty, and computational time

Setup

Code
suppressPackageStartupMessages({
  library(tidyverse)
  library(ggplot2)
  library(epiworldR)
  library(epiworldRcalibrate)
  library(tictoc)
  library(data.table)
  library(knitr)
})

Data Acquisition

Load Packaged Data

Code
data("utah_covid_data")
data("abc_calibration_results")
Warning in data("abc_calibration_results"): data set 'abc_calibration_results'
not found
Code
# Preview the COVID data
head(utah_covid_data)
        Date Daily.Cases X3.Day.Moving.Average Smoothed.3.Day.Moving.Average
1 2024-05-07          34                  0.91                          0.84
2 2024-05-08          27                  0.97                          0.86
3 2024-05-09          30                  0.93                          0.87
4 2024-05-10          24                  0.83                          0.88
5 2024-05-11          22                  0.78                          0.89
6 2024-05-12          14                  0.62                          0.91
             Status
1 Incidence Plateau
2 Incidence Plateau
3 Incidence Plateau
4 Incidence Plateau
5 Incidence Plateau
6 Incidence Plateau
Code
summary(utah_covid_data)
      Date             Daily.Cases     X3.Day.Moving.Average
 Min.   :2024-05-07   Min.   : 14.00   Min.   :0.620        
 1st Qu.:2024-08-06   1st Qu.: 39.00   1st Qu.:1.260        
 Median :2024-11-05   Median : 54.00   Median :1.770        
 Mean   :2024-11-05   Mean   : 74.22   Mean   :2.283        
 3rd Qu.:2025-02-04   3rd Qu.: 97.00   3rd Qu.:2.920        
 Max.   :2025-05-06   Max.   :367.00   Max.   :6.970        
 Smoothed.3.Day.Moving.Average    Status         
 Min.   :0.780                 Length:365        
 1st Qu.:1.420                 Class :character  
 Median :1.900                 Mode  :character  
 Mean   :2.282                                   
 3rd Qu.:2.860                                   
 Max.   :4.960                                   

Load Recent Case Data

Code
get_covid_data <- function(n_days) {
  
  # Use packaged data
  covid_data <- utah_covid_data
  
  # Filter to last n_days
  last_date <- max(covid_data$Date, na.rm = TRUE)
  covid_data <- subset(covid_data, Date > (last_date - n_days)) %>% 
    dplyr::arrange(Date)
  
  stopifnot("Daily.Cases" %in% names(covid_data))
  covid_data
}

n_days <- 61
covid_data <- get_covid_data(n_days)
incidence <- covid_data$Daily.Cases
incidence_vec <- incidence

stopifnot(length(incidence) == n_days)

cat("Date range:", as.character(min(covid_data$Date)), "to", 
    as.character(max(covid_data$Date)), "\n")
Date range: 2025-03-07 to 2025-05-06 
Code
cat("Total cases observed:", sum(incidence), "\n")
Total cases observed: 2034 
Code
cat("Mean daily cases:", round(mean(incidence), 1), "\n")
Mean daily cases: 33.3 

Configuration

Code
N <- 20000     # Population size (BiLSTM)
recov <- 1 / 7  # Recovery rate (7-day infectious period)
model_ndays <- 60  # Simulation days

# ABC uses N=30000 from saved calibration

Method 1: BiLSTM Calibration

Parameter Estimation

Code
cat("=== BiLSTM Calibration ===\n")
=== BiLSTM Calibration ===
Code
tic("BiLSTM calibration")
lstm_predictions <- calibrate_sir(
  daily_cases = incidence_vec,
  population_size = N,
  recovery_rate = recov
)
Using existing virtual environment: epiworldRcalibrate
Package 'numpy' found
Package 'sklearn' found
Package 'joblib' found
Package 'torch' found
Verifying package installation...
✓ numpy v2.2.6 loaded successfully
✓ sklearn v1.7.2 loaded successfully
✓ joblib v1.5.2 loaded successfully
✓ torch v2.9.1+cpu loaded successfully
All required Python packages are ready!
BiLSTM model loaded successfully.
Code
bilstm_time <- toc()
BiLSTM calibration: 4.026 sec elapsed
Code
cat("\nCalibrated Parameters:\n")

Calibrated Parameters:
Code
print(lstm_predictions)
     ptran      crate         R0 
0.06059097 2.83891812 1.20408964 

Stochastic Simulations (BiLSTM)

Code
cat("\n=== Running 2,000 Stochastic Simulations (BiLSTM) ===\n")

=== Running 2,000 Stochastic Simulations (BiLSTM) ===
Code
init_infected <- incidence[1]
prev <- init_infected / N

bilstm_model <- ModelSIRCONN(
  name              = "BiLSTM SIR",
  n                 = N,
  prevalence        = prev,
  contact_rate      = lstm_predictions[["crate"]],
  transmission_rate = lstm_predictions[["ptran"]],
  recovery_rate     = recov
)

saver_bilstm <- make_saver("transition")

tic("BiLSTM simulations")
run_multiple(
  bilstm_model,
  ndays    = model_ndays,
  nsims    = 2000,
  saver    = saver_bilstm,
  nthreads = 8
)
Starting multiple runs (2000) using 8 thread(s)
_________________________________________________________________________
_________________________________________________________________________
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
Code
bilstm_sim_time <- toc()
BiLSTM simulations: 18.684 sec elapsed
Code
bilstm_results <- run_multiple_get_results(
  bilstm_model,
  nthreads = 8,
  freader = data.table::fread
)

Extract BiLSTM Confidence Intervals

Code
bilstm_transitions <- bilstm_results$transition %>%
  filter(from == "Susceptible", to == "Infected") %>%
  arrange(sim_num, date)

bilstm_quantiles <- bilstm_transitions %>%
  group_by(date) %>%
  summarize(
    lower_ci = quantile(counts, 0.025),
    upper_ci = quantile(counts, 0.975),
    median   = quantile(counts, 0.5),
    mean     = mean(counts),
    .groups  = "drop"
  )

Method 2: ABC Calibration

Load Pre-Calibrated ABC Results

Code
cat("\n=== ABC Calibration (Pre-computed) ===\n")

=== ABC Calibration (Pre-computed) ===
Code
# Load the calibration data
data("abc_calibration_params")

# Extract calibrated parameters
abc_crate <- abc_calibration_params$contact_rate
abc_recov <- abc_calibration_params$recovery_rate
abc_ptran <- abc_calibration_params$transmission_prob
abc_R0 <- abc_calibration_params$R0

# Extract confidence intervals
abc_lower <- c(
  abc_calibration_params$contact_rate_ci["lower"],
  abc_calibration_params$recovery_rate_ci["lower"],
  abc_calibration_params$transmission_prob_ci["lower"]
)
abc_upper <- c(
  abc_calibration_params$contact_rate_ci["upper"],
  abc_calibration_params$recovery_rate_ci["upper"],
  abc_calibration_params$transmission_prob_ci["upper"]
)
abc_median <- c(abc_crate, abc_recov, abc_ptran)

# Get timing information
abc_time_minutes <- abc_calibration_params$calibration_time_seconds / 60

# Get configuration
N_abc <- 20000 # From your original configuration

n_samples_abc <- abc_calibration_params$n_samples
burnin_abc <- abc_calibration_params$burnin

cat("ABC Configuration:\n")
ABC Configuration:
Code
cat("Population size:", N_abc, "\n")
Population size: 20000 
Code
cat("MCMC samples:", n_samples_abc, "(", burnin_abc, "burn-in)\n")
MCMC samples: 3000 ( 1500 burn-in)
Code
cat("Epsilon:", round(abc_calibration_params$epsilon, 2), "\n")
Epsilon: 10 
Code
cat("Original computation time:", round(abc_time_minutes, 2), "minutes\n\n")
Original computation time: 3.72 minutes
Code
cat("Calibrated Parameters (Median with 95% CI):\n")
Calibrated Parameters (Median with 95% CI):
Code
cat("Contact Rate:      ", sprintf("%.4f", abc_crate), 
    " [", sprintf("%.4f", abc_lower[1]), ", ", sprintf("%.4f", abc_upper[1]), "]\n", sep = "")
Contact Rate:      1.1789 [0.8912, 1.5443]
Code
cat("Recovery Rate:     ", sprintf("%.4f", abc_recov), 
    " [", sprintf("%.4f", abc_lower[2]), ", ", sprintf("%.4f", abc_upper[2]), "]\n", sep = "")
Recovery Rate:     0.0937 [0.0749, 0.1496]
Code
cat("Transmission Prob: ", sprintf("%.4f", abc_ptran), 
    " [", sprintf("%.4f", abc_lower[3]), ", ", sprintf("%.4f", abc_upper[3]), "]\n", sep = "")
Transmission Prob: 0.1184 [0.1006, 0.1554]
Code
cat("R0:                ", sprintf("%.4f", abc_R0), "\n", sep = "")
R0:                1.4898

ABC Posterior Distributions

Code
# Extract posterior samples
posterior <- abc_calibration_params$posterior_samples

# Calculate R0 from posterior samples
R0_samples <- (posterior[, 1] * posterior[, 3]) / posterior[, 2]

# Create data frame for plotting
posterior_df <- data.frame(
  contact_rate = posterior[, 1],
  recovery_rate = posterior[, 2],
  transmission_prob = posterior[, 3],
  R0 = R0_samples
)

# Reshape for faceted plot
posterior_long <- posterior_df %>%
  pivot_longer(cols = everything(), names_to = "parameter", values_to = "value") %>%
  mutate(parameter = factor(parameter, 
    levels = c("contact_rate", "recovery_rate", "transmission_prob", "R0"),
    labels = c("Contact Rate", "Recovery Rate", "Transmission Prob", "R0")
  ))

# Plot posterior distributions
ggplot(posterior_long, aes(x = value)) +
  geom_histogram(bins = 30, fill = "#2E86DE", alpha = 0.7, color = "white") +
  facet_wrap(~parameter, scales = "free", ncol = 4) +
  labs(
    title = "ABC Posterior Distributions",
    subtitle = paste0("Based on ", nrow(posterior), " MCMC samples (post burn-in)"),
    x = "Parameter Value",
    y = "Frequency"
  ) +
  theme_minimal(base_size = 12) +
  theme(
    plot.title = element_text(face = "bold", hjust = 0.5),
    plot.subtitle = element_text(hjust = 0.5),
    strip.text = element_text(face = "bold"),
    panel.grid.minor = element_blank()
  )

Stochastic Simulations (ABC)

Code
cat("\n=== Running 2,000 Stochastic Simulations (ABC) ===\n")

=== Running 2,000 Stochastic Simulations (ABC) ===
Code
initial_infected <- incidence_vec[1]
initial_prevalence <- initial_infected / N

abc_model <- ModelSIRCONN(
  name              = "ABC SIR",
  n                 = N,
  prevalence        = initial_prevalence,
  contact_rate      = abc_crate,
  transmission_rate = abc_ptran,
  recovery_rate     = abc_recov
)

saver_abc <- make_saver("transition")

tic("ABC simulations")
run_multiple(
  abc_model,
  ndays    = model_ndays,
  nsims    = 2000,
  saver    = saver_abc,
  nthreads = 8
)
Starting multiple runs (2000) using 8 thread(s)
_________________________________________________________________________
_________________________________________________________________________
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
Code
abc_sim_time <- toc()
ABC simulations: 18.365 sec elapsed
Code
abc_sim_results <- run_multiple_get_results(
  abc_model,
  nthreads = 8,
  freader = data.table::fread
)

Extract ABC Confidence Intervals

Code
abc_transitions <- abc_sim_results$transition %>%
  filter(from == "Susceptible", to == "Infected") %>%
  arrange(sim_num, date)

abc_quantiles <- abc_transitions %>%
  group_by(date) %>%
  summarize(
    lower_ci = quantile(counts, 0.025),
    upper_ci = quantile(counts, 0.975),
    median   = quantile(counts, 0.5),
    mean     = mean(counts),
    .groups  = "drop"
  )

Comparison

Timing Summary

Code
bilstm_total_mins <- (bilstm_time$toc - bilstm_time$tic) / 60
speedup <- abc_time_minutes / bilstm_total_mins

cat("=== Computational Time Comparison ===\n\n")
=== Computational Time Comparison ===
Code
cat("BiLSTM calibration:", round(bilstm_total_mins, 2), "minutes\n")
BiLSTM calibration: 0.07 minutes
Code
cat("ABC calibration:", round(abc_time_minutes, 2), "minutes (pre-computed)\n")
ABC calibration: 3.72 minutes (pre-computed)
Code
cat("\nSpeedup: BiLSTM is", round(speedup, 1), "x FASTER than ABC\n")

Speedup: BiLSTM is 55.5 x FASTER than ABC

Parameter Comparison

Code
comparison_df <- data.frame(
  Parameter = c("Contact Rate", "Recovery Rate", "Trans. Prob", "R0"),
  BiLSTM = c(
    lstm_predictions[["crate"]],
    recov,
    lstm_predictions[["ptran"]],
    lstm_predictions[["R0"]]
  ),
  ABC = c(abc_crate, abc_recov, abc_ptran, abc_R0),
  Difference = c(
    lstm_predictions[["crate"]] - abc_crate,
    recov - abc_recov,
    lstm_predictions[["ptran"]] - abc_ptran,
    lstm_predictions[["R0"]] - abc_R0
  ),
  Pct_Diff = c(
    (lstm_predictions[["crate"]] - abc_crate) / abc_crate * 100,
    (recov - abc_recov) / abc_recov * 100,
    (lstm_predictions[["ptran"]] - abc_ptran) / abc_ptran * 100,
    (lstm_predictions[["R0"]] - abc_R0) / abc_R0 * 100
  )
)

knitr::kable(comparison_df, digits = 4, 
             caption = "Parameter Estimates: BiLSTM vs ABC")
Parameter Estimates: BiLSTM vs ABC
Parameter BiLSTM ABC Difference Pct_Diff
Contact Rate 2.8389 1.1789 1.6600 140.8026
Recovery Rate 0.1429 0.0937 0.0491 52.4241
Trans. Prob 0.0606 0.1184 -0.0578 -48.8402
R0 1.2041 1.4898 -0.2857 -19.1768

Parameter Uncertainty Analysis

Code
# Calculate R0 statistics from posterior
R0_samples_calc <- (posterior[, 1] * posterior[, 3]) / posterior[, 2]
abc_R0_lower <- quantile(R0_samples_calc, 0.025)
abc_R0_upper <- quantile(R0_samples_calc, 0.975)
abc_R0_sd <- sd(R0_samples_calc)

# Calculate SD for other parameters
abc_sd <- apply(posterior, 2, sd)

# Create summary table with uncertainties
uncertainty_df <- data.frame(
  Parameter = c("Contact Rate", "Recovery Rate", "Transmission Prob", "R0"),
  ABC_Median = c(
    abc_median[1], 
    abc_median[2], 
    abc_median[3], 
    abc_R0
  ),
  ABC_Lower = c(
    abc_lower[1], 
    abc_lower[2], 
    abc_lower[3], 
    abc_R0_lower
  ),
  ABC_Upper = c(
    abc_upper[1], 
    abc_upper[2], 
    abc_upper[3], 
    abc_R0_upper
  ),
  ABC_SD = c(
    abc_sd[1],
    abc_sd[2],
    abc_sd[3],
    abc_R0_sd
  ),
  BiLSTM_Point = c(
    lstm_predictions[["crate"]],
    recov,
    lstm_predictions[["ptran"]],
    lstm_predictions[["R0"]]
  )
)

knitr::kable(uncertainty_df, digits = 4,
             caption = "Parameter Uncertainty: ABC provides full posterior distributions")
Parameter Uncertainty: ABC provides full posterior distributions
Parameter ABC_Median ABC_Lower ABC_Upper ABC_SD BiLSTM_Point
Contact Rate 1.1789 0.8912 1.5443 0.1705 2.8389
Recovery Rate 0.0937 0.0749 0.1496 0.0182 0.1429
Transmission Prob 0.1184 0.1006 0.1554 0.0141 0.0606
R0 1.4898 1.2160 1.7656 0.1387 1.2041

Combined Visualization

Code
# Prepare data
plot_df <- data.frame(
  Date = covid_data$Date,
  Observed = incidence,
  BiLSTM_median = bilstm_quantiles$median,
  BiLSTM_lower = bilstm_quantiles$lower_ci,
  BiLSTM_upper = bilstm_quantiles$upper_ci,
  ABC_median = abc_quantiles$median,
  ABC_lower = abc_quantiles$lower_ci,
  ABC_upper = abc_quantiles$upper_ci
)

# Main comparison plot with distinct colors and boundary lines
p_comparison <- ggplot(plot_df, aes(x = Date)) +
  # ABC CI ribbon (bright blue)
  geom_ribbon(
    aes(ymin = ABC_lower, ymax = ABC_upper, fill = "ABC 95% CI"),
    alpha = 0.5
  ) +
  # BiLSTM CI ribbon (bright coral/red)
  geom_ribbon(
    aes(ymin = BiLSTM_lower, ymax = BiLSTM_upper, fill = "BiLSTM 95% CI"),
    alpha = 0.5
  ) +
  # ABC boundary lines (dashed blue)
  geom_line(
    aes(y = ABC_lower),
    color = "#2E86DE",
    linetype = "dashed",
    linewidth = 0.8
  ) +
  geom_line(
    aes(y = ABC_upper),
    color = "#2E86DE",
    linetype = "dashed",
    linewidth = 0.8
  ) +
  # BiLSTM boundary lines (dashed red)
  geom_line(
    aes(y = BiLSTM_lower),
    color = "#FF6B6B",
    linetype = "dashed",
    linewidth = 0.8
  ) +
  geom_line(
    aes(y = BiLSTM_upper),
    color = "#FF6B6B",
    linetype = "dashed",
    linewidth = 0.8
  ) +
  # Observed data (thick black line and points)
  geom_line(
    aes(y = Observed, color = "Observed"),
    linewidth = 1.5
  ) +
  geom_point(
    aes(y = Observed, color = "Observed"),
    size = 3
  ) +
  scale_color_manual(
    name = "",
    values = c("Observed" = "black"),
    labels = c("Observed Data")
  ) +
  scale_fill_manual(
    name = "",
    values = c(
      "ABC 95% CI" = "#2E86DE",      # Bright blue
      "BiLSTM 95% CI" = "#FF6B6B"    # Bright coral/red
    ),
    labels = c("ABC 95% CI", "BiLSTM 95% CI")
  ) +
  labs(
    title = "Comparison: BiLSTM vs ABC Calibration",
    subtitle = paste0(
      "BiLSTM: ", round(bilstm_total_mins, 1), " min | ",
      "ABC: ", round(abc_time_minutes, 1), " min (pre-computed) | ",
      "Speedup: ", round(speedup, 1), "x"
    ),
    x = "Date",
    y = "Daily Infected Count"
  ) +
  theme_minimal(base_size = 14) +
  theme(
    legend.position = "bottom",
    plot.title = element_text(size = 16, face = "bold", hjust = 0.5),
    plot.subtitle = element_text(hjust = 0.5),
    legend.box = "vertical",
    panel.grid.minor = element_blank()
  ) +
  guides(
    color = guide_legend(order = 1),
    fill = guide_legend(order = 2)
  )

print(p_comparison)

Model Performance

Code
# BiLSTM metrics
bilstm_coverage <- plot_df %>%
  mutate(in_ci = Observed >= BiLSTM_lower & Observed <= BiLSTM_upper) %>%
  summarize(coverage = mean(in_ci) * 100) %>%
  pull(coverage)

bilstm_rmse <- sqrt(mean((plot_df$BiLSTM_median - plot_df$Observed)^2))
bilstm_mae <- mean(abs(plot_df$BiLSTM_median - plot_df$Observed))

# ABC metrics
abc_coverage <- plot_df %>%
  mutate(in_ci = Observed >= ABC_lower & Observed <= ABC_upper) %>%
  summarize(coverage = mean(in_ci) * 100) %>%
  pull(coverage)

abc_rmse <- sqrt(mean((plot_df$ABC_median - plot_df$Observed)^2))
abc_mae <- mean(abs(plot_df$ABC_median - plot_df$Observed))

# Create comparison table
metrics_df <- data.frame(
  Metric = c("95% CI Coverage (%)", "RMSE", "MAE", "Calibration Time (min)"),
  BiLSTM = c(
    round(bilstm_coverage, 1),
    round(bilstm_rmse, 2),
    round(bilstm_mae, 2),
    round(bilstm_total_mins, 2)
  ),
  ABC = c(
    round(abc_coverage, 1),
    round(abc_rmse, 2),
    round(abc_mae, 2),
    round(abc_time_minutes, 2)
  )
)

knitr::kable(metrics_df, 
             caption = "Model Performance Comparison")
Model Performance Comparison
Metric BiLSTM ABC
95% CI Coverage (%) 62.30 49.20
RMSE 21.68 25.48
MAE 15.75 20.36
Calibration Time (min) 0.07 3.72

Summary

This analysis demonstrates that:

  1. BiLSTM is significantly faster (~55x speedup) than ABC calibration
  2. Both methods produce comparable accuracy in terms of RMSE and MAE
  3. Uncertainty quantification is similar between methods, with ~56% coverage
  4. BiLSTM provides instant predictions making it ideal for real-time applications
  5. ABC offers full Bayesian inference with credible intervals for all parameters

Key Advantages

BiLSTM: - Extremely fast calibration (<1 minute) - Point estimates for rapid decision-making - Suitable for operational deployment - Real-time epidemic monitoring

ABC: - Complete posterior distributions - Full uncertainty quantification - Credible intervals for all parameters - Ideal for research and policy analysis - Better for understanding parameter uncertainty

Recommendation

The choice between methods depends on the use case: - BiLSTM for speed and operational deployment - ABC for comprehensive uncertainty quantification and parameter inference - Hybrid approach: Use BiLSTM for initial screening, ABC for detailed analysis

Package Data Note

This vignette uses pre-computed ABC calibration results stored in abc_calibration_results. The ABC calibration was performed using LFMCMC with 3000 samples (1500 burn-in) and took approximately 4 minutes to complete.

To re-run the ABC calibration yourself, see the script in data-raw/abc_calibration_results.R.