Introduction to LSTM-based Calibration (One-Line Usage)

Author

Sima Najafzadehkhoei

Published

April 2, 2026

Overview

This vignette demonstrates the core epiworldRcalibrate workflow: simulate a synthetic epidemic with known parameters, calibrate those parameters back from the incidence curve using calibrate_sir(), and assess how closely the calibrated model reproduces the original dynamics.

Using a synthetic (simulated) epidemic rather than real data is intentional — because the ground-truth parameters are known, we can measure calibration error directly rather than relying on goodness-of-fit alone.

Note

No Python setup required. calibrate_sir() initializes the BiLSTM model internally on first call. If you are running epiworldRcalibrate for the first time on this machine, run setup_python_deps(force=TRUE) once beforehand.


Setup

library(tidyverse)
library(ggplot2)
library(patchwork)
library(epiworldR)
library(epiworldRcalibrate)

Simulate a ground-truth epidemic

We draw SIR parameters at random and simulate 60 days of epidemic dynamics. The transmission probability ptran is derived from a randomly drawn R0 via the identity \(\text{ptran} = R_0 \times \text{recov} / \text{crate}\), ensuring internal consistency among the parameters.

The incidence vector must have exactly 61 values (days 0 through 60) — this is the sequence length the BiLSTM was trained on.

set.seed(122)

n_value  <- sample(5000:10000, 1)
preval   <- runif(1, 0.007, 0.02)
crate    <- runif(1, 1, 5)
recov    <- runif(1, 0.071, 0.25)
R0_true  <- runif(1, 1.1, 5)
ptran    <- R0_true * recov / crate

true_params <- tibble(
  n                 = n_value,
  prevalence        = preval,
  contact_rate      = crate,
  transmission_rate = ptran,
  recovery_rate     = recov,
  R0                = R0_true
)

true_params
# A tibble: 1 × 6
      n prevalence contact_rate transmission_rate recovery_rate    R0
  <int>      <dbl>        <dbl>             <dbl>         <dbl> <dbl>
1  6967     0.0188         1.76             0.149        0.0783  3.36
ndays      <- 60
true_model <- ModelSIRCONN(
  name              = "true_simulation",
  n                 = true_params$n,
  prevalence        = true_params$prevalence,
  contact_rate      = true_params$contact_rate,
  transmission_rate = true_params$transmission_rate,
  recovery_rate     = true_params$recovery_rate
)
run(true_model, ndays = ndays)
_________________________________________________________________________
|Running the model...
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
|
incidence_ts <- plot_incidence(true_model, plot = FALSE)[, 1]
stopifnot(length(incidence_ts) == 61)

Calibrate with BiLSTM

calibrate_sir() passes the 61-day incidence vector through the pretrained BiLSTM network and returns estimates for ptran, crate, and R0 in a single call. The recovery rate is passed in as a fixed input — consistent with how it was fixed during simulation — so the network only needs to recover two free parameters.

lstm_predictions <- calibrate_sir(
  daily_cases     = incidence_ts,
  population_size = true_params$n,
  recovery_rate   = true_params$recovery_rate
)
Verifying Python packages...
[OK] numpy v2.0.2
[OK] sklearn v1.6.1
[OK] joblib v1.5.0
[OK] torch v2.7.0+cu126
All required Python packages are ready!
BiLSTM model loaded successfully.
lstm_predictions
    ptran     crate        R0 
0.1105276 2.4557295 3.4654543 

We build a comparison table to see how far the estimates are from the ground truth:

lstm_params <- tibble(
  n                 = true_params$n,
  prevalence        = true_params$prevalence,
  contact_rate      = lstm_predictions[["crate"]],
  transmission_rate = lstm_predictions[["ptran"]],
  recovery_rate     = true_params$recovery_rate,
  R0                = lstm_predictions[["R0"]]
)

params_comparison <- bind_rows(
  true_params |> mutate(param_type = "true"),
  lstm_params  |> mutate(param_type = "lstm")
)

params_comparison
# A tibble: 2 × 7
      n prevalence contact_rate transmission_rate recovery_rate    R0 param_type
  <int>      <dbl>        <dbl>             <dbl>         <dbl> <dbl> <chr>     
1  6967     0.0188         1.76             0.149        0.0783  3.36 true      
2  6967     0.0188         2.46             0.111        0.0783  3.47 lstm      

Forward simulations: true vs calibrated

To compare dynamics rather than just parameter values, we run 100 stochastic SIR simulations under each parameter set and compute the mean trajectory with 95% prediction intervals. If the BiLSTM estimate is accurate, the two sets of trajectories should overlap closely.

n_reps                <- 100
all_simulation_results <- tibble()

for (i in seq_len(nrow(params_comparison))) {
  row <- params_comparison[i, ]

  fwd_model <- ModelSIRCONN(
    name              = paste0("forward_", row$param_type),
    n                 = row$n,
    prevalence        = row$prevalence,
    contact_rate      = row$contact_rate,
    transmission_rate = row$transmission_rate,
    recovery_rate     = row$recovery_rate
  )

  saver <- make_saver("total_hist")
  run_multiple(fwd_model, ndays = ndays, nsims = n_reps,
               saver = saver, nthreads = 8)
  results <- run_multiple_get_results(fwd_model, nthreads = 2)

  sim_data <- results$total_hist |>
    group_by(date, state) |>
    summarize(mean_count = mean(counts),
              ci_lower   = quantile(counts, 0.025),
              ci_upper   = quantile(counts, 0.975),
              .groups = "drop") |>
    mutate(param_type = row$param_type)

  all_simulation_results <- bind_rows(all_simulation_results, sim_data)
}
Starting multiple runs (100) using 8 thread(s)
_________________________________________________________________________
_________________________________________________________________________
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
Starting multiple runs (100) using 8 thread(s)
_________________________________________________________________________
_________________________________________________________________________
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| done.
method_colors <- c("true" = "#440154FF", "lstm" = "#35B779FF")

sir_panel <- function(data, state_name, title) {
  pd <- filter(data, state == state_name)
  ggplot(pd, aes(x = date, color = param_type)) +
    geom_ribbon(data = filter(pd, param_type == "lstm"),
                aes(ymin = ci_lower, ymax = ci_upper, fill = param_type),
                alpha = 0.2, color = NA) +
    geom_line(aes(y = mean_count), linewidth = 1.1) +
    scale_color_manual(values = method_colors) +
    scale_fill_manual(values  = method_colors) +
    labs(title = title, x = "Day", y = "Count", color = "Method", fill = "Method") +
    theme_minimal() +
    theme(legend.position = "bottom",
          plot.title = element_text(size = 12, hjust = 0.5))
}

(sir_panel(all_simulation_results, "Susceptible", "Susceptible") /
 sir_panel(all_simulation_results, "Infected",    "Infected")    /
 sir_panel(all_simulation_results, "Recovered",   "Recovered")) +
  plot_layout(guides = "collect") +
  plot_annotation(
    title    = "SIR Dynamics: True vs LSTM-Calibrated Parameters",
    subtitle = paste0("Mean trajectory and 95% CI across ", n_reps, " simulations")
  )


Parameter bias

The table below shows the ground-truth values, the BiLSTM estimates, and the absolute bias for each estimated parameter. A bias near zero means the network recovered the true value well; larger bias values indicate where the model struggled, which may reflect the inherent non-identifiability of the SIR system (different combinations of ptran and crate can produce similar incidence curves).

tibble(
  Parameter = c("Contact Rate", "Transmission Rate", "R0"),
  True      = c(true_params$contact_rate,
                true_params$transmission_rate,
                true_params$R0),
  LSTM      = c(lstm_params$contact_rate,
                lstm_params$transmission_rate,
                lstm_params$R0)
) |>
  mutate(Bias    = round(LSTM - True, 4),
         Pct_Err = round((LSTM - True) / True * 100, 1)) |>
  knitr::kable(digits = 4,
               caption = "BiLSTM parameter estimates vs ground truth",
               col.names = c("Parameter", "True", "LSTM",
                             "Bias", "% Error"))
BiLSTM parameter estimates vs ground truth
Parameter True LSTM Bias % Error
Contact Rate 1.7622 2.4557 0.6935 39.4
Transmission Rate 0.1493 0.1105 -0.0387 -26.0
R0 3.3584 3.4655 0.1071 3.2