diff --git a/docs/api/metrics.rst b/docs/api/metrics.rst index 1767e0026..9e6bc160a 100644 --- a/docs/api/metrics.rst +++ b/docs/api/metrics.rst @@ -7,6 +7,8 @@ For applicable tasks, we provide the relevant metrics for model calibration, as Among these we also provide metrics related to uncertainty quantification, for model calibration, as well as metrics that measure the quality of prediction sets We also provide other metrics specically for healthcare tasks, such as drug drug interaction (DDI) rate. +For synthetic (generative) EHR data, we provide privacy, utility, and statistical +fidelity metrics. .. toctree:: @@ -19,3 +21,4 @@ tasks, such as drug drug interaction (DDI) rate. metrics/pyhealth.metrics.prediction_set metrics/pyhealth.metrics.fairness metrics/pyhealth.metrics.interpretability + metrics/pyhealth.metrics.generative diff --git a/docs/api/metrics/pyhealth.metrics.generative.rst b/docs/api/metrics/pyhealth.metrics.generative.rst new file mode 100644 index 000000000..85e448a52 --- /dev/null +++ b/docs/api/metrics/pyhealth.metrics.generative.rst @@ -0,0 +1,25 @@ +pyhealth.metrics.generative +=================================== + +Evaluation metrics for synthetic (generative) EHR data, covering privacy, +utility, and statistical fidelity. + +.. currentmodule:: pyhealth.metrics.generative + +.. autofunction:: evaluate_synthetic_ehr + +Privacy metrics +------------------------------------- + +.. autofunction:: calc_nnaar + +.. autofunction:: calc_membership_inference + +.. autofunction:: compute_discriminator_privacy + +Utility and fidelity metrics +------------------------------------- + +.. autofunction:: compute_mle + +.. autofunction:: compute_prevalence_metrics diff --git a/examples/halo_mimic3.py b/examples/halo_mimic3.py new file mode 100644 index 000000000..b6642a2cf --- /dev/null +++ b/examples/halo_mimic3.py @@ -0,0 +1,120 @@ +"""Example: train HALO on MIMIC-III and generate synthetic patients. + +This example demonstrates: +1. Loading MIMIC-III data +2. Applying the EHRGenerationMIMIC3 task (per-visit ICD-9 code sequences) +3. Creating a SampleDataset with a NestedSequenceProcessor +4. Training the HALO generator with its custom training loop +5. Generating synthetic patients +6. Evaluating the synthetic data with the generative metrics suite +""" + +import pandas as pd + +from pyhealth.datasets import MIMIC3Dataset, split_by_patient +from pyhealth.metrics.generative import evaluate_synthetic_ehr +from pyhealth.models import HALO +from pyhealth.tasks import EHRGenerationMIMIC3 + +if __name__ == "__main__": + # STEP 1: Load MIMIC-III base dataset + base_dataset = MIMIC3Dataset( + root="/srv/local/data/MIMIC-III/mimic-iii-clinical-database-1.4", + tables=["diagnoses_icd"], + dev=True, + ) + + # STEP 2: Apply the EHR generation task (unconditional, no labels). + # This task is shared by all generators in pyhealth.models.generators. + sample_dataset = base_dataset.set_task(EHRGenerationMIMIC3()) + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + sample = sample_dataset[0] + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f" Visits tensor shape: {tuple(sample['visits'].shape)}") + + # STEP 3: Split dataset by patient + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + # STEP 4: Initialize HALO (small config for the dev subset) + model = HALO( + dataset=sample_dataset, + embed_dim=128, + n_heads=4, + n_layers=4, + n_ctx=48, + batch_size=16, + epochs=5, + lr=1e-4, + save_dir="./halo_save", + ) + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") + + # STEP 5: Train with HALO's custom loop (saves best checkpoint to save_dir) + model.train_model(train_dataset, val_dataset=val_dataset) + + # STEP 6: Generate synthetic patients (one per real training patient). + synthetic = model.generate(num_samples=len(train_dataset), random_sampling=True) + print("\nGenerated synthetic patients (first 3):") + for patient in synthetic[:3]: + print(f" {patient['patient_id']}: {len(patient['visits'])} visits") + print(f" {patient['visits']}") + + # STEP 7: Evaluate the synthetic data with the generative metrics suite. + # evaluate_synthetic_ehr expects flat dataframes with one row per code: + # columns = [id, time, visit_codes, labels] + # `labels` is a placeholder -- the utility metric overwrites it with the + # next-visit prediction target. + index_to_code = { + v: k for k, v in sample_dataset.input_processors["visits"].code_vocab.items() + } + + def real_subset_to_records(subset): + for sample in subset: + pid = str(sample["patient_id"]) + visits_tensor = sample["visits"] + for t, visit in enumerate(visits_tensor.tolist()): + for idx in visit: + code = index_to_code.get(int(idx)) + if code in (None, "", ""): + continue + yield {"id": pid, "time": t, "visit_codes": code, "labels": 0} + + def synthetic_to_records(patients): + for p in patients: + pid = str(p["patient_id"]) + for t, visit in enumerate(p["visits"]): + for code in visit: + yield {"id": pid, "time": t, "visit_codes": code, "labels": 0} + + schema = {"visit_codes": str, "labels": int, "time": int, "id": str} + train_df = pd.DataFrame(real_subset_to_records(train_dataset)).astype(schema) + test_df = pd.DataFrame(real_subset_to_records(test_dataset)).astype(schema) + syn_df = pd.DataFrame(synthetic_to_records(synthetic)).astype(schema) + print( + f"\nEval rows -- train: {len(train_df)}, test: {len(test_df)}, " + f"synthetic: {len(syn_df)}" + ) + + # sample_size / n_bootstraps / n_runs are kept small for the dev subset; + # raise them when running on the full MIMIC-III cohort. + results = evaluate_synthetic_ehr( + train_ehr=train_df, + test_ehr=test_df, + syn_ehr=syn_df, + sample_size=min(30, len(train_dataset), len(test_dataset)), + mode="lstm", + metrics="all", + lstm_params={"embed_dim": 16, "hidden_dim": 16, "batch_size": 16, "epochs": 3}, + n_bootstraps=5, + n_runs=3, + ) + print("\nGenerative metrics (mean +/- std):") + for name, (mean, std) in results.items(): + print(f" {name:30s} {mean:.4f} +/- {std:.4f}") diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index da8da0f5b..f04b6ba6a 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -1,5 +1,13 @@ from .binary import binary_metrics_fn from .drug_recommendation import ddi_rate_score +from .generative import ( + calc_membership_inference, + calc_nnaar, + compute_discriminator_privacy, + compute_mle, + compute_prevalence_metrics, + evaluate_synthetic_ehr, +) from .interpretability import ( ComprehensivenessMetric, Evaluator, @@ -17,6 +25,12 @@ __all__ = [ "binary_metrics_fn", "ddi_rate_score", + "calc_nnaar", + "calc_membership_inference", + "compute_discriminator_privacy", + "compute_mle", + "compute_prevalence_metrics", + "evaluate_synthetic_ehr", "ComprehensivenessMetric", "SufficiencyMetric", "RemovalBasedMetric", diff --git a/pyhealth/metrics/generative/__init__.py b/pyhealth/metrics/generative/__init__.py new file mode 100644 index 000000000..4711417b0 --- /dev/null +++ b/pyhealth/metrics/generative/__init__.py @@ -0,0 +1,188 @@ +"""Evaluation metrics for synthetic (generative) EHR data. + +This subpackage provides metrics for assessing synthetic electronic health +record (EHR) data along three axes: + + - **Privacy** (:mod:`pyhealth.metrics.generative.privacy`): NNAAR, + membership inference, and discriminator-based adversarial accuracy. + - **Utility / fidelity** (:mod:`pyhealth.metrics.generative.utility`): + machine learning efficacy (TRTR vs TSTR) and code-prevalence similarity. + +The convenience function :func:`evaluate_synthetic_ehr` runs the full suite +and returns a single merged dictionary of ``{metric_name: (mean, std)}``. + +Note: + The MLE (utility) component is currently hard-coded to next-visit + prediction and is therefore only meaningful for sequential generators + (HALO, GPT2, PromptEHR). It will be expanded to support pluggable + downstream tasks so that bag-of-codes generators (MedGAN, CorGAN) can + be evaluated with a static-label task (e.g. mortality, readmission). + Until then, prefer the privacy and prevalence metrics when evaluating + MedGAN/CorGAN output. +""" + +import logging +from typing import Dict, Optional, Tuple + +import pandas as pd + +from .privacy import ( + calc_membership_inference, + calc_nnaar, + compute_discriminator_privacy, +) +from .utility import compute_mle, compute_prevalence_metrics +from .utils import train_lstm_model, train_sklearn_model + +logger = logging.getLogger(__name__) + +__all__ = [ + "calc_nnaar", + "calc_membership_inference", + "compute_discriminator_privacy", + "compute_mle", + "compute_prevalence_metrics", + "evaluate_synthetic_ehr", +] + + +def evaluate_synthetic_ehr( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + sample_size: int = 1000, + mode: str = "lstm", + metrics: str = "all", + lstm_params: Optional[Dict] = None, + sklearn_params: Optional[Dict] = None, + n_bootstraps: int = 100, + n_runs: int = 5, +) -> Dict[str, Tuple[float, float]]: + """Runs the full synthetic-EHR evaluation suite. + + Computes privacy and/or utility metrics comparing synthetic EHR data + against real train/test data, and returns a single merged dictionary. + + Args: + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label. + sample_size: Number of patients sampled per dataset for the + privacy metrics. + mode: Predictive backbone for the utility metrics; ``"lstm"`` uses the + built-in LSTM classifier, ``"rf"`` uses a random forest. + metrics: Which metric group to compute: ``"all"``, ``"privacy"`` or + ``"utility"``. + lstm_params: Optional overrides for the LSTM (``embed_dim``, + ``hidden_dim``, ``batch_size``, ``epochs``). + sklearn_params: Optional overrides for the sklearn model (``model``). + n_bootstraps: Number of bootstrap resamples for the utility metrics. + n_runs: Number of sampling runs for the privacy metrics. + + Returns: + Dictionary mapping each metric name to a ``(mean, std)`` tuple. + + Raises: + ValueError: If ``metrics`` or ``mode`` is not a recognized value. + """ + if metrics not in ("all", "privacy", "utility"): + raise ValueError( + f"Unknown metrics group: {metrics!r}. " + "Expected 'all', 'privacy' or 'utility'." + ) + if mode not in ("lstm", "rf"): + raise ValueError(f"Unknown mode: {mode!r}. Expected 'lstm' or 'rf'.") + + lstm_params = lstm_params or {} + sklearn_params = sklearn_params or {} + final_output: Dict[str, Tuple[float, float]] = {} + + if metrics in ("all", "privacy"): + final_output.update( + calc_nnaar( + train_ehr, + test_ehr, + syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + sample_size=sample_size, + n_runs=n_runs, + ) + ) + final_output.update( + calc_membership_inference( + train_ehr, + test_ehr, + syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + num_attack_samples=sample_size, + n_runs=n_runs, + ) + ) + + if metrics in ("all", "utility"): + if mode == "lstm": + train_fn = train_lstm_model + train_kwargs = { + "embed_dim": lstm_params.get("embed_dim", 32), + "hidden_dim": lstm_params.get("hidden_dim", 32), + "batch_size": lstm_params.get("batch_size", 32), + "epochs": lstm_params.get("epochs", 5), + "verbose": False, + } + else: + train_fn = train_sklearn_model + train_kwargs = {"model": sklearn_params.get("model", "rf")} + + final_output.update( + compute_mle( + train_fn=train_fn, + train_ehr=train_ehr, + test_ehr=test_ehr, + syn_ehr=syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + n_bootstraps=n_bootstraps, + **train_kwargs, + ) + ) + final_output.update( + compute_discriminator_privacy( + train_fn=train_fn, + train_ehr=train_ehr, + test_ehr=test_ehr, + syn_ehr=syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + n_bootstraps=n_bootstraps, + **train_kwargs, + ) + ) + final_output.update( + compute_prevalence_metrics( + train_ehr, + syn_ehr, + subject_col=subject_col, + code_col=code_col, + n_bootstraps=n_bootstraps, + ) + ) + + return final_output diff --git a/pyhealth/metrics/generative/privacy.py b/pyhealth/metrics/generative/privacy.py new file mode 100644 index 000000000..c553233ab --- /dev/null +++ b/pyhealth/metrics/generative/privacy.py @@ -0,0 +1,335 @@ +"""Privacy metrics for synthetic EHR data. + +These metrics quantify how much a synthetic EHR dataset leaks about the real +records it was trained on. They include: + + - Nearest Neighbor Adversarial Accuracy Risk (NNAAR) + - Membership Inference Attack (MIA) metrics + - A discriminator-based adversarial-accuracy privacy score + +All functions take flat EHR dataframes (one row per patient/visit/code event) +and return ``{metric_name: (mean, std)}`` summaries computed over multiple runs +or bootstrap resamples. +""" + +import copy +import logging +from typing import Callable, Dict, Tuple + +import numpy as np +import pandas as pd +from sklearn import metrics as sklearn_metrics +from sklearn.model_selection import train_test_split +from tqdm import tqdm + +from .utils import ( + convert_visits_to_sets, + find_nearest_neighbor_dist, + summarize_metric_runs, +) + +logger = logging.getLogger(__name__) + +__all__ = [ + "calc_nnaar", + "calc_membership_inference", + "compute_discriminator_privacy", +] + + +def calc_nnaar( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + sample_size: int = 1000, + n_runs: int = 5, + verbose: bool = False, +) -> Dict[str, Tuple[float, float]]: + """Computes the Nearest Neighbor Adversarial Accuracy Risk (NNAAR). + + NNAAR measures whether the synthetic data sits closer to the real training + data than to held-out test data, which would indicate memorization:: + + NNAAR = AA_ES - AA_TS + + where ``AA_ES`` is the adversarial accuracy between test and synthetic data + and ``AA_TS`` is the adversarial accuracy between train and synthetic data. + Values near 0 indicate low privacy risk. + + Args: + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label (unused, kept for a uniform API). + sample_size: Number of patients to sample per dataset per run. + n_runs: Number of independent sampling runs. + verbose: Whether to show per-run progress bars. + + Returns: + Dictionary mapping ``"nnaar"``, ``"aa_es"`` and ``"aa_ts"`` to their + ``(mean, std)`` across runs. + """ + logger.info( + "Calculating NNAAR (sample_size=%d, n_runs=%d)", sample_size, n_runs + ) + train = convert_visits_to_sets(train_ehr, subject_col, visit_col, code_col) + test = convert_visits_to_sets(test_ehr, subject_col, visit_col, code_col) + synthetic = convert_visits_to_sets(syn_ehr, subject_col, visit_col, code_col) + + metrics_runs = [] + n = min(sample_size, len(train), len(test), len(synthetic)) + + for _ in range(n_runs): + if len(train) > n: + inds = np.random.choice(len(train), n, replace=False) + s_train = [train[i] for i in inds] + else: + s_train = list(train) + if len(test) > n: + inds = np.random.choice(len(test), n, replace=False) + s_test = [test[i] for i in inds] + else: + s_test = list(test) + if len(synthetic) > n: + inds = np.random.choice(len(synthetic), n, replace=False) + s_syn = [synthetic[i] for i in inds] + else: + s_syn = list(synthetic) + + # AA_ES (test vs synthetic). + val1 = sum( + 1 + for p in tqdm(s_test, desc="Test vs Syn", disable=not verbose) + if find_nearest_neighbor_dist(p, s_syn) + > find_nearest_neighbor_dist(p, s_test) + ) + val2 = sum( + 1 + for p in tqdm(s_syn, desc="Syn vs Test", disable=not verbose) + if find_nearest_neighbor_dist(p, s_test) + > find_nearest_neighbor_dist(p, s_syn) + ) + # AA_TS (train vs synthetic). + val3 = sum( + 1 + for p in tqdm(s_train, desc="Train vs Syn", disable=not verbose) + if find_nearest_neighbor_dist(p, s_syn) + > find_nearest_neighbor_dist(p, s_train) + ) + val4 = sum( + 1 + for p in tqdm(s_syn, desc="Syn vs Train", disable=not verbose) + if find_nearest_neighbor_dist(p, s_train) + > find_nearest_neighbor_dist(p, s_syn) + ) + + aa_es = 0.5 * (val1 / n + val2 / n) + aa_ts = 0.5 * (val3 / n + val4 / n) + metrics_runs.append( + {"nnaar": aa_es - aa_ts, "aa_es": aa_es, "aa_ts": aa_ts} + ) + + return summarize_metric_runs(metrics_runs) + + +def calc_membership_inference( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + num_attack_samples: int = 1000, + n_runs: int = 5, + verbose: bool = False, +) -> Dict[str, Tuple[float, float]]: + """Computes Membership Inference Attack (MIA) metrics. + + An attacker tries to tell members (training patients) from non-members + (test patients) using proximity to the synthetic data: members are expected + to be closer to synthetic records. Predictions are made by thresholding the + nearest-neighbor distance at its median; F1, precision, recall and accuracy + near 0.5 indicate low membership-inference risk. + + Args: + train_ehr: Real training EHR dataframe (members). + test_ehr: Real held-out test EHR dataframe (non-members). + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label (unused, kept for a uniform API). + num_attack_samples: Total attack-set size (half members, half not). + n_runs: Number of independent sampling runs. + verbose: Whether to show per-run progress bars. + + Returns: + Dictionary mapping ``"MIA_F1"``, ``"MIA_Precision"``, ``"MIA_Recall"`` + and ``"MIA_Accuracy"`` to their ``(mean, std)`` across runs. + """ + logger.info( + "Calculating Membership Inference (attack_size=%d, n_runs=%d)", + num_attack_samples, + n_runs, + ) + train = convert_visits_to_sets(train_ehr, subject_col, visit_col, code_col) + test = convert_visits_to_sets(test_ehr, subject_col, visit_col, code_col) + synthetic = convert_visits_to_sets(syn_ehr, subject_col, visit_col, code_col) + + metrics_runs = [] + for _ in range(n_runs): + # Build a balanced attack set: 50% members, 50% non-members. + n_half = min(len(train), len(test), num_attack_samples) // 2 + if n_half == 0: + continue + + pos_inds = np.random.choice(len(train), n_half, replace=False) + pos_samples = [train[i] for i in pos_inds] + neg_inds = np.random.choice(len(test), n_half, replace=False) + neg_samples = [test[i] for i in neg_inds] + + attack_data = pos_samples + neg_samples + attack_labels = [1] * len(pos_samples) + [0] * len(neg_samples) + + distances = [ + find_nearest_neighbor_dist(record, synthetic) + for record in tqdm( + attack_data, desc="Calculating Distances", disable=not verbose + ) + ] + if len(distances) == 0: + continue + + # Members are expected to be closer (smaller distance) to synthetic. + median_dist = np.median(distances) + predictions = [1 if d < median_dist else 0 for d in distances] + + metrics_runs.append( + { + "MIA_F1": sklearn_metrics.f1_score(attack_labels, predictions), + "MIA_Precision": sklearn_metrics.precision_score( + attack_labels, predictions, zero_division=0 + ), + "MIA_Recall": sklearn_metrics.recall_score( + attack_labels, predictions, zero_division=0 + ), + "MIA_Accuracy": sklearn_metrics.accuracy_score( + attack_labels, predictions + ), + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("MIA results: %s", summary) + return summary + + +def compute_discriminator_privacy( + train_fn: Callable, + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + n_bootstraps: int = 5, + seed: int = 4, + **kwargs, +) -> Dict[str, Tuple[float, float]]: + """Computes a discriminator-based adversarial-accuracy privacy score. + + A classifier is trained to predict whether a record is real (1) or + synthetic (0). An accuracy near 0.5 means real and synthetic data are + indistinguishable (good privacy); accuracy well above 0.5 means the + synthetic data is easy to tell apart (poor privacy). The ``Privacy_Score`` + rescales accuracy so 1.0 is perfect privacy and 0.0 is none. + + Args: + train_fn: A training function such as + :func:`pyhealth.metrics.generative.utils.train_lstm_model` or + ``train_sklearn_model``. It must accept ``train_ehr``, ``test_ehr``, + the four column-name arguments and return ``(model, y_true, + y_pred)``. + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe (unused; kept for a uniform + API with the other metrics). + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the original label (unused; the + discriminator target replaces it). + n_bootstraps: Number of bootstrap resamples of the predictions. + seed: Random seed for the patient-level train/test split. + **kwargs: Extra keyword arguments forwarded to ``train_fn``. + + Returns: + Dictionary mapping ``"Privacy_Discriminator_Accuracy"`` and + ``"Privacy_Score"`` to their ``(mean, std)`` across bootstraps. + """ + logger.info("Computing discriminator privacy") + + # Label data: real = 1, synthetic = 0. + real_df = copy.deepcopy(train_ehr) + syn_df = copy.deepcopy(syn_ehr) + disc_label = "is_real" + real_df[disc_label] = 1 + syn_df[disc_label] = 0 + + # Disambiguate subject IDs so real/synthetic patients never collide. + real_df[subject_col] = real_df[subject_col].astype(str) + "_real" + syn_df[subject_col] = syn_df[subject_col].astype(str) + "_syn" + + combined_df = pd.concat([real_df, syn_df]) + unique_patients = combined_df[subject_col].unique() + train_ids, test_ids = train_test_split( + unique_patients, test_size=0.2, random_state=seed + ) + disc_train = combined_df[combined_df[subject_col].isin(train_ids)] + disc_test = combined_df[combined_df[subject_col].isin(test_ids)] + + logger.info( + "Discriminator train size=%d, test size=%d", + len(disc_train), + len(disc_test), + ) + _, y_true, y_pred = train_fn( + train_ehr=disc_train, + test_ehr=disc_test, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=disc_label, + **kwargs, + ) + + metrics_runs = [] + n_samples = len(y_true) + for _ in range(n_bootstraps): + if n_samples > 0: + indices = np.random.choice(n_samples, n_samples, replace=True) + y_t, y_p = y_true[indices], y_pred[indices] + else: + y_t, y_p = y_true, y_pred + + acc = sklearn_metrics.accuracy_score(y_t, y_p) if len(y_t) > 0 else 0.0 + metrics_runs.append( + { + "Privacy_Discriminator_Accuracy": acc, + # 1.0 = perfect privacy (acc 0.5); 0.0 = no privacy. + "Privacy_Score": 1.0 - 2 * abs(0.5 - acc), + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("Discriminator privacy results: %s", summary) + return summary diff --git a/pyhealth/metrics/generative/utility.py b/pyhealth/metrics/generative/utility.py new file mode 100644 index 000000000..4d65b0e6b --- /dev/null +++ b/pyhealth/metrics/generative/utility.py @@ -0,0 +1,256 @@ +"""Utility and statistical-fidelity metrics for synthetic EHR data. + +These metrics quantify how *useful* synthetic EHR data is as a stand-in for +real data: + + - Machine Learning Efficacy (MLE): compares a model trained on real data + against one trained on synthetic data, both evaluated on real data. + - Code-prevalence similarity: compares per-code patient-level prevalence + between real and synthetic data (R-squared, Pearson correlation, RMSE). + +All functions take flat EHR dataframes (one row per patient/visit/code event) +and return ``{metric_name: (mean, std)}`` summaries over bootstrap resamples. +""" + +import copy +import logging +from typing import Callable, Dict, Tuple + +import numpy as np +import pandas as pd +from sklearn import metrics as sklearn_metrics + +from .utils import build_next_visit_prediction_dataset, summarize_metric_runs + +logger = logging.getLogger(__name__) + +__all__ = [ + "compute_mle", + "compute_prevalence_metrics", +] + + +def compute_mle( + train_fn: Callable, + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + n_bootstraps: int = 5, + **kwargs, +) -> Dict[str, Tuple[float, float]]: + """Computes Machine Learning Efficacy (utility) for synthetic data. + + Two classifiers are trained on a next-visit prediction task: one on real + training data (Train-Real-Test-Real, TRTR) and one on synthetic data + (Train-Synthetic-Test-Real, TSTR). Both are evaluated on the same real test + set. Synthetic accuracy/F1 close to real accuracy/F1 indicates high utility. + + Note: + The current implementation hard-codes the downstream task to + next-visit prediction (built via + :func:`build_next_visit_prediction_dataset`). This is degenerate for + bag-of-codes generators such as MedGAN and CorGAN, which emit a + single aggregate visit per patient and so always get label=0. A + future revision will let callers plug in static-label tasks + (mortality, readmission, "ever diagnosed with X", ...) so MLE is + meaningful for both sequential (HALO, GPT2, PromptEHR) and + bag-of-codes (MedGAN, CorGAN) generators. + + Args: + train_fn: A training function such as + :func:`pyhealth.metrics.generative.utils.train_lstm_model` or + ``train_sklearn_model``, returning ``(model, y_true, y_pred)``. + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label (overwritten by the next-visit + prediction label). + n_bootstraps: Number of bootstrap resamples of the predictions. + **kwargs: Extra keyword arguments forwarded to ``train_fn``. + + Returns: + Dictionary mapping the MLE metrics (real/synthetic accuracy and F1, + their difference and ratio) to their ``(mean, std)`` across + bootstraps. + """ + logger.info("Computing MLE (utility)") + + train_task = build_next_visit_prediction_dataset( + train_ehr, subject_col, visit_col, label_col + ) + test_task = build_next_visit_prediction_dataset( + test_ehr, subject_col, visit_col, label_col + ) + syn_task = build_next_visit_prediction_dataset( + syn_ehr, subject_col, visit_col, label_col + ) + + # Train on Real, test on Real (TRTR). + _, real_y_true, real_y_pred = train_fn( + copy.deepcopy(train_task), + copy.deepcopy(test_task), + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + **kwargs, + ) + # Train on Synthetic, test on Real (TSTR). + _, syn_y_true, syn_y_pred = train_fn( + copy.deepcopy(syn_task), + copy.deepcopy(test_task), + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + **kwargs, + ) + + metrics_runs = [] + n_samples = len(real_y_true) + for _ in range(n_bootstraps): + if n_samples > 0: + indices = np.random.choice(n_samples, n_samples, replace=True) + r_true, r_pred = real_y_true[indices], real_y_pred[indices] + s_true, s_pred = syn_y_true[indices], syn_y_pred[indices] + else: + r_true, r_pred = real_y_true, real_y_pred + s_true, s_pred = syn_y_true, syn_y_pred + + real_acc = ( + sklearn_metrics.accuracy_score(r_true, r_pred) + if len(r_true) > 0 + else 0.0 + ) + syn_acc = ( + sklearn_metrics.accuracy_score(s_true, s_pred) + if len(s_true) > 0 + else 0.0 + ) + real_f1 = ( + sklearn_metrics.f1_score(r_true, r_pred, average="macro") + if len(r_true) > 0 + else 0.0 + ) + syn_f1 = ( + sklearn_metrics.f1_score(s_true, s_pred, average="macro") + if len(s_true) > 0 + else 0.0 + ) + + metrics_runs.append( + { + "MLE_Real_Accuracy": real_acc, + "MLE_Synth_Accuracy": syn_acc, + "MLE_Difference": real_acc - syn_acc, + "MLE_Ratio": syn_acc / real_acc if real_acc > 0 else 0.0, + "MLE_Real_F1": real_f1, + "MLE_Synth_F1": syn_f1, + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("MLE results: %s", summary) + return summary + + +def compute_prevalence_metrics( + train_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + code_col: str = "visit_codes", + n_bootstraps: int = 5, +) -> Dict[str, Tuple[float, float]]: + """Compares per-code patient-level prevalence of real vs synthetic data. + + For every code, prevalence is the fraction of unique patients who have that + code at least once. The real and synthetic prevalence vectors are compared + with R-squared, Pearson correlation and RMSE; bootstrap resampling is over + codes. + + Args: + train_ehr: Real training EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + code_col: Column name for the medical codes. + n_bootstraps: Number of bootstrap resamples over codes. + + Returns: + Dictionary mapping ``"Prevalence_R2"``, ``"Prevalence_Pearson"`` and + ``"Prevalence_RMSE"`` to their ``(mean, std)`` across bootstraps. + """ + logger.info("Computing prevalence metrics") + + all_codes = set() + all_codes.update(train_ehr[code_col].unique().tolist()) + all_codes.update(syn_ehr[code_col].unique().tolist()) + + n_train = train_ehr[subject_col].nunique() + n_syn = syn_ehr[subject_col].nunique() + if n_train == 0 or n_syn == 0: + return { + "Prevalence_R2": (0.0, 0.0), + "Prevalence_Pearson": (0.0, 0.0), + "Prevalence_RMSE": (0.0, 0.0), + } + + # Count unique patients per code. + train_counts = train_ehr.groupby(code_col)[subject_col].nunique() + syn_counts = syn_ehr.groupby(code_col)[subject_col].nunique() + for code in all_codes: + if code not in train_counts.index: + train_counts.loc[code] = 0 + if code not in syn_counts.index: + syn_counts.loc[code] = 0 + + train_probs = train_counts / n_train + syn_probs = syn_counts / n_syn + df_compare = pd.DataFrame( + {"real": train_probs, "syn": syn_probs} + ).fillna(0) + + metrics_runs = [] + n_samples = len(df_compare) + for _ in range(n_bootstraps): + if n_samples > 0: + df_sampled = df_compare.sample(n=n_samples, replace=True) + real_vec = df_sampled["real"].values + syn_vec = df_sampled["syn"].values + else: + real_vec = df_compare["real"].values + syn_vec = df_compare["syn"].values + + r2 = ( + sklearn_metrics.r2_score(real_vec, syn_vec) + if n_samples > 1 + else 0.0 + ) + # Pearson correlation via numpy (avoids a hard scipy dependency). + if len(np.unique(real_vec)) > 1 and len(np.unique(syn_vec)) > 1: + rho = float(np.corrcoef(real_vec, syn_vec)[0, 1]) + else: + rho = 0.0 + rmse = ( + float(np.sqrt(sklearn_metrics.mean_squared_error(real_vec, syn_vec))) + if n_samples > 0 + else 0.0 + ) + + metrics_runs.append( + { + "Prevalence_R2": r2, + "Prevalence_Pearson": rho, + "Prevalence_RMSE": rmse, + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("Prevalence results: %s", summary) + return summary diff --git a/pyhealth/metrics/generative/utils.py b/pyhealth/metrics/generative/utils.py new file mode 100644 index 000000000..2e2123586 --- /dev/null +++ b/pyhealth/metrics/generative/utils.py @@ -0,0 +1,584 @@ +"""Shared utilities for synthetic-EHR generative evaluation metrics. + +This module contains the data-preparation helpers, distance functions, and +the lightweight predictive models (an LSTM classifier and a random-forest +baseline) that the privacy and utility metrics build on. It is not intended +to be used directly; see :mod:`pyhealth.metrics.generative.privacy` and +:mod:`pyhealth.metrics.generative.utility` for the public metric functions. +""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn + +__all__ = [ + "summarize_metric_runs", + "convert_visits_to_sets", + "calculate_hamming_distance_cutoff", + "find_nearest_neighbor_dist", + "process_patient_data_for_lstm", + "collate_fn", + "EHRDataset", + "EHR_LSTM_Classifier", + "train_lstm_model", + "aggregate_patient_visits", + "train_sklearn_model", + "build_next_visit_prediction_dataset", + "convert_cols_to_multihot", +] + + +def summarize_metric_runs( + metrics_list: List[Dict[str, float]] +) -> Dict[str, Tuple[float, float]]: + """Summarizes a list of per-run metric dicts into (mean, std) tuples. + + Args: + metrics_list: List of dicts, one per run, mapping metric name to value. + + Returns: + Dictionary mapping each metric name to a ``(mean, std)`` tuple computed + across the runs. Returns an empty dict if ``metrics_list`` is empty. + """ + if not metrics_list: + return {} + summary: Dict[str, Tuple[float, float]] = {} + for key in metrics_list[0].keys(): + values = [run[key] for run in metrics_list if key in run] + summary[key] = (float(np.mean(values)), float(np.std(values))) + return summary + + +# --- Privacy distance helpers --------------------------------------------- + + +def convert_visits_to_sets( + df: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", +) -> List[List[set]]: + """Converts a flat EHR dataframe into per-patient lists of code sets. + + Each patient becomes a list of visits, and each visit is a ``set`` of the + codes recorded at that timestep. + + Args: + df: Input dataframe with one row per (patient, visit, code) event. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + + Returns: + List of patients, where each patient is a list of code sets. + """ + records = ( + df.groupby(subject_col)[[visit_col, code_col]] + .apply(lambda x: x.groupby(visit_col)[code_col].apply(set).tolist()) + .tolist() + ) + return records + + +def calculate_hamming_distance_cutoff( + v1: List[set], v2: List[set], cutoff: float +) -> float: + """Computes a set-based Hamming distance between two patients, with cutoff. + + The distance accumulates the symmetric-difference size of aligned visits + plus a penalty for differing sequence lengths. Computation stops early once + the running distance reaches ``cutoff``. + + Args: + v1: First patient as a list of code sets. + v2: Second patient as a list of code sets. + cutoff: Distance value at which to stop early. + + Returns: + The distance between ``v1`` and ``v2``, capped at ``cutoff``. + """ + len1, len2 = len(v1), len(v2) + dist = 0 if len1 == len2 else 1 + if dist >= cutoff: + return cutoff + + min_len = min(len1, len2) + for i in range(min_len): + dist += len(v1[i] ^ v2[i]) + if dist >= cutoff: + return cutoff + + if len1 > min_len: + dist += sum(len(v) for v in v1[min_len:]) + elif len2 > min_len: + dist += sum(len(v) for v in v2[min_len:]) + return dist + + +def find_nearest_neighbor_dist( + query: List[set], reference_dataset: List[List[set]] +) -> float: + """Finds the distance from a query patient to its nearest neighbor. + + Args: + query: Query patient as a list of code sets. + reference_dataset: Patients to search over. + + Returns: + The smallest :func:`calculate_hamming_distance_cutoff` distance between + ``query`` and any patient in ``reference_dataset``. + """ + best = float("inf") + for ref in reference_dataset: + d = calculate_hamming_distance_cutoff(query, ref, best) + if d == 0: + return 0 + if d < best: + best = d + return best + + +# --- LSTM classifier ------------------------------------------------------- + + +def process_patient_data_for_lstm( + df: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + code_to_idx: Optional[Dict] = None, +) -> Tuple[List[Tuple[torch.Tensor, int]], Dict]: + """Transforms a flat EHR dataframe into multi-hot visit sequences. + + Each patient is converted into a ``(seq_len, vocab_size)`` tensor of + multi-hot visit vectors, paired with a single static label (the per-patient + max of ``label_col``). + + Args: + df: Input dataframe with one row per (patient, visit, code) event. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the binary label. + code_to_idx: Optional precomputed mapping from code to integer index. + If ``None``, one is built from ``df``. + + Returns: + A tuple ``(patients, code_to_idx)`` where ``patients`` is a list of + ``(sequence_tensor, label)`` tuples. + """ + assert label_col in df.columns, f"Label column '{label_col}' not found." + assert subject_col in df.columns, f"Subject column '{subject_col}' not found." + assert visit_col in df.columns, f"Visit column '{visit_col}' not found." + + df = df.copy() + if code_to_idx is None: + vocab_size = df[code_col].nunique() + 1 + code_to_idx = { + code: idx for idx, code in enumerate(df[code_col].unique(), start=0) + } + else: + vocab_size = len(code_to_idx) + 1 + df[code_col] = df[code_col].map(code_to_idx) + + patients = [] + for _, group in df.groupby(subject_col): + # Static per-patient label: the max over visits (e.g. "ever diagnosed"). + label = group[label_col].max() + visits = group.sort_values(visit_col).groupby(visit_col) + patient_seq = [] + for _, visit_data in visits: + multi_hot = torch.zeros(vocab_size) + codes = visit_data[code_col].values + multi_hot[codes] = 1.0 + patient_seq.append(multi_hot) + patient_seq_tensor = torch.stack(patient_seq) + patients.append((patient_seq_tensor, label)) + + return patients, code_to_idx + + +def collate_fn(batch): + """Pads variable-length visit sequences for batched LSTM training. + + Args: + batch: List of ``(sequence_tensor, label)`` tuples. + + Returns: + A tuple ``(padded_seqs, lengths, labels)``. + """ + sequences, labels = zip(*batch) + lengths = torch.tensor([len(seq) for seq in sequences]) + padded_seqs = torch.nn.utils.rnn.pad_sequence( + sequences, batch_first=True, padding_value=0 + ) + labels = torch.tensor(labels, dtype=torch.float32) + return padded_seqs, lengths, labels + + +class EHRDataset(torch.utils.data.Dataset): + """A minimal :class:`torch.utils.data.Dataset` wrapper over a list.""" + + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +class EHR_LSTM_Classifier(nn.Module): + """A simple LSTM classifier over multi-hot EHR visit sequences. + + The model embeds each multi-hot visit vector, encodes the sequence with an + LSTM, and classifies using the final hidden state. + + Args: + vocab_size: Size of the code vocabulary (input dimension per visit). + embed_dim: Dimension of the dense visit embedding. + hidden_dim: Hidden dimension of the LSTM. + num_layers: Number of stacked LSTM layers. + """ + + def __init__( + self, + vocab_size: int, + embed_dim: int, + hidden_dim: int, + num_layers: int = 1, + ): + super().__init__() + self.embedding = nn.Linear(vocab_size, embed_dim) + self.lstm = nn.LSTM( + input_size=embed_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + ) + self.fc = nn.Linear(hidden_dim, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: + x = self.embedding(x) + packed_x = torch.nn.utils.rnn.pack_padded_sequence( + x, lengths.cpu(), batch_first=True, enforce_sorted=False + ) + _, (h_n, _) = self.lstm(packed_x) + final_encoding = h_n[-1] + logits = self.fc(final_encoding) + probs = self.sigmoid(logits) + return probs.squeeze(-1) + + +def train_lstm_model( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + subject_col: str, + visit_col: str, + code_col: str, + label_col: str, + embed_dim: int = 32, + hidden_dim: int = 32, + batch_size: int = 32, + epochs: int = 5, + verbose: bool = True, + seed: int = 4, +) -> Tuple[nn.Module, np.ndarray, np.ndarray]: + """Trains :class:`EHR_LSTM_Classifier` and evaluates it on a test set. + + Args: + train_ehr: Training EHR dataframe. + test_ehr: Test EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the binary label. + embed_dim: Visit embedding dimension. + hidden_dim: LSTM hidden dimension. + batch_size: Training/eval batch size. + epochs: Number of training epochs. + verbose: Whether to print per-epoch loss. + seed: Random seed for reproducibility. + + Returns: + A tuple ``(model, y_true, y_pred)`` where ``y_true`` and ``y_pred`` are + numpy arrays of test labels and binary predictions. + """ + torch.manual_seed(seed) + all_codes = set() + all_codes.update(train_ehr[code_col].unique().tolist()) + all_codes.update(test_ehr[code_col].unique().tolist()) + # Start indices at 1 to reserve 0 for padding. + code_to_idx = {code: idx for idx, code in enumerate(all_codes, start=1)} + + train_data, _ = process_patient_data_for_lstm( + train_ehr, subject_col, visit_col, code_col, label_col, code_to_idx + ) + test_data, _ = process_patient_data_for_lstm( + test_ehr, subject_col, visit_col, code_col, label_col, code_to_idx + ) + train_dataloader = torch.utils.data.DataLoader( + dataset=EHRDataset(train_data), + batch_size=batch_size, + collate_fn=collate_fn, + shuffle=True, + ) + test_dataloader = torch.utils.data.DataLoader( + dataset=EHRDataset(test_data), + batch_size=batch_size, + collate_fn=collate_fn, + shuffle=False, + ) + + model = EHR_LSTM_Classifier( + vocab_size=len(code_to_idx) + 1, + embed_dim=embed_dim, + hidden_dim=hidden_dim, + ) + criterion = nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + use_cuda = torch.cuda.is_available() + if use_cuda: + model = model.cuda() + + model.train() + for epoch in range(epochs): + total_loss = 0.0 + for batch_x, batch_lens, batch_y in train_dataloader: + optimizer.zero_grad() + if use_cuda: + batch_x, batch_y = batch_x.cuda(), batch_y.cuda() + predictions = model(batch_x, batch_lens) + loss = criterion(predictions, batch_y) + loss.backward() + optimizer.step() + total_loss += loss.item() + if verbose: + avg_loss = total_loss / max(len(train_dataloader), 1) + print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}") + + model.eval() + all_preds: List[float] = [] + all_labels: List[float] = [] + with torch.no_grad(): + for batch_x, batch_lens, batch_y in test_dataloader: + if use_cuda: + batch_x, batch_y = batch_x.cuda(), batch_y.cuda() + predictions = model(batch_x, batch_lens) + all_preds.extend(predictions.cpu().numpy()) + all_labels.extend(batch_y.cpu().numpy()) + + y_true = np.array(all_labels) + y_pred = np.array([1 if p >= 0.5 else 0 for p in all_preds]) + return model, y_true, y_pred + + +# --- Random-forest baseline ------------------------------------------------ + + +def aggregate_patient_visits( + df: pd.DataFrame, + subject_col: str, + code_col: str, + label_col: str, + code_to_idx: Dict, +) -> Tuple[np.ndarray, np.ndarray]: + """Aggregates each patient's visits into a single multi-hot vector. + + Args: + df: Input dataframe with integer-encoded codes in ``code_col``. + subject_col: Column name for patient/subject identifiers. + code_col: Column name for the (integer-encoded) medical codes. + label_col: Column name for the binary label. + code_to_idx: Mapping from code to index (used to size the vector). + + Returns: + A tuple ``(patient_vectors, patient_labels)`` of numpy arrays. + """ + patient_vectors = [] + patient_labels = [] + for _, group in df.groupby(subject_col): + codes = group[code_col].unique() + multi_hot = np.zeros(len(code_to_idx) + 1) + multi_hot[codes] = 1 + patient_vectors.append(multi_hot) + patient_labels.append(group[label_col].max()) + return np.array(patient_vectors), np.array(patient_labels) + + +def train_sklearn_model( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + subject_col: str, + visit_col: str, + code_col: str, + label_col: str, + model: str = "rf", + seed: int = 4, +) -> Tuple[object, np.ndarray, np.ndarray]: + """Trains an sklearn classifier on aggregated patient-level multi-hot data. + + Args: + train_ehr: Training EHR dataframe. + test_ehr: Test EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers (unused, kept for + a uniform signature with :func:`train_lstm_model`). + code_col: Column name for the medical codes. + label_col: Column name for the binary label. + model: Which model to train. Only ``"rf"`` (random forest) is supported. + seed: Random seed for reproducibility. + + Returns: + A tuple ``(model, y_true, y_pred)``. + """ + train_ehr = train_ehr.copy() + test_ehr = test_ehr.copy() + + all_codes = set() + all_codes.update(train_ehr[code_col].unique().tolist()) + all_codes.update(test_ehr[code_col].unique().tolist()) + # Start indices at 1 to reserve 0 for padding. + code_to_idx = {code: idx for idx, code in enumerate(all_codes, start=1)} + train_ehr[code_col] = train_ehr[code_col].map(code_to_idx) + test_ehr[code_col] = test_ehr[code_col].map(code_to_idx) + + X_train, y_train = aggregate_patient_visits( + train_ehr, subject_col, code_col, label_col, code_to_idx + ) + X_test, y_test = aggregate_patient_visits( + test_ehr, subject_col, code_col, label_col, code_to_idx + ) + + if model == "rf": + from sklearn.ensemble import RandomForestClassifier + + clf = RandomForestClassifier(n_estimators=100, random_state=seed) + else: + raise NotImplementedError(f"Model '{model}' not implemented.") + clf.fit(X_train, y_train) + + y_pred = clf.predict(X_test) + return clf, y_test, y_pred + + +# --- Task / feature construction ------------------------------------------ + + +def build_next_visit_prediction_dataset( + df: pd.DataFrame, + subject_col: str, + visit_col: str, + label_col: str, + multi_visit_sample_frac: float = 0.5, + seed: int = 4, +) -> pd.DataFrame: + """Builds a next-visit prediction task from an EHR dataframe. + + For patients with multiple visits, a fraction is sampled and their last + visit is dropped; these patients are labeled 1 (has a next visit). The + remaining multi-visit patients are kept intact and labeled 0. Single-visit + patients are labeled 0 by definition. + + Args: + df: Input EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + label_col: Column name to overwrite with the next-visit label. + multi_visit_sample_frac: Fraction of multi-visit patients to truncate. + seed: Random seed for reproducibility. + + Returns: + A new dataframe with ``label_col`` set to the next-visit label. + """ + assert 0.0 <= multi_visit_sample_frac <= 1.0, ( + "multi_visit_sample_frac must be in [0, 1]." + ) + + rng = np.random.default_rng(seed) + transformed_groups = [] + + for _, group in df.groupby(subject_col): + group_sorted = group.sort_values(visit_col) + unique_visits = np.sort(group_sorted[visit_col].unique()) + n_visits = len(unique_visits) + + if n_visits <= 1: + g = group_sorted.copy() + g[label_col] = 0 + transformed_groups.append(g) + continue + + should_truncate = rng.random() < multi_visit_sample_frac + if should_truncate: + last_visit = unique_visits[-1] + g = group_sorted[group_sorted[visit_col] != last_visit].copy() + if g.empty: + # Defensive fallback for unexpected edge cases. + g = group_sorted.copy() + g[label_col] = 0 + else: + g[label_col] = 1 + else: + g = group_sorted.copy() + g[label_col] = 0 + transformed_groups.append(g) + + if len(transformed_groups) == 0: + return df.copy() + return pd.concat(transformed_groups, ignore_index=True) + + +def convert_cols_to_multihot( + df: pd.DataFrame, + code_col: str, + visit_col: str, + cat_cols: List[str], + num_cols: List[str], + bins_per_num: int = 5, +) -> pd.DataFrame: + """Folds categorical and numeric columns into per-visit multi-hot codes. + + Categorical columns are prefixed with their column name; numeric columns + are quantile-binned and likewise prefixed. All values are combined with the + original code into a single comma-separated ``combined_codes`` column. + + Args: + df: Input dataframe. + code_col: Column name for the existing medical codes. + visit_col: Column name for visit/timestep identifiers (kept for a + uniform signature; not modified). + cat_cols: Categorical column names to fold in. + num_cols: Numeric column names to bin and fold in. + bins_per_num: Number of quantile bins per numeric column. + + Returns: + A copy of ``df`` with an added ``combined_codes`` column. + """ + df = df.copy() + for col in cat_cols: + df[col] = col + "_" + df[col].astype(str) + + for col in num_cols: + df[col + "_binned"] = pd.qcut( + df[col], q=bins_per_num, duplicates="drop" + ).astype(str) + df[col + "_binned"] = col + "_" + df[col + "_binned"] + + def combine_codes(row): + codes = [str(row[code_col])] + for col in cat_cols: + codes.append(str(row[col])) + for col in num_cols: + codes.append(str(row[col + "_binned"])) + return ",".join(codes) + + df["combined_codes"] = df.apply(combine_codes, axis=1) + return df diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..18500b9c0 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,9 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest +from .generators.halo import HALO +from .generators.gpt2 import GPT2 +from .generators.promptehr import PromptEHR +from .generators.medgan import MedGAN +from .generators.corgan import CorGAN \ No newline at end of file diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..a65bbd975 --- /dev/null +++ b/pyhealth/models/generators/__init__.py @@ -0,0 +1,7 @@ +from .halo import HALO +from .gpt2 import GPT2 +from .promptehr import PromptEHR +from .medgan import MedGAN +from .corgan import CorGAN + +__all__ = ["HALO", "GPT2", "PromptEHR", "MedGAN", "CorGAN"] diff --git a/pyhealth/models/generators/corgan.py b/pyhealth/models/generators/corgan.py new file mode 100644 index 000000000..8652c983a --- /dev/null +++ b/pyhealth/models/generators/corgan.py @@ -0,0 +1,683 @@ +"""CorGAN: Correlation-capturing GAN for synthetic EHR generation. + +This is a port of the reference implementation +(https://github.com/astorfi/cor-gan, specifically +``reference/cor-gan/Generative/corGAN/pytorch/CNN/MIMIC/wgancnnmimic.py``) +wrapped as a PyHealth ``BaseModel`` so it consumes the standard +``dataset -> SampleDataset -> model`` pipeline. + +CorGAN treats each patient as a flat bag-of-codes (no visit structure), so it +expects an input feature named ``visits`` backed by a ``MultiHotProcessor``. +Training has two phases (mirroring the reference): + +* a **convolutional autoencoder** is pre-trained with a sparse-friendly BCE + reconstruction loss, then +* a **WGAN** adversarial phase runs the generator + decoder against a + Lipschitz-clipped critic (no sigmoid; weight clipping in + ``[clamp_lower, clamp_upper]``). + +For tiny vocabularies that can't survive the 6-layer convolutional chain we +automatically fall back to the linear-autoencoder variant noted in the +reference's commented-out alternative. The public ``CorGAN`` class follows the +same API style as :class:`pyhealth.models.generators.HALO` +(``train_model`` / ``generate`` / ``save_model`` / ``load_model``). +""" + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset, RandomSampler +from tqdm import tqdm + +from pyhealth.models import BaseModel + + +# ---------------------------------------------------------------------------- +# Building blocks (ported from reference wgancnnmimic.py) +# ---------------------------------------------------------------------------- +class _MultiHotDataset(Dataset): + """Tiny ``torch.utils.data.Dataset`` over a multi-hot numpy matrix.""" + + def __init__(self, data: np.ndarray): + self.data = np.clip(data.astype(np.float32), 0.0, 1.0) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx): + return torch.from_numpy(self.data[idx]) + + +class CorGANCNNAutoencoder(nn.Module): + """1D-CNN autoencoder from the reference CorGAN paper. + + Six 1D conv layers compress the multi-hot vector down to a tiny latent; + six transposed-conv layers project back. When ``use_adaptive_pooling`` is + True we tack on an ``AdaptiveAvgPool1d`` so the decoder hits the exact + vocabulary size for any input dim (the original CNN was hard-coded for + MIMIC's vocabulary). + + Args: + feature_size: Vocabulary size. + use_adaptive_pooling: If True, force decoder output to ``feature_size`` + via adaptive average pooling. Default: True. + """ + + def __init__(self, feature_size: int, use_adaptive_pooling: bool = True): + super().__init__() + self.feature_size = feature_size + self.use_adaptive_pooling = use_adaptive_pooling + c = 4 # n_channels_base, per reference + + # Encoder: kernels (5,5,5,5,5,8), strides (2,2,3,3,3,1). + self.encoder = nn.Sequential( + nn.Conv1d(1, c, kernel_size=5, stride=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(c, 2 * c, kernel_size=5, stride=2), + nn.BatchNorm1d(2 * c), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(2 * c, 4 * c, kernel_size=5, stride=3), + nn.BatchNorm1d(4 * c), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(4 * c, 8 * c, kernel_size=5, stride=3), + nn.BatchNorm1d(8 * c), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(8 * c, 16 * c, kernel_size=5, stride=3), + nn.BatchNorm1d(16 * c), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(16 * c, 32 * c, kernel_size=8, stride=1), + nn.Tanh(), + ) + + # Decoder: kernels (5,5,7,7,7,3), strides (1,4,4,3,2,2). First layer + # has NO BatchNorm, matching the reference. The original was hard-coded + # for MIMIC's vocabulary; adaptive pooling rescales the output to any + # feature_size for downstream Sigmoid binarisation. + decoder_layers = [ + nn.ConvTranspose1d(32 * c, 16 * c, kernel_size=5, stride=1), + nn.ReLU(), + nn.ConvTranspose1d(16 * c, 8 * c, kernel_size=5, stride=4), + nn.BatchNorm1d(8 * c), + nn.ReLU(), + nn.ConvTranspose1d(8 * c, 4 * c, kernel_size=7, stride=4), + nn.BatchNorm1d(4 * c), + nn.ReLU(), + nn.ConvTranspose1d(4 * c, 2 * c, kernel_size=7, stride=3), + nn.BatchNorm1d(2 * c), + nn.ReLU(), + nn.ConvTranspose1d(2 * c, c, kernel_size=7, stride=2), + nn.BatchNorm1d(c), + nn.ReLU(), + nn.ConvTranspose1d(c, 1, kernel_size=3, stride=2), + ] + if use_adaptive_pooling: + decoder_layers.append(nn.AdaptiveAvgPool1d(output_size=feature_size)) + decoder_layers.append(nn.Sigmoid()) + self.decoder = nn.Sequential(*decoder_layers) + + def forward(self, x): + # Allow either (B, F) or (B, 1, F) input. + if x.dim() == 2: + x = x.unsqueeze(1) + decoded = self.decoder(self.encoder(x)) + if decoded.dim() == 3 and decoded.shape[1] == 1: + decoded = decoded.squeeze(1) + return decoded + + def decode(self, latent): + """Decode a latent emitted by the generator. + + The generator outputs ``(B, hidden_dim)``; we add a length-1 spatial + axis so the transposed-conv stack accepts it. + """ + if latent.dim() == 2: + latent = latent.unsqueeze(2) + decoded = self.decoder(latent) + if decoded.dim() == 3 and decoded.shape[1] == 1: + decoded = decoded.squeeze(1) + return decoded + + +class CorGANLinearAutoencoder(nn.Module): + """Linear autoencoder, the reference's commented-out alternative. + + Used as a fallback for small vocabularies where the 6-layer CNN can't + physically compress the input (its smallest viable input is ~500 + features). For unordered code spaces this is often a stronger baseline + anyway, since 1D conv assumes spatial locality. + """ + + def __init__(self, feature_size: int, latent_dim: int = 128): + super().__init__() + self.feature_size = feature_size + self.encoder = nn.Sequential( + nn.Linear(feature_size, latent_dim), + nn.ReLU(), + nn.BatchNorm1d(latent_dim), + ) + self.decoder = nn.Sequential( + nn.Linear(latent_dim, feature_size), + nn.Sigmoid(), + ) + + def forward(self, x): + return self.decoder(self.encoder(x)) + + def decode(self, latent): + return self.decoder(latent) + + +class CorGANGenerator(nn.Module): + """Two-layer MLP generator with residual connections (per reference).""" + + def __init__(self, latent_dim: int = 128, hidden_dim: int = 128): + super().__init__() + self.linear1 = nn.Linear(latent_dim, hidden_dim) + self.bn1 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.act1 = nn.ReLU() + + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.bn2 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.act2 = nn.Tanh() + + def forward(self, x): + residual = x + out = self.act1(self.bn1(self.linear1(x))) + residual + + residual = out + out = self.act2(self.bn2(self.linear2(out))) + residual + return out + + +class CorGANCritic(nn.Module): + """4-layer MLP critic with optional minibatch averaging. + + No sigmoid at the output: this is a Wasserstein critic (not a + classifier), per the reference WGAN training loop. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int = 256, + minibatch_averaging: bool = True, + ): + super().__init__() + self.minibatch_averaging = minibatch_averaging + model_input_dim = input_dim * 2 if minibatch_averaging else input_dim + + self.model = nn.Sequential( + nn.Linear(model_input_dim, hidden_dim), + nn.ReLU(True), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(True), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(True), + nn.Linear(hidden_dim, 1), + ) + + def forward(self, x): + if self.minibatch_averaging: + x_mean = torch.mean(x, dim=0).repeat(x.shape[0], 1) + x = torch.cat((x, x_mean), dim=1) + return self.model(x) + + +def _weights_init(m): + """Reference initialization scheme (wgancnnmimic.py).""" + name = m.__class__.__name__ + if "Conv" in name: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif "BatchNorm" in name: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + +def _autoencoder_loss(x_output, y_target): + """Sparse-friendly BCE used by the reference CorGAN autoencoder. + + Sum over features, then mean over batch -- equivalent to + ``BCELoss(reduction='sum') / batch_size``. ``BCELoss(reduction='mean')`` + additionally means over features and dilutes the signal for sparse + multi-hot targets. + """ + epsilon = 1e-12 + term = y_target * torch.log(x_output + epsilon) + ( + 1.0 - y_target + ) * torch.log(1.0 - x_output + epsilon) + return torch.mean(-torch.sum(term, dim=1), dim=0) + + +# Minimum feature size the 6-layer CNN encoder can survive without producing a +# non-positive spatial dimension. The reference targets MIMIC's ~7k vocabulary; +# the conv chain (kernels 5,5,5,5,5,8 / strides 2,2,3,3,3,1) only produces a +# positive output for inputs >= ~1000. We pick a slightly conservative floor. +_CNN_MIN_FEATURES = 1000 + +# The reference CNN autoencoder's bottleneck has 32 * n_channels_base (= 32*4) +# output channels. The generator output is fed in as those channels, so this +# is fixed by architecture. +_CNN_BOTTLENECK_DIM = 32 * 4 + + +# ---------------------------------------------------------------------------- +# PyHealth BaseModel wrapper +# ---------------------------------------------------------------------------- +class CorGAN(BaseModel): + """CorGAN synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``. + + Trains a 1D-convolutional autoencoder + WGAN generator/critic on multi-hot + patient vectors and generates new synthetic patients by sampling noise, + pushing it through the generator, and decoding back with the (jointly + trained) decoder. + + Generation is **unconditional**: each synthetic patient is a flat bag of + codes (no visit structure), matching the ``multi_hot`` input schema. + + Args: + dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains + ``{"visits": "multi_hot"}`` and whose ``output_schema`` is empty. + latent_dim: Generator noise dimensionality. Default: 128. + hidden_dim: Generator hidden width. Default: 128. + discriminator_hidden_dim: Critic hidden width. Default: 256. + minibatch_averaging: Concatenate per-batch mean to each critic input. + Default: True. + autoencoder_type: One of ``"cnn"`` (the reference) or ``"linear"`` + (the reference's commented-out alternative). For vocabularies + smaller than ~500 the CNN cannot compress the input, so we + silently fall back to ``"linear"``. Default: ``"cnn"``. + use_adaptive_pooling: When using the CNN autoencoder, add an + ``AdaptiveAvgPool1d`` so the decoder matches any vocabulary size. + Ignored for the linear variant. Default: True. + batch_size: Training batch size. Default: 512. + ae_epochs: Autoencoder pre-training epochs. Default: 100. + gan_epochs: Adversarial training epochs. Default: 200. + lr: Learning rate for all optimizers. Default: 1e-3. + weight_decay: L2 regularisation for all Adam optimizers. Default: 1e-4. + b1: Adam beta1. Default: 0.9. + b2: Adam beta2. Default: 0.999. + n_iter_D: Critic updates per generator update (reference: 5). + clamp_lower: WGAN critic weight-clip lower bound. Default: -0.01. + clamp_upper: WGAN critic weight-clip upper bound. Default: 0.01. + save_dir: Checkpoint directory used by ``train_model``. + Default: ``"./save/corgan/"``. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... {"patient_id": "p1", "visits": ["A", "B", "C"]}, + ... {"patient_id": "p2", "visits": ["A", "C", "D"]}, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"visits": "multi_hot"}, + ... output_schema={}, + ... ) + >>> model = CorGAN(dataset, latent_dim=16, hidden_dim=16, batch_size=2) + >>> isinstance(model, CorGAN) + True + """ + + def __init__( + self, + dataset, + latent_dim: int = 128, + hidden_dim: int = 128, + discriminator_hidden_dim: int = 256, + minibatch_averaging: bool = True, + autoencoder_type: str = "cnn", + use_adaptive_pooling: bool = True, + batch_size: int = 512, + ae_epochs: int = 100, + gan_epochs: int = 200, + lr: float = 1e-3, + weight_decay: float = 1e-4, + b1: float = 0.9, + b2: float = 0.999, + n_iter_D: int = 5, + clamp_lower: float = -0.01, + clamp_upper: float = 0.01, + save_dir: str = "./save/corgan/", + ) -> None: + super().__init__(dataset) + + if "visits" not in dataset.input_processors: + raise ValueError( + "CorGAN expects an input feature named 'visits' backed by a " + "MultiHotProcessor." + ) + + self._batch_size = batch_size + self._ae_epochs = ae_epochs + self._gan_epochs = gan_epochs + self._lr = lr + self._weight_decay = weight_decay + self._betas = (b1, b2) + self._n_iter_D = n_iter_D + self._clamp_lower = clamp_lower + self._clamp_upper = clamp_upper + self.save_dir = save_dir + + # Code vocab from the MultiHotProcessor's label_vocab. + self.visits_processor = dataset.input_processors["visits"] + self.input_dim = self.visits_processor.size() + self._idx_to_code: List[Optional[str]] = [None] * self.input_dim + for code, idx in self.visits_processor.label_vocab.items(): + self._idx_to_code[idx] = code + + # CNN can't compress small vocabularies; fall back to linear. + if autoencoder_type == "cnn" and self.input_dim < _CNN_MIN_FEATURES: + autoencoder_type = "linear" + self.autoencoder_type = autoencoder_type + + if autoencoder_type == "linear": + # Linear AE: bottleneck = generator hidden dim (user-controlled). + self.autoencoder = CorGANLinearAutoencoder( + feature_size=self.input_dim, latent_dim=hidden_dim + ) + elif autoencoder_type == "cnn": + # CNN AE: bottleneck is fixed at 128 by the conv-channel ladder. + # The generator must emit that many features so the transposed-conv + # decoder accepts its output. We silently align hidden_dim to the + # CNN bottleneck (the reference always uses 128). + if hidden_dim != _CNN_BOTTLENECK_DIM: + hidden_dim = _CNN_BOTTLENECK_DIM + self.autoencoder = CorGANCNNAutoencoder( + feature_size=self.input_dim, + use_adaptive_pooling=use_adaptive_pooling, + ) + else: + raise ValueError( + f"Unknown autoencoder_type={autoencoder_type!r}; " + "expected 'cnn' or 'linear'." + ) + + # The generator's residual connection requires latent_dim == hidden_dim + # (per the reference). Align silently if the user mismatched. + if latent_dim != hidden_dim: + latent_dim = hidden_dim + self.latent_dim = latent_dim + self.hidden_dim = hidden_dim + + self.generator = CorGANGenerator( + latent_dim=latent_dim, hidden_dim=hidden_dim + ) + self.critic = CorGANCritic( + input_dim=self.input_dim, + hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging, + ) + + self.autoencoder.apply(_weights_init) + self.generator.apply(_weights_init) + self.critic.apply(_weights_init) + + # ------------------------------------------------------------------ + # forward -- required by BaseModel + # ------------------------------------------------------------------ + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """CorGAN does not have a single supervised forward pass. + + Use :meth:`train_model` for training and :meth:`generate` for + synthesis. ``forward`` is implemented only to satisfy the + ``BaseModel`` abstract contract. + """ + raise NotImplementedError( + "CorGAN is a GAN: use train_model() and generate() instead of " + "forward()." + ) + + # ------------------------------------------------------------------ + # Custom training loop + # ------------------------------------------------------------------ + @staticmethod + def _resolve_device(device=None) -> torch.device: + """Resolve a user-supplied device, defaulting to CUDA when available.""" + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + def _build_dataloader(self, dataset) -> DataLoader: + """Stack the multi-hot tensors of ``dataset`` into a DataLoader.""" + tensors = [dataset[i]["visits"] for i in range(len(dataset))] + matrix = torch.stack(tensors).numpy() + wrapped = _MultiHotDataset(matrix) + sampler = RandomSampler(wrapped, replacement=True) + return DataLoader( + wrapped, + batch_size=self._batch_size, + shuffle=False, + num_workers=0, + drop_last=True, + sampler=sampler, + ) + + def train_model(self, train_dataset, val_dataset=None, device=None) -> Dict: + """Train CorGAN with a custom two-phase loop. + + Named ``train_model`` (not ``train``) to avoid shadowing + ``nn.Module.train()``. Phase 1 pre-trains the autoencoder with + sparse BCE reconstruction loss; phase 2 runs WGAN adversarial + training (weight-clipped critic, joint generator + decoder). + + Args: + train_dataset: ``SampleDataset`` for training. + val_dataset: Unused; accepted for API symmetry. + device: Device to train on. If ``None``, uses CUDA when available. + + Returns: + Dict with keys ``"autoencoder_loss"``, ``"critic_loss"``, + ``"generator_loss"`` -- one float per epoch in each list. + """ + device = self._resolve_device(device) + self.to(device) + print(f"Training CorGAN on: {device}") + + os.makedirs(self.save_dir, exist_ok=True) + dataloader = self._build_dataloader(train_dataset) + history: Dict[str, List[float]] = { + "autoencoder_loss": [], + "critic_loss": [], + "generator_loss": [], + } + + # ---- Phase 1: Autoencoder pretraining ---- + optimizer_ae = torch.optim.Adam( + self.autoencoder.parameters(), + lr=self._lr, + betas=self._betas, + weight_decay=self._weight_decay, + ) + for epoch in tqdm(range(self._ae_epochs), desc="AE pretrain"): + self.autoencoder.train() + last_loss = 0.0 + for batch in dataloader: + real = batch.to(self.device) + recon = self.autoencoder(real) + loss = _autoencoder_loss(recon, real) + + optimizer_ae.zero_grad() + loss.backward() + optimizer_ae.step() + last_loss = loss.item() + history["autoencoder_loss"].append(last_loss) + + # ---- Phase 2: WGAN adversarial training ---- + # The reference jointly optimises the generator and the autoencoder's + # decoder, with a smaller LR on the decoder. We reuse that scheme. + g_params = [ + {"params": self.generator.parameters()}, + {"params": self.autoencoder.decoder.parameters(), "lr": 1e-4}, + ] + optimizer_g = torch.optim.Adam( + g_params, + lr=self._lr, + betas=self._betas, + weight_decay=self._weight_decay, + ) + optimizer_d = torch.optim.Adam( + self.critic.parameters(), + lr=self._lr, + betas=self._betas, + weight_decay=self._weight_decay, + ) + one = torch.tensor(1.0, device=self.device) + mone = torch.tensor(-1.0, device=self.device) + gen_iters = 0 + + for epoch in tqdm(range(self._gan_epochs), desc="GAN train"): + self.generator.train() + self.critic.train() + self.autoencoder.eval() + self.autoencoder.decoder.train() + + last_d, last_g = 0.0, 0.0 + for real in dataloader: + real = real.to(self.device) + bs = real.size(0) + + # --- Train critic --- + for p in self.critic.parameters(): + p.requires_grad = True + # Reference: ramp up critic iterations at the start and at + # periodic intervals to keep the Wasserstein estimate tight. + n_iter_D = ( + 100 if (gen_iters < 25 or gen_iters % 500 == 0) + else self._n_iter_D + ) + for _ in range(n_iter_D): + for p in self.critic.parameters(): + p.data.clamp_(self._clamp_lower, self._clamp_upper) + + optimizer_d.zero_grad() + errD_real = torch.mean(self.critic(real)).squeeze() + errD_real.backward(one) + + z = torch.randn(bs, self.latent_dim, device=self.device) + fake = self.autoencoder.decode(self.generator(z)) + errD_fake = torch.mean(self.critic(fake.detach())).squeeze() + errD_fake.backward(mone) + last_d = (errD_real - errD_fake).item() + + optimizer_d.step() + + # --- Train generator --- + for p in self.critic.parameters(): + p.requires_grad = False + optimizer_g.zero_grad() + z = torch.randn(bs, self.latent_dim, device=self.device) + fake = self.autoencoder.decode(self.generator(z)) + errG = torch.mean(self.critic(fake)).squeeze() + errG.backward(one) + optimizer_g.step() + last_g = errG.item() + gen_iters += 1 + + history["critic_loss"].append(last_d) + history["generator_loss"].append(last_g) + + self.save_model(os.path.join(self.save_dir, "final.pt")) + return history + + # ------------------------------------------------------------------ + # Synthesis + # ------------------------------------------------------------------ + def generate( + self, + num_samples: int, + random_sampling: bool = False, + device=None, + ) -> List[Dict]: + """Generate synthetic patient records. + + Each synthetic patient is decoded from a generated multi-hot vector + by thresholding (or, optionally, Bernoulli sampling) at 0.5 and + mapping the indices back to code strings. + + Args: + num_samples: Number of synthetic patients to generate. + random_sampling: If True, Bernoulli-sample the decoder output; + otherwise threshold at 0.5 (the reference's behaviour). + Default: False. + device: Device to generate on. If ``None``, uses CUDA when + available. + + Returns: + List of dicts + ``{"patient_id": "synthetic_i", "visits": [[code, ...]]}``. + ``visits`` is a list containing a **single** visit (matching + HALO's nested-list output structure). CorGAN is a bag-of-codes + model -- following the reference preprocessing, each patient is + represented by the union of codes across all of their + historical visits -- so the single inner list is that aggregate + bag. The inner list may be empty if the generator produced an + all-zero vector. + """ + device = self._resolve_device(device) + self.to(device) + + self.generator.eval() + self.autoencoder.eval() + + bs = min(self._batch_size, max(num_samples, 1)) + rows = np.zeros((num_samples, self.input_dim), dtype=np.float32) + pbar = tqdm(total=num_samples, desc="Generating patients") + with torch.no_grad(): + i = 0 + while i < num_samples: + cur = min(bs, num_samples - i) + z = torch.randn(cur, self.latent_dim, device=self.device) + probs = self.autoencoder.decode(self.generator(z)) + if random_sampling: + sample = torch.bernoulli(probs) + else: + sample = (probs >= 0.5).float() + rows[i : i + cur] = sample.cpu().numpy() + i += cur + pbar.update(cur) + pbar.close() + + results: List[Dict] = [] + for i in range(num_samples): + codes = [ + self._idx_to_code[idx] + for idx in np.nonzero(rows[i])[0] + if self._idx_to_code[idx] not in (None, "", "") + ] + # Wrap in a single-visit list to mirror HALO's nested output. + # CorGAN models the patient as one aggregate bag of codes. + results.append({"patient_id": f"synthetic_{i}", "visits": [codes]}) + return results + + # ------------------------------------------------------------------ + # Checkpoint I/O + # ------------------------------------------------------------------ + def save_model(self, path: str) -> None: + """Save weights, vocabulary, and architecture metadata.""" + torch.save( + { + "autoencoder": self.autoencoder.state_dict(), + "generator": self.generator.state_dict(), + "critic": self.critic.state_dict(), + "autoencoder_type": self.autoencoder_type, + "input_dim": self.input_dim, + "latent_dim": self.latent_dim, + "idx_to_code": self._idx_to_code, + }, + path, + ) + + def load_model(self, path: str) -> None: + """Load weights and vocabulary from a checkpoint.""" + ckpt = torch.load(path, map_location=self.device) + self.autoencoder.load_state_dict(ckpt["autoencoder"]) + self.generator.load_state_dict(ckpt["generator"]) + self.critic.load_state_dict(ckpt["critic"]) + if "idx_to_code" in ckpt: + self._idx_to_code = ckpt["idx_to_code"] diff --git a/pyhealth/models/generators/gpt2.py b/pyhealth/models/generators/gpt2.py new file mode 100644 index 000000000..99b3bf92a --- /dev/null +++ b/pyhealth/models/generators/gpt2.py @@ -0,0 +1,366 @@ +"""GPT-2 baseline for unconditional synthetic EHR generation. + +A simple decoder-only baseline that mirrors the standalone reference script +``generate_synthetic_mimic3_gpt2.py`` (``--mode transformer_baseline``) but +plugged into the standard PyHealth ``dataset -> set_task -> SampleDataset -> +model`` pipeline. It consumes the same :class:`~pyhealth.tasks.EHRGeneration` +task as :class:`~pyhealth.models.HALO`. + +Each patient's visits are flattened into a single token stream:: + + [BOS] [VISIT_DELIM] ... [EOS] + +A small :class:`transformers.GPT2LMHeadModel` is trained on these streams with +causal language modeling. Generation autoregressively samples a token stream +(``do_sample`` + top-k/top-p) and decodes it back into per-visit code lists, +splitting on the ``[VISIT_DELIM]`` token. + +The code vocabulary is taken from the dataset's ``NestedSequenceProcessor`` +(which already reserves index 0 for ```` and index 1 for ````); three +special tokens (BOS, EOS, VISIT_DELIM) are appended, and ```` (index 0) is +reused as the padding token. +""" + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm +from transformers import GPT2Config, GPT2LMHeadModel + +from pyhealth.datasets import get_dataloader +from pyhealth.models import BaseModel + + +class GPT2(BaseModel): + """GPT-2 baseline synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``. + + Args: + dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains + ``{"visits": NestedSequenceProcessor}`` and whose ``output_schema`` + is empty. + embed_dim: GPT-2 embedding dimension (``n_embd``). Must be divisible by + ``n_heads``. Default: 512. + n_heads: Number of attention heads. Default: 8. + n_layers: Number of transformer layers. Default: 8. + max_len: Maximum token-stream length (``n_positions``); streams are + truncated to this length. Default: 512. + batch_size: Training batch size. Default: 64. + epochs: Number of training epochs. Default: 50. + lr: Learning rate for the Adam optimizer. Default: 1e-4. + save_dir: Directory for checkpoints written by ``train_model``. + Default: ``"./save/"``. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... {"patient_id": "p1", "visits": [["A", "B"], ["C"]]}, + ... {"patient_id": "p2", "visits": [["A"], ["B", "C"]]}, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"visits": "nested_sequence"}, + ... output_schema={}, + ... ) + >>> model = GPT2(dataset, embed_dim=16, n_heads=2, n_layers=2, max_len=64) + >>> isinstance(model, GPT2) + True + """ + + def __init__( + self, + dataset, + embed_dim: int = 512, + n_heads: int = 8, + n_layers: int = 8, + max_len: int = 512, + batch_size: int = 64, + epochs: int = 50, + lr: float = 1e-4, + save_dir: str = "./save/", + ) -> None: + super(GPT2, self).__init__(dataset) + + if "visits" not in dataset.input_processors: + raise ValueError( + "GPT2 expects an input feature named 'visits' backed by a " + "NestedSequenceProcessor." + ) + + self.save_dir = save_dir + self._batch_size = batch_size + self._epochs = epochs + self._lr = lr + self.max_len = max_len + + # Code vocab from the NestedSequenceProcessor (includes =0, =1). + self.visits_processor = dataset.input_processors["visits"] + self.code_vocab_size = self.visits_processor.vocab_size() + # Append three special tokens after the code vocab; reuse =0 as PAD. + self.bos_id = self.code_vocab_size + self.eos_id = self.code_vocab_size + 1 + self.delim_id = self.code_vocab_size + 2 + self.pad_id = 0 + total_vocab_size = self.code_vocab_size + 3 + + config = GPT2Config( + vocab_size=total_vocab_size, + n_positions=max_len, + n_embd=embed_dim, + n_layer=n_layers, + n_head=n_heads, + bos_token_id=self.bos_id, + eos_token_id=self.eos_id, + ) + # Registered as a sub-module so .parameters()/.to() work. + self.gpt2 = GPT2LMHeadModel(config) + + # ------------------------------------------------------------------ + @staticmethod + def _resolve_device(device=None) -> torch.device: + """Resolve a user-supplied device, defaulting to CUDA when available.""" + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + # ------------------------------------------------------------------ + # Visit index tensor -> flat causal-LM token stream + # ------------------------------------------------------------------ + def _encode_visits(self, visits: torch.Tensor): + """Flatten the padded visit-index tensor into causal-LM token streams. + + Args: + visits: LongTensor ``(batch, max_visits, max_codes_per_visit)`` from + the ``NestedSequenceProcessor``. Index 0 is ```` and is + skipped. + + Returns: + input_ids: LongTensor ``(batch, L)`` token streams, right-padded. + attention_mask: LongTensor ``(batch, L)`` (1 for real tokens). + labels: ``input_ids`` with padding positions set to ``-100`` so they + are ignored by the cross-entropy loss. + """ + batch_seqs: List[List[int]] = [] + for i in range(visits.shape[0]): + n_visits = int((visits[i].sum(dim=-1) > 0).sum().item()) + seq: List[int] = [self.bos_id] + for j in range(n_visits): + codes = [int(c) for c in visits[i, j].tolist() if c > 0] + seq.extend(codes) + if j < n_visits - 1: + seq.append(self.delim_id) + seq.append(self.eos_id) + batch_seqs.append(seq[: self.max_len]) + + length = max(len(s) for s in batch_seqs) + input_ids = torch.full( + (len(batch_seqs), length), self.pad_id, dtype=torch.long, device=self.device + ) + attention_mask = torch.zeros( + (len(batch_seqs), length), dtype=torch.long, device=self.device + ) + for i, seq in enumerate(batch_seqs): + input_ids[i, : len(seq)] = torch.tensor(seq, device=self.device) + attention_mask[i, : len(seq)] = 1 + + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + return input_ids, attention_mask, labels + + # ------------------------------------------------------------------ + # forward -- required by BaseModel + # ------------------------------------------------------------------ + def forward(self, visits: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass. + + Args: + visits: LongTensor ``(batch, max_visits, max_codes_per_visit)`` from + the ``NestedSequenceProcessor``. + **kwargs: Any other batch keys are ignored. + + Returns: + Dict with ``loss`` (scalar causal-LM cross-entropy) and ``y_prob`` + (next-token probabilities, shape ``(batch, L, vocab_size)``). + """ + visits = visits.to(self.device) + input_ids, attention_mask, labels = self._encode_visits(visits) + out = self.gpt2( + input_ids=input_ids, attention_mask=attention_mask, labels=labels + ) + return {"loss": out.loss, "y_prob": F.softmax(out.logits, dim=-1)} + + # ------------------------------------------------------------------ + # Custom training loop + # ------------------------------------------------------------------ + def train_model(self, train_dataset, val_dataset=None, device=None) -> None: + """Train the GPT-2 baseline with a custom loop. + + Named ``train_model`` (not ``train``) to avoid shadowing + ``nn.Module.train()``. Uses the standard ``get_dataloader``, an Adam + optimizer, and causal-LM loss. When ``val_dataset`` is given, validation + loss is computed after each epoch and the best checkpoint is saved to + ``self.save_dir``. + + Args: + train_dataset: ``SampleDataset`` for training. + val_dataset: Optional ``SampleDataset`` for validation. + device: Device to train on, e.g. ``"cuda"``, ``"cuda:1"``, or + ``"cpu"``. If ``None`` (default), uses CUDA when available and + falls back to CPU. + """ + device = self._resolve_device(device) + self.to(device) + print(f"Training on: {device}") + + os.makedirs(self.save_dir, exist_ok=True) + optimizer = torch.optim.Adam(self.gpt2.parameters(), lr=self._lr) + + checkpoint_path = os.path.join(self.save_dir, "gpt2_model") + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.gpt2.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + train_loader = get_dataloader( + train_dataset, batch_size=self._batch_size, shuffle=True + ) + + global_loss = 1e10 + for epoch in tqdm(range(self._epochs), desc="Epochs"): + self.gpt2.train() + batch_iter = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False) + for batch in batch_iter: + visits = batch["visits"].to(self.device) + input_ids, attention_mask, labels = self._encode_visits(visits) + + optimizer.zero_grad() + out = self.gpt2( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + ) + out.loss.backward() + optimizer.step() + batch_iter.set_postfix(loss=f"{out.loss.item():.4f}") + + if val_dataset is not None: + self.gpt2.eval() + val_loader = get_dataloader( + val_dataset, batch_size=self._batch_size, shuffle=False + ) + val_losses = [] + with torch.no_grad(): + for val_batch in val_loader: + visits = val_batch["visits"].to(self.device) + input_ids, attention_mask, labels = self._encode_visits(visits) + out = self.gpt2( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + ) + val_losses.append(out.loss.item()) + + cur_val_loss = float(np.mean(val_losses)) + print(f"Epoch {epoch} Validation Loss: {cur_val_loss:.7f}") + if cur_val_loss < global_loss: + global_loss = cur_val_loss + state = { + "model": self.gpt2.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + } + torch.save(state, checkpoint_path) + print("------------ Save best model ------------") + + # ------------------------------------------------------------------ + # Synthesis + # ------------------------------------------------------------------ + def _decode_ids(self, ids: List[int], index_to_code: Dict[int, str]) -> List[List[str]]: + """Decode a generated token stream into per-visit code lists.""" + visits_out: List[List[str]] = [] + current: List[str] = [] + for tid in ids: + if tid in (self.bos_id, self.pad_id): + continue + if tid == self.eos_id: + break + if tid == self.delim_id: + if current: + visits_out.append(current) + current = [] + continue + if tid < self.code_vocab_size: + code = index_to_code.get(int(tid)) + if code not in (None, "", ""): + current.append(code) + if current: + visits_out.append(current) + return visits_out + + def generate( + self, + num_samples: int, + device=None, + top_k: int = 50, + top_p: float = 0.95, + ) -> List[Dict]: + """Generate synthetic patients with the trained GPT-2 baseline. + + Feeds a ``[BOS]`` token and autoregressively samples a token stream with + ``top_k``/``top_p`` sampling, then decodes it into per-visit code lists. + + Args: + num_samples: Number of synthetic patients to generate. + device: Device to generate on, e.g. ``"cuda"``, ``"cuda:1"``, or + ``"cpu"``. If ``None`` (default), uses CUDA when available and + falls back to CPU. + top_k: Top-k sampling cutoff. Default: 50. + top_p: Nucleus (top-p) sampling cutoff. Default: 0.95. + + Returns: + List of dicts, each ``{"patient_id": "synthetic_i", + "visits": [[code, ...], ...]}`` with decoded code strings. + """ + device = self._resolve_device(device) + self.to(device) + + index_to_code = {v: k for k, v in self.visits_processor.code_vocab.items()} + + self.gpt2.eval() + synthetic_dataset: List[Dict] = [] + sample_batch_size = min(num_samples, 256) + generated = 0 + pbar = tqdm(total=num_samples, desc="Generating patients") + + with torch.no_grad(): + while generated < num_samples: + bs = min(sample_batch_size, num_samples - generated) + input_ids = torch.full( + (bs, 1), self.bos_id, dtype=torch.long, device=self.device + ) + out_ids = self.gpt2.generate( + input_ids, + max_length=self.max_len, + do_sample=True, + top_k=top_k, + top_p=top_p, + pad_token_id=self.pad_id, + eos_token_id=self.eos_id, + ) + for i in range(bs): + visits_out = self._decode_ids( + out_ids[i].tolist(), index_to_code + ) + synthetic_dataset.append( + { + "patient_id": f"synthetic_{generated + i}", + "visits": visits_out, + } + ) + generated += bs + pbar.update(bs) + pbar.close() + + return synthetic_dataset diff --git a/pyhealth/models/generators/halo.py b/pyhealth/models/generators/halo.py new file mode 100644 index 000000000..374c14000 --- /dev/null +++ b/pyhealth/models/generators/halo.py @@ -0,0 +1,724 @@ +"""HALO: Hierarchical Autoregressive Language mOdel for synthetic EHR generation. + +This is a faithful port of the reference implementation +(https://github.com/Brandon-Theodorou/HALO_Inpatient) wrapped as a PyHealth +``BaseModel`` so it consumes the standard +``dataset -> set_task -> SampleDataset -> model`` pipeline. + +HALO is a two-level model: + +* a GPT-2-style **coarse** transformer operates over visit-level multi-hot + vectors, and +* a **fine** autoregressive head predicts the (multi-label) set of codes within + each visit. + +The transformer/head classes below (``LayerNorm``, ``Conv1D``, ``Attention``, +``MLP``, ``Block``, ``CoarseTransformerModel``, ``AutoregressiveLinear``, +``FineAutoregressiveHead``, ``HALOModel``) are ported verbatim from the +reference ``model.py``. The only behavioural change is that PyHealth's HALO is +**unconditional** (``label_vocab_size = 0``): it generates visit-code sequences +without conditioning on CCS labels. +""" + +import copy +import math +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from pyhealth.datasets import get_dataloader +from pyhealth.models import BaseModel + + +# ---------------------------------------------------------------------------- +# Configuration (plain class, not a dataclass; mirrors reference config.py) +# ---------------------------------------------------------------------------- +class HALOConfig: + """Hyperparameter container for the HALO transformer. + + Kept as a plain class with explicit ``__init__`` assignments (matching the + reference ``config.py``) so the low-level modules can read attributes such + as ``config.n_embd``. + """ + + def __init__( + self, + total_vocab_size: int, + code_vocab_size: int, + label_vocab_size: int = 0, + special_vocab_size: int = 3, + n_positions: int = 56, + n_ctx: int = 48, + n_embd: int = 768, + n_layer: int = 12, + n_head: int = 12, + layer_norm_epsilon: float = 1e-5, + initializer_range: float = 0.02, + batch_size: int = 48, + epoch: int = 50, + pos_loss_weight: Optional[float] = None, + lr: float = 1e-4, + ) -> None: + self.total_vocab_size = total_vocab_size + self.code_vocab_size = code_vocab_size + self.label_vocab_size = label_vocab_size + self.special_vocab_size = special_vocab_size + self.n_positions = n_positions + self.n_ctx = n_ctx + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.batch_size = batch_size + self.epoch = epoch + self.pos_loss_weight = pos_loss_weight + self.lr = lr + + +# ---------------------------------------------------------------------------- +# Transformer building blocks (ported verbatim from reference model.py) +# ---------------------------------------------------------------------------- +class LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside sqrt).""" + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class Conv1D(nn.Module): + def __init__(self, nf, nx): + super(Conv1D, self).__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(*size_out) + return x + + +class Attention(nn.Module): + def __init__(self, nx, n_ctx, config, scale=False): + super(Attention, self).__init__() + n_state = nx # in Attention: n_state=n_embd (nx=n_embd) + assert n_state % config.n_head == 0 + self.register_buffer( + "bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx) + ) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + self.c_attn = Conv1D(n_state * 3, nx) + self.c_proj = Conv1D(n_state, nx) + + def _attn(self, q, k, v): + w = torch.matmul(q, k) + if self.scale: + w = w / math.sqrt(v.size(-1)) + nd, ns = w.size(-2), w.size(-1) + b = self.bias[:, :, ns - nd:ns, :ns] + w = w * b - 1e10 * (1 - b) + w = nn.Softmax(dim=-1)(w) + return torch.matmul(w, v) + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) + if k: + return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) + else: + return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def forward(self, x, layer_past=None): + x = self.c_attn(x) + query, key, value = x.split(self.split_size, dim=2) + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + if layer_past is not None: + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] + key = torch.cat((past_key, key), dim=-1) + value = torch.cat((past_value, value), dim=-2) + present = torch.stack((key.transpose(-2, -1), value)) + a = self._attn(query, key, value) + a = self.merge_heads(a) + a = self.c_proj(a) + return a, present + + +class MLP(nn.Module): + def __init__(self, n_state, config): # in MLP: n_state=4 * n_embd + super(MLP, self).__init__() + nx = config.n_embd + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + + def forward(self, x): + # tanh-approximate GELU, matching the reference HALO implementation. + h = F.gelu(self.c_fc(x), approximate="tanh") + h2 = self.c_proj(h) + return h2 + + +class Block(nn.Module): + def __init__(self, n_ctx, config, scale=False): + super(Block, self).__init__() + nx = config.n_embd + self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.attn = Attention(nx, n_ctx, config, scale) + self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.mlp = MLP(4 * nx, config) + + def forward(self, x, layer_past=None): + a, present = self.attn(self.ln_1(x), layer_past=layer_past) + x = x + a + m = self.mlp(self.ln_2(x)) + x = x + m + return x, present + + +class CoarseTransformerModel(nn.Module): + def __init__(self, config): + super(CoarseTransformerModel, self).__init__() + self.n_layer = config.n_layer + self.n_embd = config.n_embd + self.n_vocab = config.total_vocab_size + + self.vis_embed_mat = nn.Linear( + config.total_vocab_size, config.n_embd, bias=False + ) + self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd) + block = Block(config.n_ctx, config, scale=True) + self.h = nn.ModuleList( + [copy.deepcopy(block) for _ in range(config.n_layer)] + ) + self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + + def forward(self, input_visits, position_ids=None, past=None): + if past is None: + past_length = 0 + past = [None] * len(self.h) + else: + past_length = past[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_visits.size(1) + past_length, + dtype=torch.long, + device=input_visits.device, + ) + position_ids = position_ids.unsqueeze(0).expand( + input_visits.size(0), input_visits.size(1) + ) + + inputs_embeds = self.vis_embed_mat(input_visits) + position_embeds = self.pos_embed_mat(position_ids) + hidden_states = inputs_embeds + position_embeds + for block, layer_past in zip(self.h, past): + hidden_states, _ = block(hidden_states, layer_past) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class AutoregressiveLinear(nn.Linear): + """Same as Linear except it has a configurable mask on the weights.""" + + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features, bias) + self.register_buffer( + "mask", torch.tril(torch.ones(in_features, out_features)).int() + ) + + def forward(self, input): + return F.linear(input, self.mask * self.weight, self.bias) + + +class FineAutoregressiveHead(nn.Module): + def __init__(self, config): + super(FineAutoregressiveHead, self).__init__() + self.auto1 = AutoregressiveLinear( + config.n_embd + config.total_vocab_size, + config.n_embd + config.total_vocab_size, + ) + self.auto2 = AutoregressiveLinear( + config.n_embd + config.total_vocab_size, + config.n_embd + config.total_vocab_size, + ) + self.n_embd = config.n_embd + self.tot_vocab = config.total_vocab_size + + def forward(self, history, input_visits): + history = history[:, :-1, :] + input_visits = input_visits[:, 1:, :] + code_logits = self.auto2( + torch.relu(self.auto1(torch.cat((history, input_visits), dim=2))) + )[:, :, self.n_embd - 1:-1] + return code_logits + + def sample(self, history, input_visits): + history = history[:, :-1, :] + input_visits = input_visits[:, 1:, :] + currVisit = torch.cat((history, input_visits), dim=2)[:, -1, :].unsqueeze(1) + code_logits = self.auto2(torch.relu(self.auto1(currVisit)))[ + :, :, self.n_embd - 1:-1 + ] + return code_logits + + +class HALOModel(nn.Module): + """Low-level HALO transformer + autoregressive head (ported verbatim).""" + + def __init__(self, config): + super(HALOModel, self).__init__() + self.transformer = CoarseTransformerModel(config) + self.ehr_head = FineAutoregressiveHead(config) + + def forward( + self, + input_visits, + position_ids=None, + ehr_labels=None, + ehr_masks=None, + past=None, + pos_loss_weight=None, + ): + hidden_states = self.transformer(input_visits, position_ids, past) + code_logits = self.ehr_head(hidden_states, input_visits) + sig = nn.Sigmoid() + code_probs = sig(code_logits) + if ehr_labels is not None: + shift_labels = ehr_labels[..., 1:, :].contiguous() + loss_weights = None + if pos_loss_weight is not None: + loss_weights = torch.ones( + code_probs.shape, device=code_probs.device + ) + loss_weights = loss_weights + (pos_loss_weight - 1) * shift_labels + if ehr_masks is not None: + code_probs = code_probs * ehr_masks + shift_labels = shift_labels * ehr_masks + if pos_loss_weight is not None: + loss_weights = loss_weights * ehr_masks + + bce = nn.BCELoss(weight=loss_weights) + loss = bce(code_probs, shift_labels) + return loss, code_probs, shift_labels + + return code_probs + + def sample(self, input_visits, random=True): + sig = nn.Sigmoid() + hidden_states = self.transformer(input_visits) + i = 0 + while i < self.ehr_head.tot_vocab: + next_logits = self.ehr_head.sample(hidden_states, input_visits) + next_probs = sig(next_logits) + if random: + visit = torch.bernoulli(next_probs) + else: + visit = torch.round(next_probs) + + remaining_visit = visit[:, 0, i:] + nonzero = torch.nonzero(remaining_visit, as_tuple=True)[1] + if nonzero.numel() == 0: + break + + first_nonzero = nonzero.min() + input_visits[:, -1, i + first_nonzero] = visit[:, 0, i + first_nonzero] + i = i + first_nonzero + 1 + + return input_visits + + +# ---------------------------------------------------------------------------- +# PyHealth BaseModel wrapper +# ---------------------------------------------------------------------------- +class HALO(BaseModel): + """HALO synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``. + + Trains a GPT-2-style transformer on patient visit-code sequences and + generates synthetic patients by autoregressive sampling. Generation is + **unconditional** (no label conditioning). + + The model infers its code vocabulary from the fitted ``SampleDataset``: + ``code_vocab_size = dataset.input_processors["visits"].vocab_size()`` + (the ``NestedSequenceProcessor`` vocab, which already reserves index 0 for + ```` and index 1 for ````). Three special tokens are appended for + start-of-sequence, end-of-sequence, and pad-visit. + + Args: + dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains + ``{"visits": NestedSequenceProcessor}`` and whose ``output_schema`` + is empty. + embed_dim: Transformer embedding dimension (``n_embd``). Default: 768. + n_heads: Number of attention heads. Must divide ``embed_dim``. + Default: 12. + n_layers: Number of transformer layers. Default: 12. + n_ctx: Maximum number of visit positions (context length). Default: 48. + batch_size: Training batch size. Default: 48. + epochs: Number of training epochs. Default: 50. + pos_loss_weight: Positive-class weight for the BCE loss. ``None`` means + no weighting. Default: None. + lr: Learning rate for the Adam optimizer. Default: 1e-4. + save_dir: Directory for checkpoints written by ``train_model``. + Default: ``"./save/"``. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... {"patient_id": "p1", "visits": [["A", "B"], ["C"]]}, + ... {"patient_id": "p2", "visits": [["A"], ["B", "C"]]}, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"visits": "nested_sequence"}, + ... output_schema={}, + ... ) + >>> model = HALO(dataset, embed_dim=16, n_heads=2, n_layers=2, n_ctx=8) + >>> isinstance(model, HALO) + True + """ + + def __init__( + self, + dataset, + embed_dim: int = 768, + n_heads: int = 12, + n_layers: int = 12, + n_ctx: int = 48, + batch_size: int = 48, + epochs: int = 50, + pos_loss_weight: Optional[float] = None, + lr: float = 1e-4, + save_dir: str = "./save/", + ) -> None: + super(HALO, self).__init__(dataset) + + if "visits" not in dataset.input_processors: + raise ValueError( + "HALO expects an input feature named 'visits' backed by a " + "NestedSequenceProcessor." + ) + + self.save_dir = save_dir + self._batch_size = batch_size + self._epochs = epochs + self._lr = lr + + # Code vocab from the NestedSequenceProcessor (includes , ). + self.visits_processor = dataset.input_processors["visits"] + code_vocab_size = self.visits_processor.vocab_size() + label_vocab_size = 0 # unconditional generation -- no output labels + # +3 special tokens: start-of-sequence, end-of-sequence, pad-visit. + total_vocab_size = code_vocab_size + label_vocab_size + 3 + + self.config = HALOConfig( + total_vocab_size=total_vocab_size, + code_vocab_size=code_vocab_size, + label_vocab_size=label_vocab_size, + special_vocab_size=3, + n_positions=n_ctx + 8, # position table needs a little slack + n_ctx=n_ctx, + n_embd=embed_dim, + n_layer=n_layers, + n_head=n_heads, + batch_size=batch_size, + epoch=epochs, + pos_loss_weight=pos_loss_weight, + lr=lr, + ) + + # Registered as a sub-module so .parameters()/.to() work. + self.halo_model = HALOModel(self.config) + + # ------------------------------------------------------------------ + # Multi-hot encoding helper + # ------------------------------------------------------------------ + def _encode_visits(self, visits: torch.Tensor): + """Convert a padded index tensor to HALO multi-hot format. + + ``NestedSequenceProcessor`` returns code indices; the transformer + expects multi-hot vectors of shape ``(batch, n_ctx, total_vocab_size)`` + with special tokens. Layout (mirrors the reference): position 0 is the + start token, visits occupy positions 2+, the end token is placed on the + last visit's row, and the pad token fills the remaining positions. + + Args: + visits: LongTensor ``(batch, max_visits, max_codes_per_visit)``. + Index 0 is ```` and is skipped. + + Returns: + batch_ehr: FloatTensor ``(batch, n_ctx, total_vocab_size)``. + batch_mask: FloatTensor ``(batch, n_ctx - 1, 1)``, shifted to align + with the autoregressive prediction targets. + """ + cfg = self.config + batch_size = visits.shape[0] + + batch_ehr = torch.zeros( + batch_size, cfg.n_ctx, cfg.total_vocab_size, device=self.device + ) + batch_mask = torch.zeros(batch_size, cfg.n_ctx, 1, device=self.device) + + start_idx = cfg.code_vocab_size + cfg.label_vocab_size + end_idx = start_idx + 1 + pad_idx = start_idx + 2 + + for i in range(batch_size): + # Count actual (non-padding) visits for this patient. + n_visits = int((visits[i].sum(dim=-1) > 0).sum().item()) + n_visits = min(n_visits, cfg.n_ctx - 2) + for j in range(n_visits): + for code_idx in visits[i, j]: + if code_idx > 0: # skip (index 0) + batch_ehr[i, j + 2, code_idx] = 1 + batch_mask[i, j + 2] = 1 + + batch_ehr[i, 0, start_idx] = 1 # start token + batch_ehr[i, n_visits + 1, end_idx] = 1 # end token (on last visit) + batch_ehr[i, n_visits + 2:, pad_idx] = 1 # pad visits + + batch_mask = batch_mask[:, 1:, :] # shift to align with shifted targets + return batch_ehr, batch_mask + + # ------------------------------------------------------------------ + # forward -- required by BaseModel + # ------------------------------------------------------------------ + def forward(self, visits: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass. + + Args: + visits: LongTensor ``(batch, max_visits, max_codes_per_visit)`` from + the ``NestedSequenceProcessor``. + **kwargs: Any other batch keys are ignored. + + Returns: + Dict with ``loss`` (scalar BCE) and ``y_prob`` (code probabilities, + shape ``(batch, n_ctx - 1, total_vocab_size)``). + """ + visits = visits.to(self.device) + batch_ehr, batch_mask = self._encode_visits(visits) + + loss, code_probs, _ = self.halo_model( + batch_ehr, + position_ids=None, + ehr_labels=batch_ehr, + ehr_masks=batch_mask, + pos_loss_weight=self.config.pos_loss_weight, + ) + return {"loss": loss, "y_prob": code_probs} + + # ------------------------------------------------------------------ + # Custom training loop + # ------------------------------------------------------------------ + @staticmethod + def _resolve_device(device=None) -> torch.device: + """Resolve a user-supplied device, defaulting to CUDA when available. + + Args: + device: ``None``, a device string (e.g. ``"cuda"``, ``"cuda:1"``, + ``"cpu"``), or a ``torch.device``. When ``None``, CUDA is used + if available, otherwise CPU. + """ + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + def train_model(self, train_dataset, val_dataset=None, device=None) -> None: + """Train the HALO model with a custom loop. + + Named ``train_model`` (not ``train``) to avoid shadowing + ``nn.Module.train()``. Uses the standard ``get_dataloader`` (which pads + the variable visit dimension for us), an Adam optimizer, and BCE loss. + When ``val_dataset`` is given, validation loss is computed after each + epoch and the best checkpoint is saved to ``self.save_dir``. + + Args: + train_dataset: ``SampleDataset`` for training. + val_dataset: Optional ``SampleDataset`` for validation. + device: Device to train on, e.g. ``"cuda"``, ``"cuda:1"``, or + ``"cpu"``. If ``None`` (default), uses CUDA when available and + falls back to CPU. + """ + device = self._resolve_device(device) + self.to(device) + print(f"Training on: {device}") + + os.makedirs(self.save_dir, exist_ok=True) + optimizer = torch.optim.Adam(self.halo_model.parameters(), lr=self._lr) + + checkpoint_path = os.path.join(self.save_dir, "halo_model") + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.halo_model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + train_loader = get_dataloader( + train_dataset, batch_size=self._batch_size, shuffle=True + ) + + global_loss = 1e10 + for epoch in tqdm(range(self._epochs), desc="Epochs"): + self.halo_model.train() + batch_iter = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False) + for batch in batch_iter: + visits = batch["visits"].to(self.device) + batch_ehr, batch_mask = self._encode_visits(visits) + + optimizer.zero_grad() + loss, _, _ = self.halo_model( + batch_ehr, + position_ids=None, + ehr_labels=batch_ehr, + ehr_masks=batch_mask, + pos_loss_weight=self.config.pos_loss_weight, + ) + loss.backward() + optimizer.step() + batch_iter.set_postfix(loss=f"{loss.item():.4f}") + + if val_dataset is not None: + self.halo_model.eval() + val_loader = get_dataloader( + val_dataset, batch_size=self._batch_size, shuffle=False + ) + val_losses = [] + with torch.no_grad(): + for val_batch in val_loader: + visits = val_batch["visits"].to(self.device) + batch_ehr, batch_mask = self._encode_visits(visits) + val_loss, _, _ = self.halo_model( + batch_ehr, + position_ids=None, + ehr_labels=batch_ehr, + ehr_masks=batch_mask, + pos_loss_weight=self.config.pos_loss_weight, + ) + val_losses.append(val_loss.item()) + + cur_val_loss = float(np.mean(val_losses)) + print(f"Epoch {epoch} Validation Loss: {cur_val_loss:.7f}") + if cur_val_loss < global_loss: + global_loss = cur_val_loss + state = { + "model": self.halo_model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + } + torch.save(state, checkpoint_path) + print("------------ Save best model ------------") + + # ------------------------------------------------------------------ + # Synthesis + # ------------------------------------------------------------------ + def generate( + self, num_samples: int, random_sampling: bool = True, device=None + ) -> List[Dict]: + """Generate synthetic patients using the trained HALO model. + + Autoregressive sampling: feed a start token and repeatedly call + ``halo_model.sample`` until an end token is produced or ``n_ctx`` steps + are reached, then decode code indices back to code strings. + + Args: + num_samples: Number of synthetic patients to generate. + random_sampling: If True, Bernoulli sampling (stochastic). If False, + rounding (deterministic). Default: True. + device: Device to generate on, e.g. ``"cuda"``, ``"cuda:1"``, or + ``"cpu"``. If ``None`` (default), uses CUDA when available and + falls back to CPU. + + Returns: + List of dicts, each ``{"patient_id": "synthetic_i", + "visits": [[code, ...], ...]}`` with decoded code strings. + """ + device = self._resolve_device(device) + self.to(device) + + cfg = self.config + index_to_code = {v: k for k, v in self.visits_processor.code_vocab.items()} + end_token_idx = cfg.code_vocab_size + cfg.label_vocab_size + 1 + start_token_idx = cfg.code_vocab_size + cfg.label_vocab_size + + self.halo_model.eval() + synthetic_dataset: List[Dict] = [] + sample_batch_size = min(num_samples, 256) + generated = 0 + pbar = tqdm(total=num_samples, desc="Generating patients") + + with torch.no_grad(): + while generated < num_samples: + bs = min(sample_batch_size, num_samples - generated) + stoken = torch.zeros( + cfg.total_vocab_size, device=self.device, dtype=torch.float32 + ) + stoken[start_token_idx] = 1 + prev = stoken.unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1) + empty = torch.zeros( + bs, 1, cfg.total_vocab_size, + device=self.device, dtype=torch.float32, + ) + + for _ in range(cfg.n_ctx - 1): + prev = self.halo_model.sample( + torch.cat((prev, empty), dim=1), random_sampling + ) + has_end = prev[:, :, end_token_idx].sum(dim=1).bool() + if has_end.all(): + break + + batch_ehrs = prev.cpu().detach().numpy() + for i in range(bs): + ehr = batch_ehrs[i] # (seq_len, total_vocab_size) + visits_out: List[List[str]] = [] + # Position 0 is the start token; visits occupy positions 1+. + for j in range(1, len(ehr)): + indices = np.nonzero(ehr[j])[0] + visit_codes: List[str] = [] + hit_end = False + for idx in indices: + if idx < cfg.code_vocab_size: + code = index_to_code.get(int(idx)) + if code not in (None, "", ""): + visit_codes.append(code) + elif idx == end_token_idx: + hit_end = True + if visit_codes: + visits_out.append(visit_codes) + if hit_end: + break + + synthetic_dataset.append( + { + "patient_id": f"synthetic_{generated + i}", + "visits": visits_out, + } + ) + generated += bs + pbar.update(bs) + pbar.close() + + return synthetic_dataset diff --git a/pyhealth/models/generators/medgan.py b/pyhealth/models/generators/medgan.py new file mode 100644 index 000000000..15b1595e7 --- /dev/null +++ b/pyhealth/models/generators/medgan.py @@ -0,0 +1,509 @@ +"""MedGAN: Medical Generative Adversarial Network for synthetic EHR generation. + +This is a port of the reference implementation +(https://github.com/mp2893/medgan and the PyTorch reimplementation under +``reference/cor-gan/Generative/medGAN/MIMIC/pytorch/MLP/medGAN.py``) wrapped +as a PyHealth ``BaseModel`` so it consumes the standard +``dataset -> SampleDataset -> model`` pipeline. + +MedGAN treats each patient as a flat bag-of-codes (no visit structure), so it +expects an input feature named ``visits`` backed by a ``MultiHotProcessor``. +The training procedure has two phases (mirroring the reference): + +* a **linear autoencoder** is pre-trained with binary cross-entropy + reconstruction loss, and +* an **adversarial training** phase where the generator emits latent codes, + the autoencoder's decoder projects them back to a multi-hot patient vector, + and a discriminator with optional minibatch averaging tries to distinguish + real from synthetic. + +The ``MedGANAutoencoder``, ``MedGANGenerator`` and ``MedGANDiscriminator`` +modules below mirror the reference. The public ``MedGAN`` class follows the +same API style as :class:`pyhealth.models.generators.HALO` +(``train_model`` / ``generate`` / ``save_model`` / ``load_model``). +""" + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset, RandomSampler +from tqdm import tqdm + +from pyhealth.models import BaseModel + + +# ---------------------------------------------------------------------------- +# Building blocks (ported from reference medgan.py / PyTorch reimplementation) +# ---------------------------------------------------------------------------- +class _MultiHotDataset(Dataset): + """Tiny ``torch.utils.data.Dataset`` over a multi-hot numpy matrix.""" + + def __init__(self, data: np.ndarray): + self.data = data.astype(np.float32) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx): + return torch.from_numpy(self.data[idx]) + + +class MedGANAutoencoder(nn.Module): + """Linear autoencoder for MedGAN pretraining. + + Mirrors the reference single-layer encoder/decoder + (``Linear -> Tanh`` and ``Linear -> Sigmoid``). + + Args: + input_dim: Vocabulary size (number of distinct codes). + embedding_dim: Latent dimensionality. Default: 128. + """ + + def __init__(self, input_dim: int, embedding_dim: int = 128): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(input_dim, embedding_dim), + nn.Tanh(), + ) + self.decoder = nn.Sequential( + nn.Linear(embedding_dim, input_dim), + nn.Sigmoid(), + ) + + def forward(self, x): + return self.decoder(self.encoder(x)) + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.decoder(x) + + +class MedGANGenerator(nn.Module): + """Two-layer MLP generator with residual connections (per reference).""" + + def __init__(self, latent_dim: int = 128, hidden_dim: int = 128): + super().__init__() + self.linear1 = nn.Linear(latent_dim, hidden_dim) + self.bn1 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.act1 = nn.ReLU() + + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.bn2 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.act2 = nn.Tanh() + + def forward(self, x): + residual = x + out = self.act1(self.bn1(self.linear1(x))) + residual + + residual = out + out = self.act2(self.bn2(self.linear2(out))) + residual + return out + + +class MedGANDiscriminator(nn.Module): + """MLP discriminator with optional minibatch averaging (per reference).""" + + def __init__( + self, + input_dim: int, + hidden_dim: int = 256, + minibatch_averaging: bool = True, + ): + super().__init__() + self.minibatch_averaging = minibatch_averaging + model_input_dim = input_dim * 2 if minibatch_averaging else input_dim + + self.model = nn.Sequential( + nn.Linear(model_input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + nn.Sigmoid(), + ) + + def forward(self, x): + if self.minibatch_averaging: + # Average over the batch and concatenate to each sample, exactly + # as in the reference (medGAN.py). + x_mean = torch.mean(x, dim=0).repeat(x.shape[0], 1) + x = torch.cat((x, x_mean), dim=1) + return self.model(x) + + +def _weights_init(m): + """Xavier-uniform for Linear, N(1, 0.02) gamma / 0 beta for BatchNorm.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.normal_(m.weight, mean=1.0, std=0.02) + nn.init.constant_(m.bias, 0) + + +def _autoencoder_loss(x_output, y_target): + """Sparse-friendly BCE: sum over features, mean over batch. + + Equivalent to ``BCELoss(reduction='sum') / batch_size`` and matches the + reference; ``BCELoss(reduction='mean')`` would also mean over features + which dilutes the signal for sparse code vectors. + """ + epsilon = 1e-12 + term = y_target * torch.log(x_output + epsilon) + ( + 1.0 - y_target + ) * torch.log(1.0 - x_output + epsilon) + return torch.mean(-torch.sum(term, dim=1), dim=0) + + +# ---------------------------------------------------------------------------- +# PyHealth BaseModel wrapper +# ---------------------------------------------------------------------------- +class MedGAN(BaseModel): + """MedGAN synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``. + + Generates synthetic binary EHR records via the two-phase procedure from + Choi et al. (MLHC 2017): pretrain a linear autoencoder, then run BCE-GAN + adversarial training where the generator maps noise to the autoencoder's + latent space and the decoder projects back to a multi-hot patient vector. + + Generation is **unconditional**: each synthetic patient is a flat bag of + codes (no visit structure), matching the ``multi_hot`` input schema. + + Args: + dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains + ``{"visits": "multi_hot"}`` and whose ``output_schema`` is empty. + latent_dim: Generator noise dimensionality. Default: 128. + hidden_dim: Generator hidden width (also the autoencoder embedding + dimension). Default: 128. + discriminator_hidden_dim: Discriminator hidden width. Default: 256. + minibatch_averaging: Concatenate per-batch mean to each discriminator + input. Default: True. + batch_size: Training batch size. Default: 512. + ae_epochs: Autoencoder pre-training epochs. Default: 100. + gan_epochs: Adversarial training epochs. Default: 200. + ae_lr: Autoencoder learning rate. Default: 1e-3. + gan_lr: GAN learning rate. Default: 1e-3. + save_dir: Checkpoint directory used by ``train_model``. + Default: ``"./save/medgan/"``. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... {"patient_id": "p1", "visits": ["A", "B", "C"]}, + ... {"patient_id": "p2", "visits": ["A", "C", "D"]}, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"visits": "multi_hot"}, + ... output_schema={}, + ... ) + >>> model = MedGAN(dataset, latent_dim=16, hidden_dim=16, batch_size=2) + >>> isinstance(model, MedGAN) + True + """ + + def __init__( + self, + dataset, + latent_dim: int = 128, + hidden_dim: int = 128, + discriminator_hidden_dim: int = 256, + minibatch_averaging: bool = True, + batch_size: int = 512, + ae_epochs: int = 100, + gan_epochs: int = 200, + ae_lr: float = 1e-3, + gan_lr: float = 1e-3, + save_dir: str = "./save/medgan/", + ) -> None: + super().__init__(dataset) + + if "visits" not in dataset.input_processors: + raise ValueError( + "MedGAN expects an input feature named 'visits' backed by a " + "MultiHotProcessor." + ) + + self.latent_dim = latent_dim + self.hidden_dim = hidden_dim + self._batch_size = batch_size + self._ae_epochs = ae_epochs + self._gan_epochs = gan_epochs + self._ae_lr = ae_lr + self._gan_lr = gan_lr + self.save_dir = save_dir + + # Code vocab from the MultiHotProcessor's label_vocab. + self.visits_processor = dataset.input_processors["visits"] + self.input_dim = self.visits_processor.size() + self._idx_to_code: List[Optional[str]] = [None] * self.input_dim + for code, idx in self.visits_processor.label_vocab.items(): + self._idx_to_code[idx] = code + + self.autoencoder = MedGANAutoencoder( + input_dim=self.input_dim, + embedding_dim=hidden_dim, + ) + self.generator = MedGANGenerator( + latent_dim=latent_dim, + hidden_dim=hidden_dim, + ) + self.discriminator = MedGANDiscriminator( + input_dim=self.input_dim, + hidden_dim=discriminator_hidden_dim, + minibatch_averaging=minibatch_averaging, + ) + + self.autoencoder.apply(_weights_init) + self.generator.apply(_weights_init) + self.discriminator.apply(_weights_init) + + # ------------------------------------------------------------------ + # forward -- required by BaseModel + # ------------------------------------------------------------------ + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """MedGAN does not have a single supervised forward pass. + + Use :meth:`train_model` for training and :meth:`generate` for + synthesis. ``forward`` is implemented only to satisfy the + ``BaseModel`` abstract contract. + """ + raise NotImplementedError( + "MedGAN is a GAN: use train_model() and generate() instead of " + "forward()." + ) + + # ------------------------------------------------------------------ + # Custom training loop + # ------------------------------------------------------------------ + @staticmethod + def _resolve_device(device=None) -> torch.device: + """Resolve a user-supplied device, defaulting to CUDA when available.""" + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + def _build_dataloader(self, dataset) -> DataLoader: + """Stack the multi-hot tensors of ``dataset`` into a DataLoader. + + The fitted ``MultiHotProcessor`` has already converted each patient's + ``visits`` field into a ``(input_dim,)`` float32 tensor, so we can + simply stack and wrap. + """ + tensors = [dataset[i]["visits"] for i in range(len(dataset))] + matrix = torch.stack(tensors).numpy() + wrapped = _MultiHotDataset(matrix) + sampler = RandomSampler(wrapped, replacement=True) + return DataLoader( + wrapped, + batch_size=self._batch_size, + shuffle=False, + num_workers=0, + drop_last=True, + sampler=sampler, + ) + + def train_model(self, train_dataset, val_dataset=None, device=None) -> None: + """Train MedGAN with a custom two-phase loop. + + Named ``train_model`` (not ``train``) to avoid shadowing + ``nn.Module.train()``. Phase 1 pre-trains the autoencoder with + sparse-friendly BCE reconstruction loss; phase 2 runs standard + BCE-GAN adversarial training where the generator+decoder are + optimised against a binary discriminator. + + Args: + train_dataset: ``SampleDataset`` for training. + val_dataset: Unused; accepted for API symmetry with other PyHealth + trainers. + device: Device to train on (``"cuda"``, ``"cpu"``, etc.). If + ``None``, uses CUDA when available. + """ + device = self._resolve_device(device) + self.to(device) + print(f"Training MedGAN on: {device}") + + os.makedirs(self.save_dir, exist_ok=True) + dataloader = self._build_dataloader(train_dataset) + + # ---- Phase 1: Autoencoder pretraining ---- + optimizer_ae = torch.optim.Adam( + self.autoencoder.parameters(), lr=self._ae_lr + ) + for epoch in tqdm(range(self._ae_epochs), desc="AE pretrain"): + self.autoencoder.train() + total_loss, n_batches = 0.0, 0 + for batch in dataloader: + real = batch.to(self.device) + recon = self.autoencoder(real) + loss = _autoencoder_loss(recon, real) + + optimizer_ae.zero_grad() + loss.backward() + optimizer_ae.step() + + total_loss += loss.item() + n_batches += 1 + + # ---- Phase 2: Adversarial training ---- + # Generator + the autoencoder's decoder are trained jointly, matching + # the reference (the decoder is what makes synthetic samples valid). + optimizer_g = torch.optim.Adam( + list(self.generator.parameters()) + + list(self.autoencoder.decoder.parameters()), + lr=self._gan_lr, + ) + optimizer_d = torch.optim.Adam( + self.discriminator.parameters(), lr=self._gan_lr + ) + + best_d_loss = float("inf") + for epoch in tqdm(range(self._gan_epochs), desc="GAN train"): + self.generator.train() + self.discriminator.train() + self.autoencoder.eval() + self.autoencoder.decoder.train() + + epoch_d_loss, epoch_g_loss, n_batches = 0.0, 0.0, 0 + for batch in dataloader: + real = batch.to(self.device) + bs = real.size(0) + + # --- Train Discriminator --- + optimizer_d.zero_grad() + noise = torch.randn(bs, self.latent_dim, device=self.device) + fake = self.autoencoder.decode(self.generator(noise)) + + real_pred = self.discriminator(real) + fake_pred = self.discriminator(fake.detach()) + + d_loss = F.binary_cross_entropy( + real_pred, torch.ones_like(real_pred) + ) + F.binary_cross_entropy( + fake_pred, torch.zeros_like(fake_pred) + ) + d_loss.backward() + optimizer_d.step() + + # --- Train Generator (+ decoder) --- + optimizer_g.zero_grad() + fake_pred = self.discriminator(fake) + g_loss = F.binary_cross_entropy( + fake_pred, torch.ones_like(fake_pred) + ) + g_loss.backward() + optimizer_g.step() + + epoch_d_loss += d_loss.item() + epoch_g_loss += g_loss.item() + n_batches += 1 + + avg_d = epoch_d_loss / max(n_batches, 1) + if avg_d < best_d_loss: + best_d_loss = avg_d + self.save_model(os.path.join(self.save_dir, "best.pt")) + + self.save_model(os.path.join(self.save_dir, "final.pt")) + + # ------------------------------------------------------------------ + # Synthesis + # ------------------------------------------------------------------ + def generate( + self, + num_samples: int, + random_sampling: bool = False, + device=None, + ) -> List[Dict]: + """Generate synthetic patient records. + + Each synthetic patient is decoded from a generated multi-hot vector + by thresholding (or, optionally, Bernoulli sampling) at 0.5 and + mapping the indices back to code strings. + + Args: + num_samples: Number of synthetic patients to generate. + random_sampling: If True, Bernoulli-sample the decoder output; + otherwise threshold at 0.5 (the reference's behaviour). + Default: False. + device: Device to generate on. If ``None``, uses CUDA when + available. + + Returns: + List of dicts + ``{"patient_id": "synthetic_i", "visits": [[code, ...]]}``. + ``visits`` is a list containing a **single** visit (matching + HALO's nested-list output structure). MedGAN is a bag-of-codes + model -- following the reference ``process_mimic.py``, each + patient is represented by the union of codes across all of + their historical visits -- so the single inner list is that + aggregate bag. The inner list may be empty if the generator + produced an all-zero vector. + """ + device = self._resolve_device(device) + self.to(device) + + self.generator.eval() + self.autoencoder.eval() + + bs = min(self._batch_size, max(num_samples, 1)) + rows = np.zeros((num_samples, self.input_dim), dtype=np.float32) + pbar = tqdm(total=num_samples, desc="Generating patients") + with torch.no_grad(): + i = 0 + while i < num_samples: + cur = min(bs, num_samples - i) + z = torch.randn(cur, self.latent_dim, device=self.device) + probs = self.autoencoder.decode(self.generator(z)) + if random_sampling: + sample = torch.bernoulli(probs) + else: + sample = (probs >= 0.5).float() + rows[i : i + cur] = sample.cpu().numpy() + i += cur + pbar.update(cur) + pbar.close() + + results: List[Dict] = [] + for i in range(num_samples): + codes = [ + self._idx_to_code[idx] + for idx in np.nonzero(rows[i])[0] + if self._idx_to_code[idx] not in (None, "", "") + ] + # Wrap in a single-visit list to mirror HALO's nested output. + # MedGAN models the patient as one aggregate bag of codes. + results.append({"patient_id": f"synthetic_{i}", "visits": [codes]}) + return results + + # ------------------------------------------------------------------ + # Checkpoint I/O + # ------------------------------------------------------------------ + def save_model(self, path: str) -> None: + """Save weights and the code vocabulary needed for decoding.""" + torch.save( + { + "autoencoder": self.autoencoder.state_dict(), + "generator": self.generator.state_dict(), + "discriminator": self.discriminator.state_dict(), + "input_dim": self.input_dim, + "latent_dim": self.latent_dim, + "idx_to_code": self._idx_to_code, + }, + path, + ) + + def load_model(self, path: str) -> None: + """Load weights and the code vocabulary from a checkpoint.""" + ckpt = torch.load(path, map_location=self.device) + self.autoencoder.load_state_dict(ckpt["autoencoder"]) + self.generator.load_state_dict(ckpt["generator"]) + self.discriminator.load_state_dict(ckpt["discriminator"]) + if "idx_to_code" in ckpt: + self._idx_to_code = ckpt["idx_to_code"] diff --git a/pyhealth/models/generators/promptehr.py b/pyhealth/models/generators/promptehr.py new file mode 100644 index 000000000..4622e72c6 --- /dev/null +++ b/pyhealth/models/generators/promptehr.py @@ -0,0 +1,517 @@ +"""PromptEHR: prompt-learning BART for synthetic EHR generation. + +This is a PyHealth ``BaseModel`` port of PromptEHR (Wang & Sun, EMNLP'22, +https://github.com/RyanWangZf/PromptEHR), wrapped so it consumes the standard +``dataset -> set_task -> SampleDataset -> model`` pipeline and shares the same +:class:`~pyhealth.tasks.EHRGeneration` task as +:class:`~pyhealth.models.HALO` and :class:`~pyhealth.models.GPT2`. + +PromptEHR treats sequential EHRs as a *neural database* and learns to fill in +patient records with a conditional **BART** (sequence-to-sequence denoising +autoencoder) trained with **prompt learning**. The three ideas that define the +reference implementation are preserved here: + +* **BART seq2seq core.** Generation is encoder-decoder, not decoder-only. The + reference subclasses ``BartForEHRSimulation`` from ``BartPretrainedModel``; + this port wraps :class:`transformers.BartForConditionalGeneration`, mirroring + the way :class:`~pyhealth.models.GPT2` wraps ``GPT2LMHeadModel``. +* **Prompt learning.** The reference reparameterizes a learnable prompt from + patient baseline demographics and prepends it to the encoder/decoder + (``ConditionalPrompt``). PyHealth's :class:`~pyhealth.tasks.EHRGeneration` + task is *unconditional* (only ``visits``, no baseline features -- exactly like + HALO/GPT2), so the prompt reduces to a learnable continuous **soft prefix** + prepended to the encoder. This is the prompt-tuning core without the + demographic reparameterization. +* **Span-infilling objective.** The reference learns by masking spans of codes + and reconstructing them. Here the encoder sees a BART-style span-infilled + copy of the patient's code stream -- random non-overlapping spans with + lengths drawn from ``Poisson(mean_span_len)`` are each replaced by a single + ``[MASK]`` sentinel until roughly ``mask_prob`` of the stream is covered -- + and the decoder reconstructs the full stream. This matches the original BART + text-infilling objective used by PromptEHR rather than per-token masking. + +Each patient's visits are serialized into a single code stream:: + + [CODE_PROMPT] [VISIT_DELIM] ... [EOS] + +The reference handles several code types (diagnosis / procedure / drug / lab) +each with its own modality prompt token; the PyHealth ``EHRGeneration`` task +exposes a single ``visits`` modality, so a single ``[CODE_PROMPT]`` token marks +it. The code vocabulary is taken from the dataset's +``NestedSequenceProcessor`` (which already reserves index 0 for ```` and +index 1 for ````); five special tokens (BOS, EOS, VISIT_DELIM, MASK, +CODE_PROMPT) are appended, and ```` (index 0) is reused as the pad token. +""" + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +from transformers import BartConfig, BartForConditionalGeneration + +from pyhealth.datasets import get_dataloader +from pyhealth.models import BaseModel + + +class PromptEHR(BaseModel): + """PromptEHR synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``. + + Trains a BART denoising autoencoder with a learnable soft prompt on patient + visit-code streams, then generates synthetic patients by prompt-conditioned + encoder-decoder sampling. Generation is **unconditional** (no demographic + conditioning), matching the :class:`~pyhealth.tasks.EHRGeneration` task. + + Args: + dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains + ``{"visits": NestedSequenceProcessor}`` and whose ``output_schema`` + is empty. + embed_dim: BART model dimension (``d_model``). Must be divisible by + ``n_heads``. Default: 256. + n_heads: Number of attention heads (encoder and decoder). Default: 8. + n_layers: Number of encoder and decoder layers each. Default: 6. + ffn_dim: Feed-forward dimension. Default: 4 * ``embed_dim``. + prompt_length: Number of learnable soft-prompt positions prepended to + the encoder. Default: 8. + max_len: Maximum code-stream length (``max_position_embeddings``); + streams are truncated to this length. Default: 512. + mask_prob: Target fraction of the (non-sentinel) code stream covered + by masked spans in the encoder input. Default: 0.15. + mean_span_len: Mean of the Poisson distribution used to sample span + lengths for BART-style span infilling. Default: 3.0. + batch_size: Training batch size. Default: 16. + epochs: Number of training epochs. Default: 50. + lr: Learning rate for the Adam optimizer. Default: 1e-4. + save_dir: Directory for checkpoints written by ``train_model``. + Default: ``"./save/"``. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> samples = [ + ... {"patient_id": "p1", "visits": [["A", "B"], ["C"]]}, + ... {"patient_id": "p2", "visits": [["A"], ["B", "C"]]}, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"visits": "nested_sequence"}, + ... output_schema={}, + ... ) + >>> model = PromptEHR( + ... dataset, embed_dim=16, n_heads=2, n_layers=2, max_len=64 + ... ) + >>> isinstance(model, PromptEHR) + True + """ + + def __init__( + self, + dataset, + embed_dim: int = 256, + n_heads: int = 8, + n_layers: int = 6, + ffn_dim: Optional[int] = None, + prompt_length: int = 8, + max_len: int = 512, + mask_prob: float = 0.15, + mean_span_len: float = 3.0, + batch_size: int = 16, + epochs: int = 50, + lr: float = 1e-4, + save_dir: str = "./save/", + ) -> None: + super(PromptEHR, self).__init__(dataset) + + if "visits" not in dataset.input_processors: + raise ValueError( + "PromptEHR expects an input feature named 'visits' backed by a " + "NestedSequenceProcessor." + ) + + self.save_dir = save_dir + self._batch_size = batch_size + self._epochs = epochs + self._lr = lr + self.max_len = max_len + self.mask_prob = mask_prob + self.mean_span_len = mean_span_len + self.prompt_length = prompt_length + + # Code vocab from the NestedSequenceProcessor (includes =0, =1). + self.visits_processor = dataset.input_processors["visits"] + self.code_vocab_size = self.visits_processor.vocab_size() + # Append five special tokens after the code vocab; reuse =0 as PAD. + self.bos_id = self.code_vocab_size + self.eos_id = self.code_vocab_size + 1 + self.delim_id = self.code_vocab_size + 2 # visit separator + self.mask_id = self.code_vocab_size + 3 # denoising mask token + self.code_prompt_id = self.code_vocab_size + 4 # modality prompt token + self.pad_id = 0 + total_vocab_size = self.code_vocab_size + 5 + + ffn_dim = ffn_dim if ffn_dim is not None else 4 * embed_dim + config = BartConfig( + vocab_size=total_vocab_size, + max_position_embeddings=max_len, + d_model=embed_dim, + encoder_layers=n_layers, + decoder_layers=n_layers, + encoder_attention_heads=n_heads, + decoder_attention_heads=n_heads, + encoder_ffn_dim=ffn_dim, + decoder_ffn_dim=ffn_dim, + pad_token_id=self.pad_id, + bos_token_id=self.bos_id, + eos_token_id=self.eos_id, + decoder_start_token_id=self.eos_id, # BART convention + forced_bos_token_id=None, + forced_eos_token_id=None, + ) + # Registered as sub-modules so .parameters()/.to() work. + self.bart = BartForConditionalGeneration(config) + # Learnable soft prompt (prompt learning), prepended to the encoder. + self.soft_prompt = nn.Parameter(torch.zeros(prompt_length, embed_dim)) + nn.init.normal_(self.soft_prompt, std=config.init_std) + + # ------------------------------------------------------------------ + @staticmethod + def _resolve_device(device=None) -> torch.device: + """Resolve a user-supplied device, defaulting to CUDA when available.""" + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + # ------------------------------------------------------------------ + # Visit index tensor -> denoising seq2seq tensors + # ------------------------------------------------------------------ + def _serialize(self, visits: torch.Tensor) -> List[List[int]]: + """Serialize each patient's visits into a flat code stream. + + Layout (decoder target): ``[CODE_PROMPT] codes_v1 [VS] codes_v2 ... + [EOS]``. Index 0 (````) is skipped. + """ + streams: List[List[int]] = [] + for i in range(visits.shape[0]): + n_visits = int((visits[i].sum(dim=-1) > 0).sum().item()) + seq: List[int] = [self.code_prompt_id] + for j in range(n_visits): + codes = [int(c) for c in visits[i, j].tolist() if c > 0] + seq.extend(codes) + if j < n_visits - 1: + seq.append(self.delim_id) + seq.append(self.eos_id) + # Truncate but always keep the trailing EOS. + if len(seq) > self.max_len: + seq = seq[: self.max_len - 1] + [self.eos_id] + streams.append(seq) + return streams + + def _corrupt(self, stream: List[int]) -> List[int]: + """Build the encoder input: BOS + a BART span-infilled copy of the stream. + + Selects random non-overlapping spans inside the stream (excluding the + leading ``[CODE_PROMPT]`` modality marker and the trailing ``[EOS]``) + with lengths drawn from ``Poisson(mean_span_len)`` until roughly + ``mask_prob`` of the stream is covered, then replaces each span with a + single ``[MASK]`` sentinel. Visit separators and code tokens are both + eligible for masking, matching the original BART text-infilling + objective used by PromptEHR. + """ + bos = [self.bos_id] + n = len(stream) + # Inner range excludes the leading [CODE_PROMPT] and trailing [EOS]. + inner_lo, inner_hi = 1, n - 1 + inner_len = inner_hi - inner_lo + if inner_len <= 0: + return bos + list(stream) + + target_masked = int(round(self.mask_prob * inner_len)) + if target_masked <= 0: + return bos + list(stream) + + spans: List[tuple] = [] # (start, end), half-open, in stream coords + masked = 0 + # Cap attempts to avoid pathological loops on tiny / fully-packed streams. + for _ in range(4 * inner_len): + if masked >= target_masked: + break + span_len = max(1, int(np.random.poisson(self.mean_span_len))) + span_len = min(span_len, inner_len) + start = int(np.random.randint(inner_lo, inner_hi - span_len + 1)) + end = start + span_len + if any(start < e and end > s for (s, e) in spans): + continue + spans.append((start, end)) + masked += span_len + + if not spans: + return bos + list(stream) + + spans.sort() + out: List[int] = bos + list(stream[:inner_lo]) + cur = inner_lo + for (s, e) in spans: + out.extend(stream[cur:s]) + out.append(self.mask_id) + cur = e + out.extend(stream[cur:inner_hi]) + out.extend(stream[inner_hi:]) + return out + + def _encode_batch(self, visits: torch.Tensor): + """Convert padded visit indices to encoder inputs and decoder labels. + + Returns: + enc_input_ids: LongTensor ``(batch, L_enc)`` corrupted streams. + enc_attention_mask: LongTensor ``(batch, L_enc)``. + labels: LongTensor ``(batch, L_dec)`` full streams, padding -> -100. + """ + streams = self._serialize(visits) + enc_streams = [self._corrupt(s) for s in streams] + + enc_input_ids = self._pad_stack(enc_streams, self.pad_id) + enc_attention_mask = (enc_input_ids != self.pad_id).long() + # Position 0 is BOS, never masked out by the pad check; force it on. + enc_attention_mask[:, 0] = 1 + + labels = self._pad_stack(streams, self.pad_id) + labels[labels == self.pad_id] = -100 + return enc_input_ids, enc_attention_mask, labels + + def _pad_stack(self, seqs: List[List[int]], pad_value: int) -> torch.Tensor: + """Right-pad a list of int lists into a 2D LongTensor on ``self.device``.""" + length = max(len(s) for s in seqs) + out = torch.full( + (len(seqs), length), pad_value, dtype=torch.long, device=self.device + ) + for i, s in enumerate(seqs): + out[i, : len(s)] = torch.tensor(s, device=self.device) + return out + + def _encoder_inputs_embeds(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): + """Prepend the learnable soft prompt to the encoder token embeddings. + + Returns the prompt-augmented ``inputs_embeds`` and the matching + attention mask (soft-prompt positions are always attended to). + """ + token_embeds = self.bart.get_input_embeddings()(input_ids) + bsz = input_ids.shape[0] + prompt = self.soft_prompt.unsqueeze(0).expand(bsz, -1, -1) + inputs_embeds = torch.cat([prompt, token_embeds], dim=1) + prompt_mask = torch.ones( + bsz, self.prompt_length, dtype=attention_mask.dtype, device=self.device + ) + attention_mask = torch.cat([prompt_mask, attention_mask], dim=1) + return inputs_embeds, attention_mask + + # ------------------------------------------------------------------ + # forward -- required by BaseModel + # ------------------------------------------------------------------ + def forward(self, visits: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass (denoising seq2seq reconstruction). + + Args: + visits: LongTensor ``(batch, max_visits, max_codes_per_visit)`` from + the ``NestedSequenceProcessor``. + **kwargs: Any other batch keys are ignored. + + Returns: + Dict with ``loss`` (scalar seq2seq cross-entropy) and ``y_prob`` + (decoder next-token probabilities, shape ``(batch, L_dec, vocab)``). + """ + visits = visits.to(self.device) + enc_input_ids, enc_attention_mask, labels = self._encode_batch(visits) + inputs_embeds, enc_attention_mask = self._encoder_inputs_embeds( + enc_input_ids, enc_attention_mask + ) + out = self.bart( + inputs_embeds=inputs_embeds, + attention_mask=enc_attention_mask, + labels=labels, + ) + return {"loss": out.loss, "y_prob": F.softmax(out.logits, dim=-1)} + + # ------------------------------------------------------------------ + # Custom training loop + # ------------------------------------------------------------------ + def train_model(self, train_dataset, val_dataset=None, device=None) -> None: + """Train PromptEHR with a custom loop. + + Named ``train_model`` (not ``train``) to avoid shadowing + ``nn.Module.train()``. Uses the standard ``get_dataloader``, an Adam + optimizer, and the BART denoising loss. When ``val_dataset`` is given, + validation loss is computed after each epoch and the best checkpoint is + saved to ``self.save_dir``. + + Args: + train_dataset: ``SampleDataset`` for training. + val_dataset: Optional ``SampleDataset`` for validation. + device: Device to train on, e.g. ``"cuda"``, ``"cuda:1"``, or + ``"cpu"``. If ``None`` (default), uses CUDA when available and + falls back to CPU. + """ + device = self._resolve_device(device) + self.to(device) + print(f"Training on: {device}") + + os.makedirs(self.save_dir, exist_ok=True) + optimizer = torch.optim.Adam(self.parameters(), lr=self._lr) + + checkpoint_path = os.path.join(self.save_dir, "promptehr_model") + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + train_loader = get_dataloader( + train_dataset, batch_size=self._batch_size, shuffle=True + ) + + global_loss = 1e10 + for epoch in tqdm(range(self._epochs), desc="Epochs"): + self.bart.train() + batch_iter = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False) + for batch in batch_iter: + visits = batch["visits"].to(self.device) + + optimizer.zero_grad() + ret = self.forward(visits=visits) + loss = ret["loss"] + loss.backward() + optimizer.step() + batch_iter.set_postfix(loss=f"{loss.item():.4f}") + + if val_dataset is not None: + self.bart.eval() + val_loader = get_dataloader( + val_dataset, batch_size=self._batch_size, shuffle=False + ) + val_losses = [] + with torch.no_grad(): + for val_batch in val_loader: + visits = val_batch["visits"].to(self.device) + val_losses.append(self.forward(visits=visits)["loss"].item()) + + cur_val_loss = float(np.mean(val_losses)) + print(f"Epoch {epoch} Validation Loss: {cur_val_loss:.7f}") + if cur_val_loss < global_loss: + global_loss = cur_val_loss + state = { + "model": self.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + } + torch.save(state, checkpoint_path) + print("------------ Save best model ------------") + + # ------------------------------------------------------------------ + # Synthesis + # ------------------------------------------------------------------ + def _decode_ids(self, ids: List[int], index_to_code: Dict[int, str]) -> List[List[str]]: + """Decode a generated decoder token stream into per-visit code lists.""" + visits_out: List[List[str]] = [] + current: List[str] = [] + for tid in ids: + if tid in (self.bos_id, self.pad_id, self.code_prompt_id, self.mask_id): + continue + if tid == self.eos_id: + break + if tid == self.delim_id: + if current: + visits_out.append(current) + current = [] + continue + if tid < self.code_vocab_size: + code = index_to_code.get(int(tid)) + if code not in (None, "", ""): + current.append(code) + if current: + visits_out.append(current) + return visits_out + + def generate( + self, + num_samples: int, + device=None, + top_k: int = 50, + top_p: float = 0.95, + ) -> List[Dict]: + """Generate synthetic patients with the trained PromptEHR model. + + Feeds the encoder a fully-masked seed stream (so generation is driven by + the learned soft prompt), precomputes the prompt-augmented encoder + states, and autoregressively samples a decoder stream with + ``top_k``/``top_p`` sampling, then decodes it into per-visit code lists. + + Args: + num_samples: Number of synthetic patients to generate. + device: Device to generate on, e.g. ``"cuda"``, ``"cuda:1"``, or + ``"cpu"``. If ``None`` (default), uses CUDA when available and + falls back to CPU. + top_k: Top-k sampling cutoff. Default: 50. + top_p: Nucleus (top-p) sampling cutoff. Default: 0.95. + + Returns: + List of dicts, each ``{"patient_id": "synthetic_i", + "visits": [[code, ...], ...]}`` with decoded code strings. + """ + device = self._resolve_device(device) + self.to(device) + + index_to_code = {v: k for k, v in self.visits_processor.code_vocab.items()} + + self.bart.eval() + synthetic_dataset: List[Dict] = [] + sample_batch_size = min(num_samples, 256) + generated = 0 + pbar = tqdm(total=num_samples, desc="Generating patients") + + # Fully-masked seed: [BOS] [CODE_PROMPT] [MASK] [EOS]. + seed = [self.bos_id, self.code_prompt_id, self.mask_id, self.eos_id] + + with torch.no_grad(): + while generated < num_samples: + bs = min(sample_batch_size, num_samples - generated) + enc_input_ids = torch.tensor( + [seed] * bs, dtype=torch.long, device=self.device + ) + enc_attention_mask = torch.ones_like(enc_input_ids) + inputs_embeds, enc_attention_mask = self._encoder_inputs_embeds( + enc_input_ids, enc_attention_mask + ) + encoder_outputs = self.bart.get_encoder()( + inputs_embeds=inputs_embeds, + attention_mask=enc_attention_mask, + return_dict=True, + ) + out_ids = self.bart.generate( + encoder_outputs=encoder_outputs, + attention_mask=enc_attention_mask, + max_length=self.max_len, + do_sample=True, + top_k=top_k, + top_p=top_p, + num_beams=1, + pad_token_id=self.pad_id, + eos_token_id=self.eos_id, + decoder_start_token_id=self.eos_id, + ) + for i in range(bs): + # BART's generate prepends decoder_start_token_id (= eos_id) + # at position 0; skip it so the real eos terminates decoding. + visits_out = self._decode_ids( + out_ids[i].tolist()[1:], index_to_code + ) + synthetic_dataset.append( + { + "patient_id": f"synthetic_{generated + i}", + "visits": visits_out, + } + ) + generated += bs + pbar.update(bs) + pbar.close() + + return synthetic_dataset diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..2140f23ed 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -45,6 +45,13 @@ from .mortality_prediction_stagenet_mimic4 import ( MortalityPredictionStageNetMIMIC4, ) +from .generate_ehr import ( + EHRGeneration, + EHRGenerationMIMIC3, + EHRGenerationMIMIC4, + decode_dataset, + to_evaluation_dataframe, +) from .patient_linkage import patient_linkage_mimic3_fn from .readmission_prediction import ( ReadmissionPredictionEICU, diff --git a/pyhealth/tasks/generate_ehr.py b/pyhealth/tasks/generate_ehr.py new file mode 100644 index 000000000..6fb23da9d --- /dev/null +++ b/pyhealth/tasks/generate_ehr.py @@ -0,0 +1,253 @@ +"""EHR sequence-generation tasks for PyHealth generative models. + +This is the shared task for every generator in +:mod:`pyhealth.models.generators` (HALO, MedGAN, CorGAN, PromptEHR, ...). It +extracts, for each patient, the ordered list of visits where each visit is the +list of medical codes recorded in that admission. The single input feature +``visits`` is processed by :class:`~pyhealth.processors.NestedSequenceProcessor`; +there is no prediction label, so ``output_schema`` is empty. + +:class:`EHRGeneration` holds all the extraction logic; dataset-specific +subclasses only declare which event type and code attribute to read. + +Evaluating generated data +------------------------- +The privacy/utility metrics in :mod:`pyhealth.metrics.generative` (``utils.py``, +``privacy.py``, ``utility.py`` -- exposed through ``evaluate_synthetic_ehr``) +consume **long-form** dataframes: one row per ``(patient, visit, code)`` with +columns ``id`` / ``time`` / ``visit_codes`` / ``labels``. ``id`` is the patient +identifier, ``time`` the (integer) visit index, ``visit_codes`` a single code +string, and ``labels`` a patient-level binary label (reduced via ``max`` over +the patient's rows). + +Both the real task samples and a generator's ``generate()`` output use the same +``{"visits": [[code, ...], ...]}`` record shape, so +:func:`to_evaluation_dataframe` converts either into that long-form table. A +processed ``SampleDataset`` can be turned back into records with +:func:`decode_dataset`. Subjects are renumbered sequentially (0, 1, 2, ...) in +the ``id`` column -- synthetic patients do not correspond to real ones, so any +original ``patient_id`` is ignored. + +.. code-block:: python + + from pyhealth.tasks.generate_ehr import decode_dataset, to_evaluation_dataframe + from pyhealth.metrics.generative import evaluate_synthetic_ehr + + # Real train/test EHR come from the processed SampleDataset(s): + train_df = to_evaluation_dataframe(decode_dataset(train_dataset)) + test_df = to_evaluation_dataframe(decode_dataset(test_dataset)) + + # Synthetic EHR comes straight from the trained generator (HALO, GPT2, ...): + synthetic = model.generate(num_samples=len(train_dataset)) + syn_df = to_evaluation_dataframe(synthetic) + + # Privacy metrics need no labels: + results = evaluate_synthetic_ehr(train_df, test_df, syn_df, metrics="privacy") + +The **utility** metrics (machine-learning efficacy, next-visit prediction) +additionally require a meaningful binary ``labels`` column. Since this task is +unconditional (no labels), pass a ``label_fn`` to derive one per patient -- e.g. +``label_fn=lambda r: any("250" in c for v in r["visits"] for c in v)`` for a +diabetes flag -- and the same ``label_fn`` must be applied to the real and +synthetic frames. With no label available, restrict to ``metrics="privacy"``. + +Note: + The MLE component currently hard-codes the downstream task to + next-visit prediction, which is degenerate for bag-of-codes + generators (MedGAN, CorGAN) that emit a single aggregate visit per + patient. A future revision will let callers plug in static-label + tasks (e.g. mortality, readmission, "ever diagnosed with X") so MLE + is meaningful for both sequential (HALO, GPT2, PromptEHR) and + bag-of-codes generators. Until then, restrict bag-of-codes + evaluation to ``metrics="privacy"`` plus the prevalence metrics. +""" + +import logging +from typing import Callable, Dict, List, Optional, Type, Union + +from pyhealth.data.data import Patient +from pyhealth.processors import NestedSequenceProcessor + +from .base_task import BaseTask + +logger = logging.getLogger(__name__) + + +class EHRGeneration(BaseTask): + """Generic per-visit code-sequence task for unconditional EHR generators. + + Builds one sample per qualifying patient: the ordered list of visits, each + visit being the list of codes (read from ``code_attr`` on ``event_type`` + events) recorded in that admission. Patients with fewer than ``min_visits`` + qualifying visits are skipped. + + Subclass and override the class attributes for a specific dataset, or set + them on an instance. The defaults read MIMIC-III ICD-9 diagnosis codes. + + Args: + task_name: Name of the task. + input_schema: ``{"visits": NestedSequenceProcessor}``. + output_schema: empty (generative task, no labels). + event_type: Event type to pull per admission. Default + ``"diagnoses_icd"``. + code_attr: Event attribute holding the code string. Default + ``"icd9_code"``. + min_visits: Minimum qualifying visits to keep a patient. Default 2. + """ + + task_name: str = "ehr_generation" + input_schema: Dict[str, Union[str, Type]] = {"visits": NestedSequenceProcessor} + output_schema: Dict[str, Union[str, Type]] = {} + + event_type: str = "diagnoses_icd" + code_attr: str = "icd9_code" + min_visits: int = 2 + + def __call__(self, patient: Patient) -> List[Dict]: + """Extract the per-visit code sequence for a patient.""" + visits: List[List[str]] = [] + admissions = patient.get_events(event_type="admissions") + for admission in admissions: + events = patient.get_events( + event_type=self.event_type, + filters=[("hadm_id", "==", admission.hadm_id)], + ) + codes = [ + getattr(event, self.code_attr) + for event in events + if getattr(event, self.code_attr, None) + ] + if codes: + visits.append(codes) + + if len(visits) < self.min_visits: + return [] + + return [{"patient_id": patient.patient_id, "visits": visits}] + + +class EHRGenerationMIMIC3(EHRGeneration): + """EHR generation task for MIMIC-III (ICD-9 diagnosis codes). + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import EHRGenerationMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd"], + ... ) + >>> samples = dataset.set_task(EHRGenerationMIMIC3()) + """ + + task_name: str = "ehr_generation_mimic3" + event_type: str = "diagnoses_icd" + code_attr: str = "icd9_code" + + +class EHRGenerationMIMIC4(EHRGeneration): + """EHR generation task for MIMIC-IV (ICD diagnosis codes). + + Examples: + >>> from pyhealth.datasets import MIMIC4Dataset + >>> from pyhealth.tasks import EHRGenerationMIMIC4 + >>> dataset = MIMIC4Dataset( + ... ehr_root="/path/to/mimiciv/2.2/", + ... ehr_tables=["patients", "admissions", "diagnoses_icd"], + ... ) + >>> samples = dataset.set_task(EHRGenerationMIMIC4()) + """ + + task_name: str = "ehr_generation_mimic4" + event_type: str = "diagnoses_icd" + code_attr: str = "icd_code" + + +# ---------------------------------------------------------------------------- +# Conversion helpers for pyhealth.metrics.generative.evaluate_synthetic_ehr +# ---------------------------------------------------------------------------- +def to_evaluation_dataframe( + records, + label_fn: Optional[Callable[[Dict], int]] = None, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", +): + """Flatten EHR-generation records into the long-form evaluation dataframe. + + Produces the one-row-per-``(patient, visit, code)`` table consumed by + :func:`pyhealth.metrics.generative.evaluate_synthetic_ehr` (and the + ``utils.py`` / ``privacy.py`` / ``utility.py`` functions beneath it). + + Subjects are numbered **sequentially** (0, 1, 2, ...) in ``subject_col``; + any ``"patient_id"`` on the records is ignored, since synthetic patients do + not correspond to real ones. + + Args: + records: Iterable of ``{"visits": [[code, ...], ...]}`` dicts. Both the + :class:`EHRGeneration` task output and a generator's ``generate()`` + output have this shape. + label_fn: Optional callable mapping a record to a binary patient label + (0/1) used by the utility metrics. Defaults to all-zeros. + subject_col: Output patient-id column. Default ``"id"``. + visit_col: Output visit-index column. Default ``"time"``. + code_col: Output single-code column. Default ``"visit_codes"``. + label_col: Output binary-label column. Default ``"labels"``. + + Returns: + ``pandas.DataFrame`` with columns + ``[subject_col, visit_col, code_col, label_col]``. + """ + import pandas as pd + + rows = [] + for subject_id, record in enumerate(records): + label = 0 if label_fn is None else int(label_fn(record)) + for visit_idx, visit in enumerate(record["visits"]): + for code in visit: + rows.append( + { + subject_col: subject_id, + visit_col: visit_idx, + code_col: code, + label_col: label, + } + ) + return pd.DataFrame( + rows, columns=[subject_col, visit_col, code_col, label_col] + ) + + +def decode_dataset(sample_dataset, feature_key: str = "visits") -> List[Dict]: + """Decode a processed EHRGeneration ``SampleDataset`` back into records. + + Inverts the :class:`~pyhealth.processors.NestedSequenceProcessor` encoding + using its vocabulary (skipping ````/````), yielding one + ``{"visits": [[code_str, ...], ...]}`` record per sample. Use this to build + the real train/test frames that ``evaluate_synthetic_ehr`` compares against. + + Args: + sample_dataset: A ``SampleDataset`` produced by :class:`EHRGeneration`. + feature_key: Input feature key holding the nested code sequence. + Default ``"visits"``. + + Returns: + List of ``{"visits": [[code_str, ...], ...]}`` records. + """ + processor = sample_dataset.input_processors[feature_key] + index_to_code = {idx: code for code, idx in processor.code_vocab.items()} + + records: List[Dict] = [] + for i in range(len(sample_dataset)): + sample = sample_dataset[i] + visits: List[List[str]] = [] + for row in sample[feature_key].tolist(): + codes = [ + index_to_code[int(idx)] + for idx in row + if index_to_code.get(int(idx)) not in (None, "", "") + ] + if codes: + visits.append(codes) + records.append({"visits": visits}) + return records diff --git a/tests/core/test_corgan.py b/tests/core/test_corgan.py new file mode 100644 index 000000000..7c5f54963 --- /dev/null +++ b/tests/core/test_corgan.py @@ -0,0 +1,228 @@ +import tempfile +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset +from pyhealth.models import CorGAN + + +class TestCorGAN(unittest.TestCase): + """Test cases for the CorGAN synthetic-EHR generator.""" + + def setUp(self): + """Bag-of-codes generative dataset (no labels) and a tiny model.""" + self.samples = [ + {"patient_id": "patient-0", "visits": ["A05B", "A05C", "A11D", "C129"]}, + {"patient_id": "patient-1", "visits": ["A05B", "A04A", "B035"]}, + {"patient_id": "patient-2", "visits": ["C129", "A11D", "A05C", "A04A"]}, + {"patient_id": "patient-3", "visits": ["B035", "A05B", "C129"]}, + ] + self.input_schema = {"visits": "multi_hot"} + self.output_schema = {} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_corgan", + ) + + # Small vocab -> linear autoencoder is auto-selected. + self.model = CorGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + n_iter_D=1, + ) + + def test_model_initialization(self): + """Vocab is derived from MultiHotProcessor; generation is unconditional.""" + self.assertIsInstance(self.model, CorGAN) + self.assertEqual(self.model.feature_keys, ["visits"]) + self.assertEqual(self.model.label_keys, []) + + proc_vocab = self.dataset.input_processors["visits"].size() + self.assertEqual(self.model.input_dim, proc_vocab) + + def test_small_vocab_falls_back_to_linear(self): + """For tiny vocabs the 6-layer CNN can't compress -- expect linear AE.""" + self.assertEqual(self.model.autoencoder_type, "linear") + + def test_components_present(self): + """Autoencoder, generator, and critic are all registered submodules.""" + self.assertTrue(hasattr(self.model, "autoencoder")) + self.assertTrue(hasattr(self.model, "generator")) + self.assertTrue(hasattr(self.model, "critic")) + + # Critic is a Wasserstein critic -> unbounded scalar (no sigmoid). + x = torch.zeros(2, self.model.input_dim) + with torch.no_grad(): + c_out = self.model.critic(x) + self.assertEqual(c_out.shape, (2, 1)) + + def test_forward_raises(self): + """CorGAN's BaseModel forward intentionally errors out.""" + with self.assertRaises(NotImplementedError): + self.model.forward() + + def test_train_model_runs(self): + """train_model completes a tiny two-phase loop on CPU and returns history.""" + with tempfile.TemporaryDirectory() as tmp: + model = CorGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + n_iter_D=1, + save_dir=tmp, + ) + history = model.train_model(self.dataset, device="cpu") + + self.assertIn("autoencoder_loss", history) + self.assertIn("critic_loss", history) + self.assertIn("generator_loss", history) + self.assertEqual(len(history["autoencoder_loss"]), 1) + self.assertEqual(len(history["critic_loss"]), 1) + self.assertEqual(len(history["generator_loss"]), 1) + self.assertEqual(next(model.parameters()).device.type, "cpu") + + def test_generate(self): + """generate() returns the requested number of decoded synthetic patients. + + CorGAN is a bag-of-codes model, so each patient gets a single visit + containing the aggregate set of codes; the outer ``visits`` list + wraps that single visit to match HALO's nested format. + """ + synthetic = self.model.generate(num_samples=4, device="cpu") + self.assertEqual(len(synthetic), 4) + for i, patient in enumerate(synthetic): + self.assertEqual(patient["patient_id"], f"synthetic_{i}") + self.assertIsInstance(patient["visits"], list) + # Exactly one aggregate visit per patient. + self.assertEqual(len(patient["visits"]), 1) + visit = patient["visits"][0] + self.assertIsInstance(visit, list) + for code in visit: + self.assertIsInstance(code, str) + self.assertNotIn(code, ("", "")) + + def test_generate_random_sampling(self): + """random_sampling=True still produces well-formed patients.""" + synthetic = self.model.generate( + num_samples=3, random_sampling=True, device="cpu" + ) + self.assertEqual(len(synthetic), 3) + for patient in synthetic: + self.assertIn("patient_id", patient) + self.assertIn("visits", patient) + + def test_save_and_load_roundtrip(self): + """save_model + load_model preserves weights and vocabulary.""" + with tempfile.TemporaryDirectory() as tmp: + path = f"{tmp}/corgan.pt" + self.model.save_model(path) + + other = CorGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + n_iter_D=1, + save_dir=tmp, + ) + other.load_model(path) + + for p1, p2 in zip( + self.model.generator.parameters(), + other.generator.parameters(), + ): + self.assertTrue(torch.allclose(p1, p2)) + self.assertEqual(other._idx_to_code, self.model._idx_to_code) + + def test_missing_visits_processor_raises(self): + """A dataset without a 'visits' feature should be rejected.""" + bad = create_sample_dataset( + samples=[ + {"patient_id": "p1", "codes": ["A", "B"]}, + {"patient_id": "p2", "codes": ["B"]}, + ], + input_schema={"codes": "multi_hot"}, + output_schema={}, + ) + with self.assertRaises(ValueError): + CorGAN(bad, latent_dim=8, hidden_dim=8) + + def test_unknown_autoencoder_type_raises(self): + """An unknown autoencoder_type is rejected.""" + with self.assertRaises(ValueError): + CorGAN(self.dataset, autoencoder_type="rnn") + + def test_cnn_path_with_large_vocab(self): + """A vocabulary big enough survives the 6-layer CNN chain end-to-end.""" + all_codes = [f"C{i:04d}" for i in range(1200)] + samples = [ + { + "patient_id": f"p{i}", + "visits": all_codes[i * 200 : (i + 1) * 200] + + all_codes[1100:1200], + } + for i in range(6) + ] + dataset = create_sample_dataset( + samples=samples, + input_schema={"visits": "multi_hot"}, + output_schema={}, + dataset_name="cnn_corgan", + ) + with tempfile.TemporaryDirectory() as tmp: + model = CorGAN( + dataset=dataset, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + n_iter_D=1, + save_dir=tmp, + ) + self.assertEqual(model.autoencoder_type, "cnn") + # CNN bottleneck is fixed at 128 -- generator should track it. + self.assertEqual(model.latent_dim, 128) + self.assertEqual(model.hidden_dim, 128) + model.train_model(dataset, device="cpu") + out = model.generate(num_samples=2, device="cpu") + self.assertEqual(len(out), 2) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_train_and_generate_on_cuda(self): + """When CUDA is available, the device arg moves training to GPU.""" + with tempfile.TemporaryDirectory() as tmp: + model = CorGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + n_iter_D=1, + save_dir=tmp, + ) + model.train_model(self.dataset, device="cuda") + self.assertTrue(next(model.parameters()).is_cuda) + + synthetic = model.generate(num_samples=2, device="cuda") + self.assertEqual(len(synthetic), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_generative_metrics.py b/tests/core/test_generative_metrics.py new file mode 100644 index 000000000..7926d2b88 --- /dev/null +++ b/tests/core/test_generative_metrics.py @@ -0,0 +1,465 @@ +"""Unit tests for pyhealth.metrics.generative (synthetic-EHR metrics). + +Run with:: + + python -m unittest tests.core.test_generative_metrics -v +""" + +import unittest + +import numpy as np +import pandas as pd + +from pyhealth.metrics.generative import ( + calc_membership_inference, + calc_nnaar, + compute_discriminator_privacy, + compute_mle, + compute_prevalence_metrics, + evaluate_synthetic_ehr, +) +from pyhealth.metrics.generative.utils import ( + convert_cols_to_multihot, + train_lstm_model, + train_sklearn_model, +) + +SUBJECT_COL, VISIT_COL, CODE_COL, LABEL_COL = "id", "time", "visit_codes", "labels" + + +def _make_dataframes(): + """Builds small synthetic train/test/synthetic EHR dataframes.""" + train_ehr = pd.DataFrame( + { + "visit_codes": [0, 1, 3, 4, 1, 2, 0, 3, 2, 4, 1, 0, 2, 3, 4, + 1, 0, 2, 3, 4, 1, 0, 2, 3, 4], + "labels": [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, + 1, 0, 0, 1, 0, 1, 0, 0, 1, 0], + "time": [0, 0, 1, 1, 0, 1, 2, 2, 3, 3, 1, 2, 3, 4, 4, + 0, 1, 2, 3, 4, 1, 2, 3, 4, 5], + "id": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4], + } + ).astype({"visit_codes": str, "labels": int, "time": int, "id": str}) + + test_ehr = pd.DataFrame( + { + "visit_codes": [1, 2, 0, 3, 4, 2, 1, 0, 3, 4, 1, 2, 3, 0, 4, + 2, 1, 3, 0, 4], + "labels": [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, + 0, 0, 1, 0, 1], + "time": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 0, 1, 1, 2, 2, + 3, 3, 3, 4, 4], + "id": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2], + } + ).astype({"visit_codes": str, "labels": int, "time": int, "id": str}) + + syn_ehr = pd.DataFrame( + { + "visit_codes": [2, 3, 1, 4, 0, 2, 3, 1, 0, 4, 1, 2, 3, 4, 0, + 2, 1, 3, 4, 0, 2, 1, 3, 4, 0, 1, 2, 3, 4, 0], + "labels": [0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, + 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1], + "time": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5], + "id": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], + } + ).astype({"visit_codes": str, "labels": int, "time": int, "id": str}) + + return train_ehr, test_ehr, syn_ehr + + +def _generate_ehr( + n_patients, vocab, seed, id_offset=0, + n_visits_range=(2, 7), n_codes_range=(2, 6), +): + """Generates a random EHR dataframe with patients drawn from ``vocab``.""" + rng = np.random.default_rng(seed) + rows = [] + for i in range(n_patients): + pid = str(id_offset + i) + n_visits = int(rng.integers(*n_visits_range)) + label = int(rng.integers(0, 2)) + for t in range(n_visits): + n_codes = int(rng.integers(*n_codes_range)) + codes = rng.choice( + vocab, size=min(n_codes, len(vocab)), replace=False + ) + for code in codes: + rows.append( + {"id": pid, "time": t, + "visit_codes": str(code), "labels": label} + ) + return pd.DataFrame(rows).astype( + {"visit_codes": str, "labels": int, "time": int, "id": str} + ) + + +def _perturb_ehr(df, vocab, frac, seed): + """Returns a copy of ``df`` with a fraction of codes randomly replaced.""" + rng = np.random.default_rng(seed) + df = df.copy().reset_index(drop=True) + mask = rng.random(len(df)) < frac + new_codes = rng.choice(vocab, size=int(mask.sum())) + df.loc[mask, "visit_codes"] = [str(c) for c in new_codes] + return df + + +class GenerativeMetricsTestCase(unittest.TestCase): + """Shared fixtures and assertion helpers for the generative metrics.""" + + def setUp(self): + np.random.seed(0) + self.train_ehr, self.test_ehr, self.syn_ehr = _make_dataframes() + self.cols = dict( + subject_col=SUBJECT_COL, + visit_col=VISIT_COL, + code_col=CODE_COL, + label_col=LABEL_COL, + ) + + def assertSummary(self, summary, expected_keys): + """Asserts a metrics summary has the expected (mean, std) structure.""" + self.assertIsInstance(summary, dict) + for key in expected_keys: + self.assertIn(key, summary) + value = summary[key] + self.assertIsInstance(value, tuple) + self.assertEqual(len(value), 2) + mean, std = value + self.assertTrue(np.isfinite(mean), f"{key} mean not finite") + self.assertTrue(np.isfinite(std), f"{key} std not finite") + self.assertGreaterEqual(std, 0.0) + + +class TestNNAAR(GenerativeMetricsTestCase): + def test_calc_nnaar(self): + summary = calc_nnaar( + self.train_ehr, self.test_ehr, self.syn_ehr, + **self.cols, sample_size=10, n_runs=3, + ) + self.assertSummary(summary, ["nnaar", "aa_es", "aa_ts"]) + for key in ("aa_es", "aa_ts"): + self.assertGreaterEqual(summary[key][0], 0.0) + self.assertLessEqual(summary[key][0], 1.0) + self.assertGreaterEqual(summary["nnaar"][0], -1.0) + self.assertLessEqual(summary["nnaar"][0], 1.0) + + +class TestMembershipInference(GenerativeMetricsTestCase): + def test_calc_membership_inference(self): + summary = calc_membership_inference( + self.train_ehr, self.test_ehr, self.syn_ehr, + **self.cols, num_attack_samples=10, n_runs=3, + ) + keys = ["MIA_F1", "MIA_Precision", "MIA_Recall", "MIA_Accuracy"] + self.assertSummary(summary, keys) + for key in keys: + self.assertGreaterEqual(summary[key][0], 0.0) + self.assertLessEqual(summary[key][0], 1.0) + + +class TestDiscriminatorPrivacy(GenerativeMetricsTestCase): + def test_discriminator_privacy_lstm(self): + summary = compute_discriminator_privacy( + train_fn=train_lstm_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, + embed_dim=8, hidden_dim=8, batch_size=8, epochs=2, verbose=False, + ) + keys = ["Privacy_Discriminator_Accuracy", "Privacy_Score"] + self.assertSummary(summary, keys) + self.assertGreaterEqual(summary["Privacy_Score"][0], 0.0) + self.assertLessEqual(summary["Privacy_Score"][0], 1.0) + + def test_discriminator_privacy_rf(self): + summary = compute_discriminator_privacy( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, model="rf", + ) + self.assertSummary( + summary, ["Privacy_Discriminator_Accuracy", "Privacy_Score"] + ) + + +class TestMLE(GenerativeMetricsTestCase): + def test_compute_mle_lstm(self): + summary = compute_mle( + train_fn=train_lstm_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, + embed_dim=8, hidden_dim=8, batch_size=8, epochs=2, verbose=False, + ) + keys = [ + "MLE_Real_Accuracy", "MLE_Synth_Accuracy", "MLE_Difference", + "MLE_Ratio", "MLE_Real_F1", "MLE_Synth_F1", + ] + self.assertSummary(summary, keys) + for key in ("MLE_Real_Accuracy", "MLE_Synth_Accuracy"): + self.assertGreaterEqual(summary[key][0], 0.0) + self.assertLessEqual(summary[key][0], 1.0) + + def test_compute_mle_rf(self): + summary = compute_mle( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, model="rf", + ) + self.assertSummary(summary, ["MLE_Real_Accuracy", "MLE_Synth_Accuracy"]) + + +class TestPrevalenceMetrics(GenerativeMetricsTestCase): + def test_compute_prevalence_metrics(self): + summary = compute_prevalence_metrics( + self.train_ehr, self.syn_ehr, + subject_col=SUBJECT_COL, code_col=CODE_COL, n_bootstraps=3, + ) + keys = ["Prevalence_R2", "Prevalence_Pearson", "Prevalence_RMSE"] + self.assertSummary(summary, keys) + self.assertGreaterEqual(summary["Prevalence_Pearson"][0], -1.0) + self.assertLessEqual(summary["Prevalence_Pearson"][0], 1.0) + self.assertGreaterEqual(summary["Prevalence_RMSE"][0], 0.0) + + +class TestConvertColsToMultihot(GenerativeMetricsTestCase): + def test_convert_cols_to_multihot(self): + df = self.train_ehr.copy() + df["gender"] = ["M", "F"] * 12 + ["M"] + df["age"] = np.arange(len(df), dtype=float) + out = convert_cols_to_multihot( + df, code_col=CODE_COL, visit_col=VISIT_COL, + cat_cols=["gender"], num_cols=["age"], bins_per_num=2, + ) + self.assertIn("combined_codes", out.columns) + self.assertEqual(len(out), len(df)) + # Each combined code should fold in the code, the category and the bin. + first = out["combined_codes"].iloc[0] + self.assertIn("gender_", first) + self.assertIn("age_", first) + # The original dataframe must not be mutated. + self.assertNotIn("combined_codes", df.columns) + + +class TestEvaluateSyntheticEHR(GenerativeMetricsTestCase): + def test_evaluate_all_lstm(self): + out = evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + sample_size=10, mode="lstm", metrics="all", + lstm_params={"embed_dim": 8, "hidden_dim": 8, + "batch_size": 8, "epochs": 2}, + n_bootstraps=3, n_runs=3, + ) + for key in ("nnaar", "MIA_F1", "MLE_Real_Accuracy", + "Privacy_Score", "Prevalence_RMSE"): + self.assertIn(key, out) + + def test_evaluate_privacy_only_rf(self): + out = evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + sample_size=10, mode="rf", metrics="privacy", + n_bootstraps=3, n_runs=3, + ) + self.assertIn("nnaar", out) + self.assertNotIn("MLE_Real_Accuracy", out) + + def test_evaluate_utility_only_rf(self): + out = evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + mode="rf", metrics="utility", n_bootstraps=3, + ) + self.assertIn("MLE_Real_Accuracy", out) + self.assertNotIn("nnaar", out) + + def test_invalid_mode_raises(self): + with self.assertRaises(ValueError): + evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + mode="bad", + ) + + def test_invalid_metrics_raises(self): + with self.assertRaises(ValueError): + evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + metrics="bad", + ) + + +class TestMetricsBehavior(unittest.TestCase): + """Sanity checks: metrics should respond to how close synthetic data is. + + Three synthetic datasets are compared against the same real data: + + - ``exact``: an exact copy of the real training data, + - ``similar``: the training data with ~15% of codes randomly changed, + - ``different``: independent data over a disjoint code vocabulary. + + A well-behaved metric should rank these consistently (e.g. an exact copy + is the worst case for privacy and the best case for fidelity). + """ + + VOCAB_REAL = list(range(50)) + VOCAB_DIFF = list(range(100, 150)) + + @classmethod + def setUpClass(cls): + cls.train_ehr = _generate_ehr(60, cls.VOCAB_REAL, seed=1, id_offset=0) + cls.test_ehr = _generate_ehr( + 60, cls.VOCAB_REAL, seed=2, id_offset=10000 + ) + cls.syn_exact = cls.train_ehr.copy() + cls.syn_similar = _perturb_ehr( + cls.train_ehr, cls.VOCAB_REAL, frac=0.15, seed=3 + ) + cls.syn_different = _generate_ehr( + 60, cls.VOCAB_DIFF, seed=4, id_offset=20000 + ) + cls.cols = dict( + subject_col=SUBJECT_COL, + visit_col=VISIT_COL, + code_col=CODE_COL, + label_col=LABEL_COL, + ) + + def test_prevalence_orders_by_similarity(self): + # Prevalence similarity should degrade monotonically: exact > similar + # > different. + results = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + results[name] = compute_prevalence_metrics( + self.train_ehr, syn, + subject_col=SUBJECT_COL, code_col=CODE_COL, n_bootstraps=10, + ) + + rmse = {k: v["Prevalence_RMSE"][0] for k, v in results.items()} + r2 = {k: v["Prevalence_R2"][0] for k, v in results.items()} + pearson = {k: v["Prevalence_Pearson"][0] for k, v in results.items()} + + # An exact copy has identical code prevalence. + self.assertAlmostEqual(rmse["exact"], 0.0, places=9) + self.assertAlmostEqual(r2["exact"], 1.0, places=6) + self.assertAlmostEqual(pearson["exact"], 1.0, places=6) + + # Error grows / agreement shrinks as synthetic data drifts away. + self.assertLess(rmse["exact"], rmse["similar"]) + self.assertLess(rmse["similar"], rmse["different"]) + self.assertGreater(r2["exact"], r2["similar"]) + self.assertGreater(r2["similar"], r2["different"]) + self.assertGreaterEqual(pearson["exact"], pearson["similar"]) + self.assertGreater(pearson["similar"], pearson["different"]) + + def test_nnaar_flags_exact_copies(self): + # NNAAR should be high when synthetic data memorizes the training set + # and near zero otherwise. + nnaar = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + nnaar[name] = calc_nnaar( + self.train_ehr, self.test_ehr, syn, + **self.cols, sample_size=1000, n_runs=3, + )["nnaar"][0] + + self.assertGreater(nnaar["exact"], 0.5) + self.assertGreater(nnaar["exact"], nnaar["similar"]) + self.assertGreater(nnaar["exact"], nnaar["different"]) + self.assertLess(nnaar["similar"], 0.3) + self.assertLess(nnaar["different"], 0.3) + + def test_membership_inference_detects_training_data(self): + # The attack should succeed when synthetic data is derived from the + # training set and be near chance when it is unrelated. + acc = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + acc[name] = calc_membership_inference( + self.train_ehr, self.test_ehr, syn, + **self.cols, num_attack_samples=1000, n_runs=5, + )["MIA_Accuracy"][0] + + self.assertGreater(acc["exact"], 0.8) + self.assertGreater(acc["exact"], acc["different"]) + self.assertGreater(acc["similar"], acc["different"]) + self.assertLess(acc["different"], 0.7) + + def test_discriminator_privacy_orders_by_similarity(self): + # A discriminator easily separates a disjoint-vocabulary synthetic set + # (accuracy ~1, privacy score ~0) but not data derived from the real + # data (lower accuracy, higher privacy score). + score, acc = {}, {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + result = compute_discriminator_privacy( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=syn, **self.cols, n_bootstraps=10, model="rf", + ) + score[name] = result["Privacy_Score"][0] + acc[name] = result["Privacy_Discriminator_Accuracy"][0] + + # The disjoint-vocabulary set is trivially detected. + self.assertGreater(acc["different"], 0.8) + self.assertLess(score["different"], 0.1) + # Data derived from the real data is harder to flag. + self.assertGreater(acc["different"], acc["exact"]) + self.assertGreater(acc["different"], acc["similar"]) + self.assertGreater(score["exact"], score["different"]) + self.assertGreater(score["similar"], score["different"]) + + def test_mle_orders_by_similarity(self): + # Utility should be highest for an exact copy and degrade as the + # synthetic data drifts away from the real data. + mle = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + mle[name] = compute_mle( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=syn, **self.cols, n_bootstraps=10, model="rf", + ) + + # An exact copy reproduces real utility exactly. + exact = mle["exact"] + self.assertAlmostEqual(exact["MLE_Difference"][0], 0.0, places=9) + self.assertAlmostEqual(exact["MLE_Difference"][1], 0.0, places=9) + self.assertAlmostEqual(exact["MLE_Ratio"][0], 1.0, places=9) + self.assertAlmostEqual( + exact["MLE_Synth_Accuracy"][0], exact["MLE_Real_Accuracy"][0], + places=9, + ) + + # Synthetic-trained accuracy degrades monotonically. + diff = {k: abs(v["MLE_Difference"][0]) for k, v in mle.items()} + ratio = {k: v["MLE_Ratio"][0] for k, v in mle.items()} + self.assertLessEqual(diff["exact"], diff["similar"]) + self.assertLess(diff["similar"], diff["different"]) + self.assertGreaterEqual(ratio["exact"], ratio["similar"]) + self.assertGreater(ratio["similar"], ratio["different"]) + self.assertLess(ratio["different"], 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_gpt2.py b/tests/core/test_gpt2.py new file mode 100644 index 000000000..4e9519214 --- /dev/null +++ b/tests/core/test_gpt2.py @@ -0,0 +1,151 @@ +import tempfile +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import GPT2 + + +class TestGPT2(unittest.TestCase): + """Test cases for the GPT-2 baseline synthetic-EHR generator.""" + + def setUp(self): + """Set up a synthetic generative dataset (no labels) and a tiny model.""" + self.samples = [ + {"patient_id": "patient-0", "visits": [["A05B", "A05C"], ["A11D"], ["C129"]]}, + {"patient_id": "patient-1", "visits": [["A05B"], ["A04A", "B035"]]}, + {"patient_id": "patient-2", "visits": [["C129", "A11D"], ["A05C"], ["A04A"]]}, + {"patient_id": "patient-3", "visits": [["B035"], ["A05B", "C129"]]}, + ] + + # Generative task: one nested-sequence input feature, no output labels. + self.input_schema = {"visits": "nested_sequence"} + self.output_schema = {} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_gpt2", + ) + + # Small model; embed_dim must be divisible by n_heads. + self.model = GPT2( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + max_len=64, + batch_size=2, + epochs=1, + ) + + def test_model_initialization(self): + """Vocab/special-token ids are derived from the processor.""" + self.assertIsInstance(self.model, GPT2) + self.assertEqual(self.model.feature_keys, ["visits"]) + self.assertEqual(self.model.label_keys, []) + + proc_vocab = self.dataset.input_processors["visits"].vocab_size() + self.assertEqual(self.model.code_vocab_size, proc_vocab) + self.assertEqual(self.model.bos_id, proc_vocab) + self.assertEqual(self.model.eos_id, proc_vocab + 1) + self.assertEqual(self.model.delim_id, proc_vocab + 2) + self.assertEqual(self.model.gpt2.config.vocab_size, proc_vocab + 3) + + def test_forward_input_format(self): + """The standard dataloader pads the visit dimension into a 3D tensor.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + self.assertIsInstance(batch["visits"], torch.Tensor) + self.assertEqual(batch["visits"].dim(), 3) # (B, max_visits, max_codes) + + def test_model_forward(self): + """Forward returns a finite scalar loss and a probability tensor.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertEqual(ret["loss"].dim(), 0) + self.assertTrue(torch.isfinite(ret["loss"]).all()) + # y_prob: (B, L, vocab_size) + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_prob"].shape[2], self.model.gpt2.config.vocab_size) + + def test_model_backward(self): + """Backward populates gradients on model parameters.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward") + + def test_generate(self): + """generate() returns the requested number of decoded synthetic patients.""" + synthetic = self.model.generate(num_samples=4) + + self.assertEqual(len(synthetic), 4) + for i, patient in enumerate(synthetic): + self.assertEqual(patient["patient_id"], f"synthetic_{i}") + self.assertIsInstance(patient["visits"], list) + for visit in patient["visits"]: + self.assertIsInstance(visit, list) + for code in visit: + self.assertIsInstance(code, str) + self.assertNotIn(code, ("", "")) + + def _build_model(self, save_dir): + return GPT2( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + max_len=64, + batch_size=2, + epochs=1, + save_dir=save_dir, + ) + + def test_train_and_generate_accept_device_arg(self): + """train_model/generate accept an explicit device arg (CPU always works).""" + with tempfile.TemporaryDirectory() as tmp: + model = self._build_model(tmp) + model.train_model(self.dataset, device="cpu") + self.assertEqual(next(model.parameters()).device.type, "cpu") + + synthetic = model.generate(num_samples=2, device="cpu") + self.assertEqual(len(synthetic), 2) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_train_and_generate_on_cuda(self): + """When CUDA is available, the device arg moves training/generation to GPU.""" + with tempfile.TemporaryDirectory() as tmp: + model = self._build_model(tmp) + model.train_model(self.dataset, device="cuda") + self.assertTrue(next(model.parameters()).is_cuda) + + # forward should now run on CUDA without an explicit move. + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + with torch.no_grad(): + ret = model(**batch) + self.assertTrue(ret["y_prob"].is_cuda) + self.assertTrue(torch.isfinite(ret["loss"]).all()) + + synthetic = model.generate(num_samples=2, device="cuda") + self.assertEqual(len(synthetic), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_halo.py b/tests/core/test_halo.py new file mode 100644 index 000000000..be9788d62 --- /dev/null +++ b/tests/core/test_halo.py @@ -0,0 +1,152 @@ +import tempfile +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import HALO + + +class TestHALO(unittest.TestCase): + """Test cases for the HALO synthetic-EHR generator.""" + + def setUp(self): + """Set up a synthetic generative dataset (no labels) and a tiny model.""" + self.samples = [ + {"patient_id": "patient-0", "visits": [["A05B", "A05C"], ["A11D"], ["C129"]]}, + {"patient_id": "patient-1", "visits": [["A05B"], ["A04A", "B035"]]}, + {"patient_id": "patient-2", "visits": [["C129", "A11D"], ["A05C"], ["A04A"]]}, + {"patient_id": "patient-3", "visits": [["B035"], ["A05B", "C129"]]}, + ] + + # Generative task: one nested-sequence input feature, no output labels. + self.input_schema = {"visits": "nested_sequence"} + self.output_schema = {} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_halo", + ) + + # Small model; embed_dim must be divisible by n_heads. + self.model = HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=8, + batch_size=2, + epochs=1, + ) + + def test_model_initialization(self): + """Vocab sizes are derived from the processor; generation is unconditional.""" + self.assertIsInstance(self.model, HALO) + self.assertEqual(self.model.feature_keys, ["visits"]) + self.assertEqual(self.model.label_keys, []) + + proc_vocab = self.dataset.input_processors["visits"].vocab_size() + self.assertEqual(self.model.config.code_vocab_size, proc_vocab) + self.assertEqual(self.model.config.label_vocab_size, 0) + self.assertEqual(self.model.config.total_vocab_size, proc_vocab + 3) + + def test_forward_input_format(self): + """The standard dataloader pads the visit dimension into a 3D tensor.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + self.assertIsInstance(batch["visits"], torch.Tensor) + self.assertEqual(batch["visits"].dim(), 3) # (B, max_visits, max_codes) + + def test_model_forward(self): + """Forward returns a finite scalar loss and a probability tensor.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertEqual(ret["loss"].dim(), 0) + self.assertTrue(torch.isfinite(ret["loss"]).all()) + # y_prob: (B, n_ctx - 1, total_vocab_size) + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_prob"].shape[1], self.model.config.n_ctx - 1) + self.assertEqual( + ret["y_prob"].shape[2], self.model.config.total_vocab_size + ) + + def test_model_backward(self): + """Backward populates gradients on model parameters.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward") + + def test_generate(self): + """generate() returns the requested number of decoded synthetic patients.""" + synthetic = self.model.generate(num_samples=4, random_sampling=True) + + self.assertEqual(len(synthetic), 4) + for i, patient in enumerate(synthetic): + self.assertEqual(patient["patient_id"], f"synthetic_{i}") + self.assertIsInstance(patient["visits"], list) + for visit in patient["visits"]: + self.assertIsInstance(visit, list) + for code in visit: + self.assertIsInstance(code, str) + self.assertNotIn(code, ("", "")) + + def _build_model(self, save_dir): + return HALO( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + n_ctx=8, + batch_size=2, + epochs=1, + save_dir=save_dir, + ) + + def test_train_and_generate_accept_device_arg(self): + """train_model/generate accept an explicit device arg (CPU always works).""" + with tempfile.TemporaryDirectory() as tmp: + model = self._build_model(tmp) + model.train_model(self.dataset, device="cpu") + self.assertEqual(next(model.parameters()).device.type, "cpu") + + synthetic = model.generate(num_samples=2, device="cpu") + self.assertEqual(len(synthetic), 2) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_train_and_generate_on_cuda(self): + """When CUDA is available, the device arg moves training/generation to GPU.""" + with tempfile.TemporaryDirectory() as tmp: + model = self._build_model(tmp) + model.train_model(self.dataset, device="cuda") + self.assertTrue(next(model.parameters()).is_cuda) + + # forward should now run on CUDA without an explicit move. + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + with torch.no_grad(): + ret = model(**batch) + self.assertTrue(ret["y_prob"].is_cuda) + self.assertTrue(torch.isfinite(ret["loss"]).all()) + + synthetic = model.generate(num_samples=2, device="cuda") + self.assertEqual(len(synthetic), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_medgan.py b/tests/core/test_medgan.py new file mode 100644 index 000000000..f55ec9692 --- /dev/null +++ b/tests/core/test_medgan.py @@ -0,0 +1,176 @@ +import tempfile +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset +from pyhealth.models import MedGAN + + +class TestMedGAN(unittest.TestCase): + """Test cases for the MedGAN synthetic-EHR generator.""" + + def setUp(self): + """Bag-of-codes generative dataset (no labels) and a tiny model.""" + self.samples = [ + {"patient_id": "patient-0", "visits": ["A05B", "A05C", "A11D", "C129"]}, + {"patient_id": "patient-1", "visits": ["A05B", "A04A", "B035"]}, + {"patient_id": "patient-2", "visits": ["C129", "A11D", "A05C", "A04A"]}, + {"patient_id": "patient-3", "visits": ["B035", "A05B", "C129"]}, + ] + self.input_schema = {"visits": "multi_hot"} + self.output_schema = {} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_medgan", + ) + + self.model = MedGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + ) + + def test_model_initialization(self): + """Vocab size is derived from MultiHotProcessor; generation is unconditional.""" + self.assertIsInstance(self.model, MedGAN) + self.assertEqual(self.model.feature_keys, ["visits"]) + self.assertEqual(self.model.label_keys, []) + + proc_vocab = self.dataset.input_processors["visits"].size() + self.assertEqual(self.model.input_dim, proc_vocab) + + def test_components_present(self): + """Autoencoder, generator, and discriminator are all registered submodules.""" + self.assertTrue(hasattr(self.model, "autoencoder")) + self.assertTrue(hasattr(self.model, "generator")) + self.assertTrue(hasattr(self.model, "discriminator")) + + # Discriminator output is a probability (sigmoid). + x = torch.zeros(2, self.model.input_dim) + with torch.no_grad(): + d_out = self.model.discriminator(x) + self.assertEqual(d_out.shape, (2, 1)) + self.assertTrue(((d_out >= 0) & (d_out <= 1)).all()) + + def test_forward_raises(self): + """MedGAN's BaseModel forward intentionally errors out.""" + with self.assertRaises(NotImplementedError): + self.model.forward() + + def test_train_model_runs(self): + """train_model completes a tiny two-phase loop on CPU.""" + with tempfile.TemporaryDirectory() as tmp: + model = MedGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + save_dir=tmp, + ) + model.train_model(self.dataset, device="cpu") + self.assertEqual(next(model.parameters()).device.type, "cpu") + + def test_generate(self): + """generate() returns the requested number of decoded synthetic patients. + + MedGAN is a bag-of-codes model, so each patient gets a single visit + containing the aggregate set of codes; the outer ``visits`` list + wraps that single visit to match HALO's nested format. + """ + synthetic = self.model.generate(num_samples=4, device="cpu") + self.assertEqual(len(synthetic), 4) + for i, patient in enumerate(synthetic): + self.assertEqual(patient["patient_id"], f"synthetic_{i}") + self.assertIsInstance(patient["visits"], list) + # Exactly one aggregate visit per patient. + self.assertEqual(len(patient["visits"]), 1) + visit = patient["visits"][0] + self.assertIsInstance(visit, list) + for code in visit: + self.assertIsInstance(code, str) + self.assertNotIn(code, ("", "")) + + def test_generate_random_sampling(self): + """random_sampling=True still produces well-formed patients.""" + synthetic = self.model.generate( + num_samples=3, random_sampling=True, device="cpu" + ) + self.assertEqual(len(synthetic), 3) + for patient in synthetic: + self.assertIn("patient_id", patient) + self.assertIn("visits", patient) + + def test_save_and_load_roundtrip(self): + """save_model + load_model preserves weights and vocabulary.""" + with tempfile.TemporaryDirectory() as tmp: + path = f"{tmp}/medgan.pt" + self.model.save_model(path) + + # Build a fresh model and overwrite from disk; weights should match. + other = MedGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + save_dir=tmp, + ) + other.load_model(path) + + for p1, p2 in zip( + self.model.generator.parameters(), + other.generator.parameters(), + ): + self.assertTrue(torch.allclose(p1, p2)) + self.assertEqual(other._idx_to_code, self.model._idx_to_code) + + def test_missing_visits_processor_raises(self): + """A dataset without a 'visits' feature should be rejected.""" + # Build a dataset with a different input feature name. + bad = create_sample_dataset( + samples=[ + {"patient_id": "p1", "codes": ["A", "B"]}, + {"patient_id": "p2", "codes": ["B"]}, + ], + input_schema={"codes": "multi_hot"}, + output_schema={}, + ) + with self.assertRaises(ValueError): + MedGAN(bad, latent_dim=8, hidden_dim=8) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_train_and_generate_on_cuda(self): + """When CUDA is available, the device arg moves training to GPU.""" + with tempfile.TemporaryDirectory() as tmp: + model = MedGAN( + dataset=self.dataset, + latent_dim=8, + hidden_dim=8, + discriminator_hidden_dim=16, + batch_size=2, + ae_epochs=1, + gan_epochs=1, + save_dir=tmp, + ) + model.train_model(self.dataset, device="cuda") + self.assertTrue(next(model.parameters()).is_cuda) + + synthetic = model.generate(num_samples=2, device="cuda") + self.assertEqual(len(synthetic), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_promptehr.py b/tests/core/test_promptehr.py new file mode 100644 index 000000000..4323e3313 --- /dev/null +++ b/tests/core/test_promptehr.py @@ -0,0 +1,161 @@ +import tempfile +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import PromptEHR + + +class TestPromptEHR(unittest.TestCase): + """Test cases for the PromptEHR synthetic-EHR generator.""" + + def setUp(self): + """Set up a synthetic generative dataset (no labels) and a tiny model.""" + self.samples = [ + {"patient_id": "patient-0", "visits": [["A05B", "A05C"], ["A11D"], ["C129"]]}, + {"patient_id": "patient-1", "visits": [["A05B"], ["A04A", "B035"]]}, + {"patient_id": "patient-2", "visits": [["C129", "A11D"], ["A05C"], ["A04A"]]}, + {"patient_id": "patient-3", "visits": [["B035"], ["A05B", "C129"]]}, + ] + + # Generative task: one nested-sequence input feature, no output labels. + self.input_schema = {"visits": "nested_sequence"} + self.output_schema = {} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_promptehr", + ) + + # Small model; embed_dim must be divisible by n_heads. + self.model = PromptEHR( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + prompt_length=4, + max_len=64, + batch_size=2, + epochs=1, + ) + + def test_model_initialization(self): + """Vocab/special-token ids are derived from the processor.""" + self.assertIsInstance(self.model, PromptEHR) + self.assertEqual(self.model.feature_keys, ["visits"]) + self.assertEqual(self.model.label_keys, []) + + proc_vocab = self.dataset.input_processors["visits"].vocab_size() + self.assertEqual(self.model.code_vocab_size, proc_vocab) + self.assertEqual(self.model.bos_id, proc_vocab) + self.assertEqual(self.model.eos_id, proc_vocab + 1) + self.assertEqual(self.model.delim_id, proc_vocab + 2) + self.assertEqual(self.model.mask_id, proc_vocab + 3) + self.assertEqual(self.model.code_prompt_id, proc_vocab + 4) + self.assertEqual(self.model.bart.config.vocab_size, proc_vocab + 5) + # The learnable soft prompt is part of the module parameters. + self.assertEqual( + tuple(self.model.soft_prompt.shape), (4, 16) + ) + + def test_forward_input_format(self): + """The standard dataloader pads the visit dimension into a 3D tensor.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + self.assertIsInstance(batch["visits"], torch.Tensor) + self.assertEqual(batch["visits"].dim(), 3) # (B, max_visits, max_codes) + + def test_model_forward(self): + """Forward returns a finite scalar loss and a probability tensor.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertEqual(ret["loss"].dim(), 0) + self.assertTrue(torch.isfinite(ret["loss"]).all()) + # y_prob: (B, L_dec, vocab_size) + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_prob"].shape[2], self.model.bart.config.vocab_size) + + def test_model_backward(self): + """Backward populates gradients on model parameters (incl. soft prompt).""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward") + # The soft prompt specifically should receive a gradient. + self.assertIsNotNone(self.model.soft_prompt.grad) + + def test_generate(self): + """generate() returns the requested number of decoded synthetic patients.""" + synthetic = self.model.generate(num_samples=4) + + self.assertEqual(len(synthetic), 4) + for i, patient in enumerate(synthetic): + self.assertEqual(patient["patient_id"], f"synthetic_{i}") + self.assertIsInstance(patient["visits"], list) + for visit in patient["visits"]: + self.assertIsInstance(visit, list) + for code in visit: + self.assertIsInstance(code, str) + self.assertNotIn(code, ("", "")) + + def _build_model(self, save_dir): + return PromptEHR( + dataset=self.dataset, + embed_dim=16, + n_heads=2, + n_layers=2, + prompt_length=4, + max_len=64, + batch_size=2, + epochs=1, + save_dir=save_dir, + ) + + def test_train_and_generate_accept_device_arg(self): + """train_model/generate accept an explicit device arg (CPU always works).""" + with tempfile.TemporaryDirectory() as tmp: + model = self._build_model(tmp) + model.train_model(self.dataset, device="cpu") + self.assertEqual(next(model.parameters()).device.type, "cpu") + + synthetic = model.generate(num_samples=2, device="cpu") + self.assertEqual(len(synthetic), 2) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_train_and_generate_on_cuda(self): + """When CUDA is available, the device arg moves training/generation to GPU.""" + with tempfile.TemporaryDirectory() as tmp: + model = self._build_model(tmp) + model.train_model(self.dataset, device="cuda") + self.assertTrue(next(model.parameters()).is_cuda) + + # forward should now run on CUDA without an explicit move. + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + with torch.no_grad(): + ret = model(**batch) + self.assertTrue(ret["y_prob"].is_cuda) + self.assertTrue(torch.isfinite(ret["loss"]).all()) + + synthetic = model.generate(num_samples=2, device="cuda") + self.assertEqual(len(synthetic), 2) + + +if __name__ == "__main__": + unittest.main()