Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ For applicable tasks, we provide the relevant metrics for model calibration, as
Among these we also provide metrics related to uncertainty quantification, for model calibration, as well as metrics that measure the quality of prediction sets
We also provide other metrics specically for healthcare
tasks, such as drug drug interaction (DDI) rate.
For synthetic (generative) EHR data, we provide privacy, utility, and statistical
fidelity metrics.


.. toctree::
Expand All @@ -19,3 +21,4 @@ tasks, such as drug drug interaction (DDI) rate.
metrics/pyhealth.metrics.prediction_set
metrics/pyhealth.metrics.fairness
metrics/pyhealth.metrics.interpretability
metrics/pyhealth.metrics.generative
25 changes: 25 additions & 0 deletions docs/api/metrics/pyhealth.metrics.generative.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
pyhealth.metrics.generative
===================================

Evaluation metrics for synthetic (generative) EHR data, covering privacy,
utility, and statistical fidelity.

.. currentmodule:: pyhealth.metrics.generative

.. autofunction:: evaluate_synthetic_ehr

Privacy metrics
-------------------------------------

.. autofunction:: calc_nnaar

.. autofunction:: calc_membership_inference

.. autofunction:: compute_discriminator_privacy

Utility and fidelity metrics
-------------------------------------

.. autofunction:: compute_mle

.. autofunction:: compute_prevalence_metrics
120 changes: 120 additions & 0 deletions examples/halo_mimic3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Example: train HALO on MIMIC-III and generate synthetic patients.

This example demonstrates:
1. Loading MIMIC-III data
2. Applying the EHRGenerationMIMIC3 task (per-visit ICD-9 code sequences)
3. Creating a SampleDataset with a NestedSequenceProcessor
4. Training the HALO generator with its custom training loop
5. Generating synthetic patients
6. Evaluating the synthetic data with the generative metrics suite
"""

import pandas as pd

from pyhealth.datasets import MIMIC3Dataset, split_by_patient
from pyhealth.metrics.generative import evaluate_synthetic_ehr
from pyhealth.models import HALO
from pyhealth.tasks import EHRGenerationMIMIC3

if __name__ == "__main__":
# STEP 1: Load MIMIC-III base dataset
base_dataset = MIMIC3Dataset(
root="/srv/local/data/MIMIC-III/mimic-iii-clinical-database-1.4",
tables=["diagnoses_icd"],
dev=True,
)

# STEP 2: Apply the EHR generation task (unconditional, no labels).
# This task is shared by all generators in pyhealth.models.generators.
sample_dataset = base_dataset.set_task(EHRGenerationMIMIC3())
print(f"Total samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

sample = sample_dataset[0]
print("\nSample structure:")
print(f" Patient ID: {sample['patient_id']}")
print(f" Visits tensor shape: {tuple(sample['visits'].shape)}")

# STEP 3: Split dataset by patient
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)

# STEP 4: Initialize HALO (small config for the dev subset)
model = HALO(
dataset=sample_dataset,
embed_dim=128,
n_heads=4,
n_layers=4,
n_ctx=48,
batch_size=16,
epochs=5,
lr=1e-4,
save_dir="./halo_save",
)
num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel initialized with {num_params} parameters")

# STEP 5: Train with HALO's custom loop (saves best checkpoint to save_dir)
model.train_model(train_dataset, val_dataset=val_dataset)

# STEP 6: Generate synthetic patients (one per real training patient).
synthetic = model.generate(num_samples=len(train_dataset), random_sampling=True)
print("\nGenerated synthetic patients (first 3):")
for patient in synthetic[:3]:
print(f" {patient['patient_id']}: {len(patient['visits'])} visits")
print(f" {patient['visits']}")

# STEP 7: Evaluate the synthetic data with the generative metrics suite.
# evaluate_synthetic_ehr expects flat dataframes with one row per code:
# columns = [id, time, visit_codes, labels]
# `labels` is a placeholder -- the utility metric overwrites it with the
# next-visit prediction target.
index_to_code = {
v: k for k, v in sample_dataset.input_processors["visits"].code_vocab.items()
}

def real_subset_to_records(subset):
for sample in subset:
pid = str(sample["patient_id"])
visits_tensor = sample["visits"]
for t, visit in enumerate(visits_tensor.tolist()):
for idx in visit:
code = index_to_code.get(int(idx))
if code in (None, "<pad>", "<unk>"):
continue
yield {"id": pid, "time": t, "visit_codes": code, "labels": 0}

def synthetic_to_records(patients):
for p in patients:
pid = str(p["patient_id"])
for t, visit in enumerate(p["visits"]):
for code in visit:
yield {"id": pid, "time": t, "visit_codes": code, "labels": 0}

schema = {"visit_codes": str, "labels": int, "time": int, "id": str}
train_df = pd.DataFrame(real_subset_to_records(train_dataset)).astype(schema)
test_df = pd.DataFrame(real_subset_to_records(test_dataset)).astype(schema)
syn_df = pd.DataFrame(synthetic_to_records(synthetic)).astype(schema)
print(
f"\nEval rows -- train: {len(train_df)}, test: {len(test_df)}, "
f"synthetic: {len(syn_df)}"
)

# sample_size / n_bootstraps / n_runs are kept small for the dev subset;
# raise them when running on the full MIMIC-III cohort.
results = evaluate_synthetic_ehr(
train_ehr=train_df,
test_ehr=test_df,
syn_ehr=syn_df,
sample_size=min(30, len(train_dataset), len(test_dataset)),
mode="lstm",
metrics="all",
lstm_params={"embed_dim": 16, "hidden_dim": 16, "batch_size": 16, "epochs": 3},
n_bootstraps=5,
n_runs=3,
)
print("\nGenerative metrics (mean +/- std):")
for name, (mean, std) in results.items():
print(f" {name:30s} {mean:.4f} +/- {std:.4f}")
14 changes: 14 additions & 0 deletions pyhealth/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .binary import binary_metrics_fn
from .drug_recommendation import ddi_rate_score
from .generative import (
calc_membership_inference,
calc_nnaar,
compute_discriminator_privacy,
compute_mle,
compute_prevalence_metrics,
evaluate_synthetic_ehr,
)
from .interpretability import (
ComprehensivenessMetric,
Evaluator,
Expand All @@ -17,6 +25,12 @@
__all__ = [
"binary_metrics_fn",
"ddi_rate_score",
"calc_nnaar",
"calc_membership_inference",
"compute_discriminator_privacy",
"compute_mle",
"compute_prevalence_metrics",
"evaluate_synthetic_ehr",
"ComprehensivenessMetric",
"SufficiencyMetric",
"RemovalBasedMetric",
Expand Down
188 changes: 188 additions & 0 deletions pyhealth/metrics/generative/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Evaluation metrics for synthetic (generative) EHR data.

This subpackage provides metrics for assessing synthetic electronic health
record (EHR) data along three axes:

- **Privacy** (:mod:`pyhealth.metrics.generative.privacy`): NNAAR,
membership inference, and discriminator-based adversarial accuracy.
- **Utility / fidelity** (:mod:`pyhealth.metrics.generative.utility`):
machine learning efficacy (TRTR vs TSTR) and code-prevalence similarity.

The convenience function :func:`evaluate_synthetic_ehr` runs the full suite
and returns a single merged dictionary of ``{metric_name: (mean, std)}``.

Note:
The MLE (utility) component is currently hard-coded to next-visit
prediction and is therefore only meaningful for sequential generators
(HALO, GPT2, PromptEHR). It will be expanded to support pluggable
downstream tasks so that bag-of-codes generators (MedGAN, CorGAN) can
be evaluated with a static-label task (e.g. mortality, readmission).
Until then, prefer the privacy and prevalence metrics when evaluating
MedGAN/CorGAN output.
"""

import logging
from typing import Dict, Optional, Tuple

import pandas as pd

from .privacy import (
calc_membership_inference,
calc_nnaar,
compute_discriminator_privacy,
)
from .utility import compute_mle, compute_prevalence_metrics
from .utils import train_lstm_model, train_sklearn_model

logger = logging.getLogger(__name__)

__all__ = [
"calc_nnaar",
"calc_membership_inference",
"compute_discriminator_privacy",
"compute_mle",
"compute_prevalence_metrics",
"evaluate_synthetic_ehr",
]


def evaluate_synthetic_ehr(
train_ehr: pd.DataFrame,
test_ehr: pd.DataFrame,
syn_ehr: pd.DataFrame,
subject_col: str = "id",
visit_col: str = "time",
code_col: str = "visit_codes",
label_col: str = "labels",
sample_size: int = 1000,
mode: str = "lstm",
metrics: str = "all",
lstm_params: Optional[Dict] = None,
sklearn_params: Optional[Dict] = None,
n_bootstraps: int = 100,
n_runs: int = 5,
) -> Dict[str, Tuple[float, float]]:
"""Runs the full synthetic-EHR evaluation suite.

Computes privacy and/or utility metrics comparing synthetic EHR data
against real train/test data, and returns a single merged dictionary.

Args:
train_ehr: Real training EHR dataframe.
test_ehr: Real held-out test EHR dataframe.
syn_ehr: Synthetic EHR dataframe.
subject_col: Column name for patient/subject identifiers.
visit_col: Column name for visit/timestep identifiers.
code_col: Column name for the medical codes.
label_col: Column name for the label.
sample_size: Number of patients sampled per dataset for the
privacy metrics.
mode: Predictive backbone for the utility metrics; ``"lstm"`` uses the
built-in LSTM classifier, ``"rf"`` uses a random forest.
metrics: Which metric group to compute: ``"all"``, ``"privacy"`` or
``"utility"``.
lstm_params: Optional overrides for the LSTM (``embed_dim``,
``hidden_dim``, ``batch_size``, ``epochs``).
sklearn_params: Optional overrides for the sklearn model (``model``).
n_bootstraps: Number of bootstrap resamples for the utility metrics.
n_runs: Number of sampling runs for the privacy metrics.

Returns:
Dictionary mapping each metric name to a ``(mean, std)`` tuple.

Raises:
ValueError: If ``metrics`` or ``mode`` is not a recognized value.
"""
if metrics not in ("all", "privacy", "utility"):
raise ValueError(
f"Unknown metrics group: {metrics!r}. "
"Expected 'all', 'privacy' or 'utility'."
)
if mode not in ("lstm", "rf"):
raise ValueError(f"Unknown mode: {mode!r}. Expected 'lstm' or 'rf'.")

lstm_params = lstm_params or {}
sklearn_params = sklearn_params or {}
final_output: Dict[str, Tuple[float, float]] = {}

if metrics in ("all", "privacy"):
final_output.update(
calc_nnaar(
train_ehr,
test_ehr,
syn_ehr,
subject_col=subject_col,
visit_col=visit_col,
code_col=code_col,
label_col=label_col,
sample_size=sample_size,
n_runs=n_runs,
)
)
final_output.update(
calc_membership_inference(
train_ehr,
test_ehr,
syn_ehr,
subject_col=subject_col,
visit_col=visit_col,
code_col=code_col,
label_col=label_col,
num_attack_samples=sample_size,
n_runs=n_runs,
)
)

if metrics in ("all", "utility"):
if mode == "lstm":
train_fn = train_lstm_model
train_kwargs = {
"embed_dim": lstm_params.get("embed_dim", 32),
"hidden_dim": lstm_params.get("hidden_dim", 32),
"batch_size": lstm_params.get("batch_size", 32),
"epochs": lstm_params.get("epochs", 5),
"verbose": False,
}
else:
train_fn = train_sklearn_model
train_kwargs = {"model": sklearn_params.get("model", "rf")}

final_output.update(
compute_mle(
train_fn=train_fn,
train_ehr=train_ehr,
test_ehr=test_ehr,
syn_ehr=syn_ehr,
subject_col=subject_col,
visit_col=visit_col,
code_col=code_col,
label_col=label_col,
n_bootstraps=n_bootstraps,
**train_kwargs,
)
)
final_output.update(
compute_discriminator_privacy(
train_fn=train_fn,
train_ehr=train_ehr,
test_ehr=test_ehr,
syn_ehr=syn_ehr,
subject_col=subject_col,
visit_col=visit_col,
code_col=code_col,
label_col=label_col,
n_bootstraps=n_bootstraps,
**train_kwargs,
)
)
final_output.update(
compute_prevalence_metrics(
train_ehr,
syn_ehr,
subject_col=subject_col,
code_col=code_col,
n_bootstraps=n_bootstraps,
)
)

return final_output
Loading
Loading