From 18ac2ff8919d60108644f2d8538127f3e654218b Mon Sep 17 00:00:00 2001 From: William Lindskog-Munzing Date: Mon, 15 Jun 2026 20:33:17 -0400 Subject: [PATCH] fix(examples): avoid torch dataset formatter --- examples/custom-mods/custom_mods/task.py | 2 +- .../quickstart-pytorch/pytorchexample/task.py | 2 +- .../whisper-federated-finetuning/centralized.py | 12 ++++++++---- .../whisper_example/client_app.py | 8 ++++---- .../whisper_example/dataset.py | 15 +++++++++++++++ .../whisper_example/server_app.py | 4 ++-- 6 files changed, 31 insertions(+), 12 deletions(-) diff --git a/examples/custom-mods/custom_mods/task.py b/examples/custom-mods/custom_mods/task.py index 5184de6f488b..7f50e862af5b 100644 --- a/examples/custom-mods/custom_mods/task.py +++ b/examples/custom-mods/custom_mods/task.py @@ -68,7 +68,7 @@ def load_centralized_dataset(): """Load test set and return dataloader.""" # Load entire test set test_dataset = load_dataset("uoft-cs/cifar10", split="test") - dataset = test_dataset.with_format("torch").with_transform(apply_transforms) + dataset = test_dataset.with_transform(apply_transforms) return DataLoader(dataset, batch_size=128) diff --git a/examples/quickstart-pytorch/pytorchexample/task.py b/examples/quickstart-pytorch/pytorchexample/task.py index f701ceb889c5..01cc561c7c43 100644 --- a/examples/quickstart-pytorch/pytorchexample/task.py +++ b/examples/quickstart-pytorch/pytorchexample/task.py @@ -68,7 +68,7 @@ def load_centralized_dataset(): """Load test set and return dataloader.""" # Load entire test set test_dataset = load_dataset("uoft-cs/cifar10", split="test") - dataset = test_dataset.with_format("torch").with_transform(apply_transforms) + dataset = test_dataset.with_transform(apply_transforms) return DataLoader(dataset, batch_size=128) diff --git a/examples/whisper-federated-finetuning/centralized.py b/examples/whisper-federated-finetuning/centralized.py index 046ee52c7c9c..4cd4b135564d 100644 --- a/examples/whisper-federated-finetuning/centralized.py +++ b/examples/whisper-federated-finetuning/centralized.py @@ -6,7 +6,11 @@ from torch.utils.data import DataLoader from transformers import WhisperProcessor -from whisper_example.dataset import get_encoding_fn, prepare_silences_dataset +from whisper_example.dataset import ( + get_encoding_fn, + prepare_silences_dataset, + with_torch_transform, +) from whisper_example.model import ( construct_balanced_sampler, eval_model, @@ -80,13 +84,13 @@ def main(): sampler = construct_balanced_sampler(full_train_dataset) # Prepare dataloaders - train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"]) + train_dataset = with_torch_transform(full_train_dataset) train_loader = DataLoader( train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler ) - val_encoded = val_encoded.with_format("torch", columns=["data", "targets"]) + val_encoded = with_torch_transform(val_encoded) val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4) - test_dataset = test_encoded.with_format("torch", columns=["data", "targets"]) + test_dataset = with_torch_transform(test_encoded) test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4) # Model to cuda, set criterion, classification layer to train and optimiser diff --git a/examples/whisper-federated-finetuning/whisper_example/client_app.py b/examples/whisper-federated-finetuning/whisper_example/client_app.py index 0f3d392dc8a1..5cadcdd19891 100644 --- a/examples/whisper-federated-finetuning/whisper_example/client_app.py +++ b/examples/whisper-federated-finetuning/whisper_example/client_app.py @@ -15,7 +15,7 @@ from flwr.clientapp import ClientApp from torch.utils.data import DataLoader -from whisper_example.dataset import load_data +from whisper_example.dataset import load_data, with_torch_transform from whisper_example.model import construct_balanced_sampler, get_model, train_one_epoch torch.set_float32_matmul_precision( @@ -58,13 +58,13 @@ def train(msg: Message, context: Context): partition_id=partition_id, remove_cols=context.run_config["remove-cols"], ) - trainset = partition.with_format("torch", columns=["data", "targets"]) torch.set_num_threads(og_threads) # construct sampler in order to have balanced batches sampler = None - if len(trainset) > batch_size: - sampler = construct_balanced_sampler(trainset) + if len(partition) > batch_size: + sampler = construct_balanced_sampler(partition) + trainset = with_torch_transform(partition) # Construct dataloader train_loader = DataLoader( diff --git a/examples/whisper-federated-finetuning/whisper_example/dataset.py b/examples/whisper-federated-finetuning/whisper_example/dataset.py index f2615ac53fab..77ed57944ed1 100644 --- a/examples/whisper-federated-finetuning/whisper_example/dataset.py +++ b/examples/whisper-federated-finetuning/whisper_example/dataset.py @@ -2,6 +2,7 @@ import random +import torch from datasets import Dataset, concatenate_datasets, load_from_disk from flwr_datasets import FederatedDataset from flwr_datasets.partitioner import GroupedNaturalIdPartitioner @@ -53,6 +54,20 @@ def load_data_from_disk(data_path): return load_from_disk(data_path) +def _apply_torch_transform(batch): + """Convert encoded columns to torch tensors.""" + if "data" in batch: + batch["data"] = torch.as_tensor(batch["data"], dtype=torch.float32) + if "targets" in batch: + batch["targets"] = torch.as_tensor(batch["targets"], dtype=torch.long) + return batch + + +def with_torch_transform(dataset: Dataset) -> Dataset: + """Return a dataset that lazily converts encoded columns to torch tensors.""" + return dataset.with_transform(_apply_torch_transform, columns=["data", "targets"]) + + def get_encoding_fn(processor): """Return a function to use to pre-process/encode the SpeechCommands dataset. diff --git a/examples/whisper-federated-finetuning/whisper_example/server_app.py b/examples/whisper-federated-finetuning/whisper_example/server_app.py index 22679b4cdf59..011675fdec6c 100644 --- a/examples/whisper-federated-finetuning/whisper_example/server_app.py +++ b/examples/whisper-federated-finetuning/whisper_example/server_app.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from transformers import WhisperProcessor -from whisper_example.dataset import get_encoding_fn +from whisper_example.dataset import get_encoding_fn, with_torch_transform from whisper_example.model import eval_model, get_model # Create ServerApp @@ -97,7 +97,7 @@ def global_evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord: encoded = val_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols) torch.set_num_threads(og_threads) - val_encoded = encoded.with_format("torch", columns=["data", "targets"]) + val_encoded = with_torch_transform(encoded) val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4) # Run global evaluation