diff --git a/openfl-workspace/torch/histology_s3/src/dataloader.py b/openfl-workspace/torch/histology_s3/src/dataloader.py index d928184fe4..8c925bee18 100644 --- a/openfl-workspace/torch/histology_s3/src/dataloader.py +++ b/openfl-workspace/torch/histology_s3/src/dataloader.py @@ -2,22 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 """You may copy this file as the starting point of your own model.""" - -from collections.abc import Iterable -from logging import getLogger import os import sys +from collections.abc import Iterable +from logging import getLogger - -from openfl.federated import PyTorchDataLoader import numpy as np -from openfl.federated.data.sources.torch.verifiable_map_style_image_folder import VerifiableImageFolder -from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser -from openfl.utilities.path_check import is_directory_traversal import torch from torch.utils.data import random_split from torchvision.transforms import ToTensor +from openfl.federated import PyTorchDataLoader +from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser +from openfl.federated.data.sources.torch.verifiable_map_style_image_folder import VerifiableImageFolder +from openfl.utilities.path_check import is_directory_traversal logger = getLogger(__name__) @@ -25,11 +23,11 @@ class PyTorchHistologyVerifiableDataLoader(PyTorchDataLoader): """PyTorch data loader for Histology dataset.""" - def __init__(self, data_path, batch_size, **kwargs): + def __init__(self, data_path=None, batch_size=32, **kwargs): """Instantiate the data object. Args: - data_path: The file path to the data + data_path: The file path to the data. If None, initialize for model creation only. batch_size: The batch size of the data loader **kwargs: Additional arguments, passed to super init and load_mnist_shard @@ -61,17 +59,19 @@ def __init__(self, data_path, batch_size, **kwargs): else: logger.info("The dataset is valid.") - _, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard( - verifible_dataset_info=verifible_dataset_info, verify_dataset_items=verify_dataset_items, **kwargs) + X_train, y_train, X_valid, y_valid = load_histology_shard( + verifible_dataset_info=verifible_dataset_info, + verify_dataset_items=verify_dataset_items, + feature_shape=self.feature_shape, + num_classes=self.num_classes, + **kwargs + ) self.X_train = X_train self.y_train = y_train self.X_valid = X_valid self.y_valid = y_valid - self.num_classes = num_classes - - def get_feature_shape(self): """Returns the shape of an example feature array. @@ -101,7 +101,6 @@ def get_verifiable_dataset_info(self, data_path): Raises: SystemExit: If `data_path` is invalid or missing `datasources.json`. """ - """Return the verifiable dataset info object for the given data sources.""" if data_path and is_directory_traversal(data_path): logger.error("Data path is out of the openfl workspace scope.") if not os.path.isdir(data_path): @@ -152,7 +151,8 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp n_train = int(train_split_ratio * len(dataset)) n_valid = len(dataset) - n_train ds_train, ds_val = random_split( - dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0)) + dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0) + ) # create the shards X_train, y_train = list(zip(*ds_train)) @@ -164,14 +164,16 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp return (X_train, y_train), (X_valid, y_valid) - -def load_histology_shard(verifible_dataset_info, verify_dataset_items, +def load_histology_shard(verifible_dataset_info, verify_dataset_items, feature_shape=None, num_classes=None, categorical=False, channels_last=False, **kwargs): """ Load the Histology dataset. Args: - data_path (str): path to data directory + verifible_dataset_info (VerifiableDatasetInfo): The verifiable dataset info object. + verify_dataset_items (bool): True = verify the dataset items while loading data + feature_shape (list, optional): The shape of input features. + num_classes (int, optional): Number of classes. categorical (bool): True = convert the labels to one-hot encoded vectors (Default = True) channels_last (bool): True = The input images have the channels @@ -179,26 +181,23 @@ def load_histology_shard(verifible_dataset_info, verify_dataset_items, **kwargs: Additional parameters to pass to the function Returns: - list: The input shape - int: The number of classes numpy.ndarray: The training data numpy.ndarray: The training labels numpy.ndarray: The validation data numpy.ndarray: The validation labels """ - img_rows, img_cols = 150, 150 - num_classes = 8 + img_rows, img_cols = feature_shape[1], feature_shape[2] - (X_train, y_train), (X_valid, y_valid) = _load_raw_data(verifible_dataset_info, verify_dataset_items, **kwargs) + (X_train, y_train), (X_valid, y_valid) = _load_raw_data( + verifible_dataset_info, verify_dataset_items, **kwargs + ) if channels_last: X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3) X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 3) - input_shape = (img_rows, img_cols, 3) else: X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols) X_valid = X_valid.reshape(X_valid.shape[0], 3, img_rows, img_cols) - input_shape = (3, img_rows, img_cols) logger.info(f'Histology > X_train Shape : {X_train.shape}') logger.info(f'Histology > y_train Shape : {y_train.shape}') @@ -210,4 +209,4 @@ def load_histology_shard(verifible_dataset_info, verify_dataset_items, y_train = np.eye(num_classes)[y_train] y_valid = np.eye(num_classes)[y_valid] - return input_shape, num_classes, X_train, y_train, X_valid, y_valid + return X_train, y_train, X_valid, y_valid diff --git a/openfl/federated/data/sources/data_sources_json_parser.py b/openfl/federated/data/sources/data_sources_json_parser.py index d86724b6da..d4f0f87c8f 100644 --- a/openfl/federated/data/sources/data_sources_json_parser.py +++ b/openfl/federated/data/sources/data_sources_json_parser.py @@ -15,7 +15,9 @@ class DataSourcesJsonParser: @staticmethod - def parse(json_string: str) -> VerifiableDatasetInfo: + def parse( + json_string: str, label="", metadata="", check_dir_traversal=False + ) -> VerifiableDatasetInfo: """ Parse a JSON string into a dictionary. @@ -31,48 +33,69 @@ def parse(json_string: str) -> VerifiableDatasetInfo: except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format: {e}") - datasources = DataSourcesJsonParser.process_data_sources(data) + datasources = DataSourcesJsonParser.process_data_sources(data, check_dir_traversal) if not datasources: raise ValueError("No data sources were found.") return VerifiableDatasetInfo( data_sources=datasources, - label="", + label=label, + metadata=metadata, ) @staticmethod - def process_data_sources(data): + def process_data_sources(data, check_dir_traversal=False): """Process and validate data sources.""" - cwd = os.getcwd() datasources = [] + local_datasources = {} for source_name, source_info in data.items(): source_type = source_info.get("type", None) if source_type is None: raise ValueError(f"Missing 'type' key in data source configuration: {source_info}") params = source_info.get("params", {}) - if source_type == "local": - datasources.append( - DataSourcesJsonParser.process_local_source(source_name, params, cwd) - ) + if source_type == "fs": + local_datasources[source_name] = params elif source_type == "s3": datasources.append(DataSourcesJsonParser.process_s3_source(source_name, params)) - elif source_type == "azure_blob": + elif source_type == "ab": datasources.append( DataSourcesJsonParser.process_azure_blob_source(source_name, params) ) + if local_datasources: + DataSourcesJsonParser.process_local_sources( + local_datasources, datasources, check_dir_traversal + ) return [ds for ds in datasources if ds] @staticmethod - def process_local_source(source_name, params, cwd): - """Process a local data source.""" - path = params.get("path", None) - if not path: - raise ValueError(f"Missing 'path' parameter for local data source '{source_name}'") - abs_path = os.path.abspath(path) - rel_path = os.path.relpath(abs_path, cwd) - if rel_path and not is_directory_traversal(rel_path): - return LocalDataSource(source_name, rel_path, base_path=Path(".")) - else: - raise ValueError(f"Invalid path for local data source '{source_name}': {path}.") + def process_local_sources(local_datasources, datasources, check_dir_traversal=False): + """Process and validate local data sources.""" + absolute_paths = {} + for source_name, params in local_datasources.items(): + if "path" not in params: + raise ValueError( + f"Missing required field 'path' for local data source '{source_name}'." + ) + absolute_paths[source_name] = os.path.realpath(params.get("path")) + + # The reason we use common base_dir and source_path relative to that base + # is to simplify path management in containerized environments, such as Docker. + # By using a common base_dir, we can ensure that paths remain consistent + # when mounting volumes, as only the base_dir needs to be adjusted to point + # to the mount path inside the container. + # This way, we only need to adjust the base_dir to point to the mount path. + base_dir = os.path.commonpath(absolute_paths.values()) + for source_name, data_path in absolute_paths.items(): + relative_path = os.path.relpath(data_path, base_dir) + if check_dir_traversal and is_directory_traversal(data_path): + raise ValueError( + f"Invalid path for local data source '{source_name}': {data_path}." + f" Data path is out of the openfl workspace scope." + ) + datasources.append( + LocalDataSource( + name=source_name, source_path=Path(relative_path), base_path=base_dir + ) + ) @staticmethod def process_s3_source(source_name, params): diff --git a/openfl/interface/collaborator.py b/openfl/interface/collaborator.py index ba0a84977c..b4607a4502 100644 --- a/openfl/interface/collaborator.py +++ b/openfl/interface/collaborator.py @@ -231,9 +231,11 @@ def register_data_path(collaborator_name, data_path=None, silent=False): type=ClickPath(exists=True), help=( "Path to directory containing sources.json file defining the data sources of the dataset. " - "This file should contain a JSON object with the data sources to be registered. For 'local'" - " type, 'params' must include: 'path'. For 's3' type, 'params' must include: 'uri', " - "'access_key_env_name', 'secret_key_env_name', 'secret_name', and optionally 'endpoint'." + "This file should contain a JSON object with the data sources to be registered. For local " + "data source, 'type' is 'fs', and 'params' must include: 'path'. For 's3' type, 'params' " + "must include: 'uri', 'access_key_env_name', 'secret_key_env_name', 'secret_name', and " + "optionally 'endpoint'. For azure_blob, 'type' is 'ab', and 'params' must include: " + "'connection_string', 'container_name', and optionally 'folder_prefix'." ), ) def calchash(data_path): @@ -258,7 +260,7 @@ def calchash(data_path): sys.exit(1) with open(datasources_json_path, "r", encoding="utf-8") as file: data = file.read() - vds = DataSourcesJsonParser.parse(data) + vds = DataSourcesJsonParser.parse(data, check_dir_traversal=True) root_hash = vds.create_dataset_hash() hash_file_path = os.path.join(data_path, "hash.txt") with open(hash_file_path, "w", encoding="utf-8") as hash_file: