Skip to content

Add synthetic-EHR generative evaluation metrics#1148

Open
chufangao wants to merge 3 commits into
sunlabuiuc:masterfrom
chufangao:chil26_evals2
Open

Add synthetic-EHR generative evaluation metrics#1148
chufangao wants to merge 3 commits into
sunlabuiuc:masterfrom
chufangao:chil26_evals2

Conversation

@chufangao
Copy link
Copy Markdown
Collaborator

@chufangao chufangao commented May 18, 2026

Summary

Adds a complete synthetic-EHR pipeline to PyHealth: a shared generation
task, five generator models, a generative metrics subpackage, an
end-to-end example, and tests for all of the above.

What's new

1. Shared task — pyhealth/tasks/generate_ehr.py

  • EHRGeneration — base class that emits one sample per patient,
    {"patient_id", "visits": [[code, ...], ...]}, processed by
    NestedSequenceProcessor. Unconditional (no output labels).
  • EHRGenerationMIMIC3 / EHRGenerationMIMIC4 — dataset-specific
    subclasses; only the event type and code attribute differ.
  • Helpers decode_dataset and to_evaluation_dataframe convert real
    SampleDatasets and any generator's generate() output into the
    long-form (id, time, visit_codes, labels) dataframe that the metrics
    consume.

2. Generator models — pyhealth/models/generators/

All five share the EHRGeneration task and expose the same
train_model(train_dataset, val_dataset=None) / generate(num_samples) API.

Model File Family
HALO halo.py Autoregressive multi-hot transformer (Theodorou et al., 2023)
GPT2 gpt2.py Causal-LM sequence model over flattened code streams
PromptEHR promptehr.py BART denoising autoencoder with soft prompts (Wang & Sun, EMNLP'22)
MedGAN medgan.py Bag-of-codes VAE + GAN (Choi et al., MLHC'17)
CorGAN corgan.py Bag-of-codes correlation-aware GAN (Torfi & Fox, FLAIRS'20)

Each model is registered as a sub-module so .parameters()/.to(device)
work, owns its own training loop (best-checkpoint saving included), and
returns decoded code strings from generate().

3. Metrics subpackage — pyhealth/metrics/generative/

Evaluation along three axes:

  • privacy.pycalc_nnaar (Nearest Neighbor Adversarial Accuracy
    Risk), calc_membership_inference (membership inference attack), and
    compute_discriminator_privacy (real-vs-synthetic discriminator score).
  • utility.pycompute_mle (machine learning efficacy, TRTR vs TSTR)
    and compute_prevalence_metrics (code-prevalence similarity: R², Pearson,
    RMSE).
  • utils.py — shared data prep, a self-contained LSTM classifier, and a
    random-forest baseline.
  • evaluate_synthetic_ehr() — convenience orchestrator that runs the
    full suite and returns one merged {metric: (mean, std)} dict.

Metrics operate on flat dataframes (id, time, visit_codes, labels),
so they work for any generator. Public functions are re-exported from
pyhealth.metrics.

Port cleanups

  • logging instead of bare print calls.
  • Fixed a latent CUDA crash in the LSTM eval loop (.cpu().numpy()).
  • Replaced scipy.stats.pearsonr with numpy.corrcoef to avoid an
    undeclared scipy dependency.
  • Input dataframes are copied instead of mutated in place.
  • Google-style docstrings, type hints, PEP8 (≤88 chars).

4. End-to-end example — examples/halo_mimic3.py

Loads MIMIC-III → applies EHRGenerationMIMIC3 → trains HALO → generates
synthetic patients → runs the full evaluate_synthetic_ehr suite and prints
each (mean, std) metric. Verified to run end-to-end on the dev=True
subset (NNAAR, MIA, MLE TRTR/TSTR, discriminator privacy, prevalence all
produced).

5. Tests

All passing.

File Tests Covers
tests/core/test_generative_metrics.py 18 per-metric + orchestrator + behavioral sanity
tests/core/test_halo.py 7 init, train, generate, encoding
tests/core/test_gpt2.py 7 init, train, generate
tests/core/test_promptehr.py 7 init, train, generate, prompt handling
tests/core/test_medgan.py 9 VAE pretrain + GAN train + generate
tests/core/test_corgan.py 12 conv autoencoder + GAN + generate

The metrics suite includes 5 behavioral tests that verify each metric
responds sensibly across three synthetic datasets — an exact copy of the
training data, a similar set (~15% of codes perturbed), and a different
set (disjoint code vocabulary):

Metric Verified behavior
Prevalence RMSE 0 → 0.03 → 0.26; exact copy → RMSE 0, R²/Pearson = 1
NNAAR Flags memorization: 1.0 → 0.1 → 0.0
Membership inference Attack accuracy 1.0 → 0.94 → 0.46 (chance for unrelated data)
Discriminator privacy Disjoint-vocabulary data trivially flagged; real-derived data is not
MLE (utility) Exact copy reproduces real utility exactly; ratio degrades 1.0 → 0.98 → 0.81

Notes

  • MLE currently hard-codes next-visit prediction, which is degenerate for
    bag-of-codes generators (MedGAN, CorGAN) — those should be evaluated with
    metrics="privacy" plus the prevalence metrics. A future revision will
    let callers plug in a static-label task (mortality, readmission,
    "ever diagnosed with X") so MLE is meaningful for both families.
  • The discriminator-privacy score is degenerate for exact copies (the
    model predicts a constant on identical features, so the score reflects
    test-split balance rather than 0.5). The behavioral test asserts the
    robust direction — disjoint synthetic data is cleanly flagged while
    real-derived data is not.

chufangao and others added 2 commits May 17, 2026 23:46
Adds pyhealth/metrics/generative/, a subpackage for evaluating synthetic
EHR data along privacy, utility, and statistical-fidelity axes:

- privacy.py: NNAAR, membership inference attack, discriminator privacy
- utility.py: machine learning efficacy (TRTR vs TSTR), code-prevalence
  similarity (R2, Pearson, RMSE)
- utils.py: shared data prep, an LSTM classifier, and a random-forest
  baseline
- evaluate_synthetic_ehr(): convenience orchestrator for the full suite

These functions are ported from a standalone evaluation script. The
MIMIC-specific data-loading/CLI glue is dropped; the metrics work on any
flat EHR dataframe. Public functions are re-exported from
pyhealth.metrics. Adds unit tests in tests/core/test_generative_metrics.py
and Sphinx docs.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Adds pyhealth/metrics/generative/, a subpackage for evaluating synthetic
EHR data along privacy, utility, and statistical-fidelity axes:

- privacy.py: NNAAR, membership inference attack, discriminator privacy
- utility.py: machine learning efficacy (TRTR vs TSTR), code-prevalence
  similarity (R2, Pearson, RMSE)
- utils.py: shared data prep, an LSTM classifier, and a random-forest
  baseline
- evaluate_synthetic_ehr(): convenience orchestrator for the full suite

These functions are ported from a standalone evaluation script. The
MIMIC-specific data-loading/CLI glue is dropped; the metrics work on any
flat EHR dataframe. Public functions are re-exported from
pyhealth.metrics. Adds unit tests in tests/core/test_generative_metrics.py
and Sphinx docs.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>


def calc_nnaar(
train_ehr: pd.DataFrame,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we do this in dataframes? Maybe, we can chat about this.

Copy link
Copy Markdown
Collaborator Author

@chufangao chufangao May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we could change it to the nested sequence actually since for this task, it all reduces down to a sequence of sequences. Edit, actually maybe we should keep it for utility calculation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants