From 0081add474d1210759f3f1214340623375ccfb74 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Fri, 27 Feb 2026 22:28:54 +0100 Subject: [PATCH 01/18] refactor: io/hdf, processing/harmonization, core modules, logging, migration doc (v6.4.0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - io/hdf: extract write_table_to_hdf5 to io.hdf, writers re-exports - processing/harmonization: harmonize_data_frame_columns, used in Survey.get_values - core: add core.table, core.survey, core.dataset; tables/surveys/survey_collections re-export - logging: add to configuration.models, google_colab, statshelpers; fix typo in core.table - docs: MIGRATION_IMPORTS.md (import mapping when re-exports removed), REFACTORING_PLAN §3.4 - CHANGELOG + bump to 6.4.0 Made-with: Cursor --- CHANGELOG.md | 9 + docs/MIGRATION_IMPORTS.md | 154 +++++++ docs/REFACTORING_PLAN.md | 4 +- .../configuration/models.py | 5 + openfisca_survey_manager/core/__init__.py | 10 +- openfisca_survey_manager/core/dataset.py | 152 +++++++ openfisca_survey_manager/core/survey.py | 314 ++++++++++++++ openfisca_survey_manager/core/table.py | 313 ++++++++++++++ openfisca_survey_manager/google_colab.py | 7 +- openfisca_survey_manager/io/hdf.py | 45 ++ openfisca_survey_manager/io/writers.py | 38 +- .../processing/__init__.py | 9 +- .../processing/harmonization.py | 34 ++ openfisca_survey_manager/statshelpers.py | 4 + .../survey_collections.py | 154 +------ openfisca_survey_manager/surveys.py | 390 +----------------- openfisca_survey_manager/tables.py | 367 +--------------- pyproject.toml | 2 +- 18 files changed, 1072 insertions(+), 939 deletions(-) create mode 100644 docs/MIGRATION_IMPORTS.md create mode 100644 openfisca_survey_manager/core/dataset.py create mode 100644 openfisca_survey_manager/core/survey.py create mode 100644 openfisca_survey_manager/core/table.py create mode 100644 openfisca_survey_manager/io/hdf.py create mode 100644 openfisca_survey_manager/processing/harmonization.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bf85a4b..14371e39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +# 6.4.0 + +* Refactor (no breaking API changes) + - **io/hdf**: Extract HDF5 write logic into `io.hdf` (`write_table_to_hdf5`); `io.writers` re-exports for compatibility + - **processing/harmonization**: Add `harmonize_data_frame_columns` (lowercase, rename ident); used in `Survey.get_values`; export from `processing` + - **core**: Add `core.table`, `core.survey`, `core.dataset` (Table, Survey, NoMoreDataError, SurveyCollection); root `tables.py`, `surveys.py`, `survey_collections.py` re-export for compatibility + - **Logging**: Extend to all modules — add logger to `configuration.models`, `google_colab`, `statshelpers`; fix typo "folloging" → "following" in `core.table` + - **Docs**: Add `docs/MIGRATION_IMPORTS.md` (import mapping and steps when re-exports will be removed, with breaking-change warning); update `REFACTORING_PLAN.md` (§3.4 Logging done) + # 6.3.1 * Technical changes diff --git a/docs/MIGRATION_IMPORTS.md b/docs/MIGRATION_IMPORTS.md new file mode 100644 index 00000000..d7202ee5 --- /dev/null +++ b/docs/MIGRATION_IMPORTS.md @@ -0,0 +1,154 @@ +# Migration des imports après retrait des ré-exports + +Ce document décrit les changements à effectuer **lorsqu’on retirera les ré-exports** (fichiers de compatibilité) à la racine du package : mise à jour de tous les imports vers les nouveaux chemins, puis suppression des anciens modules. + +**Référence** : `docs/REFACTORING_PLAN.md`. + +--- + +## Mise en garde + +Le retrait des ré-exports est une **breaking change** : tout code (interne ou externe) qui importe depuis les anciens chemins (`config`, `paths`, `tables`, `surveys`, `survey_collections`, `read_sas`, `read_spss`, `read_dbf`, `calibration`, `calmar`, `utils`) verra ses imports **échouer** (`ModuleNotFoundError`). Il faut migrer tous les imports **avant** de supprimer les fichiers listés en section 3, et documenter le changement dans le CHANGELOG pour les projets dépendants (ex. openfisca-france-data). + +--- + +## 1. Correspondance ancien → nouveau + +| Ancien import (à supprimer) | Nouvel import (à utiliser) | +|-----------------------------|----------------------------| +| `from openfisca_survey_manager.config import Config` | `from openfisca_survey_manager.configuration.models import Config` | +| `from openfisca_survey_manager.paths import ...` | `from openfisca_survey_manager.configuration.paths import ...` | +| `from openfisca_survey_manager.tables import Table` | `from openfisca_survey_manager.core.table import Table` | +| `from openfisca_survey_manager.surveys import Survey` | `from openfisca_survey_manager.core.survey import Survey` | +| `from openfisca_survey_manager.surveys import NoMoreDataError` | `from openfisca_survey_manager.core.survey import NoMoreDataError` | +| `from openfisca_survey_manager.survey_collections import SurveyCollection` | `from openfisca_survey_manager.core.dataset import SurveyCollection` | +| `from openfisca_survey_manager.read_sas import read_sas` | `from openfisca_survey_manager.io.readers import read_sas` | +| `from openfisca_survey_manager.read_spss import read_spss` | `from openfisca_survey_manager.io.readers import read_spss` | +| `from openfisca_survey_manager.read_dbf import read_dbf` | `from openfisca_survey_manager.io.readers import read_dbf` | +| `from openfisca_survey_manager.calibration import Calibration` | `from openfisca_survey_manager.processing.weights import Calibration` | +| `from openfisca_survey_manager.calmar import calmar` | `from openfisca_survey_manager.processing.weights import calmar` | +| `from openfisca_survey_manager.calmar import check_calmar` | `from openfisca_survey_manager.processing.weights import check_calmar` | +| `from openfisca_survey_manager.utils import do_nothing, load_table, ...` | Voir section 2 (utils) | + +**Symboles exportés par `paths`** (même noms dans `configuration.paths`) : +`config_ini`, `default_config_files_directory`, `is_in_ci`, `openfisca_survey_manager_location`, `private_run_with_data`, `test_config_files_directory`. + +**Symboles exportés par `utils`** : +- Depuis `common.misc` : `asof`, `do_nothing`, `inflate_parameter_leaf`, `inflate_parameters`, `parameters_asof`, `stata_files_to_data_frames`, `variables_asof`. +- Définis dans `utils.py` : `load_table` (à déplacer vers un module adapté, ex. `core` ou `io`, avant suppression de `utils.py`). + +--- + +## 2. Fichiers à modifier quand on retire les ré-exports + +Avant (ou en même temps que) la suppression des fichiers listés en section 3, mettre à jour les imports dans les fichiers suivants. + +### 2.1 Imports depuis `config`, `paths` + +| Fichier | Remplacer | +|---------|-----------| +| `input_dataframe_generator.py` | `paths` → `configuration.paths` | +| `scripts/build_collection.py` | `paths` → `configuration.paths` | +| `temporary.py` | `paths` → `configuration.paths` | +| `google_colab.py` | `paths` → `configuration.paths` | +| `coicop.py` | `paths` → `configuration.paths` | +| `matching.py` | `paths` → `configuration.paths` | +| `tests/test_read_sas.py` | `paths` → `configuration.paths` ; `read_sas` → `io.readers` | +| `tests/test_quantile.py` | `paths` → `configuration.paths` | +| `tests/test_scenario.py` | `paths` → `configuration.paths` | + +### 2.2 Imports depuis `survey_collections`, `surveys`, `tables` + +| Fichier | Remplacer | +|---------|-----------| +| `input_dataframe_generator.py` | `survey_collections`, `surveys` → `core.dataset`, `core.survey` | +| `simulations.py` | `survey_collections`, `utils` → `core.dataset` ; utils → `common.misc` + module de `load_table` | +| `utils.py` | `survey_collections` → `core.dataset` (pour `load_table`) | +| `scripts/build_collection.py` | `survey_collections`, `surveys` → `core.dataset`, `core.survey` | +| `scenarios/abstract_scenario.py` | `calibration`, `surveys` → `processing.weights`, `core.survey` | +| `tests/test_surveys.py` | `survey_collections`, `surveys` → `core.dataset`, `core.survey` | +| `tests/test_coverage_boost.py` | `survey_collections`, `surveys`, `utils` → idem | +| `tests/test_add_survey_to_collection.py` | `survey_collections` → `core.dataset` | +| `tests/test_parquet.py` | `survey_collections` → `core.dataset` ; `surveys` (NoMoreDataError) → `core.survey` | + +### 2.3 Imports depuis `read_sas`, `read_spss`, `read_dbf` + +| Fichier | Remplacer | +|---------|-----------| +| `core/table.py` | `from openfisca_survey_manager import read_sas` → `from openfisca_survey_manager.io.readers import read_sas` ; `read_sas.read_sas` → `read_sas` dans `reader_by_source_format`. Puis `from openfisca_survey_manager.read_spss import read_spss` → `from openfisca_survey_manager.io.readers import read_spss` (dans le try/except). | +| `tests/test_read_sas.py` | `from ...paths import ...` → `configuration.paths` ; `from ...read_sas import read_sas` → `from ...io.readers import read_sas` | + +### 2.4 Imports depuis `calibration`, `calmar` + +| Fichier | Remplacer | +|---------|-----------| +| `scenarios/abstract_scenario.py` | `calibration` → `processing.weights` | +| `tests/test_calibration.py` | `calibration` → `processing.weights` | +| `tests/test_calmar.py` | `calmar` → `processing.weights` | + +### 2.5 Imports depuis `utils` + +| Fichier | Remplacer | +|---------|-----------| +| `simulations.py` | `utils.do_nothing`, `utils.load_table` → `common.misc.do_nothing` + module contenant `load_table` | +| `tests/test_coverage_boost.py` | `utils.do_nothing` → `common.misc.do_nothing` | +| `tests/test_legislation_inflator.py` | `utils.inflate_parameters`, `parameters_asof` → `common.misc` | +| `tests/test_tax_benefit_system_asof.py` | `utils.parameters_asof`, `variables_asof` → `common.misc` | + +**Note** : `load_table` dépend de `SurveyCollection` ; il doit vivre soit dans un module qui importe `core.dataset`, soit être déplacé (ex. `core.dataset` ou un module `io.loaders`) avant de supprimer `utils.py`. + +--- + +## 3. Fichiers à supprimer (ré-exports) + +Une fois tous les imports mis à jour selon les sections 1 et 2, on pourra supprimer les fichiers suivants (ils ne contiennent que des ré-exports) : + +- `config.py` +- `paths.py` +- `tables.py` +- `surveys.py` +- `survey_collections.py` +- `read_sas.py` +- `read_spss.py` +- `read_dbf.py` +- `calibration.py` +- `calmar.py` +- `utils.py` (après déplacement de `load_table` et mise à jour des imports listés en 2.5) + +--- + +## 4. Modules sans ré-export (imports canoniques) + +Ces modules n’ont pas de fichier ré-export à la racine ; le code interne les utilise déjà. Pour du code externe ou de la doc, les imports canoniques sont : + +| Symbole | Import canonique | +|---------|------------------| +| `harmonize_data_frame_columns` | `from openfisca_survey_manager.processing.harmonization import harmonize_data_frame_columns` (ou `from openfisca_survey_manager.processing import harmonize_data_frame_columns`) | +| `write_table_to_hdf5` | `from openfisca_survey_manager.io.hdf import write_table_to_hdf5` (ou `from openfisca_survey_manager.io.writers import write_table_to_hdf5`) | +| `write_table_to_parquet` | `from openfisca_survey_manager.io.writers import write_table_to_parquet` | + +--- + +## 5. Package racine `openfisca_survey_manager` + +Aujourd’hui le `__init__.py` du package n’expose que les exceptions. Si du code externe fait par exemple `from openfisca_survey_manager import read_sas`, il s’appuie sur le sous-module `read_sas.py`. **Après retrait des ré-exports**, ces chemins d’import ne seront plus valides (échec à l’import) ; les migrer vers `from openfisca_survey_manager.io.readers import read_sas` (voir section 1). + +À faire avant ou après la migration : vérifier dans ce dépôt et les projets dépendants (openfisca-france-data, etc.) les imports depuis la racine du package ou depuis les anciens modules listés en section 3. + +--- + +## 6. Ordre recommandé pour la migration + +1. **Déplacer `load_table`** vers un module définitif (ex. `core.dataset` ou `io.loaders`) et mettre à jour les appels (section 2.5). +2. **Mettre à jour tous les imports internes** (section 2) vers les nouveaux chemins, fichier par fichier. +3. **Lancer la suite de tests** : `pytest` ; corriger les oublis jusqu’à 0 échec. +4. **Supprimer les fichiers de ré-export** listés en section 3. +5. **Vérifier les usages externes** (section 5) et documenter les changements dans le CHANGELOG (breaking changes). + +--- + +## 7. Évolutions optionnelles ultérieures + +- Renommer le dossier `common/` en `utils/` une fois `utils.py` supprimé (comme prévu dans le plan de refactoring). +- Renommer `configuration/` en `config/` si on souhaite un nom plus court (en cohérence avec le plan). +- Ces renommages impliqueront une nouvelle vague de mise à jour des imports (configuration → config, common → utils). diff --git a/docs/REFACTORING_PLAN.md b/docs/REFACTORING_PLAN.md index bfd2494d..d5d8d63e 100644 --- a/docs/REFACTORING_PLAN.md +++ b/docs/REFACTORING_PLAN.md @@ -99,8 +99,8 @@ Aujourd’hui ces couches sont entremêlées (ex. lecture + nettoyage dans `tabl ### 3.4 Logging -- Remplacer les `print()` par du `logging` structuré (déjà entamé dans matching, calmar). -- Étendre à tous les modules (readers, writers, calibration, etc.). +- **Fait** : `print()` remplacés par du `logging` structuré (matching, calmar, scenarios, scripts/build_collection, simulations, readers, writers, calibration, core, processing, etc.). +- **Fait** : logging étendu à tous les modules métier (configuration/models, google_colab, statshelpers, et l’ensemble des modules concernés). ### 3.5 Gestion d’erreurs centralisée diff --git a/openfisca_survey_manager/configuration/models.py b/openfisca_survey_manager/configuration/models.py index 571a6ac8..9ab3841a 100644 --- a/openfisca_survey_manager/configuration/models.py +++ b/openfisca_survey_manager/configuration/models.py @@ -1,8 +1,11 @@ """Configuration model (Config class from config.ini).""" import configparser +import logging from pathlib import Path +log = logging.getLogger(__name__) + class Config(configparser.ConfigParser): """Parser for config.ini; used by SurveyCollection and build scripts.""" @@ -16,6 +19,7 @@ def __init__(self, config_files_directory=None): assert config_ini.exists(), f"{config_ini} is not a valid path" self.config_ini = config_ini self.read([config_ini]) + log.debug("Loaded config from %s", config_ini) def save(self): assert self.config_ini, "configuration file path is not defined" @@ -23,3 +27,4 @@ def save(self): config_file = self.config_ini.open("w") self.write(config_file) config_file.close() + log.debug("Saved config to %s", self.config_ini) diff --git a/openfisca_survey_manager/core/__init__.py b/openfisca_survey_manager/core/__init__.py index c7be67cc..af07d7cf 100644 --- a/openfisca_survey_manager/core/__init__.py +++ b/openfisca_survey_manager/core/__init__.py @@ -1,2 +1,8 @@ -# Target: Survey (surveys.py), SurveyCollection, dataset orchestration. -# See docs/REFACTORING_PLAN.md for migration steps. +# Survey, Table, SurveyCollection. Legacy modules re-export for compatibility. +# See docs/REFACTORING_PLAN.md. + +from openfisca_survey_manager.core.dataset import SurveyCollection +from openfisca_survey_manager.core.survey import NoMoreDataError, Survey +from openfisca_survey_manager.core.table import Table + +__all__ = ["NoMoreDataError", "Survey", "SurveyCollection", "Table"] diff --git a/openfisca_survey_manager/core/dataset.py b/openfisca_survey_manager/core/dataset.py new file mode 100644 index 00000000..f56072f0 --- /dev/null +++ b/openfisca_survey_manager/core/dataset.py @@ -0,0 +1,152 @@ +"""SurveyCollection: collection of surveys (dataset orchestration).""" + +import codecs +import collections +import configparser +import json +import logging +from pathlib import Path + +from openfisca_survey_manager.configuration.models import Config +from openfisca_survey_manager.configuration.paths import default_config_files_directory +from openfisca_survey_manager.core.survey import Survey +from openfisca_survey_manager.exceptions import SurveyConfigError + +log = logging.getLogger(__name__) + + +class SurveyCollection: + """A collection of Surveys.""" + + def __init__( + self, config_files_directory=default_config_files_directory, label=None, name=None, json_file_path=None + ): + self.name = name + self.label = label + self.json_file_path = json_file_path + self.surveys = [] + log.debug(f"Initializing SurveyCollection from config file found in {config_files_directory} ..") + config = Config(config_files_directory=config_files_directory) + if label is not None: + self.label = label + if name is not None: + self.name = name + if json_file_path is not None: + self.json_file_path = json_file_path + if "collections" not in config.sections(): + config["collections"] = {} + config.set("collections", self.name, str(self.json_file_path)) + config.save() + elif config is not None: + if config.has_option("collections", self.name): + self.json_file_path = config.get("collections", self.name) + elif config.get("collections", "collections_directory") is not None: + self.json_file_path = str(Path(config.get("collections", "collections_directory")) / (name + ".json")) + + self.config = config + + def __repr__(self): + header = f"""{self.name} +Survey collection of {self.label} +Contains the following surveys : +""" + surveys = [f" {survey.name} : {survey.label} \n" for survey in self.surveys] + return header + "".join(surveys) + + def dump(self, config_files_directory=None, json_file_path=None): + if self.config is not None: + config = self.config + else: + if config_files_directory is not None: + pass + else: + config_files_directory = default_config_files_directory + self.config = Config(config_files_directory=config_files_directory) + + if json_file_path is None: + assert self.json_file_path is not None, "A json_file_path should be provided" + else: + self.json_file_path = json_file_path + + config.set("collections", self.name, str(self.json_file_path)) + config.save() + with codecs.open(str(self.json_file_path), "w", encoding="utf-8") as _file: + json.dump(self.to_json(), _file, ensure_ascii=False, indent=2) + + def fill_store( + self, + source_format=None, + surveys=None, + tables=None, + overwrite=False, + keep_original_parquet_file=False, + encoding=None, + store_format="hdf5", + categorical_strategy="unique_labels", + ): + if surveys is None: + surveys = self.surveys + for survey in surveys: + survey.fill_store( + source_format=source_format, + tables=tables, + overwrite=overwrite, + keep_original_parquet_file=keep_original_parquet_file, + encoding=encoding, + store_format=store_format, + categorical_strategy=categorical_strategy, + ) + self.dump() + + def get_survey(self, survey_name): + available_surveys_names = [survey.name for survey in self.surveys] + assert survey_name in available_surveys_names, ( + f"Survey {survey_name} cannot be found for survey collection {self.name}.\n" + f"Available surveys are :{available_surveys_names}" + ) + return [survey for survey in self.surveys if survey.name == survey_name].pop() + + @classmethod + def load(cls, json_file_path=None, collection=None, config_files_directory=default_config_files_directory): + assert Path(config_files_directory).exists() + config = Config(config_files_directory=config_files_directory) + if json_file_path is None: + assert collection is not None, "A collection is needed" + try: + json_file_path = config.get("collections", collection) + except (configparser.NoOptionError, configparser.NoSectionError) as error: + msg = f"Looking for config file in {config_files_directory}" + log.debug(msg) + log.error(error) + raise error + except Exception as error: + msg = f"Looking for config file in {config_files_directory}" + log.debug(msg) + log.error(error) + raise SurveyConfigError(msg) from error + + with Path(json_file_path).open("r") as _file: + self_json = json.load(_file) + name = self_json["name"] + + self = cls(config_files_directory=config_files_directory, name=name) + self.config = config + with Path(json_file_path).open("r") as _file: + self_json = json.load(_file) + self.json_file_path = json_file_path + self.label = self_json.get("label") + self.name = self_json.get("name") + + surveys = self_json["surveys"] + for survey_name, survey_json in surveys.items(): + survey = Survey(name=survey_name) + self.surveys.append(survey.create_from_json(survey_json)) + return self + + def to_json(self): + self_json = collections.OrderedDict(()) + self_json["name"] = self.name + self_json["surveys"] = collections.OrderedDict(()) + for survey in self.surveys: + self_json["surveys"][survey.name] = survey.to_json() + return self_json diff --git a/openfisca_survey_manager/core/survey.py b/openfisca_survey_manager/core/survey.py new file mode 100644 index 00000000..8cba10b6 --- /dev/null +++ b/openfisca_survey_manager/core/survey.py @@ -0,0 +1,314 @@ +"""Survey: describes survey data and tables.""" + +import collections +import logging +import re +from pathlib import Path + +import pandas +import pyarrow as pa +import pyarrow.parquet as pq +import yaml + +from openfisca_survey_manager.core.table import Table +from openfisca_survey_manager.exceptions import SurveyIOError, SurveyManagerError +from openfisca_survey_manager.processing.harmonization import harmonize_data_frame_columns + +log = logging.getLogger(__name__) + +source_format_by_extension = { + "csv": "csv", + "sas7bdat": "sas", + "dta": "stata", + "Rdata": "Rdata", + "spss": "sav", + "parquet": "parquet", +} + +admissible_source_formats = list(source_format_by_extension.values()) + + +class NoMoreDataError(Exception): + """Raised when the user asks for more data than available in file.""" + + pass + + +class Survey: + """An object to describe survey data.""" + + hdf5_file_path = None + parquet_file_path = None + label = None + name = None + survey_collection = None + + def __init__( + self, name=None, label=None, hdf5_file_path=None, parquet_file_path=None, survey_collection=None, **kwargs + ): + assert name is not None, "A survey should have a name" + self.name = name + self.tables = collections.OrderedDict() + self.informations = {} + self.tables_index = {} + + if label is not None: + self.label = label + + if hdf5_file_path is not None: + self.hdf5_file_path = hdf5_file_path + + if parquet_file_path is not None: + self.parquet_file_path = parquet_file_path + + if survey_collection is not None: + self.survey_collection = survey_collection + + self.informations = kwargs + + def __repr__(self): + header = f"""{self.name} : survey data {self.label} +Contains the following tables : \n""" + tables = yaml.safe_dump(list(self.tables.keys()), default_flow_style=False) + informations = yaml.safe_dump(self.informations, default_flow_style=False) + return header + tables + informations + + @classmethod + def create_from_json(cls, survey_json): + self = cls( + name=survey_json.get("name"), + label=survey_json.get("label"), + hdf5_file_path=survey_json.get("hdf5_file_path"), + parquet_file_path=survey_json.get("parquet_file_path"), + **survey_json.get("informations", {}), + ) + self.tables = survey_json.get("tables") + return self + + def dump(self): + assert self.survey_collection is not None + self.survey_collection.dump() + + def fill_store( + self, + source_format=None, + tables=None, + overwrite=True, + keep_original_parquet_file=False, + encoding=None, + store_format="hdf5", + categorical_strategy="unique_labels", + ): + assert self.survey_collection is not None + assert isinstance(overwrite, (bool, list)) + survey = self + config = survey.survey_collection.config + directory_path = config.get("data", "output_directory") + if not Path(directory_path).is_dir(): + log.warning( + f"{directory_path} who should be the store data directory does not exist: we create the directory" + ) + Path(directory_path).mkdir(parents=True) + + if source_format == "parquet": + store_format = "parquet" + + if store_format == "hdf5" and survey.hdf5_file_path is None: + survey.hdf5_file_path = str(Path(directory_path) / (survey.name + ".h5")) + + if store_format == "parquet" and survey.parquet_file_path is None: + survey.parquet_file_path = str(Path(directory_path) / survey.name) + + self.store_format = store_format + + if source_format is not None: + assert source_format in admissible_source_formats, f"Data source format {source_format} is unknown" + source_formats = [source_format] + else: + source_formats = admissible_source_formats + + for source_format in source_formats: + files = f"{source_format}_files" + for data_file in survey.informations.get(files, []): + name = Path(data_file).stem + extension = Path(data_file).suffix + if tables is None or name in tables: + if keep_original_parquet_file: + if re.match(r".*-\d$", name): + name = name.split("-")[0] + parquet_file = str(Path(data_file).parent) + survey.parquet_file_path = str(Path(data_file).parent.parent) + else: + parquet_file = data_file + survey.parquet_file_path = str(Path(data_file).parent) + table = Table( + label=name, + name=name, + source_format=source_format_by_extension[extension[1:]], + survey=survey, + parquet_file=parquet_file, + ) + table.read_parquet_columns(data_file) + + else: + table = Table( + label=name, + name=name, + source_format=source_format_by_extension[extension[1:]], + survey=survey, + ) + table.fill_store( + data_file, + clean=True, + overwrite=overwrite if isinstance(overwrite, bool) else table.name in overwrite, + encoding=encoding, + categorical_strategy=categorical_strategy, + ) + self.dump() + + def get_value(self, variable, table, lowercase=False, ignorecase=False): + return self.get_values([variable], table) + + def get_values( + self, + variables=None, + table=None, + lowercase=False, + ignorecase=False, + rename_ident=True, + batch_size=None, + batch_index=0, + filter_by=None, + ) -> pandas.DataFrame: + if self.parquet_file_path is None and self.hdf5_file_path is None: + raise SurveyIOError(f"No data file found for survey {self.name}") + if self.hdf5_file_path is not None: + assert Path(self.hdf5_file_path).exists(), ( + f"{self.hdf5_file_path} is not a valid path. This could happen because " + "your data were not builded yet. Please consider using a rebuild option in your code." + ) + store = pandas.HDFStore(self.hdf5_file_path, "r") + if ignorecase: + keys = store.keys() + eligible_tables = [] + for string in keys: + match = re.findall(table, string, re.IGNORECASE) + if match: + eligible_tables.append(match[0]) + if len(eligible_tables) > 1: + raise SurveyManagerError( + f"{table} is ambiguous since the following tables are available: {eligible_tables}" + ) + elif len(eligible_tables) == 0: + raise SurveyIOError(f"No eligible available table in {keys}") + else: + table = eligible_tables[0] + try: + df = store.select(table) + except KeyError: + log.error(f"No table {table} in the file {self.hdf5_file_path}") + log.error( + f"This could happen because your data were not builded yet. Available tables are: {store.keys()}" + ) + store.close() + raise + + store.close() + + elif self.parquet_file_path is not None: + if table is None: + raise SurveyIOError("A table name is needed to retrieve data from a parquet file") + for table_name, table_content in self.tables.items(): + if table == table_name: + parquet_file = table_content.get("parquet_file") + if Path(parquet_file).is_dir(): + for file in Path(parquet_file).iterdir(): + if file.suffix == ".parquet": + one_parquet_file = str(Path(parquet_file) / file) + break + else: + raise SurveyIOError(f"No parquet file found in {parquet_file}") + else: + one_parquet_file = parquet_file + parquet_schema = pq.read_schema(one_parquet_file) + assert len(parquet_schema.names) >= 1, ( + f"The parquet file {table_content.get('parquet_file')} is empty" + ) + if variables is None: + variables = table_content.get("variables") + if filter_by: + df = pq.ParquetDataset(parquet_file, filters=filter_by).read(columns=variables).to_pandas() + elif batch_size: + if Path(parquet_file).is_dir(): + parquet_file = [str(p) for p in Path(parquet_file).glob("*.parquet")] + else: + parquet_file = [parquet_file] + tables = [] + for file_path in parquet_file: + table = pq.read_table(file_path, columns=variables) + tables.append(table) + + final_table = pa.concat_tables(tables) if len(tables) > 1 else tables[0] + record_batches = final_table.to_batches(max_chunksize=batch_size) + if len(record_batches) <= batch_index: + raise NoMoreDataError( + f"Batch {batch_index} not found in {table_name}. Max index is {len(record_batches)}" + ) + df = record_batches[batch_index].to_pandas() + else: + df = pq.ParquetDataset(parquet_file).read(columns=variables).to_pandas() + break + else: + raise SurveyIOError(f"No table {table} found in {self.parquet_file_path}") + + harmonize_data_frame_columns(df, lowercase=lowercase, rename_ident=rename_ident) + + if variables is None: + return df + else: + diff = set(variables) - set(df.columns) + if diff: + raise SurveyIOError(f"The following variable(s) {diff} are missing") + variables = list(set(variables).intersection(df.columns)) + df = df[variables] + return df + + def insert_table(self, label=None, name=None, **kwargs): + parquet_file = kwargs.pop("parquet_file", None) + data_frame = kwargs.pop("data_frame", None) + if data_frame is None: + data_frame = kwargs.pop("dataframe", None) + + if data_frame is not None: + assert isinstance(data_frame, pandas.DataFrame) + variables = kwargs.pop("variables", None) + if variables is not None: + assert set(variables) < set(data_frame.columns) + else: + variables = list(data_frame.columns) + if label is None: + label = name + table = Table(label=label, name=name, survey=self, variables=variables, parquet_file=parquet_file) + assert (table.survey.hdf5_file_path is not None) or (table.survey.parquet_file_path is not None) + if parquet_file is not None: + log.debug(f"Saving table {name} in {table.survey.parquet_file_path}") + data_frame.to_parquet(parquet_file) + else: + log.debug(f"Saving table {name} in {table.survey.hdf5_file_path}") + to_hdf_kwargs = kwargs.pop("to_hdf_kwargs", {}) + table.save_data_frame_to_hdf5(data_frame, **to_hdf_kwargs) + + if name not in self.tables: + self.tables[name] = {} + for key, val in kwargs.items(): + self.tables[name][key] = val + + def to_json(self): + self_json = collections.OrderedDict(()) + self_json["hdf5_file_path"] = str(self.hdf5_file_path) if self.hdf5_file_path else None + self_json["parquet_file_path"] = str(self.parquet_file_path) if self.parquet_file_path else None + self_json["label"] = self.label + self_json["name"] = self.name + self_json["tables"] = self.tables + self_json["informations"] = collections.OrderedDict(sorted(self.informations.items())) + return self_json diff --git a/openfisca_survey_manager/core/table.py b/openfisca_survey_manager/core/table.py new file mode 100644 index 00000000..3b98c222 --- /dev/null +++ b/openfisca_survey_manager/core/table.py @@ -0,0 +1,313 @@ +"""Table: a table of a survey (core I/O and storage).""" + +import collections +import csv +import datetime +import errno +import gc +import logging +import os +from pathlib import Path + +import pandas +from chardet.universaldetector import UniversalDetector +from pyarrow import parquet as pq + +from openfisca_survey_manager import read_sas +from openfisca_survey_manager.exceptions import SurveyIOError +from openfisca_survey_manager.io.writers import write_table_to_hdf5, write_table_to_parquet +from openfisca_survey_manager.processing.cleaning import clean_data_frame + +try: + from openfisca_survey_manager.read_spss import read_spss +except ImportError: + read_spss = None + +log = logging.getLogger(__name__) + +reader_by_source_format = { + "csv": pandas.read_csv, + "sas": read_sas.read_sas, + "spss": read_spss, + "stata": pandas.read_stata, + "parquet": pandas.read_parquet, +} + + +class Table: + """A table of a survey.""" + + label = None + name = None + source_format = None + survey = None + variables = None + parquet_file = None + + def __init__( + self, survey=None, name=None, label=None, source_format=None, variables=None, parquet_file=None, **kwargs + ): + assert name is not None, "A table should have a name" + self.name = name + self.label = label + self.source_format = source_format + self.variables = variables + self.parquet_file = parquet_file + self.informations = kwargs + + from openfisca_survey_manager.core.survey import Survey + + assert isinstance(survey, Survey), f"survey is of type {type(survey)} and not {Survey}" + self.survey = survey + if not survey.tables: + survey.tables = collections.OrderedDict() + + survey.tables[name] = collections.OrderedDict( + source_format=source_format, + variables=variables, + parquet_file=parquet_file, + ) + + def _check_and_log(self, data_file_path, store_file_path): + assert store_file_path is not None, "Store file path cannot be None" + if not Path(data_file_path).is_file(): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), data_file_path) + + log.info( + f"Inserting table {self.name} from file {data_file_path} in store file {store_file_path} " + f"at point {self.name}" + ) + + def _is_stored(self): + if self.survey.hdf5_file_path is not None: + store = pandas.HDFStore(self.survey.hdf5_file_path) + if self.name in store: + log.info(f"Exiting without overwriting {self.name} in {self.survey.hdf5_file_path}") + store.close() + return True + + store.close() + return False + else: + return False + + def _save(self, data_frame: pandas.DataFrame = None, store_format="hdf5"): + assert data_frame is not None + variables = self.variables + + if variables: + stored_variables = list(set(variables).intersection(set(data_frame.columns))) + log.info("The following variables are stored: %s", stored_variables) + if set(stored_variables) != set(variables): + log.info( + "variables wanted by the user that were not available: " + f"{list(set(variables) - set(stored_variables))}" + ) + data_frame = data_frame[stored_variables].copy() + + assert store_format in ["hdf5", "parquet"], f"invalid store_format: {store_format}" + if store_format == "hdf5": + import warnings + + warnings.warn( + "HDF5 will no longer be the default format in a future version. Please use parquet format instead.", + DeprecationWarning, + stacklevel=3, + ) + log.warning( + "HDF5 will no longer be the default format in a future version. Please use parquet format instead." + ) + self.save_data_frame_to_hdf5(data_frame) + else: + parquet_file_path = self.survey.parquet_file_path + log.info(f"Inserting table {self.name} in Parquet file {parquet_file_path}") + self.save_data_frame_to_parquet(data_frame) + gc.collect() + + def fill_store(self, data_file, overwrite: bool = False, clean: bool = False, **kwargs): + if not overwrite and self._is_stored(): + log.info(f"Exiting without overwriting {self.name} in {self.survey.hdf5_file_path}") + return + + start_table_time = datetime.datetime.now() + if self.source_format in ["sas", "parquet"] and "encoding" in kwargs: + del kwargs["encoding"] + data_frame = self.read_source(data_file, **kwargs) + try: + if clean: + clean_data_frame(data_frame) + self._save(data_frame=data_frame, store_format=self.survey.store_format) + log.info(f"File {data_file} has been processed in {datetime.datetime.now() - start_table_time}") + except Exception as e: + log.info(f"Skipping file {data_file} because of following error \n {e}") + raise e + + def read_parquet_columns(self, parquet_file=None) -> list: + if parquet_file is None: + parquet_file = self.parquet_file + log.info(f"Initializing table {self.name} from parquet file {parquet_file}") + self.source_format = "parquet" + parquet_schema = pq.read_schema(parquet_file) + self.variables = parquet_schema.names + self.survey.tables[self.name]["variables"] = self.variables + return self.variables + + def read_source(self, data_file, **kwargs): + source_format = self.source_format + store_file_path = ( + self.survey.hdf5_file_path if self.survey.store_format == "hdf5" else self.survey.parquet_file_path + ) + + self._check_and_log(data_file, store_file_path=store_file_path) + reader = reader_by_source_format[source_format] + categorical_strategy = ( + kwargs.pop("categorical_strategy", "unique_labels") + if source_format == "stata" + else kwargs.pop("categorical_strategy", None) + ) + try: + if source_format == "csv": + try: + data_frame = reader(data_file, **kwargs) + + if len(data_frame.columns) == 1 and ";" in data_frame.columns[0]: + raise SurveyIOError( + "A ';' is present in the unique column name. Looks like we got the wrong separator." + ) + + except Exception: + log.debug(f"Failing to read {data_file}, Trying to infer encoding and dialect/separator") + + detector = UniversalDetector() + with Path(data_file).open("rb") as csvfile: + for line in csvfile: + detector.feed(line) + if detector.done: + break + detector.close() + + encoding = detector.result["encoding"] + confidence = detector.result["confidence"] + + try: + with Path(data_file).open("r", newline="", encoding=encoding) as csvfile: + dialect = csv.Sniffer().sniff(csvfile.read(1024), delimiters=";,") + except Exception: + dialect = None + delimiter = ";" + + log.debug( + f"dialect.delimiter = {dialect.delimiter if dialect is not None else delimiter}, " + f"encoding = {encoding}, confidence = {confidence}" + ) + kwargs["engine"] = "python" + if dialect: + kwargs["dialect"] = dialect + else: + kwargs["delimiter"] = delimiter + kwargs["encoding"] = encoding + data_frame = reader(data_file, **kwargs) + + else: + if "encoding" in kwargs and source_format == "stata": + kwargs.pop("encoding") + if source_format == "stata": + try: + if "convert_categoricals" not in kwargs: + data_frame = reader(data_file, **kwargs) + else: + data_frame = reader(data_file, **kwargs) + except ValueError as e: + if "not unique" in str(e) or "Categorical categories must be unique" in str(e): + log.info( + f"Non-unique value labels detected in {data_file}, " + f"using strategy '{categorical_strategy}'" + ) + + kwargs_no_cat = kwargs.copy() + kwargs_no_cat["convert_categoricals"] = False + data_frame = reader(data_file, **kwargs_no_cat) + + if categorical_strategy == "unique_labels": + from pandas.io.stata import StataReader + + stata_reader = StataReader(data_file) + value_labels = stata_reader.value_labels() + + for col_name, labels in value_labels.items(): + if col_name in data_frame.columns: + unique_labels = {} + seen_labels = {} + + for code, label in labels.items(): + if pandas.isna(code): + unique_labels[code] = label + elif label in seen_labels: + unique_labels[code] = f"{label} ({code})" + else: + unique_labels[code] = label + seen_labels[label] = code + + code_to_label = {code: unique_labels[code] for code in sorted(labels.keys())} + + data_frame[col_name] = data_frame[col_name].map(code_to_label) + data_frame[col_name] = pandas.Categorical( + data_frame[col_name], + categories=list(code_to_label.values()), + ordered=False, + ) + + elif categorical_strategy == "codes": + from pandas.io.stata import StataReader + + stata_reader = StataReader(data_file) + value_labels = stata_reader.value_labels() + + for col_name, labels in value_labels.items(): + if col_name in data_frame.columns: + codes = sorted([c for c in labels if pandas.notna(c)]) + if codes: + data_frame[col_name] = pandas.Categorical( + data_frame[col_name], categories=codes, ordered=False + ) + + elif categorical_strategy == "skip": + pass + else: + log.warning(f"Unknown categorical_strategy '{categorical_strategy}', using 'skip'") + else: + raise + else: + data_frame = reader(data_file, **kwargs) + + except Exception as e: + log.info(f"Error while reading {data_file}") + raise e + + gc.collect() + return data_frame + + def save_data_frame_to_hdf5(self, data_frame, **kwargs): + hdf5_file_path = self.survey.hdf5_file_path + log.info(f"Inserting table {self.name} in HDF file {hdf5_file_path}") + store_path = self.name + write_table_to_hdf5( + data_frame, + hdf5_file_path=hdf5_file_path, + store_path=store_path, + **kwargs, + ) + + self.variables = list(data_frame.columns) + + def save_data_frame_to_parquet(self, data_frame): + parquet_file_path = self.survey.parquet_file_path + self.parquet_file = write_table_to_parquet( + data_frame, + parquet_dir_path=parquet_file_path, + table_name=self.name, + ) + self.variables = list(data_frame.columns) + + self.survey.tables[self.name]["parquet_file"] = self.parquet_file + self.survey.tables[self.name]["variables"] = self.variables diff --git a/openfisca_survey_manager/google_colab.py b/openfisca_survey_manager/google_colab.py index b5cf98a7..817f4a72 100644 --- a/openfisca_survey_manager/google_colab.py +++ b/openfisca_survey_manager/google_colab.py @@ -1,8 +1,11 @@ import configparser +import logging from pathlib import Path from openfisca_survey_manager.paths import default_config_files_directory as config_files_directory +log = logging.getLogger(__name__) + def create_raw_data_ini(value_by_option_by_section=None): """Creates raw_data.ini configureation file @@ -19,5 +22,7 @@ def create_raw_data_ini(value_by_option_by_section=None): for option, value in value_by_option.items(): config_parser.set(section, option, value) - with (Path(config_files_directory) / "raw_data.ini").open("w") as raw_data_config_file: + path = Path(config_files_directory) / "raw_data.ini" + with path.open("w") as raw_data_config_file: config_parser.write(raw_data_config_file) + log.info("Wrote raw_data.ini to %s", path) diff --git a/openfisca_survey_manager/io/hdf.py b/openfisca_survey_manager/io/hdf.py new file mode 100644 index 00000000..fcf0eba3 --- /dev/null +++ b/openfisca_survey_manager/io/hdf.py @@ -0,0 +1,45 @@ +"""HDF5 write support for survey tables.""" + +from __future__ import annotations + +import logging +from typing import Any + +import pandas as pd + +log = logging.getLogger(__name__) + + +def write_table_to_hdf5( + data_frame: pd.DataFrame, + *, + hdf5_file_path: str, + store_path: str, + **kwargs: Any, +) -> None: + """Write a DataFrame to HDF5. + + Mirrors historical behavior from `tables.Table.save_data_frame_to_hdf5`. + May mutate `data_frame` (type conversions) to ensure it can be written. + """ + try: + data_frame.to_hdf(hdf5_file_path, store_path, append=False, **kwargs) + except (TypeError, NotImplementedError): + log.info("Type problem(s) when creating %s in %s", store_path, hdf5_file_path) + dtypes = data_frame.dtypes + # Checking for strings + converted_dtypes = dtypes.isin(["mixed", "unicode"]) + if converted_dtypes.any(): + log.info("The following types are converted to strings %s", dtypes[converted_dtypes]) + for column in dtypes[converted_dtypes].index: + data_frame[column] = data_frame[column].copy().astype(str) + + # Checking for remaining categories + dtypes = data_frame.dtypes + converted_dtypes = dtypes.isin(["category"]) + if not converted_dtypes.empty: # With category table format is needed + log.info( + "The following types are added as category using the table format %s", + dtypes[converted_dtypes], + ) + data_frame.to_hdf(hdf5_file_path, store_path, append=False, format="table", **kwargs) diff --git a/openfisca_survey_manager/io/writers.py b/openfisca_survey_manager/io/writers.py index 8481fe47..25f7063f 100644 --- a/openfisca_survey_manager/io/writers.py +++ b/openfisca_survey_manager/io/writers.py @@ -4,46 +4,14 @@ import logging from pathlib import Path -from typing import Any import pandas as pd -log = logging.getLogger(__name__) - +from openfisca_survey_manager.io.hdf import write_table_to_hdf5 -def write_table_to_hdf5( - data_frame: pd.DataFrame, - *, - hdf5_file_path: str, - store_path: str, - **kwargs: Any, -) -> None: - """Write a DataFrame to HDF5. - - Mirrors historical behavior from `tables.Table.save_data_frame_to_hdf5`. - May mutate `data_frame` (type conversions) to ensure it can be written. - """ - try: - data_frame.to_hdf(hdf5_file_path, store_path, append=False, **kwargs) - except (TypeError, NotImplementedError): - log.info("Type problem(s) when creating %s in %s", store_path, hdf5_file_path) - dtypes = data_frame.dtypes - # Checking for strings - converted_dtypes = dtypes.isin(["mixed", "unicode"]) - if converted_dtypes.any(): - log.info("The following types are converted to strings %s", dtypes[converted_dtypes]) - for column in dtypes[converted_dtypes].index: - data_frame[column] = data_frame[column].copy().astype(str) +log = logging.getLogger(__name__) - # Checking for remaining categories - dtypes = data_frame.dtypes - converted_dtypes = dtypes.isin(["category"]) - if not converted_dtypes.empty: # With category table format is needed - log.info( - "The following types are added as category using the table format %s", - dtypes[converted_dtypes], - ) - data_frame.to_hdf(hdf5_file_path, store_path, append=False, format="table", **kwargs) +__all__ = ["write_table_to_hdf5", "write_table_to_parquet"] def write_table_to_parquet( diff --git a/openfisca_survey_manager/processing/__init__.py b/openfisca_survey_manager/processing/__init__.py index ce2da410..1b9b5249 100644 --- a/openfisca_survey_manager/processing/__init__.py +++ b/openfisca_survey_manager/processing/__init__.py @@ -2,6 +2,13 @@ # See docs/REFACTORING_PLAN.md for migration steps. from openfisca_survey_manager.processing.cleaning import clean_data_frame +from openfisca_survey_manager.processing.harmonization import harmonize_data_frame_columns from openfisca_survey_manager.processing.weights import Calibration, calmar, check_calmar -__all__ = ["Calibration", "calmar", "check_calmar", "clean_data_frame"] +__all__ = [ + "Calibration", + "calmar", + "check_calmar", + "clean_data_frame", + "harmonize_data_frame_columns", +] diff --git a/openfisca_survey_manager/processing/harmonization.py b/openfisca_survey_manager/processing/harmonization.py new file mode 100644 index 00000000..1edac60a --- /dev/null +++ b/openfisca_survey_manager/processing/harmonization.py @@ -0,0 +1,34 @@ +"""Column harmonization for survey data (lowercase, ident renaming).""" + +import logging +import re + +import pandas as pd + +log = logging.getLogger(__name__) + +# Column names matching this pattern (e.g. ident01, ident2019) are renamed to "ident" +IDENT_COLUMN_PATTERN = re.compile(r"(?i)ident\d{2,4}$") + + +def harmonize_data_frame_columns( + data_frame: pd.DataFrame, + *, + lowercase: bool = False, + rename_ident: bool = True, +) -> None: + """Harmonize column names in place. + + - If lowercase: rename all columns to lowercase. + - If rename_ident: rename the first column matching ident pattern (e.g. ident01, ident2019) to "ident". + """ + if lowercase: + columns = {col: col.lower() for col in data_frame.columns} + data_frame.rename(columns=columns, inplace=True) + + if rename_ident: + for column_name in data_frame.columns: + if IDENT_COLUMN_PATTERN.match(str(column_name)) is not None: + data_frame.rename(columns={column_name: "ident"}, inplace=True) + log.info("%s column have been replaced by ident", column_name) + break diff --git a/openfisca_survey_manager/statshelpers.py b/openfisca_survey_manager/statshelpers.py index ade97abb..3960afcf 100644 --- a/openfisca_survey_manager/statshelpers.py +++ b/openfisca_survey_manager/statshelpers.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import pandas as pd import weightedcalcs as wc @@ -5,6 +7,8 @@ from numpy import argsort, asarray, cumsum, linspace, ones, repeat, zeros from numpy import logical_and as and_ +log = logging.getLogger(__name__) + def gini(values, weights=None): """Computes Gini coefficient (normalized to 1). diff --git a/openfisca_survey_manager/survey_collections.py b/openfisca_survey_manager/survey_collections.py index ba25168b..1a676ee2 100644 --- a/openfisca_survey_manager/survey_collections.py +++ b/openfisca_survey_manager/survey_collections.py @@ -1,154 +1,8 @@ -import codecs -import collections -import configparser -import json -import logging -from pathlib import Path +"""Re-export for backward compatibility. -from openfisca_survey_manager.config import Config -from openfisca_survey_manager.exceptions import SurveyConfigError -from openfisca_survey_manager.paths import default_config_files_directory -from openfisca_survey_manager.surveys import Survey - -log = logging.getLogger(__name__) - - -class SurveyCollection: - """A collection of Surveys""" - - def __init__( - self, config_files_directory=default_config_files_directory, label=None, name=None, json_file_path=None - ): - self.name = name - self.label = label - self.json_file_path = json_file_path - self.surveys = [] - log.debug(f"Initializing SurveyCollection from config file found in {config_files_directory} ..") - config = Config(config_files_directory=config_files_directory) - if label is not None: - self.label = label - if name is not None: - self.name = name - if json_file_path is not None: - self.json_file_path = json_file_path - if "collections" not in config.sections(): - config["collections"] = {} - config.set("collections", self.name, str(self.json_file_path)) - config.save() - elif config is not None: - if config.has_option("collections", self.name): - self.json_file_path = config.get("collections", self.name) - elif config.get("collections", "collections_directory") is not None: - self.json_file_path = str(Path(config.get("collections", "collections_directory")) / (name + ".json")) - - self.config = config - - def __repr__(self): - header = f"""{self.name} -Survey collection of {self.label} -Contains the following surveys : +Prefer: from openfisca_survey_manager.core.dataset import SurveyCollection. """ - surveys = [f" {survey.name} : {survey.label} \n" for survey in self.surveys] - return header + "".join(surveys) - - def dump(self, config_files_directory=None, json_file_path=None): - """ - Dump the survey collection to a json file - And set the json file path in the config file - """ - if self.config is not None: - config = self.config - else: - if config_files_directory is not None: - pass - else: - config_files_directory = default_config_files_directory - self.config = Config(config_files_directory=config_files_directory) - - if json_file_path is None: - assert self.json_file_path is not None, "A json_file_path should be provided" - else: - self.json_file_path = json_file_path - - config.set("collections", self.name, str(self.json_file_path)) - config.save() - with codecs.open(str(self.json_file_path), "w", encoding="utf-8") as _file: - json.dump(self.to_json(), _file, ensure_ascii=False, indent=2) - - def fill_store( - self, - source_format=None, - surveys=None, - tables=None, - overwrite=False, - keep_original_parquet_file=False, - encoding=None, - store_format="hdf5", - categorical_strategy="unique_labels", - ): - if surveys is None: - surveys = self.surveys - for survey in surveys: - survey.fill_store( - source_format=source_format, - tables=tables, - overwrite=overwrite, - keep_original_parquet_file=keep_original_parquet_file, - encoding=encoding, - store_format=store_format, - categorical_strategy=categorical_strategy, - ) - self.dump() - - def get_survey(self, survey_name): - available_surveys_names = [survey.name for survey in self.surveys] - assert survey_name in available_surveys_names, ( - f"Survey {survey_name} cannot be found for survey collection {self.name}.\n" - f"Available surveys are :{available_surveys_names}" - ) - return [survey for survey in self.surveys if survey.name == survey_name].pop() - - @classmethod - def load(cls, json_file_path=None, collection=None, config_files_directory=default_config_files_directory): - assert Path(config_files_directory).exists() - config = Config(config_files_directory=config_files_directory) - if json_file_path is None: - assert collection is not None, "A collection is needed" - try: - json_file_path = config.get("collections", collection) - except (configparser.NoOptionError, configparser.NoSectionError) as error: - msg = f"Looking for config file in {config_files_directory}" - log.debug(msg) - log.error(error) - raise error - except Exception as error: - msg = f"Looking for config file in {config_files_directory}" - log.debug(msg) - log.error(error) - raise SurveyConfigError(msg) from error - - with Path(json_file_path).open("r") as _file: - self_json = json.load(_file) - name = self_json["name"] - - self = cls(config_files_directory=config_files_directory, name=name) - self.config = config - with Path(json_file_path).open("r") as _file: - self_json = json.load(_file) - self.json_file_path = json_file_path - self.label = self_json.get("label") - self.name = self_json.get("name") - surveys = self_json["surveys"] - for survey_name, survey_json in surveys.items(): - survey = Survey(name=survey_name) - self.surveys.append(survey.create_from_json(survey_json)) - return self +from openfisca_survey_manager.core.dataset import SurveyCollection - def to_json(self): - self_json = collections.OrderedDict(()) - self_json["name"] = self.name - self_json["surveys"] = collections.OrderedDict(()) - for survey in self.surveys: - self_json["surveys"][survey.name] = survey.to_json() - return self_json +__all__ = ["SurveyCollection"] diff --git a/openfisca_survey_manager/surveys.py b/openfisca_survey_manager/surveys.py index 77dd566e..4283544f 100644 --- a/openfisca_survey_manager/surveys.py +++ b/openfisca_survey_manager/surveys.py @@ -1,388 +1,8 @@ -#! /usr/bin/env python +"""Re-export for backward compatibility. +Prefer: from openfisca_survey_manager.core.survey import Survey, NoMoreDataError. +""" -import collections -import logging -import re -from pathlib import Path +from openfisca_survey_manager.core.survey import NoMoreDataError, Survey -import pandas -import pyarrow as pa -import pyarrow.parquet as pq -import yaml - -from openfisca_survey_manager.exceptions import SurveyIOError, SurveyManagerError - -from .tables import Table - -ident_re = re.compile(r"(?i)ident\d{2,4}$") - -log = logging.getLogger(__name__) - - -source_format_by_extension = { - "csv": "csv", - "sas7bdat": "sas", - "dta": "stata", - "Rdata": "Rdata", - "spss": "sav", - "parquet": "parquet", -} - -admissible_source_formats = list(source_format_by_extension.values()) - - -class NoMoreDataError(Exception): - # Exception when the user ask for more data than available in file - pass - - -class Survey: - """An object to describe survey data""" - - hdf5_file_path = None - parquet_file_path = None - label = None - name = None - survey_collection = None - - def __init__( - self, name=None, label=None, hdf5_file_path=None, parquet_file_path=None, survey_collection=None, **kwargs - ): - assert name is not None, "A survey should have a name" - self.name = name - self.tables = collections.OrderedDict() - self.informations = {} - self.tables_index = {} - - if label is not None: - self.label = label - - if hdf5_file_path is not None: - self.hdf5_file_path = hdf5_file_path - - if parquet_file_path is not None: - self.parquet_file_path = parquet_file_path - - if survey_collection is not None: - self.survey_collection = survey_collection - - self.informations = kwargs - - def __repr__(self): - header = f"""{self.name} : survey data {self.label} -Contains the following tables : \n""" - tables = yaml.safe_dump(list(self.tables.keys()), default_flow_style=False) - informations = yaml.safe_dump(self.informations, default_flow_style=False) - return header + tables + informations - - @classmethod - def create_from_json(cls, survey_json): - self = cls( - name=survey_json.get("name"), - label=survey_json.get("label"), - hdf5_file_path=survey_json.get("hdf5_file_path"), - parquet_file_path=survey_json.get("parquet_file_path"), - **survey_json.get("informations", {}), - ) - self.tables = survey_json.get("tables") - return self - - def dump(self): - assert self.survey_collection is not None - self.survey_collection.dump() - - def fill_store( - self, - source_format=None, - tables=None, - overwrite=True, - keep_original_parquet_file=False, - encoding=None, - store_format="hdf5", - categorical_strategy="unique_labels", - ): - """ - Convert data from the source files to store format either hdf5 or parquet. - If the source is in parquet, the data is not converted. - """ - assert self.survey_collection is not None - assert isinstance(overwrite, (bool, list)) - survey = self - # Create folder if it does not exist - config = survey.survey_collection.config - directory_path = config.get("data", "output_directory") - if not Path(directory_path).is_dir(): - log.warn( - f"{directory_path} who should be the store data directory does not exist: we create the directory" - ) - Path(directory_path).mkdir(parents=True) - - if source_format == "parquet": - store_format = "parquet" - - if store_format == "hdf5" and survey.hdf5_file_path is None: - survey.hdf5_file_path = str(Path(directory_path) / (survey.name + ".h5")) - - if store_format == "parquet" and survey.parquet_file_path is None: - survey.parquet_file_path = str(Path(directory_path) / survey.name) - - self.store_format = store_format - - if source_format is not None: - assert source_format in admissible_source_formats, f"Data source format {source_format} is unknown" - source_formats = [source_format] - else: - source_formats = admissible_source_formats - - for source_format in source_formats: - files = f"{source_format}_files" - for data_file in survey.informations.get(files, []): - name = Path(data_file).stem - extension = Path(data_file).suffix - if tables is None or name in tables: - if keep_original_parquet_file: - # Use folder instead of files if numeric at end of file - if re.match(r".*-\d$", name): - name = name.split("-")[0] - parquet_file = str(Path(data_file).parent) - # Get the parent folder - survey.parquet_file_path = str(Path(data_file).parent.parent) - else: - parquet_file = data_file - survey.parquet_file_path = str(Path(data_file).parent) - table = Table( - label=name, - name=name, - source_format=source_format_by_extension[extension[1:]], - survey=survey, - parquet_file=parquet_file, - ) - table.read_parquet_columns(data_file) - - else: - table = Table( - label=name, - name=name, - source_format=source_format_by_extension[extension[1:]], - survey=survey, - ) - table.fill_store( - data_file, - clean=True, - overwrite=overwrite if isinstance(overwrite, bool) else table.name in overwrite, - encoding=encoding, - categorical_strategy=categorical_strategy, - ) - self.dump() - - def get_value(self, variable, table, lowercase=False, ignorecase=False): - """Get variable value from a survey table. - - Args: - variable: variable to retrieve - table(str): name of the table - lowercase(bool, optional, optional): lowercase variable names, defaults to False - ignorecase: ignore case of table name, defaults to False - - Returns: - pd.DataFrame: dataframe containing the variable - - """ - return self.get_values([variable], table) - - def get_values( - self, - variables=None, - table=None, - lowercase=False, - ignorecase=False, - rename_ident=True, - batch_size=None, - batch_index=0, - filter_by=None, - ) -> pandas.DataFrame: - """Get variables values from a survey table. - - Args: - variables(list, optional, optional): variables to retrieve, defaults to None (retrieve all variables) - table(str, optional, optional): name of the table, defaults to None - ignorecase: ignore case of table name, defaults to False - lowercase(bool, optional, optional): lowercase variable names, defaults to False - rename_ident(bool, optional, optional): rename ident+yr (e.g. ident08) into ident, defaults to True - batch_size(int, optional, optional): batch size for parquet file, defaults to None - batch_index(int, optional, optional): batch index for parquet file, defaults to 0 - - Returns: - pd.DataFrame: dataframe containing the variables - - Raises: - Exception: - - """ - if self.parquet_file_path is None and self.hdf5_file_path is None: - raise SurveyIOError(f"No data file found for survey {self.name}") - if self.hdf5_file_path is not None: - assert Path(self.hdf5_file_path).exists(), ( - f"{self.hdf5_file_path} is not a valid path. This could happen because " - "your data were not builded yet. Please consider using a rebuild option in your code." - ) - store = pandas.HDFStore(self.hdf5_file_path, "r") - if ignorecase: - keys = store.keys() - eligible_tables = [] - for string in keys: - match = re.findall(table, string, re.IGNORECASE) - if match: - eligible_tables.append(match[0]) - if len(eligible_tables) > 1: - raise SurveyManagerError( - f"{table} is ambiguous since the following tables are available: {eligible_tables}" - ) - elif len(eligible_tables) == 0: - raise SurveyIOError(f"No eligible available table in {keys}") - else: - table = eligible_tables[0] - try: - df = store.select(table) - except KeyError: - log.error(f"No table {table} in the file {self.hdf5_file_path}") - log.error( - f"This could happen because your data were not builded yet. Available tables are: {store.keys()}" - ) - store.close() - raise - - store.close() - - elif self.parquet_file_path is not None: - if table is None: - raise SurveyIOError("A table name is needed to retrieve data from a parquet file") - for table_name, table_content in self.tables.items(): - if table == table_name: - parquet_file = table_content.get("parquet_file") - # Is parquet_file a folder or a file? - if Path(parquet_file).is_dir(): - # find first parquet file in folder - for file in Path(parquet_file).iterdir(): - if file.suffix == ".parquet": - one_parquet_file = str(Path(parquet_file) / file) - break - else: - raise SurveyIOError(f"No parquet file found in {parquet_file}") - else: - one_parquet_file = parquet_file - parquet_schema = pq.read_schema(one_parquet_file) - assert len(parquet_schema.names) >= 1, ( - f"The parquet file {table_content.get('parquet_file')} is empty" - ) - if variables is None: - variables = table_content.get("variables") - if filter_by: - df = pq.ParquetDataset(parquet_file, filters=filter_by).read(columns=variables).to_pandas() - elif batch_size: - if Path(parquet_file).is_dir(): - parquet_file = [str(p) for p in Path(parquet_file).glob("*.parquet")] - else: - parquet_file = [parquet_file] - # Initialize an empty list to store the Parquet tables - tables = [] - # Loop through the file paths and read each Parquet file - for file_path in parquet_file: - table = pq.read_table(file_path, columns=variables) - tables.append(table) - - # Concatenate the tables if needed - final_table = pa.concat_tables(tables) if len(tables) > 1 else tables[0] - record_batches = final_table.to_batches(max_chunksize=batch_size) - if len(record_batches) <= batch_index: - raise NoMoreDataError( - f"Batch {batch_index} not found in {table_name}. Max index is {len(record_batches)}" - ) - df = record_batches[batch_index].to_pandas() - # iter_parquet = parquet_file.iter_batches(batch_size=batch_size, columns=variables) - # index = 0 - # while True: - # try: - # batch = next(iter_parquet) - # except StopIteration: - # raise NoMoreDataError( - # f"Batch {batch_index} not found in {table_name}. Max index is {index}" - # ) - # break - # if batch_index == index: - # df = batch.to_pandas() - # break - # index += 1 - else: - df = pq.ParquetDataset(parquet_file).read(columns=variables).to_pandas() - break - else: - raise SurveyIOError(f"No table {table} found in {self.parquet_file_path}") - - if lowercase: - columns = {column_name: column_name.lower() for column_name in df} - df.rename(columns=columns, inplace=True) - - if rename_ident is True: - for column_name in df: - if ident_re.match(str(column_name)) is not None: - df.rename(columns={column_name: "ident"}, inplace=True) - log.info(f"{column_name} column have been replaced by ident") - break - - if variables is None: - return df - else: - diff = set(variables) - set(df.columns) - if diff: - raise SurveyIOError(f"The following variable(s) {diff} are missing") - variables = list(set(variables).intersection(df.columns)) - df = df[variables] - return df - - def insert_table(self, label=None, name=None, **kwargs): - """Insert a table in the Survey object. - - If a pandas dataframe is provided, it is saved in the store file - """ - parquet_file = kwargs.pop("parquet_file", None) - data_frame = kwargs.pop("data_frame", None) - if data_frame is None: - # Try without underscore - data_frame = kwargs.pop("dataframe", None) - - if data_frame is not None: - assert isinstance(data_frame, pandas.DataFrame) - variables = kwargs.pop("variables", None) - if variables is not None: - assert set(variables) < set(data_frame.columns) - else: - variables = list(data_frame.columns) - if label is None: - label = name - table = Table(label=label, name=name, survey=self, variables=variables, parquet_file=parquet_file) - assert (table.survey.hdf5_file_path is not None) or (table.survey.parquet_file_path is not None) - if parquet_file is not None: - log.debug(f"Saving table {name} in {table.survey.parquet_file_path}") - data_frame.to_parquet(parquet_file) - else: - log.debug(f"Saving table {name} in {table.survey.hdf5_file_path}") - to_hdf_kwargs = kwargs.pop("to_hdf_kwargs", {}) - table.save_data_frame_to_hdf5(data_frame, **to_hdf_kwargs) - - if name not in self.tables: - self.tables[name] = {} - for key, val in kwargs.items(): - self.tables[name][key] = val - - def to_json(self): - """Convert the survey to a JSON object.""" - self_json = collections.OrderedDict(()) - self_json["hdf5_file_path"] = str(self.hdf5_file_path) if self.hdf5_file_path else None - self_json["parquet_file_path"] = str(self.parquet_file_path) if self.parquet_file_path else None - self_json["label"] = self.label - self_json["name"] = self.name - self_json["tables"] = self.tables - self_json["informations"] = collections.OrderedDict(sorted(self.informations.items())) - return self_json +__all__ = ["NoMoreDataError", "Survey"] diff --git a/openfisca_survey_manager/tables.py b/openfisca_survey_manager/tables.py index 89951176..57d6ef66 100644 --- a/openfisca_survey_manager/tables.py +++ b/openfisca_survey_manager/tables.py @@ -1,365 +1,8 @@ -"""Tables.""" +"""Re-export for backward compatibility. -import collections -import csv -import datetime -import errno -import gc -import logging -import os -from pathlib import Path +Prefer: from openfisca_survey_manager.core.table import Table. +""" -import pandas -from chardet.universaldetector import UniversalDetector -from pyarrow import parquet as pq +from openfisca_survey_manager.core.table import Table -from openfisca_survey_manager import read_sas -from openfisca_survey_manager.exceptions import SurveyIOError -from openfisca_survey_manager.io.writers import write_table_to_hdf5, write_table_to_parquet -from openfisca_survey_manager.processing.cleaning import clean_data_frame - -try: - from openfisca_survey_manager.read_spss import read_spss -except ImportError: - read_spss = None - - -log = logging.getLogger(__name__) - - -reader_by_source_format = { - # Rdata = pandas.rpy.common.load_data, - "csv": pandas.read_csv, - "sas": read_sas.read_sas, - "spss": read_spss, - "stata": pandas.read_stata, - "parquet": pandas.read_parquet, -} - - -class Table: - """A table of a survey.""" - - label = None - name = None - source_format = None - survey = None - variables = None - parquet_file = None - - def __init__( - self, survey=None, name=None, label=None, source_format=None, variables=None, parquet_file=None, **kwargs - ): - assert name is not None, "A table should have a name" - self.name = name - self.label = label - self.source_format = source_format - self.variables = variables - self.parquet_file = parquet_file - self.informations = kwargs - - from .surveys import Survey # Keep it here to avoid infinite recursion - - assert isinstance(survey, Survey), f"survey is of type {type(survey)} and not {Survey}" - self.survey = survey - if not survey.tables: - survey.tables = collections.OrderedDict() - - survey.tables[name] = collections.OrderedDict( - source_format=source_format, - variables=variables, - parquet_file=parquet_file, - ) - - def _check_and_log(self, data_file_path, store_file_path): - """ - Check if the file exists and log the insertion. - - Args: - data_file_path: Data file path - store_file_path: Store file or dir path - - Raises: - Exception: File not found - """ - assert store_file_path is not None, "Store file path cannot be None" - if not Path(data_file_path).is_file(): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), data_file_path) - - log.info( - f"Inserting table {self.name} from file {data_file_path} in store file {store_file_path} " - f"at point {self.name}" - ) - - def _is_stored(self): - if self.survey.hdf5_file_path is not None: - store = pandas.HDFStore(self.survey.hdf5_file_path) - if self.name in store: - log.info(f"Exiting without overwriting {self.name} in {self.survey.hdf5_file_path}") - store.close() - return True - - store.close() - return False - else: - return False - - def _save(self, data_frame: pandas.DataFrame = None, store_format="hdf5"): - """ - Save a data frame in the store according to is format (HDF5 or Parque). - """ - assert data_frame is not None - variables = self.variables - - if variables: - stored_variables = list(set(variables).intersection(set(data_frame.columns))) - log.info(f"The folloging variables are stored: {stored_variables}") - if set(stored_variables) != set(variables): - log.info( - "variables wanted by the user that were not available: " - f"{list(set(variables) - set(stored_variables))}" - ) - data_frame = data_frame[stored_variables].copy() - - assert store_format in ["hdf5", "parquet"], f"invalid store_format: {store_format}" - if store_format == "hdf5": - import warnings - - warnings.warn( - "HDF5 will no longer be the default format in a future version. Please use parquet format instead.", - DeprecationWarning, - stacklevel=3, - ) - log.warning( - "HDF5 will no longer be the default format in a future version. Please use parquet format instead." - ) - self.save_data_frame_to_hdf5(data_frame) - else: - parquet_file_path = self.survey.parquet_file_path - log.info(f"Inserting table {self.name} in Parquet file {parquet_file_path}") - self.save_data_frame_to_parquet(data_frame) - gc.collect() - - def fill_store(self, data_file, overwrite: bool = False, clean: bool = False, **kwargs): - """ - Fill the store (HDF5 or parquet file) with the table. - Read the `data_file` in parameter and save it to the store. - - Args: - data_file (_type_, optional): The data file path. Defaults to None. - overwrite (bool, optional): Overwrite the data. Defaults to False. - clean (bool, optional): Clean the raw data befoe saving. Defaults to False. - store_format (str, optional): _description_. Defaults to "hdf5". - - Raises: - e: Skip file if error - """ - if not overwrite and self._is_stored(): - log.info(f"Exiting without overwriting {self.name} in {self.survey.hdf5_file_path}") - return - - start_table_time = datetime.datetime.now() - if self.source_format in ["sas", "parquet"] and "encoding" in kwargs: - del kwargs["encoding"] - data_frame = self.read_source(data_file, **kwargs) - try: - if clean: - clean_data_frame(data_frame) - self._save(data_frame=data_frame, store_format=self.survey.store_format) - log.info(f"File {data_file} has been processed in {datetime.datetime.now() - start_table_time}") - except Exception as e: - log.info(f"Skipping file {data_file} because of following error \n {e}") - raise e - - def read_parquet_columns(self, parquet_file=None) -> list: - """ - Initialize the table from a parquet file. - """ - if parquet_file is None: - parquet_file = self.parquet_file - log.info(f"Initializing table {self.name} from parquet file {parquet_file}") - self.source_format = "parquet" - parquet_schema = pq.read_schema(parquet_file) - self.variables = parquet_schema.names - self.survey.tables[self.name]["variables"] = self.variables - return self.variables - - def read_source(self, data_file, **kwargs): - source_format = self.source_format - store_file_path = ( - self.survey.hdf5_file_path if self.survey.store_format == "hdf5" else self.survey.parquet_file_path - ) - - self._check_and_log(data_file, store_file_path=store_file_path) - reader = reader_by_source_format[source_format] - # Extract categorical_strategy early - only stata format uses it - # Other formats (parquet, csv, etc.) don't support it and will error if passed - categorical_strategy = ( - kwargs.pop("categorical_strategy", "unique_labels") - if source_format == "stata" - else kwargs.pop("categorical_strategy", None) - ) - try: - if source_format == "csv": - try: - data_frame = reader(data_file, **kwargs) - - if len(data_frame.columns) == 1 and ";" in data_frame.columns[0]: - raise SurveyIOError( - "A ';' is present in the unique column name. Looks like we got the wrong separator." - ) - - except Exception: - log.debug(f"Failing to read {data_file}, Trying to infer encoding and dialect/separator") - - # Detect encoding - detector = UniversalDetector() - with Path(data_file).open("rb") as csvfile: - for line in csvfile: - detector.feed(line) - if detector.done: - break - detector.close() - - encoding = detector.result["encoding"] - confidence = detector.result["confidence"] - - # Sniff dialect - try: - with Path(data_file).open("r", newline="", encoding=encoding) as csvfile: - dialect = csv.Sniffer().sniff(csvfile.read(1024), delimiters=";,") - except Exception: - # Sometimes the sniffer fails, we switch back to the default ... of french statistical data - dialect = None - delimiter = ";" - - log.debug( - f"dialect.delimiter = {dialect.delimiter if dialect is not None else delimiter}, " - f"encoding = {encoding}, confidence = {confidence}" - ) - kwargs["engine"] = "python" - if dialect: - kwargs["dialect"] = dialect - else: - kwargs["delimiter"] = delimiter - kwargs["encoding"] = encoding - data_frame = reader(data_file, **kwargs) - - else: - # Remove encoding parameter for pandas 2.0+ compatibility (not supported in read_stata) - if "encoding" in kwargs and source_format == "stata": - kwargs.pop("encoding") - # Try to read with categoricals, handle non-unique labels with configurable strategy - if source_format == "stata": - # categorical_strategy already extracted above - - try: - # Try reading with default convert_categoricals (True) if not specified - if "convert_categoricals" not in kwargs: - data_frame = reader(data_file, **kwargs) - else: - data_frame = reader(data_file, **kwargs) - except ValueError as e: - if "not unique" in str(e) or "Categorical categories must be unique" in str(e): - log.info( - f"Non-unique value labels detected in {data_file}, " - f"using strategy '{categorical_strategy}'" - ) - - # Read without categoricals first - kwargs_no_cat = kwargs.copy() - kwargs_no_cat["convert_categoricals"] = False - data_frame = reader(data_file, **kwargs_no_cat) - - # Apply categorical strategy - if categorical_strategy == "unique_labels": - # Solution 2: Make labels unique by adding code suffix - from pandas.io.stata import StataReader - - stata_reader = StataReader(data_file) - value_labels = stata_reader.value_labels() - - for col_name, labels in value_labels.items(): - if col_name in data_frame.columns: - unique_labels = {} - seen_labels = {} - - for code, label in labels.items(): - if pandas.isna(code): - unique_labels[code] = label - elif label in seen_labels: - # Duplicate label: add code as suffix - unique_labels[code] = f"{label} ({code})" - else: - unique_labels[code] = label - seen_labels[label] = code - - # Create mapping code -> unique label - code_to_label = {code: unique_labels[code] for code in sorted(labels.keys())} - - # Map codes to unique labels and create categories - data_frame[col_name] = data_frame[col_name].map(code_to_label) - data_frame[col_name] = pandas.Categorical( - data_frame[col_name], - categories=list(code_to_label.values()), - ordered=False, - ) - - elif categorical_strategy == "codes": - # Solution 1: Use codes as categories - from pandas.io.stata import StataReader - - stata_reader = StataReader(data_file) - value_labels = stata_reader.value_labels() - - for col_name, labels in value_labels.items(): - if col_name in data_frame.columns: - codes = sorted([c for c in labels if pandas.notna(c)]) - if codes: - data_frame[col_name] = pandas.Categorical( - data_frame[col_name], categories=codes, ordered=False - ) - - elif categorical_strategy == "skip": - # Keep as-is (no categories) - pass - else: - log.warning(f"Unknown categorical_strategy '{categorical_strategy}', using 'skip'") - else: - raise - else: - data_frame = reader(data_file, **kwargs) - - except Exception as e: - log.info(f"Error while reading {data_file}") - raise e - - gc.collect() - return data_frame - - def save_data_frame_to_hdf5(self, data_frame, **kwargs): - """Save a data frame in the HDF5 file format.""" - hdf5_file_path = self.survey.hdf5_file_path - log.info(f"Inserting table {self.name} in HDF file {hdf5_file_path}") - store_path = self.name - write_table_to_hdf5( - data_frame, - hdf5_file_path=hdf5_file_path, - store_path=store_path, - **kwargs, - ) - - self.variables = list(data_frame.columns) - - def save_data_frame_to_parquet(self, data_frame): - """Save a data frame in the Parquet file format.""" - parquet_file_path = self.survey.parquet_file_path - self.parquet_file = write_table_to_parquet( - data_frame, - parquet_dir_path=parquet_file_path, - table_name=self.name, - ) - self.variables = list(data_frame.columns) - - self.survey.tables[self.name]["parquet_file"] = self.parquet_file - self.survey.tables[self.name]["variables"] = self.variables +__all__ = ["Table"] diff --git a/pyproject.toml b/pyproject.toml index a25a5be4..852c0dca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "OpenFisca-Survey-Manager" -version = "6.3.1" +version = "6.4.0" description = "A tool for managing survey/administrative data and import them in OpenFisca" readme = "README.md" keywords = ["microsimulation", "tax", "benefit", "rac", "rules-as-code", "survey", "data"] From 657e2fb65ffa98f49ce3649ccb4815dca2c3b277 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Fri, 27 Feb 2026 22:43:06 +0100 Subject: [PATCH 02/18] chore: typing for core, io, processing; CHANGELOG 6.5.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - core: type hints Table, Survey, SurveyCollection (TYPE_CHECKING for circular refs) - io: type hints readers (read_dbf cols), writers/hdf already typed - processing: type hints cleaning, harmonization, calmar, Calibration - REFACTORING_PLAN §3.3 updated; version 6.5.0 Made-with: Cursor --- CHANGELOG.md | 8 ++ docs/REFACTORING_PLAN.md | 4 +- openfisca_survey_manager/core/dataset.py | 54 ++++++--- openfisca_survey_manager/core/survey.py | 82 ++++++++----- openfisca_survey_manager/core/table.py | 55 ++++++--- openfisca_survey_manager/io/readers.py | 4 +- .../processing/cleaning.py | 2 + .../processing/harmonization.py | 2 + .../processing/weights/calibration.py | 89 +++++++------- .../processing/weights/calmar.py | 110 +++++------------- pyproject.toml | 2 +- 11 files changed, 223 insertions(+), 189 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14371e39..96b02f25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +# 6.5.0 + +* Typing (no breaking API changes) + - **core**: Type hints on `core.table` (Table), `core.survey` (Survey, NoMoreDataError), `core.dataset` (SurveyCollection); `TYPE_CHECKING` for circular refs; class attributes with defaults where needed + - **io**: Type hints on `io.readers` (read_sas, read_spss, read_dbf with `Optional[list[str]]` for cols); `io.writers` and `io.hdf` already typed + - **processing**: Type hints on `processing.cleaning`, `processing.harmonization`, `processing.weights.calmar` (linear, logit, calmar, check_calmar, etc.), `processing.weights.calibration` (Calibration class and methods) + - **Docs**: Update `REFACTORING_PLAN.md` §3.3 (typing core, io, processing done) + # 6.4.0 * Refactor (no breaking API changes) diff --git a/docs/REFACTORING_PLAN.md b/docs/REFACTORING_PLAN.md index d5d8d63e..91deea24 100644 --- a/docs/REFACTORING_PLAN.md +++ b/docs/REFACTORING_PLAN.md @@ -94,8 +94,8 @@ Aujourd’hui ces couches sont entremêlées (ex. lecture + nettoyage dans `tabl ### 3.3 Typage Python -- Ajouter progressivement des type hints sur les signatures publiques (arguments et retours). -- Priorité : `core/`, `io/`, puis `processing/`. +- **Entamé** : type hints sur les signatures publiques de `core/`, `io/` et `processing/` (cleaning, harmonization, weights/calmar, weights/calibration). +- À poursuivre : reste du package (scenarios, simulations, etc.). ### 3.4 Logging diff --git a/openfisca_survey_manager/core/dataset.py b/openfisca_survey_manager/core/dataset.py index f56072f0..3477a394 100644 --- a/openfisca_survey_manager/core/dataset.py +++ b/openfisca_survey_manager/core/dataset.py @@ -1,11 +1,14 @@ """SurveyCollection: collection of surveys (dataset orchestration).""" +from __future__ import annotations + import codecs import collections import configparser import json import logging from pathlib import Path +from typing import List, Optional, Union from openfisca_survey_manager.configuration.models import Config from openfisca_survey_manager.configuration.paths import default_config_files_directory @@ -18,9 +21,19 @@ class SurveyCollection: """A collection of Surveys.""" + name: Optional[str] = None + label: Optional[str] = None + json_file_path: Optional[str] = None + surveys: List[Survey] # set in __init__ + config: Optional[Config] = None + def __init__( - self, config_files_directory=default_config_files_directory, label=None, name=None, json_file_path=None - ): + self, + config_files_directory: Optional[Union[Path, str]] = default_config_files_directory, + label: Optional[str] = None, + name: Optional[str] = None, + json_file_path: Optional[str] = None, + ) -> None: self.name = name self.label = label self.json_file_path = json_file_path @@ -45,7 +58,7 @@ def __init__( self.config = config - def __repr__(self): + def __repr__(self) -> str: header = f"""{self.name} Survey collection of {self.label} Contains the following surveys : @@ -53,7 +66,11 @@ def __repr__(self): surveys = [f" {survey.name} : {survey.label} \n" for survey in self.surveys] return header + "".join(surveys) - def dump(self, config_files_directory=None, json_file_path=None): + def dump( + self, + config_files_directory: Optional[Union[Path, str]] = None, + json_file_path: Optional[str] = None, + ) -> None: if self.config is not None: config = self.config else: @@ -75,15 +92,15 @@ def dump(self, config_files_directory=None, json_file_path=None): def fill_store( self, - source_format=None, - surveys=None, - tables=None, - overwrite=False, - keep_original_parquet_file=False, - encoding=None, - store_format="hdf5", - categorical_strategy="unique_labels", - ): + source_format: Optional[str] = None, + surveys: Optional[List[Survey]] = None, + tables: Optional[List[str]] = None, + overwrite: bool = False, + keep_original_parquet_file: bool = False, + encoding: Optional[str] = None, + store_format: str = "hdf5", + categorical_strategy: str = "unique_labels", + ) -> None: if surveys is None: surveys = self.surveys for survey in surveys: @@ -98,7 +115,7 @@ def fill_store( ) self.dump() - def get_survey(self, survey_name): + def get_survey(self, survey_name: str) -> Survey: available_surveys_names = [survey.name for survey in self.surveys] assert survey_name in available_surveys_names, ( f"Survey {survey_name} cannot be found for survey collection {self.name}.\n" @@ -107,7 +124,12 @@ def get_survey(self, survey_name): return [survey for survey in self.surveys if survey.name == survey_name].pop() @classmethod - def load(cls, json_file_path=None, collection=None, config_files_directory=default_config_files_directory): + def load( + cls, + json_file_path: Optional[str] = None, + collection: Optional[str] = None, + config_files_directory: Optional[Union[Path, str]] = default_config_files_directory, + ) -> SurveyCollection: assert Path(config_files_directory).exists() config = Config(config_files_directory=config_files_directory) if json_file_path is None: @@ -143,7 +165,7 @@ def load(cls, json_file_path=None, collection=None, config_files_directory=defau self.surveys.append(survey.create_from_json(survey_json)) return self - def to_json(self): + def to_json(self) -> dict: self_json = collections.OrderedDict(()) self_json["name"] = self.name self_json["surveys"] = collections.OrderedDict(()) diff --git a/openfisca_survey_manager/core/survey.py b/openfisca_survey_manager/core/survey.py index 8cba10b6..59b8bd4e 100644 --- a/openfisca_survey_manager/core/survey.py +++ b/openfisca_survey_manager/core/survey.py @@ -1,9 +1,12 @@ """Survey: describes survey data and tables.""" +from __future__ import annotations + import collections import logging import re from pathlib import Path +from typing import TYPE_CHECKING, Any, List, Optional, Union import pandas import pyarrow as pa @@ -14,6 +17,9 @@ from openfisca_survey_manager.exceptions import SurveyIOError, SurveyManagerError from openfisca_survey_manager.processing.harmonization import harmonize_data_frame_columns +if TYPE_CHECKING: + from openfisca_survey_manager.core.dataset import SurveyCollection + log = logging.getLogger(__name__) source_format_by_extension = { @@ -37,15 +43,22 @@ class NoMoreDataError(Exception): class Survey: """An object to describe survey data.""" - hdf5_file_path = None - parquet_file_path = None - label = None - name = None - survey_collection = None + hdf5_file_path: Optional[str] = None + parquet_file_path: Optional[str] = None + label: Optional[str] = None + name: Optional[str] = None + survey_collection: Optional[SurveyCollection] = None + store_format: Optional[str] = None def __init__( - self, name=None, label=None, hdf5_file_path=None, parquet_file_path=None, survey_collection=None, **kwargs - ): + self, + name: Optional[str] = None, + label: Optional[str] = None, + hdf5_file_path: Optional[str] = None, + parquet_file_path: Optional[str] = None, + survey_collection: Optional[SurveyCollection] = None, + **kwargs: Any, + ) -> None: assert name is not None, "A survey should have a name" self.name = name self.tables = collections.OrderedDict() @@ -66,7 +79,7 @@ def __init__( self.informations = kwargs - def __repr__(self): + def __repr__(self) -> str: header = f"""{self.name} : survey data {self.label} Contains the following tables : \n""" tables = yaml.safe_dump(list(self.tables.keys()), default_flow_style=False) @@ -74,7 +87,7 @@ def __repr__(self): return header + tables + informations @classmethod - def create_from_json(cls, survey_json): + def create_from_json(cls, survey_json: dict) -> Survey: self = cls( name=survey_json.get("name"), label=survey_json.get("label"), @@ -85,20 +98,20 @@ def create_from_json(cls, survey_json): self.tables = survey_json.get("tables") return self - def dump(self): + def dump(self) -> None: assert self.survey_collection is not None self.survey_collection.dump() def fill_store( self, - source_format=None, - tables=None, - overwrite=True, - keep_original_parquet_file=False, - encoding=None, - store_format="hdf5", - categorical_strategy="unique_labels", - ): + source_format: Optional[str] = None, + tables: Optional[List[str]] = None, + overwrite: Union[bool, List[str]] = True, + keep_original_parquet_file: bool = False, + encoding: Optional[str] = None, + store_format: str = "hdf5", + categorical_strategy: str = "unique_labels", + ) -> None: assert self.survey_collection is not None assert isinstance(overwrite, (bool, list)) survey = self @@ -166,19 +179,25 @@ def fill_store( ) self.dump() - def get_value(self, variable, table, lowercase=False, ignorecase=False): + def get_value( + self, + variable: str, + table: Optional[str] = None, + lowercase: bool = False, + ignorecase: bool = False, + ) -> pandas.DataFrame: return self.get_values([variable], table) def get_values( self, - variables=None, - table=None, - lowercase=False, - ignorecase=False, - rename_ident=True, - batch_size=None, - batch_index=0, - filter_by=None, + variables: Optional[List[str]] = None, + table: Optional[str] = None, + lowercase: bool = False, + ignorecase: bool = False, + rename_ident: bool = True, + batch_size: Optional[int] = None, + batch_index: int = 0, + filter_by: Optional[List[tuple]] = None, ) -> pandas.DataFrame: if self.parquet_file_path is None and self.hdf5_file_path is None: raise SurveyIOError(f"No data file found for survey {self.name}") @@ -273,7 +292,12 @@ def get_values( df = df[variables] return df - def insert_table(self, label=None, name=None, **kwargs): + def insert_table( + self, + label: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: parquet_file = kwargs.pop("parquet_file", None) data_frame = kwargs.pop("data_frame", None) if data_frame is None: @@ -303,7 +327,7 @@ def insert_table(self, label=None, name=None, **kwargs): for key, val in kwargs.items(): self.tables[name][key] = val - def to_json(self): + def to_json(self) -> dict: self_json = collections.OrderedDict(()) self_json["hdf5_file_path"] = str(self.hdf5_file_path) if self.hdf5_file_path else None self_json["parquet_file_path"] = str(self.parquet_file_path) if self.parquet_file_path else None diff --git a/openfisca_survey_manager/core/table.py b/openfisca_survey_manager/core/table.py index 3b98c222..71104335 100644 --- a/openfisca_survey_manager/core/table.py +++ b/openfisca_survey_manager/core/table.py @@ -1,5 +1,7 @@ """Table: a table of a survey (core I/O and storage).""" +from __future__ import annotations + import collections import csv import datetime @@ -8,6 +10,7 @@ import logging import os from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional import pandas from chardet.universaldetector import UniversalDetector @@ -23,6 +26,9 @@ except ImportError: read_spss = None +if TYPE_CHECKING: + from openfisca_survey_manager.core.survey import Survey + log = logging.getLogger(__name__) reader_by_source_format = { @@ -37,16 +43,23 @@ class Table: """A table of a survey.""" - label = None - name = None - source_format = None - survey = None - variables = None - parquet_file = None + label: Optional[str] = None + name: Optional[str] = None + source_format: Optional[str] = None + survey: Optional[Survey] = None + variables: Optional[list[str]] = None + parquet_file: Optional[str] = None def __init__( - self, survey=None, name=None, label=None, source_format=None, variables=None, parquet_file=None, **kwargs - ): + self, + survey: Optional[Survey] = None, + name: Optional[str] = None, + label: Optional[str] = None, + source_format: Optional[str] = None, + variables: Optional[list[str]] = None, + parquet_file: Optional[str] = None, + **kwargs: Any, + ) -> None: assert name is not None, "A table should have a name" self.name = name self.label = label @@ -68,7 +81,7 @@ def __init__( parquet_file=parquet_file, ) - def _check_and_log(self, data_file_path, store_file_path): + def _check_and_log(self, data_file_path: str, store_file_path: Optional[str]) -> None: assert store_file_path is not None, "Store file path cannot be None" if not Path(data_file_path).is_file(): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), data_file_path) @@ -78,7 +91,7 @@ def _check_and_log(self, data_file_path, store_file_path): f"at point {self.name}" ) - def _is_stored(self): + def _is_stored(self) -> bool: if self.survey.hdf5_file_path is not None: store = pandas.HDFStore(self.survey.hdf5_file_path) if self.name in store: @@ -91,7 +104,11 @@ def _is_stored(self): else: return False - def _save(self, data_frame: pandas.DataFrame = None, store_format="hdf5"): + def _save( + self, + data_frame: Optional[pandas.DataFrame] = None, + store_format: str = "hdf5", + ) -> None: assert data_frame is not None variables = self.variables @@ -124,7 +141,13 @@ def _save(self, data_frame: pandas.DataFrame = None, store_format="hdf5"): self.save_data_frame_to_parquet(data_frame) gc.collect() - def fill_store(self, data_file, overwrite: bool = False, clean: bool = False, **kwargs): + def fill_store( + self, + data_file: str, + overwrite: bool = False, + clean: bool = False, + **kwargs: Any, + ) -> None: if not overwrite and self._is_stored(): log.info(f"Exiting without overwriting {self.name} in {self.survey.hdf5_file_path}") return @@ -142,7 +165,7 @@ def fill_store(self, data_file, overwrite: bool = False, clean: bool = False, ** log.info(f"Skipping file {data_file} because of following error \n {e}") raise e - def read_parquet_columns(self, parquet_file=None) -> list: + def read_parquet_columns(self, parquet_file: Optional[str] = None) -> list[str]: if parquet_file is None: parquet_file = self.parquet_file log.info(f"Initializing table {self.name} from parquet file {parquet_file}") @@ -152,7 +175,7 @@ def read_parquet_columns(self, parquet_file=None) -> list: self.survey.tables[self.name]["variables"] = self.variables return self.variables - def read_source(self, data_file, **kwargs): + def read_source(self, data_file: str, **kwargs: Any) -> pandas.DataFrame: source_format = self.source_format store_file_path = ( self.survey.hdf5_file_path if self.survey.store_format == "hdf5" else self.survey.parquet_file_path @@ -287,7 +310,7 @@ def read_source(self, data_file, **kwargs): gc.collect() return data_frame - def save_data_frame_to_hdf5(self, data_frame, **kwargs): + def save_data_frame_to_hdf5(self, data_frame: pandas.DataFrame, **kwargs: Any) -> None: hdf5_file_path = self.survey.hdf5_file_path log.info(f"Inserting table {self.name} in HDF file {hdf5_file_path}") store_path = self.name @@ -300,7 +323,7 @@ def save_data_frame_to_hdf5(self, data_frame, **kwargs): self.variables = list(data_frame.columns) - def save_data_frame_to_parquet(self, data_frame): + def save_data_frame_to_parquet(self, data_frame: pandas.DataFrame) -> None: parquet_file_path = self.survey.parquet_file_path self.parquet_file = write_table_to_parquet( data_frame, diff --git a/openfisca_survey_manager/io/readers.py b/openfisca_survey_manager/io/readers.py index b372c957..5feccab4 100644 --- a/openfisca_survey_manager/io/readers.py +++ b/openfisca_survey_manager/io/readers.py @@ -1,5 +1,7 @@ """Readers for survey data (SAS, SPSS, DBF, etc.).""" +from __future__ import annotations + import logging from typing import Optional @@ -45,7 +47,7 @@ def read_spss(spss_file_path: str) -> DataFrame: def read_dbf( dbf_path: str, index: Optional[str] = None, - cols: Optional[list] = None, + cols: Optional[list[str]] = None, incl_index: bool = False, ) -> DataFrame: """ diff --git a/openfisca_survey_manager/processing/cleaning.py b/openfisca_survey_manager/processing/cleaning.py index 6a0b81fa..5aa88b7e 100644 --- a/openfisca_survey_manager/processing/cleaning.py +++ b/openfisca_survey_manager/processing/cleaning.py @@ -1,5 +1,7 @@ """Data frame cleaning (column normalization, empty handling).""" +from __future__ import annotations + import logging import pandas as pd diff --git a/openfisca_survey_manager/processing/harmonization.py b/openfisca_survey_manager/processing/harmonization.py index 1edac60a..4d30c6b4 100644 --- a/openfisca_survey_manager/processing/harmonization.py +++ b/openfisca_survey_manager/processing/harmonization.py @@ -1,5 +1,7 @@ """Column harmonization for survey data (lowercase, ident renaming).""" +from __future__ import annotations + import logging import re diff --git a/openfisca_survey_manager/processing/weights/calibration.py b/openfisca_survey_manager/processing/weights/calibration.py index acf6b212..b28337ce 100644 --- a/openfisca_survey_manager/processing/weights/calibration.py +++ b/openfisca_survey_manager/processing/weights/calibration.py @@ -1,9 +1,13 @@ """Calibration of survey weights (SurveyScenario).""" +from __future__ import annotations + import logging import re +from typing import Any, Optional import numpy +import numpy as np import pandas as pd from numpy import logical_not from openfisca_core.model_api import Enum @@ -16,32 +20,32 @@ class Calibration: """An object to calibrate survey data of a SurveyScenario.""" - filter_by = None - initial_entity_count = None - _initial_weight_name = None - initial_weight_by_entity = None - target_margins = None - margins_by_variable = None - parameters = None - period = None - simulation = None - target_entity_count = None - other_entity_count = None - target_entity = None - weight_name = None - entities = None + filter_by: Any = None + initial_entity_count: Optional[float] = None + _initial_weight_name: Optional[str] = None + initial_weight_by_entity: dict[str, Any] # set in __init__ + target_margins: Optional[dict[str, Any]] = None + margins_by_variable: Optional[dict[str, Any]] = None + parameters: Optional[dict[str, Any]] = None + period: Any = None + simulation: Any = None + target_entity_count: Optional[float] = None + other_entity_count: Optional[float] = None + target_entity: Optional[str] = None + weight_name: Optional[str] = None + entities: Optional[list[str]] = None def __init__( self, - simulation, - target_margins, - period, - target_entity_count=None, - other_entity_count=None, - parameters=None, - filter_by=None, - entity=None, - ): + simulation: Any, + target_margins: dict[str, Any], + period: Any, + target_entity_count: Optional[float] = None, + other_entity_count: Optional[float] = None, + parameters: Optional[dict[str, Any]] = None, + filter_by: Any = None, + entity: Optional[str] = None, + ) -> None: target_entity = entity self.parameters = parameters or { "use_proportions": True, @@ -197,14 +201,14 @@ def _build_calmar_data(self) -> dict: return data - def calibrate(self, inplace=False): + def calibrate(self, inplace: bool = False) -> Optional[np.ndarray]: """Apply the calibrations by updating weights and margins. Args: - inplace (bool, optional): Whether to return the calibrated or to setthem inplace. Defaults to False. + inplace: Whether to apply in place (no return) or return calibrated weights. Returns: - numpy.array: calibrated weights + Calibrated weights array, or None if inplace=True. """ assert self.margins_by_variable is not None, "Margins by variable should be set" margins_by_variable = self.margins_by_variable @@ -230,7 +234,7 @@ def calibrate(self, inplace=False): return self.weight - def get_parameters(self) -> dict: + def get_parameters(self) -> dict[str, Any]: """Get the parameters. Returns: @@ -252,12 +256,12 @@ def get_parameters(self) -> dict: p["initial_weight"] = self.weight_name + "" return p - def set_target_margin(self, variable, target): + def set_target_margin(self, variable: str, target: Any) -> None: """Set variable target margin. Args: - variable: Target variable - target: Target value + variable: Target variable name or expression. + target: Target value (scalar or dict of category -> value). """ simulation = self.simulation period = self.period @@ -293,13 +297,13 @@ def set_target_margin(self, variable, target): self.margins_by_variable[variable]["target"] = target_by_category or target self._update_margins() - def reset(self): + def reset(self) -> None: """Reset the calibration to its initial state.""" simulation = self.simulation simulation.delete_arrays(self.weight_name, self.period) simulation.set_input(self.weight_name, self.period, numpy.array(self.initial_weight)) - def set_calibrated_weights(self): + def set_calibrated_weights(self) -> None: """Modify the weights to use the calibrated weights.""" period = self.period simulation = self.simulation @@ -314,7 +318,7 @@ def set_calibrated_weights(self): elif weight_variable.formulas: simulation.delete_arrays(weight_variable.name, period) - def summary(self): + def summary(self) -> pd.DataFrame: """Summarize margins.""" margins_df = pd.DataFrame.from_dict(self.margins_by_variable).T margins_df.loc["entity_count", "actual"] = (self.weight * self.filter_by).sum() @@ -322,7 +326,7 @@ def summary(self): margins_df.loc["entity_count", "target"] = self.target_entity_count return margins_df - def _update_margins(self): + def _update_margins(self) -> None: """Update margins.""" for variable in self.margins_by_variable: simulation = self.simulation @@ -381,17 +385,12 @@ def _update_margins(self): } self.margins_by_variable[variable].update(margin_by_type) - def _update_weights(self, margins, parameters=None): - """Run calmar, stores new weights and returns adjusted margins. - - Args: - margins: margins - parameters: Parameters (Default value = {}) - - Returns: - dict: Updated margins - - """ + def _update_weights( + self, + margins: dict[str, Any], + parameters: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """Run calmar, store new weights and return adjusted margins.""" if parameters is None: parameters = {} diff --git a/openfisca_survey_manager/processing/weights/calmar.py b/openfisca_survey_manager/processing/weights/calmar.py index dd4986d4..4133727d 100644 --- a/openfisca_survey_manager/processing/weights/calmar.py +++ b/openfisca_survey_manager/processing/weights/calmar.py @@ -1,8 +1,12 @@ """CALMAR: Calibrates weights to satisfy margins constraints.""" +from __future__ import annotations + import logging import operator +from typing import Any, Optional +import numpy as np import pandas as pd from numpy import array, dot, exp, float64, ones, sqrt, unique, zeros from numpy import log as ln @@ -12,7 +16,7 @@ log = logging.getLogger(__name__) -def linear(u): +def linear(u: np.ndarray) -> np.ndarray: """ Args: @@ -24,68 +28,27 @@ def linear(u): return 1 + u -def linear_prime(u): - """ - - Args: - u: - - Returns: - - """ +def linear_prime(u: np.ndarray) -> np.ndarray: + """Derivative of linear (constant 1).""" return ones(u.shape, dtype=float) -def raking_ratio(u): - """ - - Args: - u: - - Returns: - - """ +def raking_ratio(u: np.ndarray) -> np.ndarray: + """Raking ratio (exponential) calibration function.""" return exp(u) -def raking_ratio_prime(u): - """ - - Args: - u: - - Returns: - - """ +def raking_ratio_prime(u: np.ndarray) -> np.ndarray: + """Derivative of raking_ratio.""" return exp(u) -def logit(u, low, up): - """ - - Args: - u: - low: - up: - - Returns: - - """ +def logit(u: np.ndarray, low: float, up: float) -> np.ndarray: a = (up - low) / ((1 - low) * (up - 1)) return (low * (up - 1) + up * (1 - low) * exp(a * u)) / (up - 1 + (1 - low) * exp(a * u)) -def logit_prime(u, low, up): - """ - - Args: - u: - low: - up: - - Returns: - - """ +def logit_prime(u: np.ndarray, low: float, up: float) -> np.ndarray: a = (up - low) / ((1 - low) * (up - 1)) return ( (a * up * (1 - low) * exp(a * u)) * (up - 1 + (1 - low) * exp(a * u)) @@ -93,12 +56,12 @@ def logit_prime(u, low, up): ) / (up - 1 + (1 - low) * exp(a * u)) ** 2 -def hyperbolic_sinus(u, alpha): +def hyperbolic_sinus(u: np.ndarray, alpha: float) -> np.ndarray: logarithm = ln(2 * alpha * u + sqrt(4 * (alpha**2) * (u**2) + 1)) return 0.5 * (logarithm / alpha + sqrt((logarithm / alpha) ** 2 + 4)) -def hyperbolic_sinus_prime(u, alpha): +def hyperbolic_sinus_prime(u: np.ndarray, alpha: float) -> np.ndarray: square = sqrt(4 * (alpha**2) * (u**2) + 1) return 0.5 * ( ((4 * (alpha**2) * u) / square + 2 * alpha) / (alpha * (square + 2 * alpha * u)) @@ -107,16 +70,8 @@ def hyperbolic_sinus_prime(u, alpha): ) -def build_dummies_dict(data): - """ - - Args: - data: - - Returns: - - - """ +def build_dummies_dict(data: np.ndarray | pd.Series) -> dict[Any, np.ndarray | pd.Series]: + """Build a dict mapping each unique value to a boolean mask (data == value).""" unique_val_list = unique(data) output = {} for val in unique_val_list: @@ -125,17 +80,17 @@ def build_dummies_dict(data): def calmar( - data_in, - margins: dict, + data_in: dict[str, Any], + margins: dict[str, Any], initial_weight: str, - method="linear", - lo=None, - up=None, - alpha=None, + method: str = "linear", + lo: Optional[float] = None, + up: Optional[float] = None, + alpha: Optional[float] = None, use_proportions: bool = False, xtol: float = 1.49012e-08, maxfev: int = 256, -): +) -> tuple[np.ndarray, np.ndarray, dict[str, Any]]: """Calibrates weights to satisfy margins constraints. Args: @@ -388,16 +343,13 @@ def constraint_prime(lambda_): return pondfin_out, lambdasol, margins_new_dict -def check_calmar(margins, margins_new_dict=None): - """ - - Args: - margins: - margins_new_dict: (Default value = None) - - Returns: - - """ +def check_calmar( + margins: dict[str, Any], + margins_new_dict: Optional[dict[str, Any]] = None, +) -> None: + """Log relative difference between initial margins and calibrated margins.""" + if margins_new_dict is None: + return for variable, margin in margins.items(): if variable != "total_population": rel_diff = abs(margin - margins_new_dict[variable]) / abs(margin) diff --git a/pyproject.toml b/pyproject.toml index 852c0dca..133e5de4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "OpenFisca-Survey-Manager" -version = "6.4.0" +version = "6.5.0" description = "A tool for managing survey/administrative data and import them in OpenFisca" readme = "README.md" keywords = ["microsimulation", "tax", "benefit", "rac", "rules-as-code", "survey", "data"] From 13994651f485665529bc6a6d0666c1336822a449 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 18:32:17 +0100 Subject: [PATCH 03/18] matching: NND hot deck en pur Python (pandas + numpy) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implémentation _nnd_hotdeck_python et _create_fused_python (Manhattan/Euclidean, donor_classes, tie-breaking aléatoire) - API unifiée nnd_hotdeck() par défaut en Python, use_r=True pour StatMatch - nnd_hotdeck_using_rpy2 conservée en alias pour compatibilité - Corrections lint (noms variables, lambda->def, doublon supprimé, print->log) Made-with: Cursor --- openfisca_survey_manager/matching.py | 360 +++++++++++++++++++++------ 1 file changed, 289 insertions(+), 71 deletions(-) diff --git a/openfisca_survey_manager/matching.py b/openfisca_survey_manager/matching.py index 8fa93756..14ec18ab 100644 --- a/openfisca_survey_manager/matching.py +++ b/openfisca_survey_manager/matching.py @@ -1,16 +1,293 @@ +"""Nearest-neighbor donor (NND) hot deck matching — pure Python or R (StatMatch).""" + import logging from pathlib import Path +import numpy as np import pandas as pd -from openfisca_survey_manager.paths import openfisca_survey_manager_location +from openfisca_survey_manager.configuration.paths import openfisca_survey_manager_location log = logging.getLogger(__name__) - config_files_directory = Path(openfisca_survey_manager_location) +def _normalize_list(x, name="variables"): + """Return a list of variable names from str or list.""" + if x is None: + return None + if isinstance(x, str): + return [x] + return list(x) + + +def _nnd_hotdeck_python( + receiver: pd.DataFrame, + donor: pd.DataFrame, + matching_variables: list[str], + donor_classes: list[str] | str | None = None, + dist_fun: str = "Manhattan", + random_state: int | None = None, +) -> np.ndarray: + """ + Nearest-neighbor donor matching in pure Python (pandas + numpy). + + For each receiver row, finds the donor row that minimizes distance on + `matching_variables`. Optionally restricts to donors in the same + `donor_classes`. Ties are broken at random. + + Parameters + ---------- + receiver, donor : DataFrame + Recipient and donor datasets; must contain `matching_variables` + (and `donor_classes` if provided). Matching variables must be numeric + for Manhattan/Euclidean. + matching_variables : list of str + Column names used to compute distance. + donor_classes : str or list of str, optional + Columns defining donation classes; matching is done only within + the same class. Must not contain missing values. + dist_fun : str + "Manhattan" (default) or "Euclidean". + random_state : int, optional + Seed for breaking ties. + + Returns + ------- + mtc_ids : ndarray of int + Shape (len(receiver), 2): (receiver_index, donor_index) for each row. + Receiver index is 0..n_rec-1, donor index is 0..n_don-1. + """ + rng = np.random.default_rng(random_state) + match_vars = _normalize_list(matching_variables) + don_class = _normalize_list(donor_classes) if donor_classes is not None else None + + for col in match_vars: + if col not in receiver.columns or col not in donor.columns: + raise ValueError(f"Matching variable {col!r} missing in receiver or donor") + if don_class: + for col in don_class: + if col not in receiver.columns or col not in donor.columns: + raise ValueError(f"Donor class variable {col!r} missing in receiver or donor") + + x_rec = receiver[match_vars].astype(float).values + x_don = donor[match_vars].astype(float).values + n_rec, n_don = len(receiver), len(donor) + if n_don == 0: + raise ValueError("Donor dataframe is empty") + + if dist_fun == "Manhattan": + + def dist_fn(donors: np.ndarray, rec_row: np.ndarray) -> np.ndarray: + return np.sum(np.abs(donors - rec_row), axis=1) + + elif dist_fun == "Euclidean": + + def dist_fn(donors: np.ndarray, rec_row: np.ndarray) -> np.ndarray: + return np.sqrt(np.sum((donors - rec_row) ** 2, axis=1)) + + else: + raise ValueError(f"dist_fun must be 'Manhattan' or 'Euclidean', got {dist_fun!r}") + + if don_class is None: + # Global matching: for each receiver row, min distance over all donors + donor_ix = np.zeros(n_rec, dtype=np.intp) + for i in range(n_rec): + d = dist_fn(x_don, x_rec[i]) + min_d = np.min(d) + candidates = np.where(d == min_d)[0] + donor_ix[i] = rng.choice(candidates) + mtc_ids = np.column_stack([np.arange(n_rec), donor_ix]) + return mtc_ids + + # Within-class matching: for each group, match receiver rows to donors in same group + rec_groups = receiver.groupby(don_class, sort=False) + don_groups = donor.groupby(don_class, sort=False) + donor_iloc = np.zeros(n_rec, dtype=np.intp) + + for key, rec_grp in rec_groups: + try: + don_grp = don_groups.get_group(key) + except KeyError: + log.warning("No donors for class %s; receiver rows get donor 0", key) + continue + x_r = rec_grp[match_vars].astype(float).values + x_d = don_grp[match_vars].astype(float).values + n_r, n_d = len(rec_grp), len(don_grp) + if n_d == 0: + continue + # Receiver global ilocs for this group + rec_global_ilocs = receiver.index.get_indexer(rec_grp.index) + for j in range(n_r): + d = dist_fn(x_d, x_r[j]) + min_d = np.min(d) + candidates = np.where(d == min_d)[0] + don_local = rng.choice(candidates) + don_global_iloc = donor.index.get_loc(don_grp.index[don_local]) + donor_iloc[rec_global_ilocs[j]] = don_global_iloc + mtc_ids = np.column_stack([np.arange(n_rec), donor_iloc]) + return mtc_ids + + +def _create_fused_python( + receiver: pd.DataFrame, + donor: pd.DataFrame, + mtc_ids: np.ndarray, + z_variables: list[str], + dup_x: bool = False, + matching_variables: list[str] | None = None, +) -> pd.DataFrame: + """ + Build fused dataset: receiver plus z_variables from matched donors. + + mtc_ids : shape (n_receiver, 2), second column is donor position (integer). + """ + z_vars = _normalize_list(z_variables) + for col in z_vars: + if col not in donor.columns: + raise ValueError(f"z_variable {col!r} not in donor") + fused = receiver.copy() + don_pos = mtc_ids[:, 1] + for col in z_vars: + fused[col] = donor[col].iloc[don_pos].values + if dup_x and matching_variables: + match_vars = _normalize_list(matching_variables) + for col in match_vars: + if col in donor.columns: + fused[col + "_donor"] = donor[col].iloc[don_pos].values + return fused + + +def nnd_hotdeck( + receiver: pd.DataFrame | None = None, + donor: pd.DataFrame | None = None, + matching_variables: str | list[str] | None = None, + z_variables: str | list[str] | None = None, + donor_classes: str | list[str] | None = None, + dist_fun: str = "Manhattan", + use_r: bool = False, + random_state: int | None = None, +): + """ + Nearest-neighbor donor (NND) hot deck: match each receiver row to a donor, + then fuse z_variables from donor into receiver. + + By default uses a **pure Python** implementation (pandas + numpy). + Set `use_r=True` to use R's StatMatch via rpy2 (same API as before). + + Parameters + ---------- + receiver, donor : DataFrame + Recipient and donor datasets. + matching_variables : str or list of str + Columns used to compute distance (must be numeric for Manhattan/Euclidean). + z_variables : str or list of str + Donor columns to copy into the fused dataset. + donor_classes : str or list of str, optional + Match only within the same class (e.g. same sex). + dist_fun : str + "Manhattan" (default) or "Euclidean" (pure Python); R supports more. + use_r : bool + If True, use R StatMatch via rpy2; otherwise use pure Python. + random_state : int, optional + Seed for tie-breaking (pure Python only). + + Returns + ------- + fused_0, fused_1 : DataFrame + fused_0: receiver + z_variables from donor (no duplicate match vars). + fused_1: same with matching variables duplicated as _donor (if applicable). + """ + assert receiver is not None and donor is not None + assert matching_variables is not None and z_variables is not None + match_vars = _normalize_list(matching_variables) + z_vars = _normalize_list(z_variables) + + if use_r: + return _nnd_hotdeck_rpy2( + receiver=receiver, + donor=donor, + matching_variables=match_vars, + z_variables=z_vars, + donor_classes=donor_classes, + ) + + mtc_ids = _nnd_hotdeck_python( + receiver, + donor, + match_vars, + donor_classes=donor_classes, + dist_fun=dist_fun, + random_state=random_state, + ) + fused_0 = _create_fused_python(receiver, donor, mtc_ids, z_vars, dup_x=False) + fused_1 = _create_fused_python(receiver, donor, mtc_ids, z_vars, dup_x=True, matching_variables=match_vars) + return fused_0, fused_1 + + +def _nnd_hotdeck_rpy2(receiver, donor, matching_variables, z_variables, donor_classes=None): + """R (StatMatch) implementation via rpy2; same return as nnd_hotdeck.""" + from rpy2.robjects import pandas2ri + from rpy2.robjects.packages import importr + + pandas2ri.activate() + stat_match = importr("StatMatch") + + if donor_classes is not None: + don_class = _normalize_list(donor_classes) + for col in don_class: + if col not in receiver.columns or col not in donor.columns: + raise ValueError(f"Donor class variable {col!r} missing") + out_nnd = stat_match.NND_hotdeck( + data_rec=receiver, + data_don=donor, + match_vars=pd.Series(matching_variables), + don_class=pd.Series(don_class), + ) + else: + out_nnd = stat_match.NND_hotdeck( + data_rec=receiver, + data_don=donor, + match_vars=pd.Series(matching_variables), + ) + + fused_0 = pandas2ri.ri2py( + stat_match.create_fused(data_rec=receiver, data_don=donor, mtc_ids=out_nnd[0], z_vars=pd.Series(z_variables)) + ) + fused_1 = pandas2ri.ri2py( + stat_match.create_fused( + data_rec=receiver, + data_don=donor, + mtc_ids=out_nnd[0], + z_vars=pd.Series(z_variables), + dup_x=True, + match_vars=pd.Series(matching_variables), + ) + ) + return fused_0, fused_1 + + +def nnd_hotdeck_using_rpy2( + receiver=None, + donor=None, + matching_variables=None, + z_variables=None, + donor_classes=None, +): + """ + NND hot deck via R (StatMatch). Prefer `nnd_hotdeck(..., use_r=True)`. + """ + return nnd_hotdeck( + receiver=receiver, + donor=donor, + matching_variables=matching_variables, + z_variables=z_variables, + donor_classes=donor_classes, + use_r=True, + ) + + def nnd_hotdeck_using_feather(receiver=None, donor=None, matching_variables=None, z_variables=None): """ Not working @@ -69,77 +346,18 @@ def nnd_hotdeck_using_feather(receiver=None, donor=None, matching_variables=None log.debug("%s", r_script) -def nnd_hotdeck_using_rpy2(receiver=None, donor=None, matching_variables=None, z_variables=None, donor_classes=None): - from rpy2.robjects import pandas2ri - from rpy2.robjects.packages import importr - - assert receiver is not None and donor is not None - assert matching_variables is not None - - pandas2ri.activate() - stat_match = importr("StatMatch") - - if isinstance(donor_classes, str): - assert donor_classes in receiver, "Donor class not present in receiver" - assert donor_classes in donor, "Donor class not present in donor" - - try: - if donor_classes: - out_nnd = stat_match.NND_hotdeck( - data_rec=receiver, - data_don=donor, - match_vars=pd.Series(matching_variables), - don_class=pd.Series(donor_classes), - ) - else: - out_nnd = stat_match.NND_hotdeck( - data_rec=receiver, - data_don=donor, - match_vars=pd.Series(matching_variables), - # don_class = pd.Series(donor_classes) - ) - except Exception as e: - log.debug("receiver: %s", receiver) - log.debug("donor: %s", donor) - log.debug("matching_variables: %s", pd.Series(matching_variables)) - log.exception("NND hotdeck failed: %s", e) - raise - - # create synthetic data.set, without the - # duplication of the matching variables - - fused_0 = pandas2ri.ri2py( - stat_match.create_fused(data_rec=receiver, data_don=donor, mtc_ids=out_nnd[0], z_vars=pd.Series(z_variables)) - ) - - # create synthetic data.set, with the "duplication" - # of the matching variables - - fused_1 = pandas2ri.ri2py( - stat_match.create_fused( - data_rec=receiver, - data_don=donor, - mtc_ids=out_nnd[0], - z_vars=pd.Series(z_variables), - dup_x=True, - match_vars=pd.Series(matching_variables), - ) - ) - - return fused_0, fused_1 - - if __name__ == "__main__": log.setLevel(logging.INFO) - - receiver = pd.DataFrame() - donor = pd.DataFrame() - matching_variables = "sexe" - z_variables = "ident" - - nnd_hotdeck_using_feather( + # Minimal example: pure Python NND hot deck (no R required) + np.random.seed(42) + receiver = pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [10.0, 20.0, 30.0]}) + donor = pd.DataFrame({"x": [1.1, 2.2, 2.9], "y": [10.5, 19.0, 31.0], "ident": [100, 200, 300]}) + fused_0, fused_1 = nnd_hotdeck( receiver=receiver, donor=donor, - matching_variables=matching_variables, - z_variables=z_variables, + matching_variables=["x", "y"], + z_variables="ident", + random_state=42, ) + log.info("fused_0 (receiver + z from donor):\n%s", fused_0) + log.info("fused_1 (with _donor dup):\n%s", fused_1) From 7b78df1689440cdf37be1fe0954e6aca019fc91a Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 18:33:29 +0100 Subject: [PATCH 04/18] Retrait de stata_files_to_data_frames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Suppression de la fonction dans common.misc - Retrait des exports dans common et utils - Mise à jour MIGRATION_IMPORTS et REFACTORING_PLAN Made-with: Cursor --- docs/MIGRATION_IMPORTS.md | 2 +- docs/REFACTORING_PLAN.md | 2 +- openfisca_survey_manager/common/__init__.py | 2 -- openfisca_survey_manager/common/misc.py | 22 --------------------- openfisca_survey_manager/utils.py | 2 -- 5 files changed, 2 insertions(+), 28 deletions(-) diff --git a/docs/MIGRATION_IMPORTS.md b/docs/MIGRATION_IMPORTS.md index d7202ee5..8951e919 100644 --- a/docs/MIGRATION_IMPORTS.md +++ b/docs/MIGRATION_IMPORTS.md @@ -34,7 +34,7 @@ Le retrait des ré-exports est une **breaking change** : tout code (interne ou e `config_ini`, `default_config_files_directory`, `is_in_ci`, `openfisca_survey_manager_location`, `private_run_with_data`, `test_config_files_directory`. **Symboles exportés par `utils`** : -- Depuis `common.misc` : `asof`, `do_nothing`, `inflate_parameter_leaf`, `inflate_parameters`, `parameters_asof`, `stata_files_to_data_frames`, `variables_asof`. +- Depuis `common.misc` : `asof`, `do_nothing`, `inflate_parameter_leaf`, `inflate_parameters`, `parameters_asof`, `variables_asof`. - Définis dans `utils.py` : `load_table` (à déplacer vers un module adapté, ex. `core` ou `io`, avant suppression de `utils.py`). --- diff --git a/docs/REFACTORING_PLAN.md b/docs/REFACTORING_PLAN.md index 91deea24..46948fbd 100644 --- a/docs/REFACTORING_PLAN.md +++ b/docs/REFACTORING_PLAN.md @@ -49,7 +49,7 @@ Le déplacement effectif des modules se fera par étapes pour garder la compatib **Réalisé** : - `io/readers.py` : `read_sas`, `read_spss`, `read_dbf` (anciens modules en ré-export). -- `common/misc.py` : helpers sans dépendance survey (`do_nothing`, `inflate_parameters`, `asof`, `parameters_asof`, `variables_asof`, `stata_files_to_data_frames`) ; `utils.py` importe depuis `common.misc` et garde `load_table`. +- `common/misc.py` : helpers sans dépendance survey (`do_nothing`, `inflate_parameters`, `asof`, `parameters_asof`, `variables_asof`) ; `utils.py` importe depuis `common.misc` et garde `load_table`. - **Nettoyage** : `print()` remplacés par `logging` (matching, calmar, scenarios, scripts/build_collection, simulations). Exceptions génériques remplacées par `SurveyManagerError` / `SurveyConfigError` / `SurveyIOError` (survey_collections, tables, simulations, simulation_builder, surveys, scenarios, calmar). - **processing/weights** : `calmar` et `Calibration` déplacés dans `processing/weights/calmar.py` et `processing/weights/calibration.py` ; `calibration.py` et `calmar.py` à la racine sont des ré-exports pour compatibilité. - **processing/cleaning** : `clean_data_frame` déplacé dans `processing/cleaning.py` ; `tables.py` importe depuis `processing.cleaning` (compatibilité conservée). diff --git a/openfisca_survey_manager/common/__init__.py b/openfisca_survey_manager/common/__init__.py index 5ca887c8..8f6c40c6 100644 --- a/openfisca_survey_manager/common/__init__.py +++ b/openfisca_survey_manager/common/__init__.py @@ -7,7 +7,6 @@ inflate_parameter_leaf, inflate_parameters, parameters_asof, - stata_files_to_data_frames, variables_asof, ) @@ -17,6 +16,5 @@ "inflate_parameter_leaf", "inflate_parameters", "parameters_asof", - "stata_files_to_data_frames", "variables_asof", ] diff --git a/openfisca_survey_manager/common/misc.py b/openfisca_survey_manager/common/misc.py index 69ce1400..b5f48c9d 100644 --- a/openfisca_survey_manager/common/misc.py +++ b/openfisca_survey_manager/common/misc.py @@ -1,9 +1,7 @@ """Shared helpers (no survey collection dependency) to avoid circular imports.""" import logging -from pathlib import Path -import pandas as pd from openfisca_core import periods from openfisca_core.parameters import ParameterNode, Scale @@ -229,23 +227,3 @@ def variables_asof(tax_benefit_system, instant, variables_list=None): if variable.end is not None and periods.instant(variable.end) >= instant: variable.end = None - - -def stata_files_to_data_frames(data, period=None): - assert period is not None - period = periods.period(period) - - stata_file_by_entity = data.get("stata_file_by_entity") - if stata_file_by_entity is None: - return - - variables_from_stata_files = [] - input_data_frame_by_entity_by_period = {} - input_data_frame_by_entity_by_period[periods.period(period)] = input_data_frame_by_entity = {} - for entity, file_path in stata_file_by_entity.items(): - assert Path(file_path).exists(), f"Invalid file path: {file_path}" - entity_data_frame = input_data_frame_by_entity[entity] = pd.read_stata(file_path) - variables_from_stata_files += list(entity_data_frame.columns) - data["input_data_frame_by_entity_by_period"] = input_data_frame_by_entity_by_period - - return variables_from_stata_files diff --git a/openfisca_survey_manager/utils.py b/openfisca_survey_manager/utils.py index 24ba9c9e..78510a7a 100644 --- a/openfisca_survey_manager/utils.py +++ b/openfisca_survey_manager/utils.py @@ -11,7 +11,6 @@ inflate_parameter_leaf, inflate_parameters, parameters_asof, - stata_files_to_data_frames, variables_asof, ) from openfisca_survey_manager.survey_collections import SurveyCollection @@ -25,7 +24,6 @@ "inflate_parameters", "load_table", "parameters_asof", - "stata_files_to_data_frames", "variables_asof", ] From ff37edeba581994d3a90d01be8986f3c1677536c Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 19:39:48 +0100 Subject: [PATCH 05/18] refactor: add policy package + deprecate legacy modules - Add openfisca_survey_manager.policy (simulations, simulation_builder, aggregates) - Keep root modules as DeprecationWarning placeholders re-exporting from policy - Move policy-related tests to policy/tests and update imports - Add policy.legislation_asof and deprecate common.misc / utils helpers Made-with: Cursor --- docs/REFACTORING_PLAN.md | 5 +- openfisca_survey_manager/aggregates.py | 565 +----- openfisca_survey_manager/common/misc.py | 262 +-- openfisca_survey_manager/policy/__init__.py | 53 + openfisca_survey_manager/policy/aggregates.py | 554 ++++++ .../policy/legislation_asof.py | 229 +++ .../policy/simulation_builder.py | 291 ++++ .../policy/simulations.py | 1468 ++++++++++++++++ .../policy/tests/__init__.py | 1 + .../{ => policy}/tests/test_aggregates.py | 2 +- .../tests/test_compute_aggregate.py | 0 .../tests/test_compute_pivot_table.py | 0 .../tests/test_compute_winners_losers.py | 2 +- .../tests/test_create_data_frame_by_entity.py | 0 .../tests/test_marginal_tax_rate.py | 0 .../tests/test_summarize_variables.py | 6 +- .../simulation_builder.py | 316 +--- openfisca_survey_manager/simulations.py | 1516 +---------------- openfisca_survey_manager/tests/conftest.py | 26 + .../tests/test_coverage_boost.py | 14 +- .../tests/test_legislation_inflator.py | 2 +- .../tests/test_tax_benefit_system_asof.py | 2 +- openfisca_survey_manager/utils.py | 12 +- uv.lock | 2 +- 24 files changed, 2773 insertions(+), 2555 deletions(-) create mode 100644 openfisca_survey_manager/policy/__init__.py create mode 100644 openfisca_survey_manager/policy/aggregates.py create mode 100644 openfisca_survey_manager/policy/legislation_asof.py create mode 100644 openfisca_survey_manager/policy/simulation_builder.py create mode 100644 openfisca_survey_manager/policy/simulations.py create mode 100644 openfisca_survey_manager/policy/tests/__init__.py rename openfisca_survey_manager/{ => policy}/tests/test_aggregates.py (96%) rename openfisca_survey_manager/{ => policy}/tests/test_compute_aggregate.py (100%) rename openfisca_survey_manager/{ => policy}/tests/test_compute_pivot_table.py (100%) rename openfisca_survey_manager/{ => policy}/tests/test_compute_winners_losers.py (98%) rename openfisca_survey_manager/{ => policy}/tests/test_create_data_frame_by_entity.py (100%) rename openfisca_survey_manager/{ => policy}/tests/test_marginal_tax_rate.py (100%) rename openfisca_survey_manager/{ => policy}/tests/test_summarize_variables.py (95%) diff --git a/docs/REFACTORING_PLAN.md b/docs/REFACTORING_PLAN.md index 46948fbd..bde3cc61 100644 --- a/docs/REFACTORING_PLAN.md +++ b/docs/REFACTORING_PLAN.md @@ -35,9 +35,10 @@ openfisca_survey_manager/ │ └── misc.py # helpers partagés (éviter imports circulaires) │ ├── scenarios/ # inchangé pour l’instant +├── policy/ # simulations, simulation_builder, aggregates (à terme autre paquet) ├── scripts/ ├── tests/ -└── ... (simulations, aggregates, etc. à placer selon responsabilité) +└── ... ``` **État actuel** : les dossiers suivants existent avec des `__init__.py` de préparation (pas de code déplacé encore) : @@ -53,6 +54,8 @@ Le déplacement effectif des modules se fera par étapes pour garder la compatib - **Nettoyage** : `print()` remplacés par `logging` (matching, calmar, scenarios, scripts/build_collection, simulations). Exceptions génériques remplacées par `SurveyManagerError` / `SurveyConfigError` / `SurveyIOError` (survey_collections, tables, simulations, simulation_builder, surveys, scenarios, calmar). - **processing/weights** : `calmar` et `Calibration` déplacés dans `processing/weights/calmar.py` et `processing/weights/calibration.py` ; `calibration.py` et `calmar.py` à la racine sont des ré-exports pour compatibilité. - **processing/cleaning** : `clean_data_frame` déplacé dans `processing/cleaning.py` ; `tables.py` importe depuis `processing.cleaning` (compatibilité conservée). +- **policy/** : répertoire créé pour `simulations`, `simulation_builder`, `aggregates` (à terme déplacés dans un paquet dédié). Les modules à la racine (`simulations.py`, `simulation_builder.py`, `aggregates.py`) sont des placeholders avec `DeprecationWarning` qui ré-exportent depuis `policy`. +- **policy/tests/** : tests concernant le paquet policy (test_aggregates, test_compute_aggregate, test_compute_pivot_table, test_compute_winners_losers, test_create_data_frame_by_entity, test_marginal_tax_rate, test_summarize_variables). Ils importent depuis `openfisca_survey_manager.policy` et utilisent `create_randomly_initialized_survey_scenario` depuis `openfisca_survey_manager.tests.test_scenario`. --- diff --git a/openfisca_survey_manager/aggregates.py b/openfisca_survey_manager/aggregates.py index 7d8059c4..07629736 100644 --- a/openfisca_survey_manager/aggregates.py +++ b/openfisca_survey_manager/aggregates.py @@ -1,554 +1,21 @@ -import collections -import logging -from datetime import datetime -from pathlib import Path -from typing import Optional +"""Re-export for backward compatibility. -import numpy as np -import pandas as pd +Prefer: from openfisca_survey_manager.policy import ... +or from openfisca_survey_manager.policy.aggregates import ... +""" -log = logging.getLogger(__name__) +import warnings +from openfisca_survey_manager.policy.aggregates import AbstractAggregates -# TODO: -# * Localisation +warnings.warn( + "openfisca_survey_manager.aggregates is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy import ... " + "or from openfisca_survey_manager.policy.aggregates import ...", + DeprecationWarning, + stacklevel=2, +) - -class AbstractAggregates: - aggregate_variables = None - amount_unit = 1e6 - currency = None - base_data_frame = None - baseline_simulation = None - beneficiaries_unit = 1e3 - filter_by = None - labels = None - simulation = None - survey_scenario = None - totals_df = None - - def __init__( - self, - survey_scenario=None, - absolute_minimal_detected_variation=0, - relative_minimal_detected_variation=0, - observations_threshold=0, - ): - assert survey_scenario is not None - - self.period = survey_scenario.period - self.survey_scenario = survey_scenario - assert len(survey_scenario.simulations) >= 1 - - self.simulations = survey_scenario.simulations - self.absolute_minimal_detected_variation = absolute_minimal_detected_variation - self.relative_minimal_detected_variation = relative_minimal_detected_variation - self.observations_threshold = observations_threshold - - for name in survey_scenario.tax_benefit_systems: - assert survey_scenario.simulations[name] is not None - - self.weight_variable_by_entity = survey_scenario.weight_variable_by_entity - if self.labels is None: - amount_unit_str = f"({self.amount_unit} {self.currency})" - beneficiaries_unit_str = f"({self.beneficiaries_unit})" - self.labels = collections.OrderedDict( - ( - ("label", "Mesure"), - ("entity", "Entité"), - ("reform_amount", "Dépenses\n" + amount_unit_str), - ("reform_beneficiaries", "Bénéficiaires\n(milliers)"), - ("baseline_amount", "Dépenses initiales\n" + amount_unit_str), - ("baseline_beneficiaries", "Bénéficiaires\ninitiaux\n" + beneficiaries_unit_str), - ("actual_amount", "Dépenses\nréelles\n" + amount_unit_str), - ("actual_beneficiaries", "Bénéficiaires\nréels\n" + beneficiaries_unit_str), - ("absolute_difference_amount", "Diff. absolue\nDépenses\n" + amount_unit_str), - ("absolute_difference_beneficiaries", "Diff absolue\nBénéficiaires\n" + beneficiaries_unit_str), - ("relative_difference_amount", "Diff. relative\nDépenses"), - ("relative_difference_beneficiaries", "Diff. relative\nBénéficiaires"), - ("winners", "Gagnants"), - ("losers", "Perdants"), - ("neutral", "Neutres"), - ) - ) - - def compute_aggregates(self, use_baseline: bool = True, reform: bool = True, actual: bool = True) -> pd.DataFrame: - """ - Compute aggregate amounts. - - Args: - use_baseline (bool, optional): _description_. Defaults to True. - reform (bool, optional): _description_. Defaults to True. - actual (bool, optional): _description_. Defaults to True. - - Returns: - pd.DataFrame: The aggregates - """ - filter_by = self.filter_by - if actual: - self.totals_df = self.load_actual_data(period=self.period) - - simulation_types = [] - if use_baseline: - assert self.simulations["baseline"] is not None - simulation_types.append("baseline") - if reform: - simulation_types.append("reform") - if actual: - simulation_types.append("actual") - - data_frame_by_simulation_type = {} - - for simulation_type in simulation_types: - if simulation_type == "actual": - data_frame_by_simulation_type["actual"] = self.totals_df.copy() if self.totals_df is not None else None - else: - use_baseline = simulation_type != "reform" - data_frame = pd.DataFrame() - assert self.aggregate_variables is not None - for variable in self.aggregate_variables: - variable_data_frame = self.compute_variable_aggregates( - variable, use_baseline=use_baseline, filter_by=filter_by - ) - data_frame = pd.concat((data_frame, variable_data_frame)) - - data_frame.rename( - columns={ - "amount": f"{simulation_type}_amount", - "beneficiaries": f"{simulation_type}_beneficiaries", - }, - inplace=True, - ) - data_frame_by_simulation_type[simulation_type] = data_frame - - if use_baseline and reform: - del data_frame_by_simulation_type["reform"]["entity"] - del data_frame_by_simulation_type["reform"]["label"] - - self.base_data_frame = pd.concat( - list(data_frame_by_simulation_type.values()), - axis=1, - sort=True, - ).loc[list(self.aggregate_variables)] - return self.base_data_frame - - def compute_difference( - self, - target: str = "baseline", - default: str = "actual", - amount: bool = True, - beneficiaries: bool = True, - absolute: bool = True, - relative: bool = True, - ) -> pd.DataFrame: - """ - Compute and add relative and/or absolute differences to the data_frame. - - Args: - target (str, optional): Target simulation. Defaults to "baseline". - default (str, optional): Default simulation. Defaults to 'actual'. - amount (bool, optional): Provide amounts. Defaults to True. - beneficiaries (bool, optional): Provide beneficiaries. Defaults to True. - absolute (bool, optional): Return absolute values. Defaults to True. - relative (bool, optional): Return relative values. Defaults to True. - - Returns: - pd.DataFrame: The differences - """ - assert relative or absolute - assert amount or beneficiaries - base_data_frame = self.base_data_frame if self.base_data_frame is not None else self.compute_aggregates() - - difference_data_frame = base_data_frame[["label", "entity"]].copy() - # Remove duplicates - difference_data_frame = difference_data_frame.loc[:, ~difference_data_frame.columns.duplicated()].copy() - - quantities = [] - quantities += ["amount"] if amount else None - quantities += ["beneficiaries"] if beneficiaries else None - - for quantity in quantities: - difference_data_frame[f"absolute_difference_{quantity}"] = ( - abs(base_data_frame[f"{target}_{quantity}"]) - base_data_frame[f"{default}_{quantity}"] - ) - difference_data_frame[f"relative_difference_{quantity}"] = ( - abs(base_data_frame[f"{target}_{quantity}"]) - base_data_frame[f"{default}_{quantity}"] - ) / abs(base_data_frame[f"{default}_{quantity}"]) - - return difference_data_frame - - def compute_variable_aggregates( - self, variable: str, use_baseline: bool = False, filter_by: Optional[str] = None - ) -> pd.DataFrame: - """ - Return aggregate spending, and number of beneficiaries for the relevant entity level. - - Args: - variable (str): Name of the variable aggregated according to its entity - use_baseline (bool, optional): Use the baseline or the reform or the only avalilable - simulation when no reform (default). Defaults to False. - filter_by (str, optional): The variable to filter by. Defaults to None. - - Returns: - pd.DataFrame: The amount and beneficiaries for the variable - """ - if len(self.simulations) == 1: - simulation = next(iter(self.simulations.values())) - elif use_baseline: - simulation = self.simulations["baseline"] - else: - simulation = self.simulations["reform"] - - variables = simulation.tax_benefit_system.variables - variable_instance = variables.get(variable) - - if variable_instance is None: - msg = f"Variable {variable} is not available" - if use_baseline: - msg += " in baseline simulation" - log.info(msg) - return pd.DataFrame( - data={ - "label": variable, - "entity": "Unknown entity", - "amount": 0, - "beneficiaries": 0, - }, - index=[variable], - ) - - entity_key = variable_instance.entity.key - - if self.weight_variable_by_entity is not None: - weight = self.weight_variable_by_entity[entity_key] - assert weight in variables, f"{weight} not a variable of the tax_benefit_system" - weight_array = simulation.calculate(weight, period=self.period).astype("float") - assert not np.isnan(np.sum(weight_array)), f"The are some NaN in weights {weight} for entity {entity_key}" - # amounts and beneficiaries from current data and default data if exists - # Build weights for each entity - else: - log.debug(f"No weight variable defined for entity {entity_key}, using 1 as weight.") - weight = "weight" - weight_array = 1 - - variable_array = simulation.calculate_add(variable, period=self.period).astype("float") - assert np.isfinite(variable_array).all(), ( - f"The are non finite values in variable {variable} for entity {entity_key}" - ) - data = pd.DataFrame( - { - variable: variable_array, - weight: weight_array, - } - ) - if filter_by: - filter_dummy_variable = ( - filter_by if filter_by in variables else self.survey_scenario.filtering_variable_by_entity[entity_key] - ) - filter_dummy_array = simulation.calculate(filter_dummy_variable, period=self.period) - - else: - filter_dummy_array = 1 - - assert np.isfinite(filter_dummy_array).all(), ( - f"The are non finite values in variable {filter_dummy_variable} for entity {entity_key}" - ) - - amount = int((data[variable] * data[weight] * filter_dummy_array / self.amount_unit).sum()) - beneficiaries = int( - ((data[variable] != 0) * data[weight] * filter_dummy_array / self.beneficiaries_unit).sum() - ) - variable_data_frame = pd.DataFrame( - data={ - "label": variables[variable].label, - "entity": variables[variable].entity.key, - "amount": amount, - "beneficiaries": beneficiaries, - }, - index=[variable], - ) - - return variable_data_frame - - def create_description(self): - """Create a description dataframe.""" - now = datetime.now() - return pd.DataFrame( - [ - "OpenFisca", - "Calculé le {} à {}".format(now.strftime("%d-%m-%Y"), now.strftime("%H:%M")), - f"Système socio-fiscal au {self.simulation.period.start.year}", - f"Données d'enquêtes de l'année {self.data_year!s}", - ] - ) - - def to_csv( - self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): - """Saves the table to csv.""" - assert path is not None - - if Path(path).is_dir(): - now = datetime.now() - file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".csv") - else: - file_path = path - - df = self.get_data_frame( - absolute=absolute, - amount=amount, - beneficiaries=beneficiaries, - default=default, - relative=relative, - target=target, - ) - df.to_csv(file_path, index=False, header=True) - - def to_excel( - self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): - """Save the table to excel.""" - assert path is not None - - if Path(path).is_dir(): - now = datetime.now() - file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".xlsx") - else: - file_path = path - - df = self.get_data_frame( - absolute=absolute, - amount=amount, - beneficiaries=beneficiaries, - default=default, - relative=relative, - target=target, - ) - writer = pd.ExcelWriter(file_path) - df.to_excel(writer, "aggregates", index=False, header=True) - descr = self.create_description() - descr.to_excel(writer, "description", index=False, header=False) - writer.save() - - def to_html( - self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): - """Get or saves the table to html format.""" - df = self.get_data_frame( - absolute=absolute, - amount=amount, - beneficiaries=beneficiaries, - default=default, - relative=relative, - target=target, - ) - - if path is not None and Path(path).is_dir(): - now = datetime.now() - file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".html") - else: - file_path = path - - if file_path is not None: - with Path(file_path).open("w") as html_file: - df.to_html(html_file) - return df.to_html() - - def to_markdown( - self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): - """Get or saves the table to markdown format.""" - df = self.get_data_frame( - absolute=absolute, - amount=amount, - beneficiaries=beneficiaries, - default=default, - relative=relative, - target=target, - ) - - if path is not None and Path(path).is_dir(): - now = datetime.now() - file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".md") - else: - file_path = path - - if file_path is not None: - with Path(file_path).open("w") as markdown_file: - df.to_markdown(markdown_file) - - return df.to_markdown() - - def get_calibration_coeffcient(self, target: str = "reform") -> pd.DataFrame: - df = self.compute_aggregates( - actual=True, - use_baseline=target == "baseline", - reform=target == "reform", - ) - return df[f"{target}_amount"] / df["actual_amount"] - - def get_data_frame( - self, - absolute: bool = True, - amount: bool = True, - beneficiaries: bool = True, - default: str = "actual", - formatting: bool = True, - relative: bool = True, - target: str = "reform", - ignore_labels: bool = False, - ): - assert target is None or target in ["reform", "baseline"] - - columns = self.labels.keys() - if (absolute or relative) and (target != default): - difference_data_frame = self.compute_difference( - absolute=absolute, - amount=amount, - beneficiaries=beneficiaries, - default=default, - relative=relative, - target=target, - ) - else: - difference_data_frame = None - - # Removing unwanted columns - if amount is False: - columns = [column for column in columns if "amount" not in columns] - - if beneficiaries is False: - columns = [column for column in columns if "beneficiaries" not in column] - - if absolute is False: - columns = [column for column in columns if "absolute" not in column] - - if relative is False: - columns = [column for column in columns if "relative" not in column] - - for simulation_type in ["reform", "baseline", "actual"]: - if simulation_type not in [target, default]: - columns = [column for column in columns if simulation_type not in column] - - aggregates_data_frame = self.compute_aggregates( - actual="actual" in [target, default], - use_baseline="baseline" in [target, default], - reform="reform" in [target, default], - ) - - if "reform_amount" in aggregates_data_frame.columns and "baseline_amount" in aggregates_data_frame.columns: - winners_losers_df = self.compute_all_winners_losers(filter_by=self.filter_by) - aggregates_data_frame = aggregates_data_frame.join(winners_losers_df) - - ordered_columns = [ - "label", - "entity", - "reform_amount", - "baseline_amount", - "actual_amount", - "absolute_difference_amount", - "relative_difference_amount", - "reform_beneficiaries", - "baseline_beneficiaries", - "actual_beneficiaries", - "absolute_difference_beneficiaries", - "relative_difference_beneficiaries", - "winners", - "losers", - "neutral", - ] - if difference_data_frame is not None: - # Remove eventual duplication - difference_data_frame = difference_data_frame.loc[:, ~difference_data_frame.columns.duplicated()].copy() - aggregates_data_frame = aggregates_data_frame.loc[:, ~aggregates_data_frame.columns.duplicated()].copy() - merged_df = aggregates_data_frame.merge(difference_data_frame, how="left") - df = merged_df[[c for c in columns if c in merged_df.columns]] - else: - columns = [column for column in columns if column in aggregates_data_frame.columns] - df = aggregates_data_frame[columns] - - df = df.reindex(columns=ordered_columns).dropna(axis=1, how="all") - - if formatting: - relative_columns = [column for column in df.columns if "relative" in column] - df[relative_columns] = df[relative_columns].map(lambda x: f"{x:.2%}" if str(x) != "nan" else "nan") - for column in df.columns: - if issubclass(np.dtype(df[column]).type, np.number): - df[column] = df[column].apply(lambda x: f"{round(x):d}" if str(x) != "nan" else "nan") - - if not ignore_labels: - df = df.rename(columns=self.labels) - - return df - - def load_actual_data(self, period=None): - pass - - def compute_winners_losers(self, variable: str, filter_by: Optional[str] = None): - if "reform" not in self.simulations or "baseline" not in self.simulations: - log.warning("Cannot compute winners and losers without a reform and a baseline simulation.") - return pd.DataFrame() - - reform_simulation = self.simulations["reform"] - baseline_simulation = self.simulations["baseline"] - - variable_instance = reform_simulation.tax_benefit_system.variables.get(variable) - if variable_instance is None: - log.warning(f"Variable {variable} not found in reform simulation.") - return pd.DataFrame() - - stats = reform_simulation.compute_winners_losers( - baseline_simulation=baseline_simulation, - variable=variable, - period=self.period, - filter_by=filter_by, - filtering_variable_by_entity=self.survey_scenario.filtering_variable_by_entity, - absolute_minimal_detected_variation=self.absolute_minimal_detected_variation, - relative_minimal_detected_variation=self.relative_minimal_detected_variation, - observations_threshold=self.observations_threshold, - ) - - winners_losers_df = pd.DataFrame( - { - "winners": [stats["above_after"]], - "losers": [stats["lower_after"]], - "neutral": [stats["neutral"]], - }, - index=[variable], - ) - return winners_losers_df - - def compute_all_winners_losers(self, filter_by: Optional[str] = None): - all_winners_losers = pd.DataFrame() - for variable in self.aggregate_variables: - winners_losers = self.compute_winners_losers(variable, filter_by=filter_by) - all_winners_losers = pd.concat([all_winners_losers, winners_losers]) - return all_winners_losers +__all__ = [ + "AbstractAggregates", +] diff --git a/openfisca_survey_manager/common/misc.py b/openfisca_survey_manager/common/misc.py index b5f48c9d..d14d95cf 100644 --- a/openfisca_survey_manager/common/misc.py +++ b/openfisca_survey_manager/common/misc.py @@ -1,229 +1,33 @@ -"""Shared helpers (no survey collection dependency) to avoid circular imports.""" - -import logging - -from openfisca_core import periods -from openfisca_core.parameters import ParameterNode, Scale - -log = logging.getLogger(__name__) - - -def do_nothing(*args, **kwargs): - return None - - -def inflate_parameters( - parameters, - inflator, - base_year, - last_year=None, - ignore_missing_units=False, - start_instant=None, - round_ndigits=2, -): - """ - Inflate a Parameter node or a Parameter leaf for the years between base_year and last_year. - - ::parameters:: a Parameter node or a Parameter leaf - ::inflator:: rate used to inflate the parameter. The rate is unique for all the years - ::base_year:: base year of the parameter - ::last_year:: last year of inflation - ::ignore_missing_units:: if True, a parameter leaf without unit in metadata will not be inflated - ::start_instant:: Instant of the year when the update should start, if None will be January 1st - ::round_ndigits:: Number of digits to keep in the rounded result - """ - if (last_year is not None) and (last_year > base_year + 1): - for year in range(base_year + 1, last_year + 1): - inflate_parameters( - parameters, - inflator, - year - 1, - last_year=year, - ignore_missing_units=ignore_missing_units, - start_instant=start_instant, - round_ndigits=round_ndigits, - ) - else: - if last_year is None: - last_year = base_year + 1 - - assert last_year == base_year + 1 - - if isinstance(parameters, ParameterNode): - for sub_parameter in parameters.children.values(): - inflate_parameters( - sub_parameter, - inflator, - base_year, - last_year, - ignore_missing_units=ignore_missing_units, - start_instant=start_instant, - round_ndigits=round_ndigits, - ) - else: - acceptable_units = [ - "rate_unit", - "threshold_unit", - "unit", - ] - if ignore_missing_units: - if not hasattr(parameters, "metadata"): - return - if not bool(set(parameters.metadata.keys()) & set(acceptable_units)): - return - assert hasattr(parameters, "metadata"), f"{parameters.name} doesn't have metadata" - unit_types = set(parameters.metadata.keys()).intersection(set(acceptable_units)) - assert unit_types, ( - f"No admissible unit in metadata for parameter {parameters.name}. You may consider using " - "the option 'ignore_missing_units' from the inflate_parameters() function." - ) - if len(unit_types) > 1: - assert unit_types == {"threshold_unit", "rate_unit"}, ( - f"Too much admissible units in metadata for parameter {parameters.name}" - ) - unit_by_type = {unit_type: parameters.metadata[unit_type] for unit_type in unit_types} - for unit_type in unit_by_type: - if parameters.metadata[unit_type].startswith("currency"): - inflate_parameter_leaf( - parameters, - base_year, - inflator, - unit_type=unit_type, - start_instant=start_instant, - round_ndigits=round_ndigits, - ) - - -def inflate_parameter_leaf(sub_parameter, base_year, inflator, unit_type="unit", start_instant=None, round_ndigits=2): - """ - Inflate a Parameter leaf according to unit type for the year after base_year. - - ::sub_parameter:: a Parameter leaf - ::base_year:: base year of the parameter - ::inflator:: rate used to inflate the parameter - ::unit_type:: unit supposed by default. Other admissible unit types are threshold_unit and rate_unit - ::start_instant:: Instant of the year when the update should start, if None will be January 1st - ::round_ndigits:: Number of digits to keep in the rounded result - """ - if isinstance(sub_parameter, Scale): - if unit_type == "threshold_unit": - for bracket in sub_parameter.brackets: - threshold = bracket.children["threshold"] - inflate_parameter_leaf( - threshold, base_year, inflator, start_instant=start_instant, round_ndigits=round_ndigits - ) - return - else: - kept_instants_str = [ - parameter_at_instant.instant_str - for parameter_at_instant in sub_parameter.values_list - if periods.instant(parameter_at_instant.instant_str).year <= base_year - ] - if not kept_instants_str: - return - - last_admissible_instant_str = max(kept_instants_str) - sub_parameter.update(start=last_admissible_instant_str, value=sub_parameter(last_admissible_instant_str)) - if start_instant is not None: - assert periods.instant(start_instant).year == (base_year + 1), ( - "Year of start_instant should be base_year + 1" - ) - value = ( - round(sub_parameter(f"{base_year}-12-31") * (1 + inflator), round_ndigits) - if sub_parameter(f"{base_year}-12-31") is not None - else None - ) - sub_parameter.update( - start=start_instant, - value=value, - ) - else: - restricted_to_base_year_value_list = [ - parameter_at_instant - for parameter_at_instant in sub_parameter.values_list - if periods.instant(parameter_at_instant.instant_str).year == base_year - ] - if restricted_to_base_year_value_list: - for parameter_at_instant in reversed(restricted_to_base_year_value_list): - if parameter_at_instant.instant_str.startswith(str(base_year)): - value = ( - round(parameter_at_instant.value * (1 + inflator), round_ndigits) - if parameter_at_instant.value is not None - else None - ) - sub_parameter.update( - start=parameter_at_instant.instant_str.replace(str(base_year), str(base_year + 1)), - value=value, - ) - else: - value = ( - round(sub_parameter(f"{base_year}-12-31") * (1 + inflator), round_ndigits) - if sub_parameter(f"{base_year}-12-31") is not None - else None - ) - sub_parameter.update( - start=f"{base_year + 1}-01-01", - value=value, - ) - - -def asof(tax_benefit_system, instant): - parameters = tax_benefit_system.parameters - parameters_asof(parameters, instant) - variables_asof(tax_benefit_system, instant) - - -def leaf_asof(sub_parameter, instant): - kept_instants_str = [ - parameter_at_instant.instant_str - for parameter_at_instant in sub_parameter.values_list - if periods.instant(parameter_at_instant.instant_str) <= instant - ] - if not kept_instants_str: - sub_parameter.values_list = [] - return - - last_admissible_instant_str = max(kept_instants_str) - sub_parameter.update(start=last_admissible_instant_str, value=sub_parameter(last_admissible_instant_str)) - - -def parameters_asof(parameters, instant): - if isinstance(instant, str): - instant = periods.instant(instant) - assert isinstance(instant, periods.Instant) - - for sub_parameter in parameters.children.values(): - if isinstance(sub_parameter, ParameterNode): - parameters_asof(sub_parameter, instant) - else: - if isinstance(sub_parameter, Scale): - for bracket in sub_parameter.brackets: - threshold = bracket.children["threshold"] - rate = bracket.children.get("rate") - amount = bracket.children.get("amount") - leaf_asof(threshold, instant) - if rate: - leaf_asof(rate, instant) - if amount: - leaf_asof(amount, instant) - else: - leaf_asof(sub_parameter, instant) - - -def variables_asof(tax_benefit_system, instant, variables_list=None): - if isinstance(instant, str): - instant = periods.instant(instant) - assert isinstance(instant, periods.Instant) - - if variables_list is None: - variables_list = tax_benefit_system.variables.keys() - - for variable_name, variable in tax_benefit_system.variables.items(): - if variable_name in variables_list: - formulas = variable.formulas - for instant_str in list(formulas.keys()): - if periods.instant(instant_str) > instant: - del formulas[instant_str] - - if variable.end is not None and periods.instant(variable.end) >= instant: - variable.end = None +"""Backward-compatibility wrapper for legislation helpers. + +Deprecated: use ``openfisca_survey_manager.policy.legislation_asof`` instead. +""" + +import warnings + +warnings.warn( + "openfisca_survey_manager.common.misc is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.legislation_asof import ...", + DeprecationWarning, + stacklevel=2, +) + +from openfisca_survey_manager.policy.legislation_asof import ( # noqa: E402 + asof, + do_nothing, + inflate_parameter_leaf, + inflate_parameters, + leaf_asof, + parameters_asof, + variables_asof, +) + +__all__ = [ + "asof", + "do_nothing", + "inflate_parameter_leaf", + "inflate_parameters", + "leaf_asof", + "parameters_asof", + "variables_asof", +] diff --git a/openfisca_survey_manager/policy/__init__.py b/openfisca_survey_manager/policy/__init__.py new file mode 100644 index 00000000..60888238 --- /dev/null +++ b/openfisca_survey_manager/policy/__init__.py @@ -0,0 +1,53 @@ +# Policy-related modules (simulations, simulation_builder, aggregates). +# À terme ces briques pourront être déplacées dans un paquet dédié. +# Voir docs/REFACTORING_PLAN.md. + +from openfisca_survey_manager.policy.aggregates import AbstractAggregates +from openfisca_survey_manager.policy.simulation_builder import ( + SimulationBuilder, + diagnose_variable_mismatch, +) +from openfisca_survey_manager.policy.simulations import ( + SecretViolationError, + Simulation, + adaptative_calculate_variable, + assert_variables_in_same_entity, + compute_aggregate, + compute_pivot_table, + compute_quantiles, + compute_winners_losers, + create_data_frame_by_entity, + get_words, + inflate, + init_entity_data, + init_simulation, + init_variable_in_entity, + new_from_tax_benefit_system, + print_memory_usage, + set_weight_variable_by_entity, + summarize_variable, +) + +__all__ = [ + "AbstractAggregates", + "Simulation", + "SimulationBuilder", + "SecretViolationError", + "adaptative_calculate_variable", + "assert_variables_in_same_entity", + "compute_aggregate", + "compute_pivot_table", + "compute_quantiles", + "compute_winners_losers", + "create_data_frame_by_entity", + "diagnose_variable_mismatch", + "get_words", + "inflate", + "init_entity_data", + "init_simulation", + "init_variable_in_entity", + "new_from_tax_benefit_system", + "print_memory_usage", + "set_weight_variable_by_entity", + "summarize_variable", +] diff --git a/openfisca_survey_manager/policy/aggregates.py b/openfisca_survey_manager/policy/aggregates.py new file mode 100644 index 00000000..7d8059c4 --- /dev/null +++ b/openfisca_survey_manager/policy/aggregates.py @@ -0,0 +1,554 @@ +import collections +import logging +from datetime import datetime +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd + +log = logging.getLogger(__name__) + + +# TODO: +# * Localisation + + +class AbstractAggregates: + aggregate_variables = None + amount_unit = 1e6 + currency = None + base_data_frame = None + baseline_simulation = None + beneficiaries_unit = 1e3 + filter_by = None + labels = None + simulation = None + survey_scenario = None + totals_df = None + + def __init__( + self, + survey_scenario=None, + absolute_minimal_detected_variation=0, + relative_minimal_detected_variation=0, + observations_threshold=0, + ): + assert survey_scenario is not None + + self.period = survey_scenario.period + self.survey_scenario = survey_scenario + assert len(survey_scenario.simulations) >= 1 + + self.simulations = survey_scenario.simulations + self.absolute_minimal_detected_variation = absolute_minimal_detected_variation + self.relative_minimal_detected_variation = relative_minimal_detected_variation + self.observations_threshold = observations_threshold + + for name in survey_scenario.tax_benefit_systems: + assert survey_scenario.simulations[name] is not None + + self.weight_variable_by_entity = survey_scenario.weight_variable_by_entity + if self.labels is None: + amount_unit_str = f"({self.amount_unit} {self.currency})" + beneficiaries_unit_str = f"({self.beneficiaries_unit})" + self.labels = collections.OrderedDict( + ( + ("label", "Mesure"), + ("entity", "Entité"), + ("reform_amount", "Dépenses\n" + amount_unit_str), + ("reform_beneficiaries", "Bénéficiaires\n(milliers)"), + ("baseline_amount", "Dépenses initiales\n" + amount_unit_str), + ("baseline_beneficiaries", "Bénéficiaires\ninitiaux\n" + beneficiaries_unit_str), + ("actual_amount", "Dépenses\nréelles\n" + amount_unit_str), + ("actual_beneficiaries", "Bénéficiaires\nréels\n" + beneficiaries_unit_str), + ("absolute_difference_amount", "Diff. absolue\nDépenses\n" + amount_unit_str), + ("absolute_difference_beneficiaries", "Diff absolue\nBénéficiaires\n" + beneficiaries_unit_str), + ("relative_difference_amount", "Diff. relative\nDépenses"), + ("relative_difference_beneficiaries", "Diff. relative\nBénéficiaires"), + ("winners", "Gagnants"), + ("losers", "Perdants"), + ("neutral", "Neutres"), + ) + ) + + def compute_aggregates(self, use_baseline: bool = True, reform: bool = True, actual: bool = True) -> pd.DataFrame: + """ + Compute aggregate amounts. + + Args: + use_baseline (bool, optional): _description_. Defaults to True. + reform (bool, optional): _description_. Defaults to True. + actual (bool, optional): _description_. Defaults to True. + + Returns: + pd.DataFrame: The aggregates + """ + filter_by = self.filter_by + if actual: + self.totals_df = self.load_actual_data(period=self.period) + + simulation_types = [] + if use_baseline: + assert self.simulations["baseline"] is not None + simulation_types.append("baseline") + if reform: + simulation_types.append("reform") + if actual: + simulation_types.append("actual") + + data_frame_by_simulation_type = {} + + for simulation_type in simulation_types: + if simulation_type == "actual": + data_frame_by_simulation_type["actual"] = self.totals_df.copy() if self.totals_df is not None else None + else: + use_baseline = simulation_type != "reform" + data_frame = pd.DataFrame() + assert self.aggregate_variables is not None + for variable in self.aggregate_variables: + variable_data_frame = self.compute_variable_aggregates( + variable, use_baseline=use_baseline, filter_by=filter_by + ) + data_frame = pd.concat((data_frame, variable_data_frame)) + + data_frame.rename( + columns={ + "amount": f"{simulation_type}_amount", + "beneficiaries": f"{simulation_type}_beneficiaries", + }, + inplace=True, + ) + data_frame_by_simulation_type[simulation_type] = data_frame + + if use_baseline and reform: + del data_frame_by_simulation_type["reform"]["entity"] + del data_frame_by_simulation_type["reform"]["label"] + + self.base_data_frame = pd.concat( + list(data_frame_by_simulation_type.values()), + axis=1, + sort=True, + ).loc[list(self.aggregate_variables)] + return self.base_data_frame + + def compute_difference( + self, + target: str = "baseline", + default: str = "actual", + amount: bool = True, + beneficiaries: bool = True, + absolute: bool = True, + relative: bool = True, + ) -> pd.DataFrame: + """ + Compute and add relative and/or absolute differences to the data_frame. + + Args: + target (str, optional): Target simulation. Defaults to "baseline". + default (str, optional): Default simulation. Defaults to 'actual'. + amount (bool, optional): Provide amounts. Defaults to True. + beneficiaries (bool, optional): Provide beneficiaries. Defaults to True. + absolute (bool, optional): Return absolute values. Defaults to True. + relative (bool, optional): Return relative values. Defaults to True. + + Returns: + pd.DataFrame: The differences + """ + assert relative or absolute + assert amount or beneficiaries + base_data_frame = self.base_data_frame if self.base_data_frame is not None else self.compute_aggregates() + + difference_data_frame = base_data_frame[["label", "entity"]].copy() + # Remove duplicates + difference_data_frame = difference_data_frame.loc[:, ~difference_data_frame.columns.duplicated()].copy() + + quantities = [] + quantities += ["amount"] if amount else None + quantities += ["beneficiaries"] if beneficiaries else None + + for quantity in quantities: + difference_data_frame[f"absolute_difference_{quantity}"] = ( + abs(base_data_frame[f"{target}_{quantity}"]) - base_data_frame[f"{default}_{quantity}"] + ) + difference_data_frame[f"relative_difference_{quantity}"] = ( + abs(base_data_frame[f"{target}_{quantity}"]) - base_data_frame[f"{default}_{quantity}"] + ) / abs(base_data_frame[f"{default}_{quantity}"]) + + return difference_data_frame + + def compute_variable_aggregates( + self, variable: str, use_baseline: bool = False, filter_by: Optional[str] = None + ) -> pd.DataFrame: + """ + Return aggregate spending, and number of beneficiaries for the relevant entity level. + + Args: + variable (str): Name of the variable aggregated according to its entity + use_baseline (bool, optional): Use the baseline or the reform or the only avalilable + simulation when no reform (default). Defaults to False. + filter_by (str, optional): The variable to filter by. Defaults to None. + + Returns: + pd.DataFrame: The amount and beneficiaries for the variable + """ + if len(self.simulations) == 1: + simulation = next(iter(self.simulations.values())) + elif use_baseline: + simulation = self.simulations["baseline"] + else: + simulation = self.simulations["reform"] + + variables = simulation.tax_benefit_system.variables + variable_instance = variables.get(variable) + + if variable_instance is None: + msg = f"Variable {variable} is not available" + if use_baseline: + msg += " in baseline simulation" + log.info(msg) + return pd.DataFrame( + data={ + "label": variable, + "entity": "Unknown entity", + "amount": 0, + "beneficiaries": 0, + }, + index=[variable], + ) + + entity_key = variable_instance.entity.key + + if self.weight_variable_by_entity is not None: + weight = self.weight_variable_by_entity[entity_key] + assert weight in variables, f"{weight} not a variable of the tax_benefit_system" + weight_array = simulation.calculate(weight, period=self.period).astype("float") + assert not np.isnan(np.sum(weight_array)), f"The are some NaN in weights {weight} for entity {entity_key}" + # amounts and beneficiaries from current data and default data if exists + # Build weights for each entity + else: + log.debug(f"No weight variable defined for entity {entity_key}, using 1 as weight.") + weight = "weight" + weight_array = 1 + + variable_array = simulation.calculate_add(variable, period=self.period).astype("float") + assert np.isfinite(variable_array).all(), ( + f"The are non finite values in variable {variable} for entity {entity_key}" + ) + data = pd.DataFrame( + { + variable: variable_array, + weight: weight_array, + } + ) + if filter_by: + filter_dummy_variable = ( + filter_by if filter_by in variables else self.survey_scenario.filtering_variable_by_entity[entity_key] + ) + filter_dummy_array = simulation.calculate(filter_dummy_variable, period=self.period) + + else: + filter_dummy_array = 1 + + assert np.isfinite(filter_dummy_array).all(), ( + f"The are non finite values in variable {filter_dummy_variable} for entity {entity_key}" + ) + + amount = int((data[variable] * data[weight] * filter_dummy_array / self.amount_unit).sum()) + beneficiaries = int( + ((data[variable] != 0) * data[weight] * filter_dummy_array / self.beneficiaries_unit).sum() + ) + variable_data_frame = pd.DataFrame( + data={ + "label": variables[variable].label, + "entity": variables[variable].entity.key, + "amount": amount, + "beneficiaries": beneficiaries, + }, + index=[variable], + ) + + return variable_data_frame + + def create_description(self): + """Create a description dataframe.""" + now = datetime.now() + return pd.DataFrame( + [ + "OpenFisca", + "Calculé le {} à {}".format(now.strftime("%d-%m-%Y"), now.strftime("%H:%M")), + f"Système socio-fiscal au {self.simulation.period.start.year}", + f"Données d'enquêtes de l'année {self.data_year!s}", + ] + ) + + def to_csv( + self, + path=None, + absolute=True, + amount=True, + beneficiaries=True, + default="actual", + relative=True, + target="reform", + ): + """Saves the table to csv.""" + assert path is not None + + if Path(path).is_dir(): + now = datetime.now() + file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".csv") + else: + file_path = path + + df = self.get_data_frame( + absolute=absolute, + amount=amount, + beneficiaries=beneficiaries, + default=default, + relative=relative, + target=target, + ) + df.to_csv(file_path, index=False, header=True) + + def to_excel( + self, + path=None, + absolute=True, + amount=True, + beneficiaries=True, + default="actual", + relative=True, + target="reform", + ): + """Save the table to excel.""" + assert path is not None + + if Path(path).is_dir(): + now = datetime.now() + file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".xlsx") + else: + file_path = path + + df = self.get_data_frame( + absolute=absolute, + amount=amount, + beneficiaries=beneficiaries, + default=default, + relative=relative, + target=target, + ) + writer = pd.ExcelWriter(file_path) + df.to_excel(writer, "aggregates", index=False, header=True) + descr = self.create_description() + descr.to_excel(writer, "description", index=False, header=False) + writer.save() + + def to_html( + self, + path=None, + absolute=True, + amount=True, + beneficiaries=True, + default="actual", + relative=True, + target="reform", + ): + """Get or saves the table to html format.""" + df = self.get_data_frame( + absolute=absolute, + amount=amount, + beneficiaries=beneficiaries, + default=default, + relative=relative, + target=target, + ) + + if path is not None and Path(path).is_dir(): + now = datetime.now() + file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".html") + else: + file_path = path + + if file_path is not None: + with Path(file_path).open("w") as html_file: + df.to_html(html_file) + return df.to_html() + + def to_markdown( + self, + path=None, + absolute=True, + amount=True, + beneficiaries=True, + default="actual", + relative=True, + target="reform", + ): + """Get or saves the table to markdown format.""" + df = self.get_data_frame( + absolute=absolute, + amount=amount, + beneficiaries=beneficiaries, + default=default, + relative=relative, + target=target, + ) + + if path is not None and Path(path).is_dir(): + now = datetime.now() + file_path = Path(path) / "Aggregates_%s.%s" % (now.strftime("%d-%m-%Y"), ".md") + else: + file_path = path + + if file_path is not None: + with Path(file_path).open("w") as markdown_file: + df.to_markdown(markdown_file) + + return df.to_markdown() + + def get_calibration_coeffcient(self, target: str = "reform") -> pd.DataFrame: + df = self.compute_aggregates( + actual=True, + use_baseline=target == "baseline", + reform=target == "reform", + ) + return df[f"{target}_amount"] / df["actual_amount"] + + def get_data_frame( + self, + absolute: bool = True, + amount: bool = True, + beneficiaries: bool = True, + default: str = "actual", + formatting: bool = True, + relative: bool = True, + target: str = "reform", + ignore_labels: bool = False, + ): + assert target is None or target in ["reform", "baseline"] + + columns = self.labels.keys() + if (absolute or relative) and (target != default): + difference_data_frame = self.compute_difference( + absolute=absolute, + amount=amount, + beneficiaries=beneficiaries, + default=default, + relative=relative, + target=target, + ) + else: + difference_data_frame = None + + # Removing unwanted columns + if amount is False: + columns = [column for column in columns if "amount" not in columns] + + if beneficiaries is False: + columns = [column for column in columns if "beneficiaries" not in column] + + if absolute is False: + columns = [column for column in columns if "absolute" not in column] + + if relative is False: + columns = [column for column in columns if "relative" not in column] + + for simulation_type in ["reform", "baseline", "actual"]: + if simulation_type not in [target, default]: + columns = [column for column in columns if simulation_type not in column] + + aggregates_data_frame = self.compute_aggregates( + actual="actual" in [target, default], + use_baseline="baseline" in [target, default], + reform="reform" in [target, default], + ) + + if "reform_amount" in aggregates_data_frame.columns and "baseline_amount" in aggregates_data_frame.columns: + winners_losers_df = self.compute_all_winners_losers(filter_by=self.filter_by) + aggregates_data_frame = aggregates_data_frame.join(winners_losers_df) + + ordered_columns = [ + "label", + "entity", + "reform_amount", + "baseline_amount", + "actual_amount", + "absolute_difference_amount", + "relative_difference_amount", + "reform_beneficiaries", + "baseline_beneficiaries", + "actual_beneficiaries", + "absolute_difference_beneficiaries", + "relative_difference_beneficiaries", + "winners", + "losers", + "neutral", + ] + if difference_data_frame is not None: + # Remove eventual duplication + difference_data_frame = difference_data_frame.loc[:, ~difference_data_frame.columns.duplicated()].copy() + aggregates_data_frame = aggregates_data_frame.loc[:, ~aggregates_data_frame.columns.duplicated()].copy() + merged_df = aggregates_data_frame.merge(difference_data_frame, how="left") + df = merged_df[[c for c in columns if c in merged_df.columns]] + else: + columns = [column for column in columns if column in aggregates_data_frame.columns] + df = aggregates_data_frame[columns] + + df = df.reindex(columns=ordered_columns).dropna(axis=1, how="all") + + if formatting: + relative_columns = [column for column in df.columns if "relative" in column] + df[relative_columns] = df[relative_columns].map(lambda x: f"{x:.2%}" if str(x) != "nan" else "nan") + for column in df.columns: + if issubclass(np.dtype(df[column]).type, np.number): + df[column] = df[column].apply(lambda x: f"{round(x):d}" if str(x) != "nan" else "nan") + + if not ignore_labels: + df = df.rename(columns=self.labels) + + return df + + def load_actual_data(self, period=None): + pass + + def compute_winners_losers(self, variable: str, filter_by: Optional[str] = None): + if "reform" not in self.simulations or "baseline" not in self.simulations: + log.warning("Cannot compute winners and losers without a reform and a baseline simulation.") + return pd.DataFrame() + + reform_simulation = self.simulations["reform"] + baseline_simulation = self.simulations["baseline"] + + variable_instance = reform_simulation.tax_benefit_system.variables.get(variable) + if variable_instance is None: + log.warning(f"Variable {variable} not found in reform simulation.") + return pd.DataFrame() + + stats = reform_simulation.compute_winners_losers( + baseline_simulation=baseline_simulation, + variable=variable, + period=self.period, + filter_by=filter_by, + filtering_variable_by_entity=self.survey_scenario.filtering_variable_by_entity, + absolute_minimal_detected_variation=self.absolute_minimal_detected_variation, + relative_minimal_detected_variation=self.relative_minimal_detected_variation, + observations_threshold=self.observations_threshold, + ) + + winners_losers_df = pd.DataFrame( + { + "winners": [stats["above_after"]], + "losers": [stats["lower_after"]], + "neutral": [stats["neutral"]], + }, + index=[variable], + ) + return winners_losers_df + + def compute_all_winners_losers(self, filter_by: Optional[str] = None): + all_winners_losers = pd.DataFrame() + for variable in self.aggregate_variables: + winners_losers = self.compute_winners_losers(variable, filter_by=filter_by) + all_winners_losers = pd.concat([all_winners_losers, winners_losers]) + return all_winners_losers diff --git a/openfisca_survey_manager/policy/legislation_asof.py b/openfisca_survey_manager/policy/legislation_asof.py new file mode 100644 index 00000000..b5f48c9d --- /dev/null +++ b/openfisca_survey_manager/policy/legislation_asof.py @@ -0,0 +1,229 @@ +"""Shared helpers (no survey collection dependency) to avoid circular imports.""" + +import logging + +from openfisca_core import periods +from openfisca_core.parameters import ParameterNode, Scale + +log = logging.getLogger(__name__) + + +def do_nothing(*args, **kwargs): + return None + + +def inflate_parameters( + parameters, + inflator, + base_year, + last_year=None, + ignore_missing_units=False, + start_instant=None, + round_ndigits=2, +): + """ + Inflate a Parameter node or a Parameter leaf for the years between base_year and last_year. + + ::parameters:: a Parameter node or a Parameter leaf + ::inflator:: rate used to inflate the parameter. The rate is unique for all the years + ::base_year:: base year of the parameter + ::last_year:: last year of inflation + ::ignore_missing_units:: if True, a parameter leaf without unit in metadata will not be inflated + ::start_instant:: Instant of the year when the update should start, if None will be January 1st + ::round_ndigits:: Number of digits to keep in the rounded result + """ + if (last_year is not None) and (last_year > base_year + 1): + for year in range(base_year + 1, last_year + 1): + inflate_parameters( + parameters, + inflator, + year - 1, + last_year=year, + ignore_missing_units=ignore_missing_units, + start_instant=start_instant, + round_ndigits=round_ndigits, + ) + else: + if last_year is None: + last_year = base_year + 1 + + assert last_year == base_year + 1 + + if isinstance(parameters, ParameterNode): + for sub_parameter in parameters.children.values(): + inflate_parameters( + sub_parameter, + inflator, + base_year, + last_year, + ignore_missing_units=ignore_missing_units, + start_instant=start_instant, + round_ndigits=round_ndigits, + ) + else: + acceptable_units = [ + "rate_unit", + "threshold_unit", + "unit", + ] + if ignore_missing_units: + if not hasattr(parameters, "metadata"): + return + if not bool(set(parameters.metadata.keys()) & set(acceptable_units)): + return + assert hasattr(parameters, "metadata"), f"{parameters.name} doesn't have metadata" + unit_types = set(parameters.metadata.keys()).intersection(set(acceptable_units)) + assert unit_types, ( + f"No admissible unit in metadata for parameter {parameters.name}. You may consider using " + "the option 'ignore_missing_units' from the inflate_parameters() function." + ) + if len(unit_types) > 1: + assert unit_types == {"threshold_unit", "rate_unit"}, ( + f"Too much admissible units in metadata for parameter {parameters.name}" + ) + unit_by_type = {unit_type: parameters.metadata[unit_type] for unit_type in unit_types} + for unit_type in unit_by_type: + if parameters.metadata[unit_type].startswith("currency"): + inflate_parameter_leaf( + parameters, + base_year, + inflator, + unit_type=unit_type, + start_instant=start_instant, + round_ndigits=round_ndigits, + ) + + +def inflate_parameter_leaf(sub_parameter, base_year, inflator, unit_type="unit", start_instant=None, round_ndigits=2): + """ + Inflate a Parameter leaf according to unit type for the year after base_year. + + ::sub_parameter:: a Parameter leaf + ::base_year:: base year of the parameter + ::inflator:: rate used to inflate the parameter + ::unit_type:: unit supposed by default. Other admissible unit types are threshold_unit and rate_unit + ::start_instant:: Instant of the year when the update should start, if None will be January 1st + ::round_ndigits:: Number of digits to keep in the rounded result + """ + if isinstance(sub_parameter, Scale): + if unit_type == "threshold_unit": + for bracket in sub_parameter.brackets: + threshold = bracket.children["threshold"] + inflate_parameter_leaf( + threshold, base_year, inflator, start_instant=start_instant, round_ndigits=round_ndigits + ) + return + else: + kept_instants_str = [ + parameter_at_instant.instant_str + for parameter_at_instant in sub_parameter.values_list + if periods.instant(parameter_at_instant.instant_str).year <= base_year + ] + if not kept_instants_str: + return + + last_admissible_instant_str = max(kept_instants_str) + sub_parameter.update(start=last_admissible_instant_str, value=sub_parameter(last_admissible_instant_str)) + if start_instant is not None: + assert periods.instant(start_instant).year == (base_year + 1), ( + "Year of start_instant should be base_year + 1" + ) + value = ( + round(sub_parameter(f"{base_year}-12-31") * (1 + inflator), round_ndigits) + if sub_parameter(f"{base_year}-12-31") is not None + else None + ) + sub_parameter.update( + start=start_instant, + value=value, + ) + else: + restricted_to_base_year_value_list = [ + parameter_at_instant + for parameter_at_instant in sub_parameter.values_list + if periods.instant(parameter_at_instant.instant_str).year == base_year + ] + if restricted_to_base_year_value_list: + for parameter_at_instant in reversed(restricted_to_base_year_value_list): + if parameter_at_instant.instant_str.startswith(str(base_year)): + value = ( + round(parameter_at_instant.value * (1 + inflator), round_ndigits) + if parameter_at_instant.value is not None + else None + ) + sub_parameter.update( + start=parameter_at_instant.instant_str.replace(str(base_year), str(base_year + 1)), + value=value, + ) + else: + value = ( + round(sub_parameter(f"{base_year}-12-31") * (1 + inflator), round_ndigits) + if sub_parameter(f"{base_year}-12-31") is not None + else None + ) + sub_parameter.update( + start=f"{base_year + 1}-01-01", + value=value, + ) + + +def asof(tax_benefit_system, instant): + parameters = tax_benefit_system.parameters + parameters_asof(parameters, instant) + variables_asof(tax_benefit_system, instant) + + +def leaf_asof(sub_parameter, instant): + kept_instants_str = [ + parameter_at_instant.instant_str + for parameter_at_instant in sub_parameter.values_list + if periods.instant(parameter_at_instant.instant_str) <= instant + ] + if not kept_instants_str: + sub_parameter.values_list = [] + return + + last_admissible_instant_str = max(kept_instants_str) + sub_parameter.update(start=last_admissible_instant_str, value=sub_parameter(last_admissible_instant_str)) + + +def parameters_asof(parameters, instant): + if isinstance(instant, str): + instant = periods.instant(instant) + assert isinstance(instant, periods.Instant) + + for sub_parameter in parameters.children.values(): + if isinstance(sub_parameter, ParameterNode): + parameters_asof(sub_parameter, instant) + else: + if isinstance(sub_parameter, Scale): + for bracket in sub_parameter.brackets: + threshold = bracket.children["threshold"] + rate = bracket.children.get("rate") + amount = bracket.children.get("amount") + leaf_asof(threshold, instant) + if rate: + leaf_asof(rate, instant) + if amount: + leaf_asof(amount, instant) + else: + leaf_asof(sub_parameter, instant) + + +def variables_asof(tax_benefit_system, instant, variables_list=None): + if isinstance(instant, str): + instant = periods.instant(instant) + assert isinstance(instant, periods.Instant) + + if variables_list is None: + variables_list = tax_benefit_system.variables.keys() + + for variable_name, variable in tax_benefit_system.variables.items(): + if variable_name in variables_list: + formulas = variable.formulas + for instant_str in list(formulas.keys()): + if periods.instant(instant_str) > instant: + del formulas[instant_str] + + if variable.end is not None and periods.instant(variable.end) >= instant: + variable.end = None diff --git a/openfisca_survey_manager/policy/simulation_builder.py b/openfisca_survey_manager/policy/simulation_builder.py new file mode 100644 index 00000000..1feaad1f --- /dev/null +++ b/openfisca_survey_manager/policy/simulation_builder.py @@ -0,0 +1,291 @@ +import logging + +from openfisca_core.model_api import MONTH, YEAR +from openfisca_core.simulations.simulation_builder import SimulationBuilder + +from openfisca_survey_manager.exceptions import SurveyManagerError + +SimulationBuilder.id_variable_by_entity_key = None +SimulationBuilder.role_variable_by_entity_key = None +SimulationBuilder.used_as_input_variables = None +SimulationBuilder.used_as_input_variables_by_entity = None + + +log = logging.getLogger(__name__) + + +# Helpers + + +def diagnose_variable_mismatch(used_as_input_variables, input_data_frame): + """Diagnose variables mismatch. + + Args: + used_as_input_variables(lsit): List of variable to test presence + input_data_frame: DataFrame in which to test variables presence + + """ + variables_mismatch = ( + set(used_as_input_variables).difference(set(input_data_frame.columns)) if used_as_input_variables else None + ) + if variables_mismatch: + log.info( + "The following variables used as input variables are not present in the input data frame: " + f"\n {sorted(variables_mismatch)}" + ) + if variables_mismatch: + log.debug(f"The following variables are used as input variables: \n {sorted(used_as_input_variables)}") + log.debug(f"The input_data_frame contains the following variables: \n {sorted(input_data_frame.columns)}") + + +# SimulationBuilder monkey-patched methods + + +def _set_id_variable_by_entity_key(builder) -> dict[str, str]: + """Identify and sets the correct ids for the different entities.""" + if builder.id_variable_by_entity_key is None: + log.debug("Use default id_variable names") + builder.id_variable_by_entity_key = { + entity.key: entity.key + "_id" for entity in builder.tax_benefit_system.entities + } + + return builder.id_variable_by_entity_key + + +def _set_role_variable_by_entity_key(builder) -> dict[str, str]: + """Identify and sets the correct roles for the different entities.""" + if builder.role_variable_by_entity_key is None: + builder.role_variable_by_entity_key = { + entity.key: entity.key + "_role_index" for entity in builder.tax_benefit_system.entities + } + + return builder.role_variable_by_entity_key + + +def _set_used_as_input_variables_by_entity(builder) -> dict[str, list[str]]: + """Identify and sets the correct input variables for the different entities.""" + if builder.used_as_input_variables_by_entity is not None: + return + + tax_benefit_system = builder.tax_benefit_system + + tax_variables = set(tax_benefit_system.variables.keys()) + input_variables = set(builder.used_as_input_variables) + assert input_variables <= tax_variables, ( + "Some variables used as input variables are not part of the tax benefit system:\n " + f"{input_variables.difference(tax_variables)}" + ) + + builder.used_as_input_variables_by_entity = {} + + for entity in tax_benefit_system.entities: + builder.used_as_input_variables_by_entity[entity.key] = [ + variable + for variable in builder.used_as_input_variables + if tax_benefit_system.get_variable(variable).entity.key == entity.key + ] + + return builder.used_as_input_variables_by_entity + + +def filter_input_variables(builder, input_data_frame, tax_benefit_system): + """Filter the input data frame from variables that won't be used or are set to be computed. + + Args: + input_data_frame: Input dataframe (Default value = None) + + Returns: + pd.DataFrame: filtered dataframe + + """ + assert input_data_frame is not None + id_variable_by_entity_key = builder.id_variable_by_entity_key + role_variable_by_entity_key = builder.role_variable_by_entity_key + used_as_input_variables = builder.used_as_input_variables + + variables = tax_benefit_system.variables + + id_variables = [id_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] + role_variables = [role_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] + + log.debug(f"Variable used_as_input_variables in filter: \n {used_as_input_variables}") + + unknown_columns = [] + for column_name in input_data_frame: + if column_name in id_variables + role_variables: + continue + if column_name not in variables: + unknown_columns.append(column_name) + + input_data_frame.drop(unknown_columns, axis=1, inplace=True) + + if unknown_columns: + log.debug(f"The following unknown columns {sorted(unknown_columns)}, are dropped from input table") + + used_columns = [] + dropped_columns = [] + for column_name in input_data_frame: + if column_name in id_variables + role_variables: + continue + variable = variables[column_name] + # Keeping the calculated variables that are initialized by the input data + if variable.formulas: + if column_name in used_as_input_variables: + used_columns.append(column_name) + continue + + dropped_columns.append(column_name) + + input_data_frame.drop(dropped_columns, axis=1, inplace=True) + + if used_columns: + log.debug( + f"These columns are not dropped because present in used_as_input_variables:\n {sorted(used_columns)}" + ) + if dropped_columns: + log.debug( + "These columns in survey are set to be calculated, we drop them from the input table:\n " + f"{sorted(dropped_columns)}" + ) + + log.debug(f"Keeping the following variables in the input_data_frame:\n {sorted(input_data_frame.columns)}") + return input_data_frame + + +def init_all_entities(builder, input_data_frame, period=None): + assert period is not None + log.debug(f"Initialasing simulation using input_data_frame for period {period}") + builder._set_id_variable_by_entity_key() + builder._set_role_variable_by_entity_key() + + if period.unit == YEAR: # 1. year + simulation = builder.init_simulation_with_data_frame( + input_data_frame=input_data_frame, + period=period, + ) + elif period.unit == MONTH and period.size == 3: # 2. quarter + for offset in range(period.size): + period_item = period.first_month.offset(offset, MONTH) + simulation = builder.init_simulation_with_data_frame( + input_data_frame=input_data_frame, + period=period_item, + ) + elif period.unit == MONTH and period.size == 1: # 3. months + simulation = builder.init_simulation_with_data_frame( + input_data_frame=input_data_frame, + period=period, + ) + else: + raise SurveyManagerError(f"Invalid period {period}") + + simulation.id_variable_by_entity_key = builder.id_variable_by_entity_key + return simulation + + +def init_entity_structure(builder, entity, input_data_frame): + """Initialize sthe simulation with tax_benefit_system entities and input_data_frame. + + Args: + tax_benefit_system(TaxBenfitSystem): The TaxBenefitSystem to get the structure from + entity(Entity): The entity to initialize structure + input_data_frame(pd.DataFrame): The input + builder(Builder): The builder + + """ + tax_benefit_system = builder.tax_benefit_system + builder._set_id_variable_by_entity_key() + builder._set_role_variable_by_entity_key() + builder._set_used_as_input_variables_by_entity() + + input_data_frame = builder.filter_input_variables(input_data_frame, tax_benefit_system) + + id_variables = [builder.id_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] + role_variables = [ + builder.role_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities + ] + + if entity.is_person: + for id_variable in id_variables + role_variables: + assert id_variable in input_data_frame.columns, f"Variable {id_variable} is not present in input dataframe" + + ids = range(len(input_data_frame)) + if entity.is_person: + builder.declare_person_entity(entity.key, ids) + for group_entity in tax_benefit_system.group_entities: + _key = group_entity.key + _id_variable = builder.id_variable_by_entity_key[_key] + _role_variable = builder.role_variable_by_entity_key[_key] + group_population = builder.declare_entity( + _key, input_data_frame[_id_variable].drop_duplicates().sort_values().values + ) + builder.join_with_persons( + group_population, + input_data_frame[_id_variable].astype("int").values, + input_data_frame[_role_variable].astype("int").values, + ) + + +def init_simulation_with_data_frame(builder, input_data_frame, period): + """Initialize the simulation period with current input_data_frame for an entity if specified.""" + used_as_input_variables = builder.used_as_input_variables + id_variable_by_entity_key = builder.id_variable_by_entity_key + role_variable_by_entity_key = builder.role_variable_by_entity_key + tax_benefit_system = builder.tax_benefit_system + assert tax_benefit_system is not None + + diagnose_variable_mismatch(used_as_input_variables, input_data_frame) + + id_variables = [id_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] + role_variables = [role_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] + + for id_variable in id_variables + role_variables: + assert id_variable in input_data_frame.columns, f"Variable {id_variable} is not present in input dataframe" + + input_data_frame = builder.filter_input_variables(input_data_frame, tax_benefit_system) + + index_by_entity_key = {} + + for entity in tax_benefit_system.entities: + builder.init_entity_structure(entity, input_data_frame) + + if entity.is_person: + continue + + else: + index_by_entity_key[entity.key] = ( + input_data_frame.loc[ + input_data_frame[role_variable_by_entity_key[entity.key]] == 0, + id_variable_by_entity_key[entity.key], + ] + .sort_values() + .index + ) + + for column_name, column_serie in input_data_frame.items(): + if role_variable_by_entity_key is not None and column_name in role_variable_by_entity_key.values(): + continue + + if id_variable_by_entity_key is not None and column_name in id_variable_by_entity_key.values(): + continue + + simulation = builder.build(tax_benefit_system) + entity = tax_benefit_system.variables[column_name].entity + if entity.is_person: + simulation.init_variable_in_entity(entity.key, column_name, column_serie, period) + else: + simulation.init_variable_in_entity( + entity.key, column_name, column_serie[index_by_entity_key[entity.key]], period + ) + + assert builder.id_variable_by_entity_key is not None + simulation.id_variable_by_entity_key = builder.id_variable_by_entity_key + return simulation + + +SimulationBuilder._set_id_variable_by_entity_key = _set_id_variable_by_entity_key +SimulationBuilder._set_role_variable_by_entity_key = _set_role_variable_by_entity_key +SimulationBuilder._set_used_as_input_variables_by_entity = _set_used_as_input_variables_by_entity +SimulationBuilder.filter_input_variables = filter_input_variables +SimulationBuilder.init_all_entities = init_all_entities +SimulationBuilder.init_entity_structure = init_entity_structure +SimulationBuilder.init_simulation_with_data_frame = init_simulation_with_data_frame diff --git a/openfisca_survey_manager/policy/simulations.py b/openfisca_survey_manager/policy/simulations.py new file mode 100644 index 00000000..942ee3a4 --- /dev/null +++ b/openfisca_survey_manager/policy/simulations.py @@ -0,0 +1,1468 @@ +"""Monkey-patch openfisca_core.simulations.Simulation to work with pandas.""" + +import logging +import re +import warnings +from typing import Any, Callable, Optional, Union + +import humanize +import numpy as np +import pandas as pd +from numpy import logical_or as or_ +from openfisca_core import periods +from openfisca_core.indexed_enums import Enum, EnumArray +from openfisca_core.memory_config import MemoryConfig +from openfisca_core.periods import ETERNITY, MONTH, YEAR +from openfisca_core.simulations import Simulation +from openfisca_core.types import Array, Period, TaxBenefitSystem +from openfisca_core.types import CoreEntity as Entity + +from openfisca_survey_manager.exceptions import SurveyManagerError +from openfisca_survey_manager.policy.simulation_builder import ( + SimulationBuilder, + diagnose_variable_mismatch, +) +from openfisca_survey_manager.statshelpers import mark_weighted_percentiles +from openfisca_survey_manager.survey_collections import SurveyCollection +from openfisca_survey_manager.utils import do_nothing, load_table + +log = logging.getLogger(__name__) + + +# Helpers + + +def assert_variables_in_same_entity(tax_benefit_system: TaxBenefitSystem, variables: list): + """ + Assert that variables are in the same entity. + + Args: + tax_benefit_system (TaxBenefitSystem): Host tax benefit system + variables (List[str]): Variables supposed to belong to the same entity + + Returns: + str: Common entity of the variables + """ + entity = None + for variable_name in variables: + variable = tax_benefit_system.variables.get(variable_name) + assert variable + if entity is None: + entity = variable.entity + assert variable.entity == entity, ( + f"{variables} are not from the same entity: {variable_name} doesn't belong to {entity.key}" + ) + return entity.key + + +def get_words(text: str): + return re.compile("[A-Za-z_]+[A-Za-z0-9_]*").findall(text) + + +# Main functions + + +def adaptative_calculate_variable( + simulation: Simulation, variable: str, period: Optional[Union[int, str, Period]] +) -> Array: + """ + Calculate variable by adpating it definition period to the target period. + + Args: + simulation (Simulation): Simulation to suse + variable (str): Variable to be computed + period (Optional[Union[int, str, Period]]): Target period + + Returns: + Array: Values of the variable on the target period + """ + if not isinstance(period, periods.Period): + period = periods.period(str(period)) + + tax_benefit_system = simulation.tax_benefit_system + assert tax_benefit_system is not None + + assert variable in tax_benefit_system.variables, f"{variable} is not a valid variable" + period_size_independent = tax_benefit_system.get_variable(variable).is_period_size_independent + definition_period = tax_benefit_system.get_variable(variable).definition_period + + if period_size_independent is False and definition_period != "eternity": + values = simulation.calculate_add(variable, period=period) + elif period_size_independent is True and definition_period == "month" and period.size_in_months > 1: + values = simulation.calculate(variable, period=period.first_month) + elif period_size_independent is True and definition_period == "month" and period.size_in_months == 1: + values = simulation.calculate(variable, period=period) + elif period_size_independent is True and definition_period == "year" and period.size_in_months > 12: + values = simulation.calculate(variable, period=period.start.offset("first-of", "year").period("year")) + elif period_size_independent is True and definition_period == "year" and period.size_in_months == 12: + values = simulation.calculate(variable, period=period) + elif period_size_independent is True and definition_period == "year": + values = simulation.calculate(variable, period=period.this_year) + elif definition_period == "eternity": + values = simulation.calculate(variable, period=period) + else: + values = None + assert values is not None, f"Unspecified calculation period for variable {variable}" + + return values + + +def compute_aggregate( + simulation: Simulation, + variable: Optional[str] = None, + aggfunc: str = "sum", + filter_by: Optional[str] = None, + period: Optional[Union[int, str, Period]] = None, + missing_variable_default_value: Any = np.nan, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + filtering_variable_by_entity: Optional[dict] = None, +) -> Optional[Union[float, int]]: + """ + Compute aggregate of a variable. + + Args: + simulation (Simulation): Simulation to use for the computation + variable (str, optional): Variable to aggregate. Defaults to None. + aggfunc (str, optional): Aggregation function. Defaults to 'sum'. + filter_by (str, optional): Filter variable or expression to use. Defaults to None. + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. + weighted (bool, optional): Whether to weight the variable or not. Defaults to True. + alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. + Defaults to None. + filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. + + Returns: + float: Aggregate + """ + weight_variable_by_entity = simulation.weight_variable_by_entity + tax_benefit_system = simulation.tax_benefit_system + + if period is None: + period = simulation.period + + assert variable in tax_benefit_system.variables, f"{variable} is not a variable of the tax benefit system" + entity_key = tax_benefit_system.variables[variable].entity.key + + if filter_by is None and filtering_variable_by_entity is not None: + filter_by_variable = filtering_variable_by_entity.get(entity_key) + + if filter_by: + filter_by_variable = get_words(filter_by)[0] + assert filter_by_variable in tax_benefit_system.variables, ( + f"{filter_by_variable} is not a variable of the tax benefit system" + ) + entity_key = tax_benefit_system.variables[variable].entity.key + filter_by_entity_key = tax_benefit_system.variables[filter_by_variable].entity.key + assert filter_by_entity_key == entity_key, ( + f"You tried to compute agregates for variable '{variable}', of entity {entity_key}" + f" filtering by variable '{filter_by_variable}', of entity {filter_by_entity_key}. This is not possible." + f" Please choose a filter-by variable of same entity as '{variable}'." + ) + + expressions = [] + if filter_by is not None: + if filter_by in tax_benefit_system.variables: + filter_entity_key = tax_benefit_system.variables.get(filter_by).entity.key + assert filter_entity_key == entity_key, ( + f"You tried to compute agregates for variable '{variable}', of entity {entity_key} " + f"filtering by variable '{filter_by_variable}', of entity {filter_by_entity_key}. " + "This is not possible. " + f"Please choose a filter-by variable of same entity as '{variable}'." + ) + else: + filter_entity_key = assert_variables_in_same_entity(tax_benefit_system, get_words(filter_by)) + expressions.extend([filter_by]) + assert filter_entity_key == entity_key + else: + filter_dummy = np.array(1.0) + + uniform_weight = np.array(1.0) + weight_variable = None + if weighted: + assert or_(alternative_weights, weight_variable_by_entity), ( + f"The weighted option is set at True but there is no weight variable for entity {entity_key} " + "nor alternative weights. Either define a weight variable or switch to unweighted" + ) + if alternative_weights: + if isinstance(alternative_weights, str): + assert alternative_weights in tax_benefit_system.variables, ( + f"{alternative_weights} is not a valid variable of the tax benefit system" + ) + weight_variable = alternative_weights + + elif (type(alternative_weights) is int) or (type(alternative_weights) is float): + weight_variable = None + uniform_weight = float(alternative_weights) + elif weight_variable_by_entity: + weight_variable = weight_variable_by_entity[entity_key] + + if variable in simulation.tax_benefit_system.variables: + value = simulation.adaptative_calculate_variable(variable=variable, period=period) + else: + log.debug(f"Variable {variable} not found. Assigning {missing_variable_default_value}") + return missing_variable_default_value + + weight = ( + simulation.adaptative_calculate_variable(weight_variable, period=period).astype(float) + if weight_variable + else uniform_weight + ) + if weight_variable: + assert any(weight != 0), "Weights shall not be all zeroes" + else: + assert uniform_weight != 0 + + if filter_by is not None: + expression_data_frame = simulation.create_data_frame_by_entity( + variables=get_words(filter_by), period=period, index=False + )[entity_key] + for expression in expressions: + expression_data_frame[expression] = expression_data_frame.eval(expression) + + filter_dummy = expression_data_frame[filter_by] + else: + filter_dummy = 1.0 + + if aggfunc == "sum": + aggregate = (value * weight * filter_dummy).sum() + elif aggfunc == "mean": + aggregate = (value * weight * filter_dummy).sum() / (weight * filter_dummy).sum() + elif aggfunc == "count": + aggregate = (weight * filter_dummy).sum() + elif aggfunc == "count_non_zero": + aggregate = (weight * (value != 0) * filter_dummy).sum() + else: + aggregate = None + + return aggregate + + +def compute_quantiles( + simulation: Simulation, + variable: str, + nquantiles: Optional[int] = None, + period: Optional[Union[int, str, Period]] = None, + filter_by: Optional[str] = None, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + filtering_variable_by_entity: Optional[dict] = None, +) -> list[float]: + """ + Compute quantiles of a variable. + + Args: + simulation (Simulation, optional): Simulation to be used. Defaults to None. + variable (str, optional): Variable which quantiles are computed. Defaults to None. + nquantiles (int, optional): Number of quantiles. Defaults to None. + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. + weighted (bool, optional): Whether to weight the variable or not. Defaults to True. + alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. + Defaults to None. + filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. + + Returns: + List(float): The quantiles values + """ + weight_variable_by_entity = simulation.weight_variable_by_entity + weight_variable = None + entity_key = simulation.tax_benefit_system.variables[variable].entity.key + if weight_variable_by_entity: + weight_variable = weight_variable_by_entity[entity_key] + + variable_values = simulation.adaptative_calculate_variable(variable, period) + if weighted: + assert (alternative_weights is not None) or (weight_variable is not None) + weight = ( + alternative_weights if alternative_weights is not None else simulation.calculate(weight_variable, period) + ) + else: + weight = np.ones(len(variable_values)) + + if filtering_variable_by_entity is not None and filter_by is None: + filter_by = filtering_variable_by_entity.get(entity_key) + + if filter_by is not None: + filter_entity_key = simulation.tax_benefit_system.variables.get(filter_by).entity.key + assert filter_entity_key == entity_key + filter_dummy = simulation.calculate(filter_by, period=period).astype(bool) + + variable_values = variable_values[filter_dummy].copy() + weight = weight[filter_dummy].copy() + + labels = np.arange(1, nquantiles + 1) + method = 2 + _, values = mark_weighted_percentiles(variable_values, labels, weight, method, return_quantiles=True) + return values + + +def compute_pivot_table( + simulation: Simulation = None, + baseline_simulation: Simulation = None, + aggfunc: str = "mean", + columns: Optional[list[str]] = None, + difference: bool = False, + filter_by: Optional[str] = None, + index: Optional[list[str]] = None, + period: Optional[Union[int, str, Period]] = None, + use_baseline_for_columns: Optional[bool] = None, + values: Optional[list[str]] = None, + missing_variable_default_value: Any = np.nan, + concat_axis: Optional[int] = None, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + filtering_variable_by_entity: Optional[dict] = None, +): + """ + Compute pivot table. + + Args: + simulation (Simulation, optional): Main simulation. Defaults to None. + baseline_simulation (Simulation, optional): Baseline simulation. Defaults to None. + aggfunc (str, optional): Aggregation function. Defaults to 'mean'. + columns (List[str], optional): Variables to use in columns. Defaults to None. + difference (bool, optional): Whether to compute the difference with baseline. Defaults to False. + filter_by (str, optional): Filter variable or expression to use. Defaults to None. + index (List[str], optional): _description_. Defaults to None. + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + use_baseline_for_columns (bool, optional): _description_. Defaults to None. + values (List[str], optional): _description_. Defaults to None. + missing_variable_default_value (optional): _description_. Defaults to np.nan. + concat_axis (int, optional): _description_. Defaults to None. + weighted (bool, optional): Whether to weight the variable or not. Defaults to True. + alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. + Defaults to None. + filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. + + Returns: + _type_: _description_ + """ + weight_variable_by_entity = simulation.weight_variable_by_entity + + admissible_aggfuncs = ["max", "mean", "min", "sum", "count", "sum_abs"] + assert aggfunc in admissible_aggfuncs + assert columns or index or values + + if baseline_simulation is not None: + tax_benefit_system = baseline_simulation.tax_benefit_system + else: + tax_benefit_system = simulation.tax_benefit_system + + assert period is not None + + if isinstance(columns, str): + columns = [columns] + elif columns is None: + columns = [] + assert isinstance(columns, list) + + if isinstance(index, str): + index = [index] + elif index is None: + index = [] + assert isinstance(index, list) + + if isinstance(values, str): + values = [values] + elif values is None: + values = [] + assert isinstance(values, list) + + entity_key = None + for axe in [columns, index, values]: + if (len(axe) != 0) and (entity_key is None): + entity_key = tax_benefit_system.variables[axe[0]].entity.key + continue + + if filter_by is None and filtering_variable_by_entity is not None: + filter_by = filtering_variable_by_entity.get(entity_key) + + variables = set(index + columns) + + # Select the entity weight corresponding to the variables that will provide values + uniform_weight = 1.0 + weight_variable = None + if weighted: + if alternative_weights: + if isinstance(alternative_weights, str): + assert alternative_weights in tax_benefit_system.variables, ( + f"{alternative_weights} is not a valid variable of the tax benefit system" + ) + weight_variable = alternative_weights + + elif (type(alternative_weights) is int) or (type(alternative_weights) is float): + weight_variable = None + uniform_weight = float(alternative_weights) + + else: + if weight_variable_by_entity: + weight_variable = weight_variable_by_entity[entity_key] + variables.add(weight_variable) + + else: + log.warn( + f"There is no weight variable for entity {entity_key} nor alternative weights. " + "Switch to unweighted" + ) + + expressions = [] + if filter_by is not None: + if filter_by in tax_benefit_system.variables: + variables.add(filter_by) + filter_entity_key = tax_benefit_system.variables.get(filter_by).entity.key + assert filter_entity_key == entity_key + else: + filter_entity_key = assert_variables_in_same_entity(tax_benefit_system, get_words(filter_by)) + expressions.extend([filter_by]) + assert filter_entity_key == entity_key + else: + filter_dummy = np.array(1.0) + + for expression in expressions: + expression_variables = get_words(expression) + expression_entity_key = assert_variables_in_same_entity(tax_benefit_system, expression_variables) + assert expression_entity_key == entity_key + for variable in expression_variables: + variables.add(variable) + + for variable in variables | set(values): + if variable in tax_benefit_system.variables: + assert tax_benefit_system.variables[variable].entity.key == entity_key, ( + f"The variable {variable} does not belong to entity {entity_key}" + ) + + if difference: + assert simulation is not None and baseline_simulation is not None + reform_data_frame = simulation.create_data_frame_by_entity(values, period=period, index=False)[ + entity_key + ].fillna(missing_variable_default_value) + baseline_data_frame = baseline_simulation.create_data_frame_by_entity(values, period=period, index=False)[ + entity_key + ].fillna(missing_variable_default_value) + for value_variable in values: + if value_variable not in baseline_data_frame: + baseline_data_frame[value_variable] = missing_variable_default_value + if value_variable not in reform_data_frame: + reform_data_frame[value_variable] = missing_variable_default_value + + data_frame = reform_data_frame - baseline_data_frame + + else: + if values: + data_frame = simulation.create_data_frame_by_entity(values, period=period, index=False)[entity_key] + for value_variable in values: + if value_variable not in data_frame: + data_frame[value_variable] = missing_variable_default_value + else: + data_frame = None + + use_baseline_data = difference or use_baseline_for_columns + + # use baseline if explicited or when computing difference + if use_baseline_data: + baseline_vars_data_frame = baseline_simulation.create_data_frame_by_entity( + variables=variables, period=period, index=False + )[entity_key] + else: + baseline_vars_data_frame = simulation.create_data_frame_by_entity( + variables=variables, period=period, index=False + )[entity_key] + + for expression in expressions: + baseline_vars_data_frame[expression] = baseline_vars_data_frame.eval(expression) + if filter_by is not None: + filter_dummy = baseline_vars_data_frame[filter_by] + if weight_variable is None: + weight_variable = "weight" + baseline_vars_data_frame[weight_variable] = uniform_weight + baseline_vars_data_frame[weight_variable] = baseline_vars_data_frame[weight_variable] * filter_dummy + # We drop variables that are in values from baseline_vars_data_frame + dropped_columns = [column for column in baseline_vars_data_frame.columns if column in values] + baseline_vars_data_frame.drop(columns=dropped_columns, inplace=True) + + data_frame = pd.concat( + [baseline_vars_data_frame, data_frame], + axis=1, + ) + + if values: + data_frame_by_value = {} + for value in values: + if aggfunc in ["mean", "sum", "sum_abs", "count"]: + data_frame[value] = ( + data_frame[value] * data_frame[weight_variable] + if aggfunc != "sum_abs" + else data_frame[value].abs() * data_frame[weight_variable] + ) + data_frame[value] = data_frame[value].fillna(missing_variable_default_value) + pivot_sum = data_frame.pivot_table(index=index, columns=columns, values=value, aggfunc="sum") + pivot_mass = data_frame.pivot_table( + index=index, columns=columns, values=weight_variable, aggfunc="sum" + ) + if aggfunc == "mean": + try: # Deal with a pivot_table pandas bug https://github.com/pandas-dev/pandas/issues/17038 + result = pivot_sum / pivot_mass.loc[weight_variable] + except KeyError: + result = pivot_sum / pivot_mass + elif aggfunc in ["sum", "sum_abs"]: + result = pivot_sum + elif aggfunc == "count": + result = pivot_mass.rename(columns={weight_variable: value}, index={weight_variable: value}) + + elif aggfunc in ["min", "max"]: + data_frame[value].fillna(missing_variable_default_value, inplace=True) + result = data_frame.pivot_table(index=index, columns=columns, values=value, aggfunc=aggfunc) + + data_frame_by_value[value] = result + + if len(list(data_frame_by_value.keys())) > 1: + if concat_axis is None: + return data_frame_by_value + else: + assert concat_axis in [0, 1] + return pd.concat(data_frame_by_value.values(), axis=concat_axis) + else: + return next(iter(data_frame_by_value.values())) + + else: + assert aggfunc == "count", "Can only use count for aggfunc if no values" + return data_frame.pivot_table(index=index, columns=columns, values=weight_variable, aggfunc="sum") + + +def create_data_frame_by_entity( + simulation: Simulation, + variables: Optional[list[str]] = None, + expressions: Optional[list[str]] = None, + filter_by: Optional[str] = None, + index: bool = False, + period: Optional[Union[int, str, Period]] = None, + merge: bool = False, +) -> Union[pd.DataFrame, dict[str, pd.DataFrame]]: + """ + Create dataframe(s) of variables for the whole selected population. + + Args: + simulation (Simulation): Simulation to use. + variables (Optional[List], optional): Variables to retrieve, None means all. Defaults to None. + expressions (Optional[List[str]], optional): _description_. Defaults to None. + filter_by (str, optional): Filter variable or expression to use. Defaults to None. + index (bool, optional): Whether to use index (id) variables. Defaults to False. + period (Optional[Union[int, str, Period]], optional): Period of the computation. Defaults to None. + merge (bool, optional): Wheter to merge the datafrales into one. Defaults to False. + + Returns: + pd.DataFrame of Dict: Dataframe(s) with the variables values + """ + assert simulation is not None + id_variable_by_entity_key = simulation.id_variable_by_entity_key + tax_benefit_system = simulation.tax_benefit_system + assert tax_benefit_system is not None + + if period is None: + period = simulation.period + + assert variables or index or expressions or filter_by + + if merge: + index = True + if expressions is None: + expressions = [] + + if filter_by is not None: + if filter_by in tax_benefit_system.variables: + variables.append(filter_by) + filter_entity_key = tax_benefit_system.variables.get(filter_by).entity.key + else: + filter_entity_key = assert_variables_in_same_entity(tax_benefit_system, get_words(filter_by)) + expressions.append(filter_by) + + expressions_by_entity_key = {} + for expression in expressions: + expression_variables = get_words(expression) + entity_key = assert_variables_in_same_entity(tax_benefit_system, expression_variables) + if entity_key in expressions_by_entity_key: + expressions_by_entity_key[entity_key].append(expression) + else: + expressions_by_entity_key[entity_key] = [expression] + variables += expression_variables + + variables = set(variables) + + missing_variables = set(variables).difference(set(tax_benefit_system.variables.keys())) + if missing_variables: + log.info(f"These variables aren't part of the tax-benefit system: {missing_variables} and thus ignored") + + columns_to_fetch = [ + tax_benefit_system.variables.get(variable_name) + for variable_name in variables + if tax_benefit_system.variables.get(variable_name) is not None + ] + + tax_vars = list(tax_benefit_system.variables.keys()) + assert len(columns_to_fetch) >= 1, ( + f"None of the requested variables {variables} are in the tax-benefit-system {tax_vars}" + ) + + assert simulation is not None + + openfisca_data_frame_by_entity_key = {} + non_person_entities = [] + + for entity in tax_benefit_system.entities: + entity_key = entity.key + column_names = [column.name for column in columns_to_fetch if column.entity.key == entity_key] + openfisca_data_frame_by_entity_key[entity_key] = pd.DataFrame( + { + column_name: simulation.adaptative_calculate_variable(column_name, period=period) + for column_name in column_names + } + ) + if entity.is_person: + person_entity = entity + else: + non_person_entities.append(entity) + + if index: + person_data_frame = openfisca_data_frame_by_entity_key.get(person_entity.key) + person_data_frame.index.name = id_variable_by_entity_key.get("person", "person_id") + if person_data_frame is None: + person_data_frame = pd.DataFrame() + for entity in non_person_entities: + entity_key_id = id_variable_by_entity_key[entity.key] + person_data_frame[entity_key_id] = simulation.populations[entity.key].members_entity_id + flattened_roles = entity.flattened_roles + index_by_role = {flattened_roles[index]: index for index in range(len(flattened_roles))} + person_data_frame["{}_{}".format(entity.key, "role")] = pd.Series( + simulation.populations[entity.key].members_role + ).map(index_by_role) + person_data_frame["{}_{}".format(entity.key, "position")] = simulation.populations[ + entity.key + ].members_position + + # Set index names as entity_id + openfisca_data_frame_by_entity_key[entity.key].index.name = entity_key_id + openfisca_data_frame_by_entity_key[entity.key].reset_index(inplace=True) + person_data_frame.reset_index(inplace=True) + + for entity_key, expressions in expressions_by_entity_key.items(): + data_frame = openfisca_data_frame_by_entity_key[entity_key] + for expression in expressions: + data_frame[expression] = data_frame.eval(expression) + + if filter_by is not None: + openfisca_data_frame_by_entity_key[filter_entity_key] = ( + openfisca_data_frame_by_entity_key[filter_entity_key] + .loc[openfisca_data_frame_by_entity_key[filter_entity_key][filter_by]] + .copy() + ) + + if not merge: + return openfisca_data_frame_by_entity_key + else: + for entity_key, openfisca_data_frame in openfisca_data_frame_by_entity_key.items(): + if entity_key != person_entity.key: + entity_key_id = id_variable_by_entity_key[entity_key] + if len(openfisca_data_frame) > 0: + person_data_frame = person_data_frame.merge( + openfisca_data_frame.reset_index(), + left_on=entity_key_id, + right_on=entity_key_id, + ) + return person_data_frame + + +class SecretViolationError(Exception): + """Raised if the result of the simulation do not comform with regulators rules.""" + + pass + + +def compute_winners_losers( + simulation: Simulation, + baseline_simulation: Simulation, + variable: str, + filter_by: Optional[str] = None, + period: Optional[Union[int, str, Period]] = None, + absolute_minimal_detected_variation: float = 0, + relative_minimal_detected_variation: float = 0.01, + observations_threshold: Optional[int] = None, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + filtering_variable_by_entity: Optional[dict] = None, +) -> dict[str, Union[int, float]]: + """ + Compute the number of winners and losers for a given variable. + + Args: + simulation (_type_): The main simulation. + baseline_simulation (_type_): The baseline simulation + variable (str): The variable to use. + filter_by (str, optional): The variable or expression to be used as a filter. Defaults to None. + period (Optional[Union[int, str, Period]], optional): The period of the simulation. Defaults to None. + absolute_minimal_detected_variation (float, optional): Absolute minimal variation to be detected, in ratio. + Ie 0.5 means 5% of variation wont be counted.. + relative_minimal_detected_variation (float, optional): Relative minimal variation to be detected, in ratio. + Defaults to .01. + observations_threshold (int, optional): Number of observations needed to avoid a statistical secret violation. + Defaults to None. + weighted (bool, optional): Whether to use weights. Defaults to True. + alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. + Defaults to None. + filtering_variable_by_entity (_type_, optional): The variable to be used as a filter for each entity. + Defaults to None. + + Raises: + SecretViolationError: Raised when statistical secret is violated. + + Returns: + Dict[str, int]: Statistics about winners and losers between the main simulation and the baseline. + """ + weight_variable_by_entity = simulation.weight_variable_by_entity + entity_key = baseline_simulation.tax_benefit_system.variables[variable].entity.key + + # Get the results of the simulation + after = simulation.adaptative_calculate_variable(variable, period=period) + before = baseline_simulation.adaptative_calculate_variable(variable, period=period) + + # Filter if needed + if filtering_variable_by_entity is not None and filter_by is None: + filter_by = filtering_variable_by_entity.get(entity_key) + + if filter_by is not None: + filter_entity_key = baseline_simulation.tax_benefit_system.variables.get(filter_by).entity.key + assert filter_entity_key == entity_key + filter_dummy = baseline_simulation.calculate(filter_by, period=period).astype(bool) + + after = after[filter_dummy].copy() + before = before[filter_dummy].copy() + + # Define weights + weight = np.ones(len(after)) + if weighted: + if alternative_weights is not None: + weight = alternative_weights + elif weight_variable_by_entity is not None: + weight_variable = weight_variable_by_entity[entity_key] + weight = baseline_simulation.calculate(weight_variable, period=period) + else: + log.warn( + f"There is no weight variable for entity {entity_key} nor alternative weights. Switch to unweighted" + ) + + # Compute the weigthed number of zeros or non zeros + value_by_simulation = {"after": after, "before": before} + stats_by_simulation = {} + for simulation_prefix, value in value_by_simulation.items(): + stats = {} + stats["count_zero"] = (weight.astype("float64") * (absolute_minimal_detected_variation > np.abs(value))).sum() + stats["count_non_zero"] = sum(weight.astype("float64")) - stats["count_zero"] + stats_by_simulation[simulation_prefix] = stats + del stats + + # Compute the number of entity above or below after + after_value = after + before_value = before + with np.errstate(divide="ignore", invalid="ignore"): + above_after = ((after_value - before_value) / np.abs(before_value)) > relative_minimal_detected_variation + almost_zero_before = np.abs(before_value) < absolute_minimal_detected_variation + above_after[almost_zero_before * (after_value >= 0)] = (after_value >= absolute_minimal_detected_variation)[ + almost_zero_before * (after_value >= 0) + ] + with np.errstate(divide="ignore", invalid="ignore"): + below_after = ((after_value - before_value) / np.abs(before_value)) < -relative_minimal_detected_variation + below_after[almost_zero_before * (after_value < 0)] = (after_value < -absolute_minimal_detected_variation)[ + almost_zero_before * (after_value < 0) + ] + + # Check if there is a secret violation, without weights + if observations_threshold is not None: + not_legit_below = (below_after.sum() < observations_threshold) & (below_after.sum() > 0) + not_legit_above = (above_after.sum() < observations_threshold) & (above_after.sum() > 0) + if not_legit_below | not_legit_above: + raise SecretViolationError("Not enough observations involved") + + # Apply weights + above_after_count = (above_after.astype("float64") * weight.astype("float64")).sum() + below_after_count = (below_after.astype("float64") * weight.astype("float64")).sum() + total = sum(weight) + neutral = total - above_after_count - below_after_count + + return { + "total": total, + "non_zero_before": stats_by_simulation["before"]["count_non_zero"], + "non_zero_after": stats_by_simulation["after"]["count_non_zero"], + "above_after": above_after_count, + "lower_after": below_after_count, + "neutral": neutral, + "tolerance_factor_used": relative_minimal_detected_variation, + "weight_factor": 1, + } + + +def init_entity_data( + simulation: Simulation, + entity: Entity, + filtered_input_data_frame: pd.DataFrame, + period: Period, + used_as_input_variables_by_entity: dict, +): + """ + Initialize entity in simulation at some period with input provided by a dataframe. + + Args: + simulation (Simulation): The simulation to initialize. + entity (Entity): The entity which variables to initialize. + filtered_input_data_frame (pd.DataFrame): The dataframe with the variables values. + period (Period): The period to initialize. + used_as_input_variables_by_entity (Dict): The variable to be used to initialize each entity. + """ + used_as_input_variables = used_as_input_variables_by_entity[entity.key] + input_data_frame = filtered_input_data_frame + # input_data_frame = self.filter_input_variables(input_data_frame = input_data_frame) + diagnose_variable_mismatch(used_as_input_variables, input_data_frame) + + for column_name, column_serie in input_data_frame.items(): + variable_instance = simulation.tax_benefit_system.variables.get(column_name) + if variable_instance is None: + log.debug(f"Ignoring {column_name} in input data") + continue + + if variable_instance.entity.key != entity.key: + log.debug( + f"Ignoring variable {column_name} which is not part of entity {entity.key} " + f"but {variable_instance.entity.key}" + ) + continue + init_variable_in_entity(simulation, entity.key, column_name, column_serie, period) + + +def inflate( + simulation: Simulation, + inflator_by_variable: Optional[dict] = None, + period: Optional[Union[int, str, Period]] = None, + target_by_variable: Optional[dict] = None, +) -> None: + tax_benefit_system = simulation.tax_benefit_system + for variable_name in set(inflator_by_variable.keys()).union(set(target_by_variable.keys())): + assert variable_name in tax_benefit_system.variables, ( + f"Variable {variable_name} is not a valid variable of the tax-benefit system" + ) + if variable_name in target_by_variable: + inflator = inflator_by_variable[variable_name] = target_by_variable[ + variable_name + ] / simulation.compute_aggregate(variable=variable_name, period=period) + log.debug( + f"Using {inflator} as inflator for {variable_name} to reach the target " + f"{target_by_variable[variable_name]} " + ) + else: + assert variable_name in inflator_by_variable, "variable_name is not in inflator_by_variable" + target = inflator_by_variable[variable_name] * simulation.compute_aggregate( + variable=variable_name, period=period + ) + log.debug( + f"Using inflator {inflator_by_variable[variable_name]} for {variable_name}. " + f"The target is thus {target}" + ) + inflator = inflator_by_variable[variable_name] + + array = simulation.calculate_add(variable_name, period=period) + assert array is not None + simulation.delete_arrays(variable_name, period=period) # delete existing arrays + simulation.set_input(variable_name, period, inflator * array) # insert inflated array + + +def _load_table_for_survey( + config_files_directory: str, + collection: str, + survey: str, + table: str, + batch_size: Optional[int] = None, + batch_index: Optional[int] = None, + filter_by: Optional[str] = None, +): + if survey is not None: + input_data_frame = load_table( + config_files_directory=config_files_directory, + collection=collection, + survey=survey, + table=table, + batch_size=batch_size, + batch_index=batch_index, + filter_by=filter_by, + ) + else: + input_data_frame = load_table( + config_files_directory=config_files_directory, + collection=collection, + survey="input", + table=table, + batch_size=batch_size, + batch_index=batch_index, + filter_by=filter_by, + ) + return input_data_frame + + +def _input_data_table_by_entity_by_period_monolithic( + tax_benefit_system: TaxBenefitSystem, + simulation: Simulation, + period: Period, + input_data_table_by_entity: dict, + builder: SimulationBuilder, + custom_input_data_frame: Callable, + config_files_directory: str, + collection: str, + survey: Optional[str] = None, +): + """ + Initialize simulation with input data from a table for each entity and period. + """ + period = periods.period(period) + simulation_datasets = {} + entities = tax_benefit_system.entities + for entity in entities: + # Read all tables for the entity + log.debug(f"init_simulation - {period=} {entity.key=}") + table = input_data_table_by_entity.get(entity.key) + filter_by = input_data_table_by_entity.get("filter_by") + if table is None: + continue + input_data_frame = _load_table_for_survey( + config_files_directory, collection, survey, table, filter_by=filter_by + ) + simulation_datasets[entity.key] = input_data_frame + + if simulation is None: + # Instantiate simulation only for the fist period + # Next period will reuse the same simulation + for entity in entities: + table = input_data_table_by_entity.get(entity.key) + if table is None: + continue + custom_input_data_frame(input_data_frame, period=period, entity=entity.key) + builder.init_entity_structure(entity, simulation_datasets[entity.key]) # TODO complete args + simulation = builder.build(tax_benefit_system) + simulation.id_variable_by_entity_key = ( + builder.id_variable_by_entity_key + ) # Should be propagated to enhanced build + + for entity in entities: + # Load data in the simulation + table = input_data_table_by_entity.get(entity.key) + if table is None: + continue + log.debug(f"init_simulation - {entity.key=} {len(input_data_frame)=}") + simulation.init_entity_data( + entity, simulation_datasets[entity.key], period, builder.used_as_input_variables_by_entity + ) + del simulation_datasets[entity.key] + return simulation + + +def _input_data_table_by_entity_by_period_batch( + tax_benefit_system: TaxBenefitSystem, + simulation: Simulation, + period: Period, + input_data_table_by_entity: dict, + builder: SimulationBuilder, + custom_input_data_frame: Callable, + config_files_directory: str, + collection: str, + survey: Optional[str] = None, +): + """ + Initialize simulation with input data from a table for each entity and period. + """ + period = periods.period(period) + batch_size = input_data_table_by_entity.get("batch_size") + batch_index = input_data_table_by_entity.get("batch_index", 0) + batch_entity = input_data_table_by_entity.get("batch_entity") + batch_entity_key = input_data_table_by_entity.get("batch_entity_key") + filtered_entity = input_data_table_by_entity.get("filtered_entity") + filtered_entity_on_key = input_data_table_by_entity.get("filtered_entity_on_key") + if not batch_entity or not batch_entity_key or not filtered_entity or not filtered_entity_on_key: + raise SurveyManagerError( + "batch_entity, batch_entity_key, filtered_entity and filtered_entity_on_key are required" + ) + simulation_datasets = { + batch_entity: { + "table_key": batch_entity_key, + "input_data_frame": None, + "entity": None, + }, + filtered_entity: { + "table_key": filtered_entity_on_key, + "input_data_frame": None, + "entity": None, + }, + } + batch_entity_ids = None + entities = tax_benefit_system.entities + + if len(entities) > 2: + # Batch mode could work only with batch_entity and filtered_entity, and no others + warnings.warn( + "survey-manager.simulation._input_data_table_by_entity_by_period_batch : " + f"Your TaxBenefitSystem has {len(entities)} entities but we will only load " + f"{batch_entity} and {filtered_entity}.", + stacklevel=2, + ) + + for entity_name, entity_data in simulation_datasets.items(): + # Find Identity object from TaxBenefitSystem + for entity in entities: + if entity.key == entity_name: + entity_data["entity"] = entity + break + + # Load the batch entity + table = input_data_table_by_entity[batch_entity] + input_data_frame = _load_table_for_survey( + config_files_directory, collection, survey, table, batch_size, batch_index + ) + batch_entity_ids = input_data_frame[batch_entity_key].to_list() + simulation_datasets[batch_entity]["input_data_frame"] = input_data_frame + + # Load the filtered entity + table = input_data_table_by_entity[filtered_entity] + filter_by = [(filtered_entity_on_key, "in", batch_entity_ids)] + input_data_frame = _load_table_for_survey(config_files_directory, collection, survey, table, filter_by=filter_by) + simulation_datasets[filtered_entity]["input_data_frame"] = input_data_frame + + if simulation is None: + for entity_name, entity_data in simulation_datasets.items(): + custom_input_data_frame(entity_data["input_data_frame"], period=period, entity=entity_name) + builder.init_entity_structure(entity_data["entity"], entity_data["input_data_frame"]) + simulation = builder.build(tax_benefit_system) + simulation.id_variable_by_entity_key = ( + builder.id_variable_by_entity_key + ) # Should be propagated to enhanced build + for _entity_name, entity_data in simulation_datasets.items(): + simulation.init_entity_data( + entity_data["entity"], entity_data["input_data_frame"], period, builder.used_as_input_variables_by_entity + ) + return simulation + + +def init_simulation( + tax_benefit_system: TaxBenefitSystem, + period: Union[str, int, Period], + data: dict, +): + builder = SimulationBuilder() + builder.create_entities(tax_benefit_system) + + collection = data.get("collection") + custom_input_data_frame = data.get("custom_input_data_frame", do_nothing) + data_year = data.get("data_year") + survey = data.get("survey") + config_files_directory = data.get("config_files_directory") + builder.used_as_input_variables = data.get("used_as_input_variables") + builder.id_variable_by_entity_key = data.get("id_variable_by_entity_key") + builder.role_variable_by_entity_key = data.get("role_variable_by_entity_key") + builder.tax_benefit_system = tax_benefit_system + + default_source_types = [ + "input_data_frame", + "input_data_table", + "input_data_frame_by_entity", + "input_data_frame_by_entity_by_period", + "input_data_table_by_entity_by_period", + "input_data_table_by_period", + ] + source_types = [source_type_ for source_type_ in default_source_types if data.get(source_type_) is not None] + assert len(source_types) < 2, "There are too many data source types" + assert len(source_types) >= 1, f"There should be one data source type included in {default_source_types}" + source_type = source_types[0] + source = data[source_type] + + if source_type == "input_data_frame_by_entity": + assert data_year is not None + source_type = "input_data_frame_by_entity_by_period" + source = {periods.period(data_year): source} + + input_data_survey_prefix = data.get("input_data_survey_prefix") if data is not None else None + + if source_type == "input_data_frame": + simulation = builder.init_all_entities(source, period) + + if source_type == "input_data_table": + # Case 1: fill simulation with a unique input_data_frame given by the attribute + if input_data_survey_prefix is not None: + openfisca_survey_collection = SurveyCollection.load(collection=collection) + openfisca_survey = openfisca_survey_collection.get_survey(f"{input_data_survey_prefix}_{data_year}") + input_data_frame = openfisca_survey.get_values(table="input").reset_index(drop=True) + else: + raise NotImplementedError + + custom_input_data_frame(input_data_frame, period=period) + simulation = builder.init_all_entities(input_data_frame, builder, period) # monolithic dataframes + + elif source_type == "input_data_table_by_period": + # Case 2: fill simulation with input_data_frame by period containing all entity variables + input_data_table_by_period = data.get("input_data_table_by_period") + for period, table in input_data_table_by_period.items(): + period = periods.period(period) + log.debug(f"From survey {survey} loading table {table}") + input_data_frame = load_table( + config_files_directory=config_files_directory, + collection=collection, + survey=survey, + input_data_survey_prefix=input_data_survey_prefix, + table=table, + ) + custom_input_data_frame(input_data_frame, period=period) + simulation = builder.init_all_entities(input_data_frame, builder, period) # monolithic dataframes + + elif source_type == "input_data_frame_by_entity_by_period": + for period, input_data_frame_by_entity in source.items(): + period = periods.period(period) + for entity in tax_benefit_system.entities: + input_data_frame = input_data_frame_by_entity.get(entity.key) + if input_data_frame is None: + continue + custom_input_data_frame(input_data_frame, period=period, entity=entity.key) + builder.init_entity_structure(entity, input_data_frame) # TODO complete args + + simulation = builder.build(tax_benefit_system) + simulation.id_variable_by_entity_key = ( + builder.id_variable_by_entity_key + ) # Should be propagated to enhanced build + + for period, input_data_frame_by_entity in source.items(): + for entity in tax_benefit_system.entities: + input_data_frame = input_data_frame_by_entity.get(entity.key) + if input_data_frame is None: + log.debug(f"No input_data_frame found for entity {entity} at period {period}") + continue + custom_input_data_frame(input_data_frame, period=period, entity=entity.key) + simulation.init_entity_data( + entity, input_data_frame, period, builder.used_as_input_variables_by_entity + ) + + elif source_type == "input_data_table_by_entity_by_period": + # Case 3: fill simulation with input_data_table by entity_by_period containing a dictionnary + # of all periods containing a dictionnary of entity variables + input_data_table_by_entity_by_period = source + simulation = None + for period, input_data_table_by_entity in input_data_table_by_entity_by_period.items(): + if input_data_table_by_entity.get("batch_size"): + simulation = _input_data_table_by_entity_by_period_batch( + tax_benefit_system, + simulation, + period, + input_data_table_by_entity, + builder, + custom_input_data_frame, + config_files_directory, + collection, + survey, + ) + else: + simulation = _input_data_table_by_entity_by_period_monolithic( + tax_benefit_system, + simulation, + period, + input_data_table_by_entity, + builder, + custom_input_data_frame, + config_files_directory, + collection, + survey, + ) + + else: + pass + + if data_year is not None: + simulation.period = periods.period(data_year) + + return simulation + + +def init_variable_in_entity( + simulation: Simulation, + entity: str, + variable_name: str, + series: pd.Series, + period: Period, +): + variable = simulation.tax_benefit_system.variables[variable_name] + + # np.issubdtype cannot handles categorical variables + if (not isinstance(series.dtype, pd.CategoricalDtype)) and pd.api.types.is_float_dtype(series.values.dtype): + if series.isnull().any(): + log.debug( + f"There are {series.isnull().sum()} NaN values for " + f"{series.notnull().sum()} non NaN values in variable {variable_name}" + ) + log.debug( + f"We convert NaN values of variable {variable_name} to {variable.default_value} its default value" + ) + series = pd.to_numeric(series).fillna(variable.default_value).astype(variable.value_type) + assert series.notnull().all(), ( + f"There are {series.isnull().sum()} NaN values for " + f"{series.notnull().sum()} non NaN values in variable {variable_name}" + ) + + enum_variable_imputed_as_enum = variable.value_type == Enum and ( + isinstance(series.dtype, pd.CategoricalDtype) + or not (np.issubdtype(series.values.dtype, np.integer) or np.issubdtype(series.values.dtype, float)) + ) + + if enum_variable_imputed_as_enum: + if series.isnull().any(): + log.debug( + f"There are {series.isnull().sum()} NaN values " + f"({series.isnull().mean() * 100}% of the array) in variable {variable_name}" + ) + log.debug( + f"We convert these NaN values of variable {variable_name} to " + f"{variable.default_value._name_} its default value" + ) + series.fillna(variable.default_value._name_, inplace=True) + possible_values = variable.possible_values + if isinstance(series.dtype, pd.CategoricalDtype): + series = series.cat.codes + else: + msg = f"There are errors with {variable_name}" + assert series.isin(list(possible_values._member_names_)).all(), msg + series = series.apply(lambda v: variable.possible_values[v].index) + + if series.values.dtype != variable.dtype: + log.debug(f"Converting {variable_name} from dtype {series.values.dtype} to {variable.dtype}") + + array = series.values.astype(variable.dtype) + np_array = np.array(array, dtype=variable.dtype) + + if (variable.value_type == Enum) and ( + np.issubdtype(series.values.dtype, np.integer) or np.issubdtype(series.values.dtype, float) + ): + np_array = EnumArray(np_array, variable.possible_values) + + if variable.definition_period == YEAR and period.unit == MONTH: + # Some variables defined for a year are present in month/quarter dataframes + # Cleaning the dataframe would probably be better in the long run + log.warn( + f"Trying to set a monthly value for variable {variable_name}, which is defined on a year. " + "The montly values you provided will be summed." + ) + + if simulation.get_array(variable_name, period.this_year) is not None: + array_sum = simulation.get_array(variable_name, period.this_year) + np_array + simulation.set_input(variable_name, period.this_year, array_sum) + else: + simulation.set_input(variable_name, period.this_year, np_array) + + else: + simulation.set_input(variable_name, period, np_array) + + +def new_from_tax_benefit_system( + tax_benefit_system: TaxBenefitSystem, + debug: bool = False, + trace: bool = False, + data: Optional[dict] = None, + memory_config: MemoryConfig = None, + period: Optional[Union[int, str, Period]] = None, + custom_initialize: Optional[Callable] = None, +) -> Simulation: + """ + Create and initialize a simulation from a tax and benefit system and data. + + Args: + tax_benefit_system (TaxBenefitSystem): The tax and benefit system + debug (bool, optional): Whether to activate debugging. Defaults to False. + trace (bool, optional): Whether to activate tracing. Defaults to False. + data (Dict, optional): The information about data. Defaults to None. + memory_config (MemoryConfig, optional): The memory handling config. Defaults to None. + period (Optional[Union[int, str, Period]], optional): The period of the simulation. Defaults to None. + custom_initialize (Callable, optional): The post-processing initialization function. Defaults to None. + + Returns: + Simulation: The completely initialized function + """ + + simulation = Simulation.init_simulation(tax_benefit_system, period, data) + simulation.debug = debug + simulation.trace = trace + simulation.opt_out_cache = simulation.tax_benefit_system.cache_blacklist is not None + simulation.memory_config = memory_config + + if custom_initialize: + custom_initialize(simulation) + + return simulation + + +def print_memory_usage(simulation: Simulation): + """ + Print memory usage. + + Args: + simulation (Simulation): The simulation which memory usage is to be printed + """ + memory_usage_by_variable = simulation.get_memory_usage()["by_variable"] + try: + usage_stats = simulation.tracer.usage_stats + except AttributeError: + log.warning( + "The simulation trace mode is not activated. You need to activate it to get stats " + "about variable usage (hits)." + ) + usage_stats = None + infos_lines = [] + + for variable, infos in memory_usage_by_variable.items(): + hits = usage_stats[variable]["nb_requests"] if usage_stats else None + infos_lines.append( + ( + infos["total_nb_bytes"], + variable, + "{}: {} periods * {} cells * item size {} ({}) = {} with {} hits".format( + variable, + infos["nb_arrays"], + infos["nb_cells_by_array"], + infos["cell_size"], + infos["dtype"], + humanize.naturalsize(infos["total_nb_bytes"], gnu=True), + hits, + ), + ) + ) + infos_lines.sort() + for _, _, line in infos_lines: + log.info("%s", line.rjust(100)) + + +def set_weight_variable_by_entity( + simulation: Simulation, + weight_variable_by_entity: dict, +) -> None: + """ + Set weight variable for each entity. + + Args: + simulation (Simulation): The simulation concerned. + weight_variable_by_entity (Dict): The weight variable for each entity. + """ + simulation.weight_variable_by_entity = weight_variable_by_entity + + +def summarize_variable( + simulation: Simulation, + variable: Optional[str] = None, + weighted: bool = False, + force_compute: bool = False, +): + """Print a summary of a variable including its memory usage. + + Args: + variable(string): The variable being summarized + use_baseline(bool): The tax-benefit-system considered + weighted(bool): Whether the produced statistics should be weigthted or not + force_compute(bool): Whether the computation of the variable should be forced + """ + tax_benefit_system = simulation.tax_benefit_system + variable_instance = tax_benefit_system.variables.get(variable) + assert variable_instance is not None, f"{variable} is not a valid variable" + + default_value = variable_instance.default_value + value_type = variable_instance.value_type + + if variable_instance.is_neutralized: + log.info( + "%s: neutralized variable (%s, default = %s)", + variable, + str(np.dtype(value_type)), + default_value, + ) + return + + if weighted: + weight_variable = simulation.weight_variable_by_entity[variable_instance.entity.key] + weights = simulation.calculate(weight_variable, simulation.period) + + infos = simulation.get_memory_usage(variables=[variable])["by_variable"].get(variable) + if not infos: + if force_compute: + simulation.adaptative_calculate_variable(variable=variable, period=simulation.period) + simulation.summarize_variable(variable=variable, weighted=weighted) + return + else: + log.info( + "%s is not computed yet. Use keyword argument force_compute = True", + variable, + ) + return + + header_line = "{}: {} periods * {} cells * item size {} ({}, default = {}) = {}".format( + variable, + infos["nb_arrays"], + infos["nb_cells_by_array"], + infos["cell_size"], + str(np.dtype(infos["dtype"])), + default_value, + humanize.naturalsize(infos["total_nb_bytes"], gnu=True), + ) + log.info("%s", header_line) + log.info("Details:") + holder = simulation.get_holder(variable) + if holder is not None: + if holder.variable.definition_period == ETERNITY: + array = holder.get_array(ETERNITY) + else: + for period in sorted(simulation.get_known_periods(variable)): + array = holder.get_array(period) + if array.shape == (): + log.info("%s: always = %s", period, array) + continue + + if value_type == Enum: + possible_values = variable_instance.possible_values + categories_by_index = dict( + zip(range(len(possible_values._member_names_)), possible_values._member_names_) + ) + categories_type = pd.api.types.CategoricalDtype( + categories=possible_values._member_names_, ordered=True + ) + df = pd.DataFrame({variable: array}).replace(categories_by_index).astype(categories_type) + df["weights"] = weights if weighted else 1 + groupby = df.groupby(variable)["weights"].sum() + total = groupby.sum() + expr = [f" {index} = {row:.2e} ({row / total:.1%})" for index, row in groupby.items()] + log.info("%s: %s.", period, ",".join(expr)) + continue + + # Handle numeric types + default_array = np.array(default_value, dtype=infos["dtype"]) + total = np.sum(array * weights) if weighted else np.sum(array) + mean_val = (array * weights).sum() / weights.sum() if weighted else array.mean() + total_val = f"mass = {total:.2e}" + stats_val = ( + f"min = {array.min()}, max = {array.max()}, {total_val}, " + f"default = {(array == default_array).sum() / len(array):.1%}, " + f"median = {np.median(array)}" + ) + log.info("%s: mean = %s, %s", period, mean_val, stats_val) + + +# Monkey patching + +Simulation.adaptative_calculate_variable = adaptative_calculate_variable +Simulation.compute_aggregate = compute_aggregate +Simulation.compute_pivot_table = compute_pivot_table +Simulation.create_data_frame_by_entity = create_data_frame_by_entity +Simulation.compute_quantiles = compute_quantiles +Simulation.compute_winners_losers = compute_winners_losers +Simulation.new_from_tax_benefit_system = new_from_tax_benefit_system +Simulation.inflate = inflate +Simulation.init_entity_data = init_entity_data +Simulation.init_simulation = init_simulation +Simulation.init_variable_in_entity = init_variable_in_entity +Simulation.print_memory_usage = print_memory_usage +Simulation.set_weight_variable_by_entity = set_weight_variable_by_entity +Simulation.summarize_variable = summarize_variable diff --git a/openfisca_survey_manager/policy/tests/__init__.py b/openfisca_survey_manager/policy/tests/__init__.py new file mode 100644 index 00000000..f9d9e925 --- /dev/null +++ b/openfisca_survey_manager/policy/tests/__init__.py @@ -0,0 +1 @@ +# Tests for policy package (simulations, simulation_builder, aggregates). diff --git a/openfisca_survey_manager/tests/test_aggregates.py b/openfisca_survey_manager/policy/tests/test_aggregates.py similarity index 96% rename from openfisca_survey_manager/tests/test_aggregates.py rename to openfisca_survey_manager/policy/tests/test_aggregates.py index 38c461d4..7ec83109 100644 --- a/openfisca_survey_manager/tests/test_aggregates.py +++ b/openfisca_survey_manager/policy/tests/test_aggregates.py @@ -1,7 +1,7 @@ import pytest from openfisca_country_template.reforms.modify_social_security_taxation import modify_social_security_taxation -from openfisca_survey_manager.aggregates import AbstractAggregates +from openfisca_survey_manager.policy import AbstractAggregates from openfisca_survey_manager.tests.test_scenario import create_randomly_initialized_survey_scenario diff --git a/openfisca_survey_manager/tests/test_compute_aggregate.py b/openfisca_survey_manager/policy/tests/test_compute_aggregate.py similarity index 100% rename from openfisca_survey_manager/tests/test_compute_aggregate.py rename to openfisca_survey_manager/policy/tests/test_compute_aggregate.py diff --git a/openfisca_survey_manager/tests/test_compute_pivot_table.py b/openfisca_survey_manager/policy/tests/test_compute_pivot_table.py similarity index 100% rename from openfisca_survey_manager/tests/test_compute_pivot_table.py rename to openfisca_survey_manager/policy/tests/test_compute_pivot_table.py diff --git a/openfisca_survey_manager/tests/test_compute_winners_losers.py b/openfisca_survey_manager/policy/tests/test_compute_winners_losers.py similarity index 98% rename from openfisca_survey_manager/tests/test_compute_winners_losers.py rename to openfisca_survey_manager/policy/tests/test_compute_winners_losers.py index 3e50d93a..63fd7867 100644 --- a/openfisca_survey_manager/tests/test_compute_winners_losers.py +++ b/openfisca_survey_manager/policy/tests/test_compute_winners_losers.py @@ -1,7 +1,7 @@ import pytest from openfisca_country_template.reforms.modify_social_security_taxation import modify_social_security_taxation -from openfisca_survey_manager.simulations import SecretViolationError +from openfisca_survey_manager.policy import SecretViolationError from openfisca_survey_manager.tests.test_scenario import create_randomly_initialized_survey_scenario diff --git a/openfisca_survey_manager/tests/test_create_data_frame_by_entity.py b/openfisca_survey_manager/policy/tests/test_create_data_frame_by_entity.py similarity index 100% rename from openfisca_survey_manager/tests/test_create_data_frame_by_entity.py rename to openfisca_survey_manager/policy/tests/test_create_data_frame_by_entity.py diff --git a/openfisca_survey_manager/tests/test_marginal_tax_rate.py b/openfisca_survey_manager/policy/tests/test_marginal_tax_rate.py similarity index 100% rename from openfisca_survey_manager/tests/test_marginal_tax_rate.py rename to openfisca_survey_manager/policy/tests/test_marginal_tax_rate.py diff --git a/openfisca_survey_manager/tests/test_summarize_variables.py b/openfisca_survey_manager/policy/tests/test_summarize_variables.py similarity index 95% rename from openfisca_survey_manager/tests/test_summarize_variables.py rename to openfisca_survey_manager/policy/tests/test_summarize_variables.py index 5e7c2ebb..4b62f548 100644 --- a/openfisca_survey_manager/tests/test_summarize_variables.py +++ b/openfisca_survey_manager/policy/tests/test_summarize_variables.py @@ -19,7 +19,7 @@ def test_summarize_variable_log_output(caplog): The doctest used to check stdout; we now send that output to the logging system. This test captures logs and verifies the expected content is present. """ - with caplog.at_level(logging.INFO, logger="openfisca_survey_manager.simulations"): + with caplog.at_level(logging.INFO, logger="openfisca_survey_manager.policy.simulations"): survey_scenario = create_randomly_initialized_survey_scenario(collection=None) survey_scenario.summarize_variable(variable="housing_occupancy_status", force_compute=True) @@ -32,7 +32,7 @@ def test_summarize_variable_log_output(caplog): assert "owner" in text or "tenant" in text or "free_lodger" in text or "homeless" in text caplog.clear() - with caplog.at_level(logging.INFO, logger="openfisca_survey_manager.simulations"): + with caplog.at_level(logging.INFO, logger="openfisca_survey_manager.policy.simulations"): survey_scenario.summarize_variable(variable="rent", force_compute=True) messages = [r.message for r in caplog.records] @@ -43,7 +43,7 @@ def test_summarize_variable_log_output(caplog): survey_scenario.tax_benefit_systems["baseline"].neutralize_variable("age") caplog.clear() - with caplog.at_level(logging.INFO, logger="openfisca_survey_manager.simulations"): + with caplog.at_level(logging.INFO, logger="openfisca_survey_manager.policy.simulations"): survey_scenario.summarize_variable(variable="age") messages = [r.message for r in caplog.records] diff --git a/openfisca_survey_manager/simulation_builder.py b/openfisca_survey_manager/simulation_builder.py index 1feaad1f..baf45057 100644 --- a/openfisca_survey_manager/simulation_builder.py +++ b/openfisca_survey_manager/simulation_builder.py @@ -1,291 +1,25 @@ -import logging - -from openfisca_core.model_api import MONTH, YEAR -from openfisca_core.simulations.simulation_builder import SimulationBuilder - -from openfisca_survey_manager.exceptions import SurveyManagerError - -SimulationBuilder.id_variable_by_entity_key = None -SimulationBuilder.role_variable_by_entity_key = None -SimulationBuilder.used_as_input_variables = None -SimulationBuilder.used_as_input_variables_by_entity = None - - -log = logging.getLogger(__name__) - - -# Helpers - - -def diagnose_variable_mismatch(used_as_input_variables, input_data_frame): - """Diagnose variables mismatch. - - Args: - used_as_input_variables(lsit): List of variable to test presence - input_data_frame: DataFrame in which to test variables presence - - """ - variables_mismatch = ( - set(used_as_input_variables).difference(set(input_data_frame.columns)) if used_as_input_variables else None - ) - if variables_mismatch: - log.info( - "The following variables used as input variables are not present in the input data frame: " - f"\n {sorted(variables_mismatch)}" - ) - if variables_mismatch: - log.debug(f"The following variables are used as input variables: \n {sorted(used_as_input_variables)}") - log.debug(f"The input_data_frame contains the following variables: \n {sorted(input_data_frame.columns)}") - - -# SimulationBuilder monkey-patched methods - - -def _set_id_variable_by_entity_key(builder) -> dict[str, str]: - """Identify and sets the correct ids for the different entities.""" - if builder.id_variable_by_entity_key is None: - log.debug("Use default id_variable names") - builder.id_variable_by_entity_key = { - entity.key: entity.key + "_id" for entity in builder.tax_benefit_system.entities - } - - return builder.id_variable_by_entity_key - - -def _set_role_variable_by_entity_key(builder) -> dict[str, str]: - """Identify and sets the correct roles for the different entities.""" - if builder.role_variable_by_entity_key is None: - builder.role_variable_by_entity_key = { - entity.key: entity.key + "_role_index" for entity in builder.tax_benefit_system.entities - } - - return builder.role_variable_by_entity_key - - -def _set_used_as_input_variables_by_entity(builder) -> dict[str, list[str]]: - """Identify and sets the correct input variables for the different entities.""" - if builder.used_as_input_variables_by_entity is not None: - return - - tax_benefit_system = builder.tax_benefit_system - - tax_variables = set(tax_benefit_system.variables.keys()) - input_variables = set(builder.used_as_input_variables) - assert input_variables <= tax_variables, ( - "Some variables used as input variables are not part of the tax benefit system:\n " - f"{input_variables.difference(tax_variables)}" - ) - - builder.used_as_input_variables_by_entity = {} - - for entity in tax_benefit_system.entities: - builder.used_as_input_variables_by_entity[entity.key] = [ - variable - for variable in builder.used_as_input_variables - if tax_benefit_system.get_variable(variable).entity.key == entity.key - ] - - return builder.used_as_input_variables_by_entity - - -def filter_input_variables(builder, input_data_frame, tax_benefit_system): - """Filter the input data frame from variables that won't be used or are set to be computed. - - Args: - input_data_frame: Input dataframe (Default value = None) - - Returns: - pd.DataFrame: filtered dataframe - - """ - assert input_data_frame is not None - id_variable_by_entity_key = builder.id_variable_by_entity_key - role_variable_by_entity_key = builder.role_variable_by_entity_key - used_as_input_variables = builder.used_as_input_variables - - variables = tax_benefit_system.variables - - id_variables = [id_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] - role_variables = [role_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] - - log.debug(f"Variable used_as_input_variables in filter: \n {used_as_input_variables}") - - unknown_columns = [] - for column_name in input_data_frame: - if column_name in id_variables + role_variables: - continue - if column_name not in variables: - unknown_columns.append(column_name) - - input_data_frame.drop(unknown_columns, axis=1, inplace=True) - - if unknown_columns: - log.debug(f"The following unknown columns {sorted(unknown_columns)}, are dropped from input table") - - used_columns = [] - dropped_columns = [] - for column_name in input_data_frame: - if column_name in id_variables + role_variables: - continue - variable = variables[column_name] - # Keeping the calculated variables that are initialized by the input data - if variable.formulas: - if column_name in used_as_input_variables: - used_columns.append(column_name) - continue - - dropped_columns.append(column_name) - - input_data_frame.drop(dropped_columns, axis=1, inplace=True) - - if used_columns: - log.debug( - f"These columns are not dropped because present in used_as_input_variables:\n {sorted(used_columns)}" - ) - if dropped_columns: - log.debug( - "These columns in survey are set to be calculated, we drop them from the input table:\n " - f"{sorted(dropped_columns)}" - ) - - log.debug(f"Keeping the following variables in the input_data_frame:\n {sorted(input_data_frame.columns)}") - return input_data_frame - - -def init_all_entities(builder, input_data_frame, period=None): - assert period is not None - log.debug(f"Initialasing simulation using input_data_frame for period {period}") - builder._set_id_variable_by_entity_key() - builder._set_role_variable_by_entity_key() - - if period.unit == YEAR: # 1. year - simulation = builder.init_simulation_with_data_frame( - input_data_frame=input_data_frame, - period=period, - ) - elif period.unit == MONTH and period.size == 3: # 2. quarter - for offset in range(period.size): - period_item = period.first_month.offset(offset, MONTH) - simulation = builder.init_simulation_with_data_frame( - input_data_frame=input_data_frame, - period=period_item, - ) - elif period.unit == MONTH and period.size == 1: # 3. months - simulation = builder.init_simulation_with_data_frame( - input_data_frame=input_data_frame, - period=period, - ) - else: - raise SurveyManagerError(f"Invalid period {period}") - - simulation.id_variable_by_entity_key = builder.id_variable_by_entity_key - return simulation - - -def init_entity_structure(builder, entity, input_data_frame): - """Initialize sthe simulation with tax_benefit_system entities and input_data_frame. - - Args: - tax_benefit_system(TaxBenfitSystem): The TaxBenefitSystem to get the structure from - entity(Entity): The entity to initialize structure - input_data_frame(pd.DataFrame): The input - builder(Builder): The builder - - """ - tax_benefit_system = builder.tax_benefit_system - builder._set_id_variable_by_entity_key() - builder._set_role_variable_by_entity_key() - builder._set_used_as_input_variables_by_entity() - - input_data_frame = builder.filter_input_variables(input_data_frame, tax_benefit_system) - - id_variables = [builder.id_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] - role_variables = [ - builder.role_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities - ] - - if entity.is_person: - for id_variable in id_variables + role_variables: - assert id_variable in input_data_frame.columns, f"Variable {id_variable} is not present in input dataframe" - - ids = range(len(input_data_frame)) - if entity.is_person: - builder.declare_person_entity(entity.key, ids) - for group_entity in tax_benefit_system.group_entities: - _key = group_entity.key - _id_variable = builder.id_variable_by_entity_key[_key] - _role_variable = builder.role_variable_by_entity_key[_key] - group_population = builder.declare_entity( - _key, input_data_frame[_id_variable].drop_duplicates().sort_values().values - ) - builder.join_with_persons( - group_population, - input_data_frame[_id_variable].astype("int").values, - input_data_frame[_role_variable].astype("int").values, - ) - - -def init_simulation_with_data_frame(builder, input_data_frame, period): - """Initialize the simulation period with current input_data_frame for an entity if specified.""" - used_as_input_variables = builder.used_as_input_variables - id_variable_by_entity_key = builder.id_variable_by_entity_key - role_variable_by_entity_key = builder.role_variable_by_entity_key - tax_benefit_system = builder.tax_benefit_system - assert tax_benefit_system is not None - - diagnose_variable_mismatch(used_as_input_variables, input_data_frame) - - id_variables = [id_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] - role_variables = [role_variable_by_entity_key[_entity.key] for _entity in tax_benefit_system.group_entities] - - for id_variable in id_variables + role_variables: - assert id_variable in input_data_frame.columns, f"Variable {id_variable} is not present in input dataframe" - - input_data_frame = builder.filter_input_variables(input_data_frame, tax_benefit_system) - - index_by_entity_key = {} - - for entity in tax_benefit_system.entities: - builder.init_entity_structure(entity, input_data_frame) - - if entity.is_person: - continue - - else: - index_by_entity_key[entity.key] = ( - input_data_frame.loc[ - input_data_frame[role_variable_by_entity_key[entity.key]] == 0, - id_variable_by_entity_key[entity.key], - ] - .sort_values() - .index - ) - - for column_name, column_serie in input_data_frame.items(): - if role_variable_by_entity_key is not None and column_name in role_variable_by_entity_key.values(): - continue - - if id_variable_by_entity_key is not None and column_name in id_variable_by_entity_key.values(): - continue - - simulation = builder.build(tax_benefit_system) - entity = tax_benefit_system.variables[column_name].entity - if entity.is_person: - simulation.init_variable_in_entity(entity.key, column_name, column_serie, period) - else: - simulation.init_variable_in_entity( - entity.key, column_name, column_serie[index_by_entity_key[entity.key]], period - ) - - assert builder.id_variable_by_entity_key is not None - simulation.id_variable_by_entity_key = builder.id_variable_by_entity_key - return simulation - - -SimulationBuilder._set_id_variable_by_entity_key = _set_id_variable_by_entity_key -SimulationBuilder._set_role_variable_by_entity_key = _set_role_variable_by_entity_key -SimulationBuilder._set_used_as_input_variables_by_entity = _set_used_as_input_variables_by_entity -SimulationBuilder.filter_input_variables = filter_input_variables -SimulationBuilder.init_all_entities = init_all_entities -SimulationBuilder.init_entity_structure = init_entity_structure -SimulationBuilder.init_simulation_with_data_frame = init_simulation_with_data_frame +"""Re-export for backward compatibility. + +Prefer: from openfisca_survey_manager.policy import ... +or from openfisca_survey_manager.policy.simulation_builder import ... +""" + +import warnings + +from openfisca_survey_manager.policy.simulation_builder import ( + SimulationBuilder, + diagnose_variable_mismatch, +) + +warnings.warn( + "openfisca_survey_manager.simulation_builder is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy import ... " + "or from openfisca_survey_manager.policy.simulation_builder import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "SimulationBuilder", + "diagnose_variable_mismatch", +] diff --git a/openfisca_survey_manager/simulations.py b/openfisca_survey_manager/simulations.py index 090ce1b8..5d6479f1 100644 --- a/openfisca_survey_manager/simulations.py +++ b/openfisca_survey_manager/simulations.py @@ -1,1465 +1,57 @@ -"""Monkey-patch openfisca_core.simulations.Simulation to work with pandas.""" +"""Re-export for backward compatibility. -import logging -import re -import warnings -from typing import Any, Callable, Optional, Union - -import humanize -import numpy as np -import pandas as pd -from numpy import logical_or as or_ -from openfisca_core import periods -from openfisca_core.indexed_enums import Enum, EnumArray -from openfisca_core.memory_config import MemoryConfig -from openfisca_core.periods import ETERNITY, MONTH, YEAR -from openfisca_core.simulations import Simulation -from openfisca_core.types import Array, Period, TaxBenefitSystem -from openfisca_core.types import CoreEntity as Entity - -from openfisca_survey_manager.exceptions import SurveyManagerError -from openfisca_survey_manager.simulation_builder import SimulationBuilder, diagnose_variable_mismatch -from openfisca_survey_manager.statshelpers import mark_weighted_percentiles -from openfisca_survey_manager.survey_collections import SurveyCollection -from openfisca_survey_manager.utils import do_nothing, load_table - -log = logging.getLogger(__name__) - - -# Helpers - - -def assert_variables_in_same_entity(tax_benefit_system: TaxBenefitSystem, variables: list): - """ - Assert that variables are in the same entity. - - Args: - tax_benefit_system (TaxBenefitSystem): Host tax benefit system - variables (List[str]): Variables supposed to belong to the same entity - - Returns: - str: Common entity of the variables - """ - entity = None - for variable_name in variables: - variable = tax_benefit_system.variables.get(variable_name) - assert variable - if entity is None: - entity = variable.entity - assert variable.entity == entity, ( - f"{variables} are not from the same entity: {variable_name} doesn't belong to {entity.key}" - ) - return entity.key - - -def get_words(text: str): - return re.compile("[A-Za-z_]+[A-Za-z0-9_]*").findall(text) - - -# Main functions - - -def adaptative_calculate_variable( - simulation: Simulation, variable: str, period: Optional[Union[int, str, Period]] -) -> Array: - """ - Calculate variable by adpating it definition period to the target period. - - Args: - simulation (Simulation): Simulation to suse - variable (str): Variable to be computed - period (Optional[Union[int, str, Period]]): Target period - - Returns: - Array: Values of the variable on the target period - """ - if not isinstance(period, periods.Period): - period = periods.period(str(period)) - - tax_benefit_system = simulation.tax_benefit_system - assert tax_benefit_system is not None - - assert variable in tax_benefit_system.variables, f"{variable} is not a valid variable" - period_size_independent = tax_benefit_system.get_variable(variable).is_period_size_independent - definition_period = tax_benefit_system.get_variable(variable).definition_period - - if period_size_independent is False and definition_period != "eternity": - values = simulation.calculate_add(variable, period=period) - elif period_size_independent is True and definition_period == "month" and period.size_in_months > 1: - values = simulation.calculate(variable, period=period.first_month) - elif period_size_independent is True and definition_period == "month" and period.size_in_months == 1: - values = simulation.calculate(variable, period=period) - elif period_size_independent is True and definition_period == "year" and period.size_in_months > 12: - values = simulation.calculate(variable, period=period.start.offset("first-of", "year").period("year")) - elif period_size_independent is True and definition_period == "year" and period.size_in_months == 12: - values = simulation.calculate(variable, period=period) - elif period_size_independent is True and definition_period == "year": - values = simulation.calculate(variable, period=period.this_year) - elif definition_period == "eternity": - values = simulation.calculate(variable, period=period) - else: - values = None - assert values is not None, f"Unspecified calculation period for variable {variable}" - - return values - - -def compute_aggregate( - simulation: Simulation, - variable: Optional[str] = None, - aggfunc: str = "sum", - filter_by: Optional[str] = None, - period: Optional[Union[int, str, Period]] = None, - missing_variable_default_value: Any = np.nan, - weighted: bool = True, - alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, -) -> Optional[Union[float, int]]: - """ - Compute aggregate of a variable. - - Args: - simulation (Simulation): Simulation to use for the computation - variable (str, optional): Variable to aggregate. Defaults to None. - aggfunc (str, optional): Aggregation function. Defaults to 'sum'. - filter_by (str, optional): Filter variable or expression to use. Defaults to None. - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. - weighted (bool, optional): Whether to weight the variable or not. Defaults to True. - alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. - Defaults to None. - filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. - - Returns: - float: Aggregate - """ - weight_variable_by_entity = simulation.weight_variable_by_entity - tax_benefit_system = simulation.tax_benefit_system - - if period is None: - period = simulation.period - - assert variable in tax_benefit_system.variables, f"{variable} is not a variable of the tax benefit system" - entity_key = tax_benefit_system.variables[variable].entity.key - - if filter_by is None and filtering_variable_by_entity is not None: - filter_by_variable = filtering_variable_by_entity.get(entity_key) - - if filter_by: - filter_by_variable = get_words(filter_by)[0] - assert filter_by_variable in tax_benefit_system.variables, ( - f"{filter_by_variable} is not a variable of the tax benefit system" - ) - entity_key = tax_benefit_system.variables[variable].entity.key - filter_by_entity_key = tax_benefit_system.variables[filter_by_variable].entity.key - assert filter_by_entity_key == entity_key, ( - f"You tried to compute agregates for variable '{variable}', of entity {entity_key}" - f" filtering by variable '{filter_by_variable}', of entity {filter_by_entity_key}. This is not possible." - f" Please choose a filter-by variable of same entity as '{variable}'." - ) - - expressions = [] - if filter_by is not None: - if filter_by in tax_benefit_system.variables: - filter_entity_key = tax_benefit_system.variables.get(filter_by).entity.key - assert filter_entity_key == entity_key, ( - f"You tried to compute agregates for variable '{variable}', of entity {entity_key} " - f"filtering by variable '{filter_by_variable}', of entity {filter_by_entity_key}. " - "This is not possible. " - f"Please choose a filter-by variable of same entity as '{variable}'." - ) - else: - filter_entity_key = assert_variables_in_same_entity(tax_benefit_system, get_words(filter_by)) - expressions.extend([filter_by]) - assert filter_entity_key == entity_key - else: - filter_dummy = np.array(1.0) - - uniform_weight = np.array(1.0) - weight_variable = None - if weighted: - assert or_(alternative_weights, weight_variable_by_entity), ( - f"The weighted option is set at True but there is no weight variable for entity {entity_key} " - "nor alternative weights. Either define a weight variable or switch to unweighted" - ) - if alternative_weights: - if isinstance(alternative_weights, str): - assert alternative_weights in tax_benefit_system.variables, ( - f"{alternative_weights} is not a valid variable of the tax benefit system" - ) - weight_variable = alternative_weights - - elif (type(alternative_weights) is int) or (type(alternative_weights) is float): - weight_variable = None - uniform_weight = float(alternative_weights) - elif weight_variable_by_entity: - weight_variable = weight_variable_by_entity[entity_key] - - if variable in simulation.tax_benefit_system.variables: - value = simulation.adaptative_calculate_variable(variable=variable, period=period) - else: - log.debug(f"Variable {variable} not found. Assigning {missing_variable_default_value}") - return missing_variable_default_value - - weight = ( - simulation.adaptative_calculate_variable(weight_variable, period=period).astype(float) - if weight_variable - else uniform_weight - ) - if weight_variable: - assert any(weight != 0), "Weights shall not be all zeroes" - else: - assert uniform_weight != 0 - - if filter_by is not None: - expression_data_frame = simulation.create_data_frame_by_entity( - variables=get_words(filter_by), period=period, index=False - )[entity_key] - for expression in expressions: - expression_data_frame[expression] = expression_data_frame.eval(expression) - - filter_dummy = expression_data_frame[filter_by] - else: - filter_dummy = 1.0 - - if aggfunc == "sum": - aggregate = (value * weight * filter_dummy).sum() - elif aggfunc == "mean": - aggregate = (value * weight * filter_dummy).sum() / (weight * filter_dummy).sum() - elif aggfunc == "count": - aggregate = (weight * filter_dummy).sum() - elif aggfunc == "count_non_zero": - aggregate = (weight * (value != 0) * filter_dummy).sum() - else: - aggregate = None - - return aggregate - - -def compute_quantiles( - simulation: Simulation, - variable: str, - nquantiles: Optional[int] = None, - period: Optional[Union[int, str, Period]] = None, - filter_by: Optional[str] = None, - weighted: bool = True, - alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, -) -> list[float]: - """ - Compute quantiles of a variable. - - Args: - simulation (Simulation, optional): Simulation to be used. Defaults to None. - variable (str, optional): Variable which quantiles are computed. Defaults to None. - nquantiles (int, optional): Number of quantiles. Defaults to None. - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. - weighted (bool, optional): Whether to weight the variable or not. Defaults to True. - alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. - Defaults to None. - filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. - - Returns: - List(float): The quantiles values - """ - weight_variable_by_entity = simulation.weight_variable_by_entity - weight_variable = None - entity_key = simulation.tax_benefit_system.variables[variable].entity.key - if weight_variable_by_entity: - weight_variable = weight_variable_by_entity[entity_key] - - variable_values = simulation.adaptative_calculate_variable(variable, period) - if weighted: - assert (alternative_weights is not None) or (weight_variable is not None) - weight = ( - alternative_weights if alternative_weights is not None else simulation.calculate(weight_variable, period) - ) - else: - weight = np.ones(len(variable_values)) - - if filtering_variable_by_entity is not None and filter_by is None: - filter_by = filtering_variable_by_entity.get(entity_key) - - if filter_by is not None: - filter_entity_key = simulation.tax_benefit_system.variables.get(filter_by).entity.key - assert filter_entity_key == entity_key - filter_dummy = simulation.calculate(filter_by, period=period).astype(bool) - - variable_values = variable_values[filter_dummy].copy() - weight = weight[filter_dummy].copy() - - labels = np.arange(1, nquantiles + 1) - method = 2 - _, values = mark_weighted_percentiles(variable_values, labels, weight, method, return_quantiles=True) - return values - - -def compute_pivot_table( - simulation: Simulation = None, - baseline_simulation: Simulation = None, - aggfunc: str = "mean", - columns: Optional[list[str]] = None, - difference: bool = False, - filter_by: Optional[str] = None, - index: Optional[list[str]] = None, - period: Optional[Union[int, str, Period]] = None, - use_baseline_for_columns: Optional[bool] = None, - values: Optional[list[str]] = None, - missing_variable_default_value: Any = np.nan, - concat_axis: Optional[int] = None, - weighted: bool = True, - alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, -): - """ - Compute pivot table. - - Args: - simulation (Simulation, optional): Main simulation. Defaults to None. - baseline_simulation (Simulation, optional): Baseline simulation. Defaults to None. - aggfunc (str, optional): Aggregation function. Defaults to 'mean'. - columns (List[str], optional): Variables to use in columns. Defaults to None. - difference (bool, optional): Whether to compute the difference with baseline. Defaults to False. - filter_by (str, optional): Filter variable or expression to use. Defaults to None. - index (List[str], optional): _description_. Defaults to None. - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - use_baseline_for_columns (bool, optional): _description_. Defaults to None. - values (List[str], optional): _description_. Defaults to None. - missing_variable_default_value (optional): _description_. Defaults to np.nan. - concat_axis (int, optional): _description_. Defaults to None. - weighted (bool, optional): Whether to weight the variable or not. Defaults to True. - alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. - Defaults to None. - filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. - - Returns: - _type_: _description_ - """ - weight_variable_by_entity = simulation.weight_variable_by_entity - - admissible_aggfuncs = ["max", "mean", "min", "sum", "count", "sum_abs"] - assert aggfunc in admissible_aggfuncs - assert columns or index or values - - if baseline_simulation is not None: - tax_benefit_system = baseline_simulation.tax_benefit_system - else: - tax_benefit_system = simulation.tax_benefit_system - - assert period is not None - - if isinstance(columns, str): - columns = [columns] - elif columns is None: - columns = [] - assert isinstance(columns, list) - - if isinstance(index, str): - index = [index] - elif index is None: - index = [] - assert isinstance(index, list) - - if isinstance(values, str): - values = [values] - elif values is None: - values = [] - assert isinstance(values, list) - - entity_key = None - for axe in [columns, index, values]: - if (len(axe) != 0) and (entity_key is None): - entity_key = tax_benefit_system.variables[axe[0]].entity.key - continue - - if filter_by is None and filtering_variable_by_entity is not None: - filter_by = filtering_variable_by_entity.get(entity_key) - - variables = set(index + columns) - - # Select the entity weight corresponding to the variables that will provide values - uniform_weight = 1.0 - weight_variable = None - if weighted: - if alternative_weights: - if isinstance(alternative_weights, str): - assert alternative_weights in tax_benefit_system.variables, ( - f"{alternative_weights} is not a valid variable of the tax benefit system" - ) - weight_variable = alternative_weights - - elif (type(alternative_weights) is int) or (type(alternative_weights) is float): - weight_variable = None - uniform_weight = float(alternative_weights) - - else: - if weight_variable_by_entity: - weight_variable = weight_variable_by_entity[entity_key] - variables.add(weight_variable) - - else: - log.warn( - f"There is no weight variable for entity {entity_key} nor alternative weights. " - "Switch to unweighted" - ) - - expressions = [] - if filter_by is not None: - if filter_by in tax_benefit_system.variables: - variables.add(filter_by) - filter_entity_key = tax_benefit_system.variables.get(filter_by).entity.key - assert filter_entity_key == entity_key - else: - filter_entity_key = assert_variables_in_same_entity(tax_benefit_system, get_words(filter_by)) - expressions.extend([filter_by]) - assert filter_entity_key == entity_key - else: - filter_dummy = np.array(1.0) - - for expression in expressions: - expression_variables = get_words(expression) - expression_entity_key = assert_variables_in_same_entity(tax_benefit_system, expression_variables) - assert expression_entity_key == entity_key - for variable in expression_variables: - variables.add(variable) - - for variable in variables | set(values): - if variable in tax_benefit_system.variables: - assert tax_benefit_system.variables[variable].entity.key == entity_key, ( - f"The variable {variable} does not belong to entity {entity_key}" - ) - - if difference: - assert simulation is not None and baseline_simulation is not None - reform_data_frame = simulation.create_data_frame_by_entity(values, period=period, index=False)[ - entity_key - ].fillna(missing_variable_default_value) - baseline_data_frame = baseline_simulation.create_data_frame_by_entity(values, period=period, index=False)[ - entity_key - ].fillna(missing_variable_default_value) - for value_variable in values: - if value_variable not in baseline_data_frame: - baseline_data_frame[value_variable] = missing_variable_default_value - if value_variable not in reform_data_frame: - reform_data_frame[value_variable] = missing_variable_default_value - - data_frame = reform_data_frame - baseline_data_frame - - else: - if values: - data_frame = simulation.create_data_frame_by_entity(values, period=period, index=False)[entity_key] - for value_variable in values: - if value_variable not in data_frame: - data_frame[value_variable] = missing_variable_default_value - else: - data_frame = None - - use_baseline_data = difference or use_baseline_for_columns - - # use baseline if explicited or when computing difference - if use_baseline_data: - baseline_vars_data_frame = baseline_simulation.create_data_frame_by_entity( - variables=variables, period=period, index=False - )[entity_key] - else: - baseline_vars_data_frame = simulation.create_data_frame_by_entity( - variables=variables, period=period, index=False - )[entity_key] - - for expression in expressions: - baseline_vars_data_frame[expression] = baseline_vars_data_frame.eval(expression) - if filter_by is not None: - filter_dummy = baseline_vars_data_frame[filter_by] - if weight_variable is None: - weight_variable = "weight" - baseline_vars_data_frame[weight_variable] = uniform_weight - baseline_vars_data_frame[weight_variable] = baseline_vars_data_frame[weight_variable] * filter_dummy - # We drop variables that are in values from baseline_vars_data_frame - dropped_columns = [column for column in baseline_vars_data_frame.columns if column in values] - baseline_vars_data_frame.drop(columns=dropped_columns, inplace=True) - - data_frame = pd.concat( - [baseline_vars_data_frame, data_frame], - axis=1, - ) - - if values: - data_frame_by_value = {} - for value in values: - if aggfunc in ["mean", "sum", "sum_abs", "count"]: - data_frame[value] = ( - data_frame[value] * data_frame[weight_variable] - if aggfunc != "sum_abs" - else data_frame[value].abs() * data_frame[weight_variable] - ) - data_frame[value] = data_frame[value].fillna(missing_variable_default_value) - pivot_sum = data_frame.pivot_table(index=index, columns=columns, values=value, aggfunc="sum") - pivot_mass = data_frame.pivot_table( - index=index, columns=columns, values=weight_variable, aggfunc="sum" - ) - if aggfunc == "mean": - try: # Deal with a pivot_table pandas bug https://github.com/pandas-dev/pandas/issues/17038 - result = pivot_sum / pivot_mass.loc[weight_variable] - except KeyError: - result = pivot_sum / pivot_mass - elif aggfunc in ["sum", "sum_abs"]: - result = pivot_sum - elif aggfunc == "count": - result = pivot_mass.rename(columns={weight_variable: value}, index={weight_variable: value}) - - elif aggfunc in ["min", "max"]: - data_frame[value].fillna(missing_variable_default_value, inplace=True) - result = data_frame.pivot_table(index=index, columns=columns, values=value, aggfunc=aggfunc) - - data_frame_by_value[value] = result - - if len(list(data_frame_by_value.keys())) > 1: - if concat_axis is None: - return data_frame_by_value - else: - assert concat_axis in [0, 1] - return pd.concat(data_frame_by_value.values(), axis=concat_axis) - else: - return next(iter(data_frame_by_value.values())) - - else: - assert aggfunc == "count", "Can only use count for aggfunc if no values" - return data_frame.pivot_table(index=index, columns=columns, values=weight_variable, aggfunc="sum") - - -def create_data_frame_by_entity( - simulation: Simulation, - variables: Optional[list[str]] = None, - expressions: Optional[list[str]] = None, - filter_by: Optional[str] = None, - index: bool = False, - period: Optional[Union[int, str, Period]] = None, - merge: bool = False, -) -> Union[pd.DataFrame, dict[str, pd.DataFrame]]: - """ - Create dataframe(s) of variables for the whole selected population. - - Args: - simulation (Simulation): Simulation to use. - variables (Optional[List], optional): Variables to retrieve, None means all. Defaults to None. - expressions (Optional[List[str]], optional): _description_. Defaults to None. - filter_by (str, optional): Filter variable or expression to use. Defaults to None. - index (bool, optional): Whether to use index (id) variables. Defaults to False. - period (Optional[Union[int, str, Period]], optional): Period of the computation. Defaults to None. - merge (bool, optional): Wheter to merge the datafrales into one. Defaults to False. - - Returns: - pd.DataFrame of Dict: Dataframe(s) with the variables values - """ - assert simulation is not None - id_variable_by_entity_key = simulation.id_variable_by_entity_key - tax_benefit_system = simulation.tax_benefit_system - assert tax_benefit_system is not None - - if period is None: - period = simulation.period - - assert variables or index or expressions or filter_by - - if merge: - index = True - if expressions is None: - expressions = [] - - if filter_by is not None: - if filter_by in tax_benefit_system.variables: - variables.append(filter_by) - filter_entity_key = tax_benefit_system.variables.get(filter_by).entity.key - else: - filter_entity_key = assert_variables_in_same_entity(tax_benefit_system, get_words(filter_by)) - expressions.append(filter_by) - - expressions_by_entity_key = {} - for expression in expressions: - expression_variables = get_words(expression) - entity_key = assert_variables_in_same_entity(tax_benefit_system, expression_variables) - if entity_key in expressions_by_entity_key: - expressions_by_entity_key[entity_key].append(expression) - else: - expressions_by_entity_key[entity_key] = [expression] - variables += expression_variables - - variables = set(variables) - - missing_variables = set(variables).difference(set(tax_benefit_system.variables.keys())) - if missing_variables: - log.info(f"These variables aren't part of the tax-benefit system: {missing_variables} and thus ignored") - - columns_to_fetch = [ - tax_benefit_system.variables.get(variable_name) - for variable_name in variables - if tax_benefit_system.variables.get(variable_name) is not None - ] - - tax_vars = list(tax_benefit_system.variables.keys()) - assert len(columns_to_fetch) >= 1, ( - f"None of the requested variables {variables} are in the tax-benefit-system {tax_vars}" - ) - - assert simulation is not None - - openfisca_data_frame_by_entity_key = {} - non_person_entities = [] - - for entity in tax_benefit_system.entities: - entity_key = entity.key - column_names = [column.name for column in columns_to_fetch if column.entity.key == entity_key] - openfisca_data_frame_by_entity_key[entity_key] = pd.DataFrame( - { - column_name: simulation.adaptative_calculate_variable(column_name, period=period) - for column_name in column_names - } - ) - if entity.is_person: - person_entity = entity - else: - non_person_entities.append(entity) - - if index: - person_data_frame = openfisca_data_frame_by_entity_key.get(person_entity.key) - person_data_frame.index.name = id_variable_by_entity_key.get("person", "person_id") - if person_data_frame is None: - person_data_frame = pd.DataFrame() - for entity in non_person_entities: - entity_key_id = id_variable_by_entity_key[entity.key] - person_data_frame[entity_key_id] = simulation.populations[entity.key].members_entity_id - flattened_roles = entity.flattened_roles - index_by_role = {flattened_roles[index]: index for index in range(len(flattened_roles))} - person_data_frame["{}_{}".format(entity.key, "role")] = pd.Series( - simulation.populations[entity.key].members_role - ).map(index_by_role) - person_data_frame["{}_{}".format(entity.key, "position")] = simulation.populations[ - entity.key - ].members_position +Prefer: from openfisca_survey_manager.policy import ... +or from openfisca_survey_manager.policy.simulations import ... +""" - # Set index names as entity_id - openfisca_data_frame_by_entity_key[entity.key].index.name = entity_key_id - openfisca_data_frame_by_entity_key[entity.key].reset_index(inplace=True) - person_data_frame.reset_index(inplace=True) - - for entity_key, expressions in expressions_by_entity_key.items(): - data_frame = openfisca_data_frame_by_entity_key[entity_key] - for expression in expressions: - data_frame[expression] = data_frame.eval(expression) - - if filter_by is not None: - openfisca_data_frame_by_entity_key[filter_entity_key] = ( - openfisca_data_frame_by_entity_key[filter_entity_key] - .loc[openfisca_data_frame_by_entity_key[filter_entity_key][filter_by]] - .copy() - ) - - if not merge: - return openfisca_data_frame_by_entity_key - else: - for entity_key, openfisca_data_frame in openfisca_data_frame_by_entity_key.items(): - if entity_key != person_entity.key: - entity_key_id = id_variable_by_entity_key[entity_key] - if len(openfisca_data_frame) > 0: - person_data_frame = person_data_frame.merge( - openfisca_data_frame.reset_index(), - left_on=entity_key_id, - right_on=entity_key_id, - ) - return person_data_frame - - -class SecretViolationError(Exception): - """Raised if the result of the simulation do not comform with regulators rules.""" - - pass - - -def compute_winners_losers( - simulation: Simulation, - baseline_simulation: Simulation, - variable: str, - filter_by: Optional[str] = None, - period: Optional[Union[int, str, Period]] = None, - absolute_minimal_detected_variation: float = 0, - relative_minimal_detected_variation: float = 0.01, - observations_threshold: Optional[int] = None, - weighted: bool = True, - alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, -) -> dict[str, Union[int, float]]: - """ - Compute the number of winners and losers for a given variable. - - Args: - simulation (_type_): The main simulation. - baseline_simulation (_type_): The baseline simulation - variable (str): The variable to use. - filter_by (str, optional): The variable or expression to be used as a filter. Defaults to None. - period (Optional[Union[int, str, Period]], optional): The period of the simulation. Defaults to None. - absolute_minimal_detected_variation (float, optional): Absolute minimal variation to be detected, in ratio. - Ie 0.5 means 5% of variation wont be counted.. - relative_minimal_detected_variation (float, optional): Relative minimal variation to be detected, in ratio. - Defaults to .01. - observations_threshold (int, optional): Number of observations needed to avoid a statistical secret violation. - Defaults to None. - weighted (bool, optional): Whether to use weights. Defaults to True. - alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. - Defaults to None. - filtering_variable_by_entity (_type_, optional): The variable to be used as a filter for each entity. - Defaults to None. - - Raises: - SecretViolationError: Raised when statistical secret is violated. - - Returns: - Dict[str, int]: Statistics about winners and losers between the main simulation and the baseline. - """ - weight_variable_by_entity = simulation.weight_variable_by_entity - entity_key = baseline_simulation.tax_benefit_system.variables[variable].entity.key - - # Get the results of the simulation - after = simulation.adaptative_calculate_variable(variable, period=period) - before = baseline_simulation.adaptative_calculate_variable(variable, period=period) - - # Filter if needed - if filtering_variable_by_entity is not None and filter_by is None: - filter_by = filtering_variable_by_entity.get(entity_key) - - if filter_by is not None: - filter_entity_key = baseline_simulation.tax_benefit_system.variables.get(filter_by).entity.key - assert filter_entity_key == entity_key - filter_dummy = baseline_simulation.calculate(filter_by, period=period).astype(bool) - - after = after[filter_dummy].copy() - before = before[filter_dummy].copy() - - # Define weights - weight = np.ones(len(after)) - if weighted: - if alternative_weights is not None: - weight = alternative_weights - elif weight_variable_by_entity is not None: - weight_variable = weight_variable_by_entity[entity_key] - weight = baseline_simulation.calculate(weight_variable, period=period) - else: - log.warn( - f"There is no weight variable for entity {entity_key} nor alternative weights. Switch to unweighted" - ) - - # Compute the weigthed number of zeros or non zeros - value_by_simulation = {"after": after, "before": before} - stats_by_simulation = {} - for simulation_prefix, value in value_by_simulation.items(): - stats = {} - stats["count_zero"] = (weight.astype("float64") * (absolute_minimal_detected_variation > np.abs(value))).sum() - stats["count_non_zero"] = sum(weight.astype("float64")) - stats["count_zero"] - stats_by_simulation[simulation_prefix] = stats - del stats - - # Compute the number of entity above or below after - after_value = after - before_value = before - with np.errstate(divide="ignore", invalid="ignore"): - above_after = ((after_value - before_value) / np.abs(before_value)) > relative_minimal_detected_variation - almost_zero_before = np.abs(before_value) < absolute_minimal_detected_variation - above_after[almost_zero_before * (after_value >= 0)] = (after_value >= absolute_minimal_detected_variation)[ - almost_zero_before * (after_value >= 0) - ] - with np.errstate(divide="ignore", invalid="ignore"): - below_after = ((after_value - before_value) / np.abs(before_value)) < -relative_minimal_detected_variation - below_after[almost_zero_before * (after_value < 0)] = (after_value < -absolute_minimal_detected_variation)[ - almost_zero_before * (after_value < 0) - ] - - # Check if there is a secret violation, without weights - if observations_threshold is not None: - not_legit_below = (below_after.sum() < observations_threshold) & (below_after.sum() > 0) - not_legit_above = (above_after.sum() < observations_threshold) & (above_after.sum() > 0) - if not_legit_below | not_legit_above: - raise SecretViolationError("Not enough observations involved") - - # Apply weights - above_after_count = (above_after.astype("float64") * weight.astype("float64")).sum() - below_after_count = (below_after.astype("float64") * weight.astype("float64")).sum() - total = sum(weight) - neutral = total - above_after_count - below_after_count - - return { - "total": total, - "non_zero_before": stats_by_simulation["before"]["count_non_zero"], - "non_zero_after": stats_by_simulation["after"]["count_non_zero"], - "above_after": above_after_count, - "lower_after": below_after_count, - "neutral": neutral, - "tolerance_factor_used": relative_minimal_detected_variation, - "weight_factor": 1, - } - - -def init_entity_data( - simulation: Simulation, - entity: Entity, - filtered_input_data_frame: pd.DataFrame, - period: Period, - used_as_input_variables_by_entity: dict, -): - """ - Initialize entity in simulation at some period with input provided by a dataframe. - - Args: - simulation (Simulation): The simulation to initialize. - entity (Entity): The entity which variables to initialize. - filtered_input_data_frame (pd.DataFrame): The dataframe with the variables values. - period (Period): The period to initialize. - used_as_input_variables_by_entity (Dict): The variable to be used to initialize each entity. - """ - used_as_input_variables = used_as_input_variables_by_entity[entity.key] - input_data_frame = filtered_input_data_frame - # input_data_frame = self.filter_input_variables(input_data_frame = input_data_frame) - diagnose_variable_mismatch(used_as_input_variables, input_data_frame) - - for column_name, column_serie in input_data_frame.items(): - variable_instance = simulation.tax_benefit_system.variables.get(column_name) - if variable_instance is None: - log.debug(f"Ignoring {column_name} in input data") - continue - - if variable_instance.entity.key != entity.key: - log.debug( - f"Ignoring variable {column_name} which is not part of entity {entity.key} " - f"but {variable_instance.entity.key}" - ) - continue - init_variable_in_entity(simulation, entity.key, column_name, column_serie, period) - - -def inflate( - simulation: Simulation, - inflator_by_variable: Optional[dict] = None, - period: Optional[Union[int, str, Period]] = None, - target_by_variable: Optional[dict] = None, -) -> None: - tax_benefit_system = simulation.tax_benefit_system - for variable_name in set(inflator_by_variable.keys()).union(set(target_by_variable.keys())): - assert variable_name in tax_benefit_system.variables, ( - f"Variable {variable_name} is not a valid variable of the tax-benefit system" - ) - if variable_name in target_by_variable: - inflator = inflator_by_variable[variable_name] = target_by_variable[ - variable_name - ] / simulation.compute_aggregate(variable=variable_name, period=period) - log.debug( - f"Using {inflator} as inflator for {variable_name} to reach the target " - f"{target_by_variable[variable_name]} " - ) - else: - assert variable_name in inflator_by_variable, "variable_name is not in inflator_by_variable" - target = inflator_by_variable[variable_name] * simulation.compute_aggregate( - variable=variable_name, period=period - ) - log.debug( - f"Using inflator {inflator_by_variable[variable_name]} for {variable_name}. " - f"The target is thus {target}" - ) - inflator = inflator_by_variable[variable_name] - - array = simulation.calculate_add(variable_name, period=period) - assert array is not None - simulation.delete_arrays(variable_name, period=period) # delete existing arrays - simulation.set_input(variable_name, period, inflator * array) # insert inflated array - - -def _load_table_for_survey( - config_files_directory: str, - collection: str, - survey: str, - table: str, - batch_size: Optional[int] = None, - batch_index: Optional[int] = None, - filter_by: Optional[str] = None, -): - if survey is not None: - input_data_frame = load_table( - config_files_directory=config_files_directory, - collection=collection, - survey=survey, - table=table, - batch_size=batch_size, - batch_index=batch_index, - filter_by=filter_by, - ) - else: - input_data_frame = load_table( - config_files_directory=config_files_directory, - collection=collection, - survey="input", - table=table, - batch_size=batch_size, - batch_index=batch_index, - filter_by=filter_by, - ) - return input_data_frame - - -def _input_data_table_by_entity_by_period_monolithic( - tax_benefit_system: TaxBenefitSystem, - simulation: Simulation, - period: Period, - input_data_table_by_entity: dict, - builder: SimulationBuilder, - custom_input_data_frame: Callable, - config_files_directory: str, - collection: str, - survey: Optional[str] = None, -): - """ - Initialize simulation with input data from a table for each entity and period. - """ - period = periods.period(period) - simulation_datasets = {} - entities = tax_benefit_system.entities - for entity in entities: - # Read all tables for the entity - log.debug(f"init_simulation - {period=} {entity.key=}") - table = input_data_table_by_entity.get(entity.key) - filter_by = input_data_table_by_entity.get("filter_by") - if table is None: - continue - input_data_frame = _load_table_for_survey( - config_files_directory, collection, survey, table, filter_by=filter_by - ) - simulation_datasets[entity.key] = input_data_frame - - if simulation is None: - # Instantiate simulation only for the fist period - # Next period will reuse the same simulation - for entity in entities: - table = input_data_table_by_entity.get(entity.key) - if table is None: - continue - custom_input_data_frame(input_data_frame, period=period, entity=entity.key) - builder.init_entity_structure(entity, simulation_datasets[entity.key]) # TODO complete args - simulation = builder.build(tax_benefit_system) - simulation.id_variable_by_entity_key = ( - builder.id_variable_by_entity_key - ) # Should be propagated to enhanced build - - for entity in entities: - # Load data in the simulation - table = input_data_table_by_entity.get(entity.key) - if table is None: - continue - log.debug(f"init_simulation - {entity.key=} {len(input_data_frame)=}") - simulation.init_entity_data( - entity, simulation_datasets[entity.key], period, builder.used_as_input_variables_by_entity - ) - del simulation_datasets[entity.key] - return simulation - - -def _input_data_table_by_entity_by_period_batch( - tax_benefit_system: TaxBenefitSystem, - simulation: Simulation, - period: Period, - input_data_table_by_entity: dict, - builder: SimulationBuilder, - custom_input_data_frame: Callable, - config_files_directory: str, - collection: str, - survey: Optional[str] = None, -): - """ - Initialize simulation with input data from a table for each entity and period. - """ - period = periods.period(period) - batch_size = input_data_table_by_entity.get("batch_size") - batch_index = input_data_table_by_entity.get("batch_index", 0) - batch_entity = input_data_table_by_entity.get("batch_entity") - batch_entity_key = input_data_table_by_entity.get("batch_entity_key") - filtered_entity = input_data_table_by_entity.get("filtered_entity") - filtered_entity_on_key = input_data_table_by_entity.get("filtered_entity_on_key") - if not batch_entity or not batch_entity_key or not filtered_entity or not filtered_entity_on_key: - raise SurveyManagerError( - "batch_entity, batch_entity_key, filtered_entity and filtered_entity_on_key are required" - ) - simulation_datasets = { - batch_entity: { - "table_key": batch_entity_key, - "input_data_frame": None, - "entity": None, - }, - filtered_entity: { - "table_key": filtered_entity_on_key, - "input_data_frame": None, - "entity": None, - }, - } - batch_entity_ids = None - entities = tax_benefit_system.entities - - if len(entities) > 2: - # Batch mode could work only with batch_entity and filtered_entity, and no others - warnings.warn( - "survey-manager.simulation._input_data_table_by_entity_by_period_batch : " - f"Your TaxBenefitSystem has {len(entities)} entities but we will only load " - f"{batch_entity} and {filtered_entity}.", - stacklevel=2, - ) - - for entity_name, entity_data in simulation_datasets.items(): - # Find Identity object from TaxBenefitSystem - for entity in entities: - if entity.key == entity_name: - entity_data["entity"] = entity - break - - # Load the batch entity - table = input_data_table_by_entity[batch_entity] - input_data_frame = _load_table_for_survey( - config_files_directory, collection, survey, table, batch_size, batch_index - ) - batch_entity_ids = input_data_frame[batch_entity_key].to_list() - simulation_datasets[batch_entity]["input_data_frame"] = input_data_frame - - # Load the filtered entity - table = input_data_table_by_entity[filtered_entity] - filter_by = [(filtered_entity_on_key, "in", batch_entity_ids)] - input_data_frame = _load_table_for_survey(config_files_directory, collection, survey, table, filter_by=filter_by) - simulation_datasets[filtered_entity]["input_data_frame"] = input_data_frame - - if simulation is None: - for entity_name, entity_data in simulation_datasets.items(): - custom_input_data_frame(entity_data["input_data_frame"], period=period, entity=entity_name) - builder.init_entity_structure(entity_data["entity"], entity_data["input_data_frame"]) - simulation = builder.build(tax_benefit_system) - simulation.id_variable_by_entity_key = ( - builder.id_variable_by_entity_key - ) # Should be propagated to enhanced build - for _entity_name, entity_data in simulation_datasets.items(): - simulation.init_entity_data( - entity_data["entity"], entity_data["input_data_frame"], period, builder.used_as_input_variables_by_entity - ) - return simulation - - -def init_simulation( - tax_benefit_system: TaxBenefitSystem, - period: Union[str, int, Period], - data: dict, -): - builder = SimulationBuilder() - builder.create_entities(tax_benefit_system) - - collection = data.get("collection") - custom_input_data_frame = data.get("custom_input_data_frame", do_nothing) - data_year = data.get("data_year") - survey = data.get("survey") - config_files_directory = data.get("config_files_directory") - builder.used_as_input_variables = data.get("used_as_input_variables") - builder.id_variable_by_entity_key = data.get("id_variable_by_entity_key") - builder.role_variable_by_entity_key = data.get("role_variable_by_entity_key") - builder.tax_benefit_system = tax_benefit_system - - default_source_types = [ - "input_data_frame", - "input_data_table", - "input_data_frame_by_entity", - "input_data_frame_by_entity_by_period", - "input_data_table_by_entity_by_period", - "input_data_table_by_period", - ] - source_types = [source_type_ for source_type_ in default_source_types if data.get(source_type_) is not None] - assert len(source_types) < 2, "There are too many data source types" - assert len(source_types) >= 1, f"There should be one data source type included in {default_source_types}" - source_type = source_types[0] - source = data[source_type] - - if source_type == "input_data_frame_by_entity": - assert data_year is not None - source_type = "input_data_frame_by_entity_by_period" - source = {periods.period(data_year): source} - - input_data_survey_prefix = data.get("input_data_survey_prefix") if data is not None else None - - if source_type == "input_data_frame": - simulation = builder.init_all_entities(source, period) - - if source_type == "input_data_table": - # Case 1: fill simulation with a unique input_data_frame given by the attribute - if input_data_survey_prefix is not None: - openfisca_survey_collection = SurveyCollection.load(collection=collection) - openfisca_survey = openfisca_survey_collection.get_survey(f"{input_data_survey_prefix}_{data_year}") - input_data_frame = openfisca_survey.get_values(table="input").reset_index(drop=True) - else: - raise NotImplementedError - - custom_input_data_frame(input_data_frame, period=period) - simulation = builder.init_all_entities(input_data_frame, builder, period) # monolithic dataframes - - elif source_type == "input_data_table_by_period": - # Case 2: fill simulation with input_data_frame by period containing all entity variables - input_data_table_by_period = data.get("input_data_table_by_period") - for period, table in input_data_table_by_period.items(): - period = periods.period(period) - log.debug(f"From survey {survey} loading table {table}") - input_data_frame = load_table( - config_files_directory=config_files_directory, - collection=collection, - survey=survey, - input_data_survey_prefix=input_data_survey_prefix, - table=table, - ) - custom_input_data_frame(input_data_frame, period=period) - simulation = builder.init_all_entities(input_data_frame, builder, period) # monolithic dataframes - - elif source_type == "input_data_frame_by_entity_by_period": - for period, input_data_frame_by_entity in source.items(): - period = periods.period(period) - for entity in tax_benefit_system.entities: - input_data_frame = input_data_frame_by_entity.get(entity.key) - if input_data_frame is None: - continue - custom_input_data_frame(input_data_frame, period=period, entity=entity.key) - builder.init_entity_structure(entity, input_data_frame) # TODO complete args - - simulation = builder.build(tax_benefit_system) - simulation.id_variable_by_entity_key = ( - builder.id_variable_by_entity_key - ) # Should be propagated to enhanced build - - for period, input_data_frame_by_entity in source.items(): - for entity in tax_benefit_system.entities: - input_data_frame = input_data_frame_by_entity.get(entity.key) - if input_data_frame is None: - log.debug(f"No input_data_frame found for entity {entity} at period {period}") - continue - custom_input_data_frame(input_data_frame, period=period, entity=entity.key) - simulation.init_entity_data( - entity, input_data_frame, period, builder.used_as_input_variables_by_entity - ) - - elif source_type == "input_data_table_by_entity_by_period": - # Case 3: fill simulation with input_data_table by entity_by_period containing a dictionnary - # of all periods containing a dictionnary of entity variables - input_data_table_by_entity_by_period = source - simulation = None - for period, input_data_table_by_entity in input_data_table_by_entity_by_period.items(): - if input_data_table_by_entity.get("batch_size"): - simulation = _input_data_table_by_entity_by_period_batch( - tax_benefit_system, - simulation, - period, - input_data_table_by_entity, - builder, - custom_input_data_frame, - config_files_directory, - collection, - survey, - ) - else: - simulation = _input_data_table_by_entity_by_period_monolithic( - tax_benefit_system, - simulation, - period, - input_data_table_by_entity, - builder, - custom_input_data_frame, - config_files_directory, - collection, - survey, - ) - - else: - pass - - if data_year is not None: - simulation.period = periods.period(data_year) - - return simulation - - -def init_variable_in_entity( - simulation: Simulation, - entity: str, - variable_name: str, - series: pd.Series, - period: Period, -): - variable = simulation.tax_benefit_system.variables[variable_name] - - # np.issubdtype cannot handles categorical variables - if (not isinstance(series.dtype, pd.CategoricalDtype)) and pd.api.types.is_float_dtype(series.values.dtype): - if series.isnull().any(): - log.debug( - f"There are {series.isnull().sum()} NaN values for " - f"{series.notnull().sum()} non NaN values in variable {variable_name}" - ) - log.debug( - f"We convert NaN values of variable {variable_name} to {variable.default_value} its default value" - ) - series = pd.to_numeric(series).fillna(variable.default_value).astype(variable.value_type) - assert series.notnull().all(), ( - f"There are {series.isnull().sum()} NaN values for " - f"{series.notnull().sum()} non NaN values in variable {variable_name}" - ) - - enum_variable_imputed_as_enum = variable.value_type == Enum and ( - isinstance(series.dtype, pd.CategoricalDtype) - or not (np.issubdtype(series.values.dtype, np.integer) or np.issubdtype(series.values.dtype, float)) - ) - - if enum_variable_imputed_as_enum: - if series.isnull().any(): - log.debug( - f"There are {series.isnull().sum()} NaN values " - f"({series.isnull().mean() * 100}% of the array) in variable {variable_name}" - ) - log.debug( - f"We convert these NaN values of variable {variable_name} to " - f"{variable.default_value._name_} its default value" - ) - series.fillna(variable.default_value._name_, inplace=True) - possible_values = variable.possible_values - if isinstance(series.dtype, pd.CategoricalDtype): - series = series.cat.codes - else: - msg = f"There are errors with {variable_name}" - assert series.isin(list(possible_values._member_names_)).all(), msg - series = series.apply(lambda v: variable.possible_values[v].index) - - if series.values.dtype != variable.dtype: - log.debug(f"Converting {variable_name} from dtype {series.values.dtype} to {variable.dtype}") - - array = series.values.astype(variable.dtype) - np_array = np.array(array, dtype=variable.dtype) - - if (variable.value_type == Enum) and ( - np.issubdtype(series.values.dtype, np.integer) or np.issubdtype(series.values.dtype, float) - ): - np_array = EnumArray(np_array, variable.possible_values) - - if variable.definition_period == YEAR and period.unit == MONTH: - # Some variables defined for a year are present in month/quarter dataframes - # Cleaning the dataframe would probably be better in the long run - log.warn( - f"Trying to set a monthly value for variable {variable_name}, which is defined on a year. " - "The montly values you provided will be summed." - ) - - if simulation.get_array(variable_name, period.this_year) is not None: - array_sum = simulation.get_array(variable_name, period.this_year) + np_array - simulation.set_input(variable_name, period.this_year, array_sum) - else: - simulation.set_input(variable_name, period.this_year, np_array) - - else: - simulation.set_input(variable_name, period, np_array) - - -def new_from_tax_benefit_system( - tax_benefit_system: TaxBenefitSystem, - debug: bool = False, - trace: bool = False, - data: Optional[dict] = None, - memory_config: MemoryConfig = None, - period: Optional[Union[int, str, Period]] = None, - custom_initialize: Optional[Callable] = None, -) -> Simulation: - """ - Create and initialize a simulation from a tax and benefit system and data. - - Args: - tax_benefit_system (TaxBenefitSystem): The tax and benefit system - debug (bool, optional): Whether to activate debugging. Defaults to False. - trace (bool, optional): Whether to activate tracing. Defaults to False. - data (Dict, optional): The information about data. Defaults to None. - memory_config (MemoryConfig, optional): The memory handling config. Defaults to None. - period (Optional[Union[int, str, Period]], optional): The period of the simulation. Defaults to None. - custom_initialize (Callable, optional): The post-processing initialization function. Defaults to None. - - Returns: - Simulation: The completely initialized function - """ - - simulation = Simulation.init_simulation(tax_benefit_system, period, data) - simulation.debug = debug - simulation.trace = trace - simulation.opt_out_cache = simulation.tax_benefit_system.cache_blacklist is not None - simulation.memory_config = memory_config - - if custom_initialize: - custom_initialize(simulation) - - return simulation - - -def print_memory_usage(simulation: Simulation): - """ - Print memory usage. - - Args: - simulation (Simulation): The simulation which memory usage is to be printed - """ - memory_usage_by_variable = simulation.get_memory_usage()["by_variable"] - try: - usage_stats = simulation.tracer.usage_stats - except AttributeError: - log.warning( - "The simulation trace mode is not activated. You need to activate it to get stats " - "about variable usage (hits)." - ) - usage_stats = None - infos_lines = [] - - for variable, infos in memory_usage_by_variable.items(): - hits = usage_stats[variable]["nb_requests"] if usage_stats else None - infos_lines.append( - ( - infos["total_nb_bytes"], - variable, - "{}: {} periods * {} cells * item size {} ({}) = {} with {} hits".format( - variable, - infos["nb_arrays"], - infos["nb_cells_by_array"], - infos["cell_size"], - infos["dtype"], - humanize.naturalsize(infos["total_nb_bytes"], gnu=True), - hits, - ), - ) - ) - infos_lines.sort() - for _, _, line in infos_lines: - log.info("%s", line.rjust(100)) - - -def set_weight_variable_by_entity( - simulation: Simulation, - weight_variable_by_entity: dict, -) -> None: - """ - Set weight variable for each entity. - - Args: - simulation (Simulation): The simulation concerned. - weight_variable_by_entity (Dict): The weight variable for each entity. - """ - simulation.weight_variable_by_entity = weight_variable_by_entity - - -def summarize_variable( - simulation: Simulation, - variable: Optional[str] = None, - weighted: bool = False, - force_compute: bool = False, -): - """Print a summary of a variable including its memory usage. - - Args: - variable(string): The variable being summarized - use_baseline(bool): The tax-benefit-system considered - weighted(bool): Whether the produced statistics should be weigthted or not - force_compute(bool): Whether the computation of the variable should be forced - """ - tax_benefit_system = simulation.tax_benefit_system - variable_instance = tax_benefit_system.variables.get(variable) - assert variable_instance is not None, f"{variable} is not a valid variable" - - default_value = variable_instance.default_value - value_type = variable_instance.value_type - - if variable_instance.is_neutralized: - log.info( - "%s: neutralized variable (%s, default = %s)", - variable, - str(np.dtype(value_type)), - default_value, - ) - return - - if weighted: - weight_variable = simulation.weight_variable_by_entity[variable_instance.entity.key] - weights = simulation.calculate(weight_variable, simulation.period) - - infos = simulation.get_memory_usage(variables=[variable])["by_variable"].get(variable) - if not infos: - if force_compute: - simulation.adaptative_calculate_variable(variable=variable, period=simulation.period) - simulation.summarize_variable(variable=variable, weighted=weighted) - return - else: - log.info( - "%s is not computed yet. Use keyword argument force_compute = True", - variable, - ) - return - - header_line = "{}: {} periods * {} cells * item size {} ({}, default = {}) = {}".format( - variable, - infos["nb_arrays"], - infos["nb_cells_by_array"], - infos["cell_size"], - str(np.dtype(infos["dtype"])), - default_value, - humanize.naturalsize(infos["total_nb_bytes"], gnu=True), - ) - log.info("%s", header_line) - log.info("Details:") - holder = simulation.get_holder(variable) - if holder is not None: - if holder.variable.definition_period == ETERNITY: - array = holder.get_array(ETERNITY) - else: - for period in sorted(simulation.get_known_periods(variable)): - array = holder.get_array(period) - if array.shape == (): - log.info("%s: always = %s", period, array) - continue - - if value_type == Enum: - possible_values = variable_instance.possible_values - categories_by_index = dict( - zip(range(len(possible_values._member_names_)), possible_values._member_names_) - ) - categories_type = pd.api.types.CategoricalDtype( - categories=possible_values._member_names_, ordered=True - ) - df = pd.DataFrame({variable: array}).replace(categories_by_index).astype(categories_type) - df["weights"] = weights if weighted else 1 - groupby = df.groupby(variable)["weights"].sum() - total = groupby.sum() - expr = [f" {index} = {row:.2e} ({row / total:.1%})" for index, row in groupby.items()] - log.info("%s: %s.", period, ",".join(expr)) - continue - - # Handle numeric types - default_array = np.array(default_value, dtype=infos["dtype"]) - total = np.sum(array * weights) if weighted else np.sum(array) - mean_val = (array * weights).sum() / weights.sum() if weighted else array.mean() - total_val = f"mass = {total:.2e}" - stats_val = ( - f"min = {array.min()}, max = {array.max()}, {total_val}, " - f"default = {(array == default_array).sum() / len(array):.1%}, " - f"median = {np.median(array)}" - ) - log.info("%s: mean = %s, %s", period, mean_val, stats_val) - - -# Monkey patching +import warnings -Simulation.adaptative_calculate_variable = adaptative_calculate_variable -Simulation.compute_aggregate = compute_aggregate -Simulation.compute_pivot_table = compute_pivot_table -Simulation.create_data_frame_by_entity = create_data_frame_by_entity -Simulation.compute_quantiles = compute_quantiles -Simulation.compute_winners_losers = compute_winners_losers -Simulation.new_from_tax_benefit_system = new_from_tax_benefit_system -Simulation.inflate = inflate -Simulation.init_entity_data = init_entity_data -Simulation.init_simulation = init_simulation -Simulation.init_variable_in_entity = init_variable_in_entity -Simulation.print_memory_usage = print_memory_usage -Simulation.set_weight_variable_by_entity = set_weight_variable_by_entity -Simulation.summarize_variable = summarize_variable +from openfisca_survey_manager.policy.simulations import ( + SecretViolationError, + Simulation, + adaptative_calculate_variable, + assert_variables_in_same_entity, + compute_aggregate, + compute_pivot_table, + compute_quantiles, + compute_winners_losers, + create_data_frame_by_entity, + get_words, + inflate, + init_entity_data, + init_simulation, + init_variable_in_entity, + new_from_tax_benefit_system, + print_memory_usage, + set_weight_variable_by_entity, + summarize_variable, +) + +warnings.warn( + "openfisca_survey_manager.simulations is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy import ... " + "or from openfisca_survey_manager.policy.simulations import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "SecretViolationError", + "Simulation", + "adaptative_calculate_variable", + "assert_variables_in_same_entity", + "compute_aggregate", + "compute_pivot_table", + "compute_quantiles", + "compute_winners_losers", + "create_data_frame_by_entity", + "get_words", + "inflate", + "init_entity_data", + "init_simulation", + "init_variable_in_entity", + "new_from_tax_benefit_system", + "print_memory_usage", + "set_weight_variable_by_entity", + "summarize_variable", +] diff --git a/openfisca_survey_manager/tests/conftest.py b/openfisca_survey_manager/tests/conftest.py index 825a5cd1..5d639e77 100644 --- a/openfisca_survey_manager/tests/conftest.py +++ b/openfisca_survey_manager/tests/conftest.py @@ -1,8 +1,34 @@ +import json from pathlib import Path import pandas as pd import pytest +# Répertoire des données de test (à côté de conftest.py) +TESTS_DATA_FILES = Path(__file__).resolve().parent / "data_files" +TEST_RANDOM_GENERATOR_JSON = TESTS_DATA_FILES / "test_random_generator.json" + +# Contenu minimal pour que SurveyCollection.load() réussisse +TEST_RANDOM_GENERATOR_MINIMAL_JSON = {"name": "test_random_generator", "surveys": {}} + + +@pytest.fixture(scope="session", autouse=True) +def ensure_test_random_generator_json(): + """Crée test_random_generator.json avec un contenu minimal s'il est absent ou vide.""" + TESTS_DATA_FILES.mkdir(parents=True, exist_ok=True) + if not TEST_RANDOM_GENERATOR_JSON.exists(): + TEST_RANDOM_GENERATOR_JSON.write_text( + json.dumps(TEST_RANDOM_GENERATOR_MINIMAL_JSON, indent=2), + encoding="utf-8", + ) + else: + raw = TEST_RANDOM_GENERATOR_JSON.read_text(encoding="utf-8").strip() + if not raw: + TEST_RANDOM_GENERATOR_JSON.write_text( + json.dumps(TEST_RANDOM_GENERATOR_MINIMAL_JSON, indent=2), + encoding="utf-8", + ) + @pytest.fixture def parquet_data(tmp_path: Path): diff --git a/openfisca_survey_manager/tests/test_coverage_boost.py b/openfisca_survey_manager/tests/test_coverage_boost.py index 3e1e6930..2da67554 100644 --- a/openfisca_survey_manager/tests/test_coverage_boost.py +++ b/openfisca_survey_manager/tests/test_coverage_boost.py @@ -5,12 +5,12 @@ import pytest from openfisca_core import periods -from openfisca_survey_manager.aggregates import AbstractAggregates from openfisca_survey_manager.input_dataframe_generator import ( make_input_dataframe_by_entity, randomly_init_variable, set_table_in_survey, ) +from openfisca_survey_manager.policy import AbstractAggregates from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.scripts.build_collection import ( check_template_config_files, @@ -19,7 +19,6 @@ from openfisca_survey_manager.survey_collections import SurveyCollection from openfisca_survey_manager.surveys import Survey from openfisca_survey_manager.tests import tax_benefit_system -from openfisca_survey_manager.utils import do_nothing from openfisca_survey_manager.variables import quantile @@ -135,17 +134,6 @@ def test_build_collection_helpers(tmp_path): create_data_file_by_format(str(tmp_path)) -def test_google_colab_boost(): - from openfisca_survey_manager.google_colab import create_raw_data_ini - - with suppress(Exception): - create_raw_data_ini({"test": {"opt": "val"}}) - - -def test_utils_do_nothing(): - assert do_nothing(1, a=2) is None - - def test_matching_mock_extended(monkeypatch): import sys diff --git a/openfisca_survey_manager/tests/test_legislation_inflator.py b/openfisca_survey_manager/tests/test_legislation_inflator.py index 7344a7b7..986174fb 100644 --- a/openfisca_survey_manager/tests/test_legislation_inflator.py +++ b/openfisca_survey_manager/tests/test_legislation_inflator.py @@ -1,7 +1,7 @@ from openfisca_core import periods from openfisca_country_template import CountryTaxBenefitSystem -from openfisca_survey_manager.utils import inflate_parameters, parameters_asof +from openfisca_survey_manager.policy.legislation_asof import inflate_parameters, parameters_asof def test_asof_simple_annual_parameter(): diff --git a/openfisca_survey_manager/tests/test_tax_benefit_system_asof.py b/openfisca_survey_manager/tests/test_tax_benefit_system_asof.py index 994f7f44..3318df92 100644 --- a/openfisca_survey_manager/tests/test_tax_benefit_system_asof.py +++ b/openfisca_survey_manager/tests/test_tax_benefit_system_asof.py @@ -2,7 +2,7 @@ from openfisca_core.parameters import ParameterNode, Scale from openfisca_country_template import CountryTaxBenefitSystem -from openfisca_survey_manager.utils import parameters_asof, variables_asof +from openfisca_survey_manager.policy.legislation_asof import parameters_asof, variables_asof def check_max_instant_leaf(sub_parameter, instant): diff --git a/openfisca_survey_manager/utils.py b/openfisca_survey_manager/utils.py index 78510a7a..3b3d6fa4 100644 --- a/openfisca_survey_manager/utils.py +++ b/openfisca_survey_manager/utils.py @@ -1,11 +1,12 @@ -"""Utilities: re-exports from common.misc + load_table (survey-dependent).""" +"""Utilities: re-exports from policy.legislation_asof + load_table (survey-dependent).""" import logging +import warnings from typing import Optional import pandas as pd -from openfisca_survey_manager.common.misc import ( +from openfisca_survey_manager.policy.legislation_asof import ( asof, do_nothing, inflate_parameter_leaf, @@ -17,6 +18,13 @@ log = logging.getLogger(__name__) +warnings.warn( + "openfisca_survey_manager.utils is deprecated for legislation helpers and will be removed in a future " + "version. Prefer: from openfisca_survey_manager.policy.legislation_asof import ...", + DeprecationWarning, + stacklevel=2, +) + __all__ = [ "asof", "do_nothing", diff --git a/uv.lock b/uv.lock index 966c17f3..ae2d22e0 100644 --- a/uv.lock +++ b/uv.lock @@ -1272,7 +1272,7 @@ wheels = [ [[package]] name = "openfisca-survey-manager" -version = "6.3.0" +version = "6.5.0" source = { editable = "." } dependencies = [ { name = "chardet" }, From f19fb11bce451f30e308c42d72877825b5a89448 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 21:42:44 +0100 Subject: [PATCH 06/18] refactor: move input_dataframe_generator to tests, coicop/matching/statshelpers/variables to policy - input_dataframe_generator: moved to tests/ (used only by tests) - coicop, matching, statshelpers, variables: moved to policy/ - update all imports accordingly - coicop: paths -> configuration.paths Made-with: Cursor --- docs/MIGRATION_IMPORTS.md | 4 ++-- .../{ => policy}/coicop.py | 2 +- .../{ => policy}/matching.py | 0 openfisca_survey_manager/policy/simulations.py | 2 +- .../{ => policy}/statshelpers.py | 0 .../{ => policy}/variables.py | 2 +- .../{ => tests}/input_dataframe_generator.py | 15 ++++++++++----- .../tests/test_add_survey_to_collection.py | 4 ++-- .../tests/test_coverage_boost.py | 18 +++++++++--------- .../tests/test_matching.py | 2 +- .../tests/test_quantile.py | 4 ++-- .../tests/test_scenario.py | 12 ++++++------ openfisca_survey_manager/tests/test_surveys.py | 6 +++--- .../tests/test_top_bottom_share.py | 2 +- 14 files changed, 39 insertions(+), 34 deletions(-) rename openfisca_survey_manager/{ => policy}/coicop.py (97%) rename openfisca_survey_manager/{ => policy}/matching.py (100%) rename openfisca_survey_manager/{ => policy}/statshelpers.py (100%) rename openfisca_survey_manager/{ => policy}/variables.py (96%) rename openfisca_survey_manager/{ => tests}/input_dataframe_generator.py (96%) diff --git a/docs/MIGRATION_IMPORTS.md b/docs/MIGRATION_IMPORTS.md index 8951e919..f6a9da10 100644 --- a/docs/MIGRATION_IMPORTS.md +++ b/docs/MIGRATION_IMPORTS.md @@ -47,7 +47,7 @@ Avant (ou en même temps que) la suppression des fichiers listés en section 3, | Fichier | Remplacer | |---------|-----------| -| `input_dataframe_generator.py` | `paths` → `configuration.paths` | +| `tests/input_dataframe_generator.py` | `paths` → `configuration.paths` (module déplacé dans `tests/`) | | `scripts/build_collection.py` | `paths` → `configuration.paths` | | `temporary.py` | `paths` → `configuration.paths` | | `google_colab.py` | `paths` → `configuration.paths` | @@ -61,7 +61,7 @@ Avant (ou en même temps que) la suppression des fichiers listés en section 3, | Fichier | Remplacer | |---------|-----------| -| `input_dataframe_generator.py` | `survey_collections`, `surveys` → `core.dataset`, `core.survey` | +| `tests/input_dataframe_generator.py` | `survey_collections`, `surveys` → `core.dataset`, `core.survey` | | `simulations.py` | `survey_collections`, `utils` → `core.dataset` ; utils → `common.misc` + module de `load_table` | | `utils.py` | `survey_collections` → `core.dataset` (pour `load_table`) | | `scripts/build_collection.py` | `survey_collections`, `surveys` → `core.dataset`, `core.survey` | diff --git a/openfisca_survey_manager/coicop.py b/openfisca_survey_manager/policy/coicop.py similarity index 97% rename from openfisca_survey_manager/coicop.py rename to openfisca_survey_manager/policy/coicop.py index b30223d5..22367356 100644 --- a/openfisca_survey_manager/coicop.py +++ b/openfisca_survey_manager/policy/coicop.py @@ -3,7 +3,7 @@ import pandas as pd -from openfisca_survey_manager.paths import openfisca_survey_manager_location +from openfisca_survey_manager.configuration.paths import openfisca_survey_manager_location log = logging.getLogger(__name__) diff --git a/openfisca_survey_manager/matching.py b/openfisca_survey_manager/policy/matching.py similarity index 100% rename from openfisca_survey_manager/matching.py rename to openfisca_survey_manager/policy/matching.py diff --git a/openfisca_survey_manager/policy/simulations.py b/openfisca_survey_manager/policy/simulations.py index 942ee3a4..2d55f20f 100644 --- a/openfisca_survey_manager/policy/simulations.py +++ b/openfisca_survey_manager/policy/simulations.py @@ -22,7 +22,7 @@ SimulationBuilder, diagnose_variable_mismatch, ) -from openfisca_survey_manager.statshelpers import mark_weighted_percentiles +from openfisca_survey_manager.policy.statshelpers import mark_weighted_percentiles from openfisca_survey_manager.survey_collections import SurveyCollection from openfisca_survey_manager.utils import do_nothing, load_table diff --git a/openfisca_survey_manager/statshelpers.py b/openfisca_survey_manager/policy/statshelpers.py similarity index 100% rename from openfisca_survey_manager/statshelpers.py rename to openfisca_survey_manager/policy/statshelpers.py diff --git a/openfisca_survey_manager/variables.py b/openfisca_survey_manager/policy/variables.py similarity index 96% rename from openfisca_survey_manager/variables.py rename to openfisca_survey_manager/policy/variables.py index f91cbcc3..0e506c14 100644 --- a/openfisca_survey_manager/variables.py +++ b/openfisca_survey_manager/policy/variables.py @@ -3,7 +3,7 @@ from numpy import arange from openfisca_core.model_api import ADD, YEAR, Variable, where -from openfisca_survey_manager.statshelpers import mark_weighted_percentiles, weightedcalcs_quantiles +from openfisca_survey_manager.policy.statshelpers import mark_weighted_percentiles, weightedcalcs_quantiles log = logging.getLogger(__name__) diff --git a/openfisca_survey_manager/input_dataframe_generator.py b/openfisca_survey_manager/tests/input_dataframe_generator.py similarity index 96% rename from openfisca_survey_manager/input_dataframe_generator.py rename to openfisca_survey_manager/tests/input_dataframe_generator.py index 4fa18510..3316ac5c 100644 --- a/openfisca_survey_manager/input_dataframe_generator.py +++ b/openfisca_survey_manager/tests/input_dataframe_generator.py @@ -1,3 +1,5 @@ +"""Helpers to build input dataframes and fill surveys for tests.""" + import configparser import logging import random @@ -7,9 +9,12 @@ import pandas as pd from openfisca_core import periods -from openfisca_survey_manager.paths import default_config_files_directory, openfisca_survey_manager_location -from openfisca_survey_manager.survey_collections import SurveyCollection -from openfisca_survey_manager.surveys import Survey +from openfisca_survey_manager.configuration.paths import ( + default_config_files_directory, + openfisca_survey_manager_location, +) +from openfisca_survey_manager.core.dataset import SurveyCollection +from openfisca_survey_manager.core.survey import Survey log = logging.getLogger(__name__) @@ -27,7 +32,7 @@ def make_input_dataframe_by_entity(tax_benefit_system, nb_persons, nb_groups): Example: - >>> from openfisca_survey_manager.input_dataframe_generator import make_input_dataframe_by_entity + >>> from openfisca_survey_manager.tests.input_dataframe_generator import make_input_dataframe_by_entity >>> from openfisca_country_template import CountryTaxBenefitSystem >>> tbs = CountryTaxBenefitSystem() >>> input_dataframe_by_entity = make_input_dataframe_by_entity(tbs, 400, 100) @@ -150,7 +155,7 @@ def randomly_init_variable( seed: Random seed used whe ndrawing the values (Default value = None) Examples - >>> from openfisca_survey_manager.input_dataframe_generator import make_input_dataframe_by_entity + >>> from openfisca_survey_manager.tests.input_dataframe_generator import make_input_dataframe_by_entity >>> from openfisca_country_template import CountryTaxBenefitSystem >>> tbs = CountryTaxBenefitSystem() >>> input_dataframe_by_entity = make_input_dataframe_by_entity(tbs, 400, 100) diff --git a/openfisca_survey_manager/tests/test_add_survey_to_collection.py b/openfisca_survey_manager/tests/test_add_survey_to_collection.py index 3d2c968d..a417f29d 100644 --- a/openfisca_survey_manager/tests/test_add_survey_to_collection.py +++ b/openfisca_survey_manager/tests/test_add_survey_to_collection.py @@ -1,11 +1,11 @@ import pandas as pd import pytest -from openfisca_survey_manager.input_dataframe_generator import set_table_in_survey +from openfisca_survey_manager.core.dataset import SurveyCollection from openfisca_survey_manager.scripts.build_collection import ( add_survey_to_collection, ) -from openfisca_survey_manager.survey_collections import SurveyCollection +from openfisca_survey_manager.tests.input_dataframe_generator import set_table_in_survey @pytest.mark.order(after="test_write_parquet.py::test_write_parquet_one_file_per_entity") diff --git a/openfisca_survey_manager/tests/test_coverage_boost.py b/openfisca_survey_manager/tests/test_coverage_boost.py index 2da67554..2838ab5f 100644 --- a/openfisca_survey_manager/tests/test_coverage_boost.py +++ b/openfisca_survey_manager/tests/test_coverage_boost.py @@ -5,21 +5,21 @@ import pytest from openfisca_core import periods -from openfisca_survey_manager.input_dataframe_generator import ( - make_input_dataframe_by_entity, - randomly_init_variable, - set_table_in_survey, -) +from openfisca_survey_manager.core.dataset import SurveyCollection +from openfisca_survey_manager.core.survey import Survey from openfisca_survey_manager.policy import AbstractAggregates +from openfisca_survey_manager.policy.variables import quantile from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.scripts.build_collection import ( check_template_config_files, create_data_file_by_format, ) -from openfisca_survey_manager.survey_collections import SurveyCollection -from openfisca_survey_manager.surveys import Survey from openfisca_survey_manager.tests import tax_benefit_system -from openfisca_survey_manager.variables import quantile +from openfisca_survey_manager.tests.input_dataframe_generator import ( + make_input_dataframe_by_entity, + randomly_init_variable, + set_table_in_survey, +) def setup_test_config(config_files_directory: Path): @@ -156,7 +156,7 @@ def test_matching_mock_extended(monkeypatch): ) monkeypatch.setitem(sys.modules, "rpy2", fake_rpy2) monkeypatch.setitem(sys.modules, "rpy2.robjects", fake_rpy2.robjects) - from openfisca_survey_manager.matching import nnd_hotdeck_using_feather, nnd_hotdeck_using_rpy2 + from openfisca_survey_manager.policy.matching import nnd_hotdeck_using_feather, nnd_hotdeck_using_rpy2 receiver = pd.DataFrame({"a": [1], "c": [1]}) donor = pd.DataFrame({"a": [1], "b": [2], "c": [1]}) diff --git a/openfisca_survey_manager/tests/test_matching.py b/openfisca_survey_manager/tests/test_matching.py index 10d91255..ecc1b39e 100644 --- a/openfisca_survey_manager/tests/test_matching.py +++ b/openfisca_survey_manager/tests/test_matching.py @@ -2,7 +2,7 @@ import pandas as pd -from openfisca_survey_manager.matching import nnd_hotdeck_using_rpy2 +from openfisca_survey_manager.policy.matching import nnd_hotdeck_using_rpy2 try: import rpy2 diff --git a/openfisca_survey_manager/tests/test_quantile.py b/openfisca_survey_manager/tests/test_quantile.py index 6c33c69b..0038b7a7 100644 --- a/openfisca_survey_manager/tests/test_quantile.py +++ b/openfisca_survey_manager/tests/test_quantile.py @@ -5,10 +5,10 @@ from openfisca_core.taxbenefitsystems import TaxBenefitSystem from openfisca_survey_manager.paths import default_config_files_directory +from openfisca_survey_manager.policy.statshelpers import mark_weighted_percentiles +from openfisca_survey_manager.policy.variables import quantile from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario -from openfisca_survey_manager.statshelpers import mark_weighted_percentiles from openfisca_survey_manager.tests.test_scenario import setup_test_config -from openfisca_survey_manager.variables import quantile Individu = build_entity( key="individu", diff --git a/openfisca_survey_manager/tests/test_scenario.py b/openfisca_survey_manager/tests/test_scenario.py index 13fa5c0a..9657639d 100644 --- a/openfisca_survey_manager/tests/test_scenario.py +++ b/openfisca_survey_manager/tests/test_scenario.py @@ -7,17 +7,17 @@ from openfisca_core import periods from openfisca_core.tools import assert_near -from openfisca_survey_manager.input_dataframe_generator import ( - make_input_dataframe_by_entity, - random_data_generator, - randomly_init_variable, -) -from openfisca_survey_manager.paths import ( +from openfisca_survey_manager.configuration.paths import ( default_config_files_directory, ) from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.scenarios.reform_scenario import ReformScenario from openfisca_survey_manager.tests import tax_benefit_system +from openfisca_survey_manager.tests.input_dataframe_generator import ( + make_input_dataframe_by_entity, + random_data_generator, + randomly_init_variable, +) log = logging.getLogger(__name__) diff --git a/openfisca_survey_manager/tests/test_surveys.py b/openfisca_survey_manager/tests/test_surveys.py index de2556be..d3891ca7 100644 --- a/openfisca_survey_manager/tests/test_surveys.py +++ b/openfisca_survey_manager/tests/test_surveys.py @@ -1,9 +1,9 @@ import pandas as pd import pytest -from openfisca_survey_manager.input_dataframe_generator import set_table_in_survey -from openfisca_survey_manager.survey_collections import SurveyCollection -from openfisca_survey_manager.surveys import Survey +from openfisca_survey_manager.core.dataset import SurveyCollection +from openfisca_survey_manager.core.survey import Survey +from openfisca_survey_manager.tests.input_dataframe_generator import set_table_in_survey @pytest.fixture diff --git a/openfisca_survey_manager/tests/test_top_bottom_share.py b/openfisca_survey_manager/tests/test_top_bottom_share.py index 05497b56..acf6b362 100644 --- a/openfisca_survey_manager/tests/test_top_bottom_share.py +++ b/openfisca_survey_manager/tests/test_top_bottom_share.py @@ -1,6 +1,6 @@ import numpy as np -from openfisca_survey_manager.statshelpers import bottom_share, top_share +from openfisca_survey_manager.policy.statshelpers import bottom_share, top_share size = 1000 x = np.ones(size) + np.random.uniform(0, 0.00000001, size) From 3da51d73e2963191db6d5bebdce80008d502020f Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 21:46:18 +0100 Subject: [PATCH 07/18] refactor: move calmar/calibration to policy, add placeholders with DeprecationWarnings - calmar, calibration: processing/weights -> policy/ - processing/weights: re-export from policy + DeprecationWarning - processing/__init__: lazy import for weights (avoid circular import) - placeholders at root: coicop, input_dataframe_generator, matching, statshelpers, variables (with DeprecationWarnings) - calmar, calibration placeholders: add DeprecationWarnings Made-with: Cursor --- openfisca_survey_manager/calibration.py | 13 ++++++- openfisca_survey_manager/calmar.py | 16 +++++++- openfisca_survey_manager/coicop.py | 29 +++++++++++++++ .../input_dataframe_generator.py | 29 +++++++++++++++ openfisca_survey_manager/matching.py | 25 +++++++++++++ .../weights => policy}/calibration.py | 2 +- .../{processing/weights => policy}/calmar.py | 0 .../processing/__init__.py | 11 +++++- .../processing/weights/__init__.py | 16 ++++++-- .../scenarios/abstract_scenario.py | 2 +- openfisca_survey_manager/statshelpers.py | 37 +++++++++++++++++++ .../tests/test_calibration.py | 2 +- openfisca_survey_manager/tests/test_calmar.py | 2 +- openfisca_survey_manager/variables.py | 25 +++++++++++++ 14 files changed, 197 insertions(+), 12 deletions(-) create mode 100644 openfisca_survey_manager/coicop.py create mode 100644 openfisca_survey_manager/input_dataframe_generator.py create mode 100644 openfisca_survey_manager/matching.py rename openfisca_survey_manager/{processing/weights => policy}/calibration.py (99%) rename openfisca_survey_manager/{processing/weights => policy}/calmar.py (100%) create mode 100644 openfisca_survey_manager/statshelpers.py create mode 100644 openfisca_survey_manager/variables.py diff --git a/openfisca_survey_manager/calibration.py b/openfisca_survey_manager/calibration.py index 1670538c..23124894 100644 --- a/openfisca_survey_manager/calibration.py +++ b/openfisca_survey_manager/calibration.py @@ -1,8 +1,17 @@ """Re-export for backward compatibility. -Prefer: from openfisca_survey_manager.processing.weights import Calibration. +Prefer: from openfisca_survey_manager.policy.calibration import ... """ -from openfisca_survey_manager.processing.weights import Calibration +import warnings + +from openfisca_survey_manager.policy.calibration import Calibration + +warnings.warn( + "openfisca_survey_manager.calibration is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.calibration import ...", + DeprecationWarning, + stacklevel=2, +) __all__ = ["Calibration"] diff --git a/openfisca_survey_manager/calmar.py b/openfisca_survey_manager/calmar.py index 77c2ba53..6734f8d7 100644 --- a/openfisca_survey_manager/calmar.py +++ b/openfisca_survey_manager/calmar.py @@ -1,5 +1,17 @@ -"""Re-export for backward compatibility. Prefer: from openfisca_survey_manager.processing.weights import calmar.""" +"""Re-export for backward compatibility. -from openfisca_survey_manager.processing.weights import calmar, check_calmar +Prefer: from openfisca_survey_manager.policy.calmar import ... +""" + +import warnings + +from openfisca_survey_manager.policy.calmar import calmar, check_calmar + +warnings.warn( + "openfisca_survey_manager.calmar is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.calmar import ...", + DeprecationWarning, + stacklevel=2, +) __all__ = ["calmar", "check_calmar"] diff --git a/openfisca_survey_manager/coicop.py b/openfisca_survey_manager/coicop.py new file mode 100644 index 00000000..dc991a89 --- /dev/null +++ b/openfisca_survey_manager/coicop.py @@ -0,0 +1,29 @@ +"""Re-export for backward compatibility. + +Prefer: from openfisca_survey_manager.policy.coicop import ... +""" + +import warnings + +from openfisca_survey_manager.policy.coicop import ( + build_coicop_level_nomenclature, + build_raw_coicop_nomenclature, + divisions, + legislation_directory, + sub_levels, +) + +warnings.warn( + "openfisca_survey_manager.coicop is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.coicop import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "build_coicop_level_nomenclature", + "build_raw_coicop_nomenclature", + "divisions", + "legislation_directory", + "sub_levels", +] diff --git a/openfisca_survey_manager/input_dataframe_generator.py b/openfisca_survey_manager/input_dataframe_generator.py new file mode 100644 index 00000000..d32cebc0 --- /dev/null +++ b/openfisca_survey_manager/input_dataframe_generator.py @@ -0,0 +1,29 @@ +"""Re-export for backward compatibility. + +Prefer: from openfisca_survey_manager.tests.input_dataframe_generator import ... +""" + +import warnings + +from openfisca_survey_manager.tests.input_dataframe_generator import ( + build_input_dataframe_from_test_case, + make_input_dataframe_by_entity, + random_data_generator, + randomly_init_variable, + set_table_in_survey, +) + +warnings.warn( + "openfisca_survey_manager.input_dataframe_generator is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.tests.input_dataframe_generator import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "build_input_dataframe_from_test_case", + "make_input_dataframe_by_entity", + "random_data_generator", + "randomly_init_variable", + "set_table_in_survey", +] diff --git a/openfisca_survey_manager/matching.py b/openfisca_survey_manager/matching.py new file mode 100644 index 00000000..57ae0b36 --- /dev/null +++ b/openfisca_survey_manager/matching.py @@ -0,0 +1,25 @@ +"""Re-export for backward compatibility. + +Prefer: from openfisca_survey_manager.policy.matching import ... +""" + +import warnings + +from openfisca_survey_manager.policy.matching import ( + nnd_hotdeck, + nnd_hotdeck_using_feather, + nnd_hotdeck_using_rpy2, +) + +warnings.warn( + "openfisca_survey_manager.matching is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.matching import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "nnd_hotdeck", + "nnd_hotdeck_using_feather", + "nnd_hotdeck_using_rpy2", +] diff --git a/openfisca_survey_manager/processing/weights/calibration.py b/openfisca_survey_manager/policy/calibration.py similarity index 99% rename from openfisca_survey_manager/processing/weights/calibration.py rename to openfisca_survey_manager/policy/calibration.py index b28337ce..a4ee84aa 100644 --- a/openfisca_survey_manager/processing/weights/calibration.py +++ b/openfisca_survey_manager/policy/calibration.py @@ -12,7 +12,7 @@ from numpy import logical_not from openfisca_core.model_api import Enum -from openfisca_survey_manager.processing.weights.calmar import calmar +from openfisca_survey_manager.policy.calmar import calmar log = logging.getLogger(__name__) diff --git a/openfisca_survey_manager/processing/weights/calmar.py b/openfisca_survey_manager/policy/calmar.py similarity index 100% rename from openfisca_survey_manager/processing/weights/calmar.py rename to openfisca_survey_manager/policy/calmar.py diff --git a/openfisca_survey_manager/processing/__init__.py b/openfisca_survey_manager/processing/__init__.py index 1b9b5249..b0097a35 100644 --- a/openfisca_survey_manager/processing/__init__.py +++ b/openfisca_survey_manager/processing/__init__.py @@ -3,7 +3,16 @@ from openfisca_survey_manager.processing.cleaning import clean_data_frame from openfisca_survey_manager.processing.harmonization import harmonize_data_frame_columns -from openfisca_survey_manager.processing.weights import Calibration, calmar, check_calmar + + +# Lazy import to avoid circular dependency (processing -> policy -> survey_collections -> core) +def __getattr__(name: str): + if name in ("Calibration", "calmar", "check_calmar"): + from openfisca_survey_manager.processing.weights import Calibration, calmar, check_calmar + + return {"Calibration": Calibration, "calmar": calmar, "check_calmar": check_calmar}[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "Calibration", diff --git a/openfisca_survey_manager/processing/weights/__init__.py b/openfisca_survey_manager/processing/weights/__init__.py index 66b4484b..cfa0ba8f 100644 --- a/openfisca_survey_manager/processing/weights/__init__.py +++ b/openfisca_survey_manager/processing/weights/__init__.py @@ -1,6 +1,16 @@ -# Calibration and CALMAR weight calibration. See docs/REFACTORING_PLAN.md. +# Calibration and CALMAR weight calibration. Re-exports from policy. -from openfisca_survey_manager.processing.weights.calibration import Calibration -from openfisca_survey_manager.processing.weights.calmar import calmar, check_calmar +import warnings + +from openfisca_survey_manager.policy.calibration import Calibration +from openfisca_survey_manager.policy.calmar import calmar, check_calmar + +warnings.warn( + "openfisca_survey_manager.processing.weights is deprecated for Calibration/calmar. " + "Prefer: from openfisca_survey_manager.policy.calibration import Calibration, " + "from openfisca_survey_manager.policy.calmar import calmar, check_calmar", + DeprecationWarning, + stacklevel=3, +) __all__ = ["Calibration", "calmar", "check_calmar"] diff --git a/openfisca_survey_manager/scenarios/abstract_scenario.py b/openfisca_survey_manager/scenarios/abstract_scenario.py index 57fe9493..f2765dd7 100644 --- a/openfisca_survey_manager/scenarios/abstract_scenario.py +++ b/openfisca_survey_manager/scenarios/abstract_scenario.py @@ -11,8 +11,8 @@ from openfisca_core.tools.simulation_dumper import dump_simulation, restore_simulation from openfisca_core.types import Array, Period, TaxBenefitSystem -from openfisca_survey_manager.calibration import Calibration from openfisca_survey_manager.exceptions import SurveyManagerError +from openfisca_survey_manager.policy.calibration import Calibration from openfisca_survey_manager.simulations import Simulation from openfisca_survey_manager.surveys import Survey diff --git a/openfisca_survey_manager/statshelpers.py b/openfisca_survey_manager/statshelpers.py new file mode 100644 index 00000000..c41f2d0b --- /dev/null +++ b/openfisca_survey_manager/statshelpers.py @@ -0,0 +1,37 @@ +"""Re-export for backward compatibility. + +Prefer: from openfisca_survey_manager.policy.statshelpers import ... +""" + +import warnings + +from openfisca_survey_manager.policy.statshelpers import ( + bottom_share, + gini, + kakwani, + lorenz, + mark_weighted_percentiles, + pseudo_lorenz, + top_share, + weighted_quantiles, + weightedcalcs_quantiles, +) + +warnings.warn( + "openfisca_survey_manager.statshelpers is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.statshelpers import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "bottom_share", + "gini", + "kakwani", + "lorenz", + "mark_weighted_percentiles", + "pseudo_lorenz", + "top_share", + "weighted_quantiles", + "weightedcalcs_quantiles", +] diff --git a/openfisca_survey_manager/tests/test_calibration.py b/openfisca_survey_manager/tests/test_calibration.py index 9e2bc37b..5d8074fd 100644 --- a/openfisca_survey_manager/tests/test_calibration.py +++ b/openfisca_survey_manager/tests/test_calibration.py @@ -1,7 +1,7 @@ from openfisca_core import periods from openfisca_core.tools import assert_near -from openfisca_survey_manager.calibration import Calibration +from openfisca_survey_manager.policy.calibration import Calibration from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.tests import tax_benefit_system from openfisca_survey_manager.tests.test_scenario import ( diff --git a/openfisca_survey_manager/tests/test_calmar.py b/openfisca_survey_manager/tests/test_calmar.py index 93a914d4..884f3997 100644 --- a/openfisca_survey_manager/tests/test_calmar.py +++ b/openfisca_survey_manager/tests/test_calmar.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from openfisca_survey_manager.calmar import calmar +from openfisca_survey_manager.policy.calmar import calmar def create_input_dataframe(entities=1): diff --git a/openfisca_survey_manager/variables.py b/openfisca_survey_manager/variables.py new file mode 100644 index 00000000..afc512ba --- /dev/null +++ b/openfisca_survey_manager/variables.py @@ -0,0 +1,25 @@ +"""Re-export for backward compatibility. + +Prefer: from openfisca_survey_manager.policy.variables import ... +""" + +import warnings + +from openfisca_survey_manager.policy.variables import ( + create_quantile, + old_quantile, + quantile, +) + +warnings.warn( + "openfisca_survey_manager.variables is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.variables import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "create_quantile", + "old_quantile", + "quantile", +] From 7062454a93dc7b2b0776437d30c582c7e11cc821 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 21:58:45 +0100 Subject: [PATCH 08/18] refactor: move scenarios to policy, add placeholders with DeprecationWarning - scenarios/ -> policy/scenarios/ (abstract_scenario, reform_scenario) - Placeholders at openfisca_survey_manager/scenarios/ with DeprecationWarning - Remove common, processing/weights, root placeholders (coicop, matching, etc.) - Update all imports to policy.scenarios - Add missing Survey import in abstract_scenario Made-with: Cursor --- openfisca_survey_manager/coicop.py | 29 - openfisca_survey_manager/common/__init__.py | 20 - openfisca_survey_manager/common/misc.py | 33 - .../input_dataframe_generator.py | 29 - openfisca_survey_manager/matching.py | 25 - .../policy/scenarios/__init__.py | 4 + .../policy/scenarios/abstract_scenario.py | 763 +++++++++++++++++ .../policy/scenarios/reform_scenario.py | 293 +++++++ .../processing/__init__.py | 3 +- .../processing/weights/__init__.py | 16 - .../scenarios/__init__.py | 17 + .../scenarios/abstract_scenario.py | 770 +----------------- .../scenarios/reform_scenario.py | 300 +------ openfisca_survey_manager/statshelpers.py | 37 - .../tests/test_calibration.py | 2 +- .../tests/test_coverage_boost.py | 2 +- openfisca_survey_manager/tests/test_enum.py | 2 +- .../tests/test_parquet.py | 6 +- .../tests/test_quantile.py | 2 +- .../tests/test_scenario.py | 4 +- openfisca_survey_manager/variables.py | 25 - 21 files changed, 1112 insertions(+), 1270 deletions(-) delete mode 100644 openfisca_survey_manager/coicop.py delete mode 100644 openfisca_survey_manager/common/__init__.py delete mode 100644 openfisca_survey_manager/common/misc.py delete mode 100644 openfisca_survey_manager/input_dataframe_generator.py delete mode 100644 openfisca_survey_manager/matching.py create mode 100644 openfisca_survey_manager/policy/scenarios/__init__.py create mode 100644 openfisca_survey_manager/policy/scenarios/abstract_scenario.py create mode 100644 openfisca_survey_manager/policy/scenarios/reform_scenario.py delete mode 100644 openfisca_survey_manager/processing/weights/__init__.py delete mode 100644 openfisca_survey_manager/statshelpers.py delete mode 100644 openfisca_survey_manager/variables.py diff --git a/openfisca_survey_manager/coicop.py b/openfisca_survey_manager/coicop.py deleted file mode 100644 index dc991a89..00000000 --- a/openfisca_survey_manager/coicop.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.coicop import ... -""" - -import warnings - -from openfisca_survey_manager.policy.coicop import ( - build_coicop_level_nomenclature, - build_raw_coicop_nomenclature, - divisions, - legislation_directory, - sub_levels, -) - -warnings.warn( - "openfisca_survey_manager.coicop is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.coicop import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "build_coicop_level_nomenclature", - "build_raw_coicop_nomenclature", - "divisions", - "legislation_directory", - "sub_levels", -] diff --git a/openfisca_survey_manager/common/__init__.py b/openfisca_survey_manager/common/__init__.py deleted file mode 100644 index 8f6c40c6..00000000 --- a/openfisca_survey_manager/common/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Target: shared helpers to avoid circular imports (from utils.py, paths, etc.). -# Final name will be utils/ once utils.py is migrated. See docs/REFACTORING_PLAN.md. - -from openfisca_survey_manager.common.misc import ( - asof, - do_nothing, - inflate_parameter_leaf, - inflate_parameters, - parameters_asof, - variables_asof, -) - -__all__ = [ - "asof", - "do_nothing", - "inflate_parameter_leaf", - "inflate_parameters", - "parameters_asof", - "variables_asof", -] diff --git a/openfisca_survey_manager/common/misc.py b/openfisca_survey_manager/common/misc.py deleted file mode 100644 index d14d95cf..00000000 --- a/openfisca_survey_manager/common/misc.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Backward-compatibility wrapper for legislation helpers. - -Deprecated: use ``openfisca_survey_manager.policy.legislation_asof`` instead. -""" - -import warnings - -warnings.warn( - "openfisca_survey_manager.common.misc is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.legislation_asof import ...", - DeprecationWarning, - stacklevel=2, -) - -from openfisca_survey_manager.policy.legislation_asof import ( # noqa: E402 - asof, - do_nothing, - inflate_parameter_leaf, - inflate_parameters, - leaf_asof, - parameters_asof, - variables_asof, -) - -__all__ = [ - "asof", - "do_nothing", - "inflate_parameter_leaf", - "inflate_parameters", - "leaf_asof", - "parameters_asof", - "variables_asof", -] diff --git a/openfisca_survey_manager/input_dataframe_generator.py b/openfisca_survey_manager/input_dataframe_generator.py deleted file mode 100644 index d32cebc0..00000000 --- a/openfisca_survey_manager/input_dataframe_generator.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.tests.input_dataframe_generator import ... -""" - -import warnings - -from openfisca_survey_manager.tests.input_dataframe_generator import ( - build_input_dataframe_from_test_case, - make_input_dataframe_by_entity, - random_data_generator, - randomly_init_variable, - set_table_in_survey, -) - -warnings.warn( - "openfisca_survey_manager.input_dataframe_generator is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.tests.input_dataframe_generator import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "build_input_dataframe_from_test_case", - "make_input_dataframe_by_entity", - "random_data_generator", - "randomly_init_variable", - "set_table_in_survey", -] diff --git a/openfisca_survey_manager/matching.py b/openfisca_survey_manager/matching.py deleted file mode 100644 index 57ae0b36..00000000 --- a/openfisca_survey_manager/matching.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.matching import ... -""" - -import warnings - -from openfisca_survey_manager.policy.matching import ( - nnd_hotdeck, - nnd_hotdeck_using_feather, - nnd_hotdeck_using_rpy2, -) - -warnings.warn( - "openfisca_survey_manager.matching is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.matching import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "nnd_hotdeck", - "nnd_hotdeck_using_feather", - "nnd_hotdeck_using_rpy2", -] diff --git a/openfisca_survey_manager/policy/scenarios/__init__.py b/openfisca_survey_manager/policy/scenarios/__init__.py new file mode 100644 index 00000000..aa248f89 --- /dev/null +++ b/openfisca_survey_manager/policy/scenarios/__init__.py @@ -0,0 +1,4 @@ +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario +from openfisca_survey_manager.policy.scenarios.reform_scenario import ReformScenario + +__all__ = ["AbstractSurveyScenario", "ReformScenario"] diff --git a/openfisca_survey_manager/policy/scenarios/abstract_scenario.py b/openfisca_survey_manager/policy/scenarios/abstract_scenario.py new file mode 100644 index 00000000..62e1a5ba --- /dev/null +++ b/openfisca_survey_manager/policy/scenarios/abstract_scenario.py @@ -0,0 +1,763 @@ +"Abstract survey scenario definition." + +import logging +from pathlib import Path +from typing import Optional, Union + +import numpy as np +import pandas as pd +from openfisca_core import periods +from openfisca_core.periods import MONTH, YEAR +from openfisca_core.tools.simulation_dumper import dump_simulation, restore_simulation +from openfisca_core.types import Array, Period, TaxBenefitSystem + +from openfisca_survey_manager.core.survey import Survey +from openfisca_survey_manager.exceptions import SurveyManagerError +from openfisca_survey_manager.policy import Simulation +from openfisca_survey_manager.policy.calibration import Calibration + +log = logging.getLogger(__name__) + + +class AbstractSurveyScenario: + """Abstract survey scenario.""" + + cache_blacklist = None + collection = None + debug = False + filtering_variable_by_entity = None + id_variable_by_entity_key = None + inflator_by_variable = None # factor used to inflate variable total + input_data_frame = None + input_data_table_by_entity_by_period = None + input_data_table_by_period = None + non_neutralizable_variables = None + period = None + role_variable_by_entity_key = None + simulations = None + target_by_variable = None # variable total target to inflate to + tax_benefit_systems = None + trace = False + used_as_input_variables = None + used_as_input_variables_by_entity = None + variation_factor = 0.03 # factor used to compute variation when estimating marginal tax rate + varying_variable = None + weight_variable_by_entity = None + + def build_input_data(self, **kwargs): + """Build input data.""" + raise NotImplementedError + + def calculate_series(self, variable, period=None, simulation=None): + """Compute variable values for period for a given simulation. + + Args: + variable(str, optional): Variable to compute + period(Period, optional): Period, defaults to None + simulation(str, optional): Simulation to use + + Returns: + pandas.Series: Variable values + + """ + return pd.Series( + data=self.calculate_variable(variable, period, simulation=simulation), + name=variable, + ) + + def calculate_variable(self, variable, period=None, simulation=None): + """Compute variable values for period for a given simulation. + + Args: + variable(str, optional): Variable to compute + period(Period, optional): Period, defaults to None + simulation(str, optional): Simulation to use + + Returns: + numpy.ndarray: Variable values + + """ + if simulation is None: + assert len(self.simulations.keys()) == 1 + simulation = next(iter(self.simulations.values())) + else: + simulation = self.simulations[simulation] + assert simulation is not None + return simulation.adaptative_calculate_variable(variable, period=period) + + def calibrate( + self, + period: Optional[int] = None, + target_margins_by_variable: Optional[dict] = None, + parameters: Optional[dict] = None, + target_entity_count: Optional[float] = None, + other_entity_count: Optional[float] = None, + entity: Optional[str] = None, + ): + """Calibrate the scenario data. + + Args: + period (int, optionnal): Period of calibration. Defaults to scenario.year + target_margins_by_variable (dict, optional): Variable targets margins. Defaults to None. + parameters (dict, optional): Calibration parameters. Defaults to None. + target_entity_count (float, optional): Total population target. Defaults to None. + other_entity_count (float, optional): Total population target of the second entity. Defaults to None. + entity (str): Entity specified when no variable comes with a target margins + but `target_entity_count` is not None. + """ + survey_scenario = self + + if period is None: + assert survey_scenario.period is not None + period = survey_scenario.period + + if parameters is not None: + assert parameters["method"] in ["linear", "raking ratio", "logit", "hyperbolic sinus"], ( + "Incorrect parameter value: method should be 'linear', 'raking ratio', 'logit' or 'hyperbolic sinus'" + ) + if parameters["method"] == "logit": + assert parameters["invlo"] is not None + assert parameters["up"] is not None + elif parameters["method"] == "hyperbolic sinus": + assert parameters["alpha"] is not None + else: + parameters = {"method": "logit", "up": 3, "invlo": 3} + + # TODO: filtering using filtering_variable_by_entity + for simulation in self.simulations.values(): + if simulation is None: + continue + calibration = Calibration( + simulation, + target_margins_by_variable, + period, + target_entity_count=target_entity_count, + other_entity_count=other_entity_count, + entity=entity, + parameters=parameters, + # filter_by = self.filter_by, + ) + calibration.calibrate(inplace=True) + simulation.calibration = calibration + + def compute_aggregate( + self, + variable: Optional[str] = None, + aggfunc: str = "sum", + filter_by: Optional[str] = None, + period: Optional[Union[int, str, Period]] = None, + simulation: Optional[str] = None, + baseline_simulation: Optional[str] = None, + missing_variable_default_value=np.nan, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + ): + """Compute variable aggregate. + + Args: + variable (str, optional): Variable to aggregate. Defaults to None. + aggfunc (str, optional): Aggregation function. Defaults to 'sum'. + filter_by (str, optional): Filter variable or expression to use. Defaults to None. + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + simulation(str, optional): Simulation to use + baseline_simulation(str, optional): Baseline simulation to use when computing a difference + missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. + weighted (bool, optional): Whether to weight the variable or not. Defaults to True. + alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. + Defaults to None. + filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. + + Returns: + float: Aggregate + """ + assert aggfunc in ["count", "mean", "sum", "count_non_zero"] + assert period is not None + assert variable is not None + if simulation is None: + assert len(self.simulations.keys()) == 1 + simulation = next(iter(self.simulations.values())) + else: + simulation = self.simulations[simulation] + + assert simulation is not None, f"Missing {simulation} simulation" + + if baseline_simulation: + baseline_simulation = self.simulations[baseline_simulation] + return simulation.compute_aggregate( + variable=variable, + aggfunc=aggfunc, + filter_by=filter_by, + period=period, + missing_variable_default_value=missing_variable_default_value, + weighted=weighted, + alternative_weights=alternative_weights, + filtering_variable_by_entity=self.filtering_variable_by_entity, + ) - baseline_simulation.compute_aggregate( + variable=variable, + aggfunc=aggfunc, + filter_by=filter_by, + period=period, + missing_variable_default_value=missing_variable_default_value, + weighted=weighted, + alternative_weights=alternative_weights, + filtering_variable_by_entity=self.filtering_variable_by_entity, + ) + + return simulation.compute_aggregate( + variable=variable, + aggfunc=aggfunc, + filter_by=filter_by, + period=period, + missing_variable_default_value=missing_variable_default_value, + weighted=weighted, + alternative_weights=alternative_weights, + filtering_variable_by_entity=self.filtering_variable_by_entity, + ) + + def compute_quantiles( + self, + simulation: Simulation, + variable: str, + nquantiles: Optional[int] = None, + period: Optional[Union[int, str, Period]] = None, + filter_by=None, + weighted: bool = True, + alternative_weights=None, + filtering_variable_by_entity=None, + ) -> list[float]: + """ + Compute quantiles of a variable. + + Args: + simulation (Simulation, optional): Simulation to be used. Defaults to None. + variable (str, optional): Variable which quantiles are computed. Defaults to None. + nquantiles (int, optional): Number of quantiles. Defaults to None. + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. + weighted (bool, optional): Whether to weight the variable or not. Defaults to True. + alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. + Defaults to None. + filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. + + Returns: + List(float): The quantiles values + """ + assert variable is not None + assert nquantiles is not None + simulation = self.simulations[simulation] + assert simulation is not None, f"Missing {simulation} simulation" + + return simulation.compute_quantiles( + variable=variable, + period=period, + nquantiles=nquantiles, + filter_by=filter_by, + weighted=weighted, + alternative_weights=alternative_weights, + ) + + def compute_marginal_tax_rate( + self, + target_variable: str, + period: Optional[Union[int, str, Period]], + simulation: Optional[str] = None, + value_for_zero_varying_variable: float = 0.0, + ) -> Array: + """ + Compute marginal a rate of a target (MTR) with respect to a varying variable. + + Args: + target_variable (str): the variable which marginal tax rate is computed + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + simulation(str, optional): Simulation to use + value_for_zero_varying_variable (float, optional): value of MTR when the varying variable is zero. + Defaults to 0. + + Returns: + numpy.array: Vector of marginal rates + """ + varying_variable = self.varying_variable + if simulation is None: + assert len(self.simulations.keys()) == 2 + simulation_name = next(name for name in self.simulations if not name.startswith("_modified_")) + simulation = self.simulations[simulation_name] + else: + simulation_name = simulation + simulation = self.simulations[simulation_name] + + modified_simulation = self.simulations[f"_modified_{simulation_name}"] + + variables = simulation.tax_benefit_system.variables + assert target_variable in variables + + variables_belong_to_same_entity = ( + variables[varying_variable].entity.key == variables[target_variable].entity.key + ) + varying_variable_belongs_to_person_entity = variables[varying_variable].entity.is_person + + assert variables_belong_to_same_entity or varying_variable_belongs_to_person_entity + + if variables_belong_to_same_entity: + modified_varying = modified_simulation.calculate_add(varying_variable, period=period) + varying = simulation.calculate_add(varying_variable, period=period) + else: + target_variable_entity_key = variables[target_variable].entity.key + + def cast_to_target_entity(simulation: Simulation): + population = simulation.populations[target_variable_entity_key] + df = ( + pd.DataFrame( + { + "members_entity_id": population._members_entity_id, + varying_variable: simulation.calculate_add(varying_variable, period=period), + } + ) + .groupby("members_entity_id") + .sum() + ) + varying_variable_for_target_entity = df.loc[population.ids, varying_variable].values + return varying_variable_for_target_entity + + modified_varying = cast_to_target_entity(modified_simulation) + varying = cast_to_target_entity(simulation) + + modified_target = modified_simulation.calculate_add(target_variable, period=period) + target = simulation.calculate_add(target_variable, period=period) + + numerator = modified_target - target + denominator = modified_varying - varying + marginal_rate = 1 - np.divide( + numerator, + denominator, + out=np.full_like(numerator, value_for_zero_varying_variable, dtype=float), + where=(denominator != 0), + ) + + return marginal_rate + + def compute_pivot_table( + self, + aggfunc="mean", + columns=None, + baseline_simulation=None, + filter_by=None, + index=None, + period=None, + simulation=None, + difference=False, + use_baseline_for_columns=None, + values=None, + missing_variable_default_value=np.nan, + concat_axis=None, + weighted=True, + alternative_weights=None, + ): + """Compute a pivot table of agregated values casted along specified index and columns. + + Args: + aggfunc(str, optional): Aggregation function, defaults to 'mean' + columns(list, optional): Variable(s) in columns, defaults to None + difference(bool, optional): Compute difference, defaults to False + filter_by(str, optional): Boolean variable to filter by, defaults to None + index(list, optional): Variable(s) in index (lines), defaults to None + period(Period, optional): Period, defaults to None + simulation(str, optional): Simulation to use + baseline_simulation(str, optional): Baseline simulation to use when computing a difference + use_baseline_for_columns(bool, optional): Use columns from baseline columns values, defaults to None + values(list, optional): Aggregated variable(s) within cells, defaults to None + missing_variable_default_value(float, optional): Default value for missing variables, defaults to np.nan + concat_axis(int, optional): Axis to concatenate along (index = 0, columns = 1), defaults to None + weighted(bool, optional): Whether to weight te aggregates (Default value = True) + alternative_weights(str or int or float, optional): Weight variable name or numerical value. + Use Simulation's weight_variable_by_entity if None, and if the later is None uses 1 + ((Default value = None) + + Returns: + pd.DataFrame: Pivot table + + """ + assert (not difference) or (baseline_simulation is not None), ( + "Can't have difference when not baseline simulation" + ) + + simulation = self.simulations[simulation] + if baseline_simulation: + baseline_simulation = self.simulations[baseline_simulation] + + filtering_variable_by_entity = self.filtering_variable_by_entity + + return simulation.compute_pivot_table( + baseline_simulation=baseline_simulation, + aggfunc=aggfunc, + columns=columns, + difference=difference, + filter_by=filter_by, + index=index, + period=period, + use_baseline_for_columns=use_baseline_for_columns, + values=values, + missing_variable_default_value=missing_variable_default_value, + concat_axis=concat_axis, + weighted=weighted, + alternative_weights=alternative_weights, + filtering_variable_by_entity=filtering_variable_by_entity, + ) + + def compute_winners_losers( + self, + variable, + simulation, + baseline_simulation=None, + filter_by=None, + period=None, + absolute_minimal_detected_variation=0, + relative_minimal_detected_variation=0.01, + observations_threshold=None, + weighted=True, + alternative_weights=None, + ): + simulation = self.simulations[simulation] + if baseline_simulation: + baseline_simulation = self.simulations[baseline_simulation] + + return simulation.compute_winners_losers( + baseline_simulation, + variable=variable, + filter_by=filter_by, + period=period, + absolute_minimal_detected_variation=absolute_minimal_detected_variation, + relative_minimal_detected_variation=relative_minimal_detected_variation, + observations_threshold=observations_threshold, + weighted=weighted, + alternative_weights=alternative_weights, + filtering_variable_by_entity=self.filtering_variable_by_entity, + ) + + def create_data_frame_by_entity( + self, variables=None, expressions=None, filter_by=None, index=False, period=None, simulation=None, merge=False + ): + """Create dataframe(s) of computed variable for every entity (eventually merged in a unique dataframe). + + Args: + variables(list, optional): Variable to compute, defaults to None + expressions(str, optional): Expressions to compute, defaults to None + filter_by(str, optional): Boolean variable or expression, defaults to None + index(bool, optional): Index by entity id, defaults to False + period(Period, optional): Period, defaults to None + simulation(str, optional): Simulation to use + merge(bool, optional): Merge all the entities in one data frame, defaults to False + + Returns: + dict or pandas.DataFrame: Dictionnary of dataframes by entities or dataframe with all the computed variables + + """ + if simulation is None: + assert len(self.simulations.keys()) == 1 + simulation = next(iter(self.simulations.values())) + else: + simulation = self.simulations[simulation] + + return simulation.create_data_frame_by_entity( + variables=variables, + expressions=expressions, + filter_by=filter_by, + index=index, + period=period, + merge=merge, + ) + + def custom_input_data_frame(self, input_data_frame, **kwargs): + """Customize input data frame. + + Args: + input_data_frame: original input data frame. + kwargs: keyword arguments. + """ + pass + + def dump_data_frame_by_entity(self, variables=None, survey_collection=None, survey_name=None): + assert survey_collection is not None + assert survey_name is not None + assert variables is not None + openfisca_data_frame_by_entity = self.create_data_frame_by_entity(variables=variables) + for entity_key, data_frame in openfisca_data_frame_by_entity.items(): + survey = Survey(name=survey_name) + survey.insert_table(name=entity_key, data_frame=data_frame) + survey_collection.surveys.append(survey) + survey_collection.dump(collection="openfisca") + + def dump_simulations(self, directory: str): + """ + Dump simulations. + + Args: + directory (str, optional): Dump directory + """ + assert directory is not None + use_sub_directories = len(self.simulations) >= 2 + + if use_sub_directories: + for simulation_name, simulation in self.simulations.items(): + dump_simulation(simulation, directory=str(Path(directory) / simulation_name)) + else: + assert len(self.simulations.keys()) == 1 + simulation = next(iter(self.simulations.values())) + dump_simulation(simulation, directory) + + def generate_performance_data(self, output_dir: str): + if not self.trace: + raise SurveyManagerError("Method generate_performance_data cannot be used if trace hasn't been activated.") + + for simulation_name, simulation in self.simulations.items(): + simulation_dir = Path(output_dir) / f"{simulation_name}_perf_log" + if not Path(output_dir).exists(): + Path(output_dir).mkdir() + if not Path(simulation_dir).exists(): + Path(simulation_dir).mkdir() + simulation.tracer.generate_performance_graph(simulation_dir) + simulation.tracer.generate_performance_tables(simulation_dir) + + def inflate(self, inflator_by_variable=None, period=None, target_by_variable=None): + assert inflator_by_variable or target_by_variable + assert period is not None + inflator_by_variable = {} if inflator_by_variable is None else inflator_by_variable + target_by_variable = {} if target_by_variable is None else target_by_variable + self.inflator_by_variable = inflator_by_variable + self.target_by_variable = target_by_variable + + for _, simulation in self.simulations.items(): + simulation.inflate(inflator_by_variable, period, target_by_variable) + + def init_from_data( + self, + calibration_kwargs=None, + inflation_kwargs=None, + rebuild_input_data=False, + rebuild_kwargs=None, + data=None, + memory_config=None, + use_marginal_tax_rate=False, + ): + """Initialise a survey scenario from data. + + Args: + rebuild_input_data(bool): Whether or not to clean, format and save data. + Take a look at :func:`build_input_data` + data(dict): Contains the data, or metadata needed to know where to find it. + use_marginal_tax_rate(bool): True to go into marginal effective tax rate computation mode. + calibration_kwargs(dict): Calibration options (Default value = None) + inflation_kwargs(dict): Inflations options (Default value = None) + rebuild_input_data(bool): Whether to rebuild the data (Default value = False) + rebuild_kwargs: Rebuild options (Default value = None) + """ + # When not ``None``, it'll try to get the data for *period*. + if data is not None: + data_year = data.get("data_year", self.period) + + # When ``True`` it'll assume it is raw data and do all that described supra. + # When ``False``, it'll assume data is ready for consumption. + if rebuild_input_data: + if rebuild_kwargs is not None: + self.build_input_data(year=data_year, **rebuild_kwargs) + else: + self.build_input_data(year=data_year) + + debug = self.debug + trace = self.trace + + if use_marginal_tax_rate: + for name, tax_benefit_system in self.tax_benefit_systems.items(): + assert self.varying_variable in tax_benefit_system.variables, ( + f"Variable {self.varying_variable} is not present tax benefit system named {name}" + ) + + # Inverting reform and baseline because we are more likely + # to use baseline input in reform than the other way around + self.simulations = {} + for simulation_name, _ in self.tax_benefit_systems.items(): + self.new_simulation(simulation_name, debug=debug, data=data, trace=trace, memory_config=memory_config) + if use_marginal_tax_rate: + self.new_simulation( + simulation_name, + debug=debug, + data=data, + trace=trace, + memory_config=memory_config, + marginal_tax_rate_only=True, + ) + + if calibration_kwargs is not None: + assert set(calibration_kwargs.keys()).issubset( + {"target_margins_by_variable", "parameters", "target_entity_count", "other_entity_count", "entity"} + ) + + if inflation_kwargs is not None: + assert set(inflation_kwargs.keys()).issubset({"inflator_by_variable", "target_by_variable", "period"}) + + if calibration_kwargs: + self.calibrate(**calibration_kwargs) + + if inflation_kwargs: + self.inflate(**inflation_kwargs) + + def new_simulation( + self, simulation_name, debug=False, trace=False, data=None, memory_config=None, marginal_tax_rate_only=False + ): + tax_benefit_system = self.tax_benefit_systems[simulation_name] + assert tax_benefit_system is not None + + period = periods.period(self.period) + + if "custom_initialize" in dir(self): + custom_initialize = None if marginal_tax_rate_only else self.custom_initialize + else: + custom_initialize = None + + data["collection"] = self.collection + data["id_variable_by_entity_key"] = self.id_variable_by_entity_key + data["role_variable_by_entity_key"] = self.role_variable_by_entity_key + data["used_as_input_variables"] = self.used_as_input_variables + + simulation = Simulation.new_from_tax_benefit_system( + tax_benefit_system=tax_benefit_system, + debug=debug, + trace=trace, + data=data, + memory_config=memory_config, + period=period, + custom_initialize=custom_initialize, + ) + + if marginal_tax_rate_only: + self._apply_modification(simulation, period) + if custom_initialize: + custom_initialize(simulation) + self.simulations[f"_modified_{simulation_name}"] = simulation + else: + self.simulations[simulation_name] = simulation + + simulation.weight_variable_by_entity = self.weight_variable_by_entity + + if self.period is not None: + simulation.period = periods.period(self.period) + + return simulation + + def memory_usage(self): + """Log memory usage.""" + for simulation_name, simulation in self.simulations.items(): + log.info("simulation: %s", simulation_name) + simulation.print_memory_usage() + + def neutralize_variables(self, tax_benefit_system): + """Neutralizes input variables not in input dataframe and keep some crucial variables. + + Args: + tax_benefit_system: The TaxBenefitSystem variables belongs to + + """ + for variable_name, variable in tax_benefit_system.variables.items(): + if variable.formulas: + continue + if self.used_as_input_variables and (variable_name in self.used_as_input_variables): + continue + if self.non_neutralizable_variables and (variable_name in self.non_neutralizable_variables): + continue + if self.weight_variable_by_entity and variable_name in list(self.weight_variable_by_entity.values()): + continue + + tax_benefit_system.neutralize_variable(variable_name) + + def restore_simulations(self, directory, **kwargs): + """Restores SurveyScenario's simulations. + + Args: + directory: Directory to restore simulations from + kwargs: Restoration options + + """ + assert Path(directory).exists(), "Cannot restore simulations from non existent directory" + use_sub_directories = len(self.tax_benefit_systems) >= 2 + + self.simulations = {} + if use_sub_directories: + for simulation_name, tax_benefit_system in self.tax_benefit_systems.items(): + simulation = restore_simulation(str(Path(directory) / simulation_name), tax_benefit_system, **kwargs) + simulation.id_variable_by_entity_key = self.id_variable_by_entity_key + self.simulations[simulation_name] = simulation + else: + simulation = restore_simulation(directory, next(iter(self.tax_benefit_systems.values())), **kwargs) + simulation.id_variable_by_entity_key = self.id_variable_by_entity_key + self.simulations["unique_simulation"] = simulation + + def set_input_data_frame(self, input_data_frame): + """Set the input dataframe. + + Args: + input_data_frame (pd.DataFrame): Input data frame + + """ + self.input_data_frame = input_data_frame + + def set_tax_benefit_systems(self, tax_benefit_systems: dict[str, TaxBenefitSystem]): + """ + Set the tax and benefit systems of the scenario. + + Args: + tax_benefit_systems (Dict[str, TaxBenefitSystem]): The tax benefit systems + """ + for tax_benefit_system in tax_benefit_systems.values(): + assert tax_benefit_system is not None + if self.cache_blacklist is not None: + tax_benefit_system.cache_blacklist = self.cache_blacklist + # + self.tax_benefit_systems = tax_benefit_systems + + def set_weight_variable_by_entity(self, weight_variable_by_entity=None): + if weight_variable_by_entity is not None: + self.weight_variable_by_entity = weight_variable_by_entity + + if self.simulations is not None: + for simulation in self.simulations.values(): + simulation.set_weight_variable_by_entity(self.weight_variable_by_entity) + + def summarize_variable(self, variable=None, weighted=False, force_compute=False): + """Log a summary of a variable including its memory usage for all the simulations. + + Args: + variable(string): The variable being summarized + weighted(bool): Whether the produced statistics should be weigthted or not + force_compute(bool): Whether the computation of the variable should be forced + + Example (output is sent to the logging system at INFO level): + >>> from openfisca_survey_manager.tests.test_scenario import create_randomly_initialized_survey_scenario + >>> survey_scenario = create_randomly_initialized_survey_scenario(collection = None) + >>> survey_scenario.summarize_variable(variable = "housing_occupancy_status", force_compute = True) + >>> survey_scenario.summarize_variable(variable = "rent", force_compute = True) + >>> survey_scenario.tax_benefit_systems["baseline"].neutralize_variable('age') + >>> survey_scenario.summarize_variable(variable = "age") + """ + for _simulation_name, simulation in self.simulations.items(): + simulation.summarize_variable(variable, weighted, force_compute) + + def _apply_modification(self, simulation, period): + period = periods.period(period) + varying_variable = self.varying_variable + definition_period = simulation.tax_benefit_system.variables[varying_variable].definition_period + + def set_variable(varying_variable, varying_variable_value, period_): + delta = self.variation_factor * varying_variable_value + new_variable_value = varying_variable_value + delta + simulation.delete_arrays(varying_variable, period_) + simulation.set_input(varying_variable, period_, new_variable_value) + + if period.unit == definition_period: + varying_variable_value = simulation.calculate(varying_variable, period=period) + set_variable(varying_variable, varying_variable_value, period) + + elif (definition_period == MONTH) and (period.unit == YEAR and period.size_in_months == 12): + varying_variable_value = simulation.calculate_add(varying_variable, period=period) + for period_ in [periods.Period(("month", period.start.offset(month, "month"), 1)) for month in range(12)]: + set_variable(varying_variable, varying_variable_value / 12, period_) + else: + ValueError() diff --git a/openfisca_survey_manager/policy/scenarios/reform_scenario.py b/openfisca_survey_manager/policy/scenarios/reform_scenario.py new file mode 100644 index 00000000..1859f636 --- /dev/null +++ b/openfisca_survey_manager/policy/scenarios/reform_scenario.py @@ -0,0 +1,293 @@ +"""Abstract survey scenario definition.""" + +import logging +from typing import Optional, Union + +import numpy as np +import pandas as pd +from openfisca_core.types import Array, Period + +from openfisca_survey_manager.policy import Simulation +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario + +log = logging.getLogger(__name__) + + +class ReformScenario(AbstractSurveyScenario): + """Reform survey scenario.""" + + def _get_simulation(self, use_baseline: bool = False): + """ + Get relevant simulation + + Args: + use_baseline (bool, optional): Whether to get baseline or reform simulation. Defaults to False. + """ + + if len(self.simulations) == 1: + return next(iter(self.simulations.values())) + + simulation_name = "baseline" if use_baseline else "reform" + simulation = self.simulations[simulation_name] + assert simulation is not None, f"{simulation_name} does not exist" + return simulation + + def build_input_data(self, **kwargs): + """Build input data.""" + raise NotImplementedError + + def calculate_series(self, variable, period=None, use_baseline=False): + """Compute variable values for period and baseline or reform tax benefit and system. + + Args: + variable(str, optional): Variable to compute + period(Period, optional): Period, defaults to None + use_baseline(bool, optional): Use baseline tax and benefit system, defaults to False + + Returns: + pandas.Series: Variable values + + """ + return pd.Series( + data=self.calculate_variable(variable, period, use_baseline), + name=variable, + ) + + def calculate_variable(self, variable, period=None, use_baseline=False): + """Compute variable values for period and baseline or reform tax benefit and system. + + Args: + variable(str, optional): Variable to compute + period(Period, optional): Period, defaults to None + use_baseline(bool, optional): Use baseline tax and benefit system, defaults to False + + Returns: + numpy.ndarray: Variable values + + """ + simulation = self._get_simulation(use_baseline) + return simulation.adaptative_calculate_variable(variable, period=period) + + def compute_aggregate( + self, + variable: Optional[str] = None, + aggfunc: str = "sum", + filter_by: Optional[str] = None, + period: Optional[Union[int, str, Period]] = None, + use_baseline: bool = False, + difference: bool = False, + missing_variable_default_value=np.nan, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + ): + """Compute variable aggregate. + + Args: + variable (str, optional): Variable to aggregate. Defaults to None. + aggfunc (str, optional): Aggregation function. Defaults to 'sum'. + filter_by (str, optional): Filter variable or expression to use. Defaults to None. + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + use_baseline: Use baseline simulation. Defaults to False. + missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. + weighted (bool, optional): Whether to weight the variable or not. Defaults to True. + alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. + Defaults to None. + filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. + + Returns: + float: Aggregate + """ + assert aggfunc in ["count", "mean", "sum", "count_non_zero"] + assert period is not None + assert not (difference and use_baseline), "Can't have difference and use_baseline both set to True" + + if difference: + return self.compute_aggregate( + variable=variable, + aggfunc=aggfunc, + filter_by=filter_by, + period=period, + use_baseline=False, + missing_variable_default_value=missing_variable_default_value, + weighted=weighted, + alternative_weights=alternative_weights, + ) - self.compute_aggregate( + variable=variable, + aggfunc=aggfunc, + filter_by=filter_by, + period=period, + use_baseline=True, + missing_variable_default_value=missing_variable_default_value, + weighted=weighted, + alternative_weights=alternative_weights, + ) + + assert variable is not None + simulation = self._get_simulation(use_baseline) + return simulation.compute_aggregate( + variable=variable, + aggfunc=aggfunc, + filter_by=filter_by, + period=period, + missing_variable_default_value=missing_variable_default_value, + weighted=weighted, + alternative_weights=alternative_weights, + filtering_variable_by_entity=self.filtering_variable_by_entity, + ) + + def compute_quantiles( + self, + variable: Optional[str] = None, + nquantiles=None, + period=None, + use_baseline=False, + filter_by=None, + weighted=True, + alternative_weights=None, + ): + assert variable is not None + assert nquantiles is not None + simulation = self._get_simulation(use_baseline) + + return simulation.compute_quantiles( + variable=variable, + period=period, + nquantiles=nquantiles, + filter_by=filter_by, + weighted=weighted, + alternative_weights=alternative_weights, + ) + + def compute_marginal_tax_rate( + self, + target_variable: str, + period: Optional[Union[int, str, Period]], + use_baseline: bool = False, + value_for_zero_varying_variable: float = 0.0, + ) -> Array: + """ + Compute marginal a rate of a target (MTR) with respect to a varying variable. + + Args: + target_variable (str): the variable which marginal tax rate is computed + period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. + use_baseline: Use baseline simulation. Defaults to False. + value_for_zero_varying_variable (float, optional): value of MTR when the varying variable is zero. + Defaults to 0. + + Returns: + numpy.array: Vector of marginal rates + """ + if use_baseline: + return super().compute_marginal_tax_rate( + target_variable=target_variable, + period=period, + simulation="baseline", + value_for_zero_varying_variable=value_for_zero_varying_variable, + ) + else: + return super().compute_marginal_tax_rate( + target_variable=target_variable, + period=period, + simulation="reform", + value_for_zero_varying_variable=value_for_zero_varying_variable, + ) + + def compute_pivot_table( + self, + aggfunc="mean", + columns=None, + difference=False, + filter_by=None, + index=None, + period=None, + use_baseline=False, + use_baseline_for_columns=None, + values=None, + missing_variable_default_value=np.nan, + concat_axis=None, + weighted=True, + alternative_weights=None, + ): + filtering_variable_by_entity = self.filtering_variable_by_entity + + return Simulation.compute_pivot_table( + aggfunc=aggfunc, + columns=columns, + baseline_simulation=self._get_simulation(use_baseline=True), + filter_by=filter_by, + index=index, + period=period, + simulation=self._get_simulation(use_baseline), + difference=difference, + use_baseline_for_columns=use_baseline_for_columns, + values=values, + missing_variable_default_value=missing_variable_default_value, + concat_axis=concat_axis, + weighted=weighted, + alternative_weights=alternative_weights, + filtering_variable_by_entity=filtering_variable_by_entity, + ) + + def compute_winners_losers( + self, + variable=None, + filter_by=None, + period=None, + absolute_minimal_detected_variation=0, + relative_minimal_detected_variation=0.01, + observations_threshold=None, + weighted=True, + alternative_weights=None, + ): + return super().compute_winners_losers( + simulation="reform", + baseline_simulation="baseline", + variable=variable, + filter_by=filter_by, + period=period, + absolute_minimal_detected_variation=getattr( + self, "absolute_minimal_detected_variation", absolute_minimal_detected_variation + ), + relative_minimal_detected_variation=getattr( + self, "relative_minimal_detected_variation", relative_minimal_detected_variation + ), + observations_threshold=getattr(self, "observations_threshold", observations_threshold), + weighted=weighted, + alternative_weights=alternative_weights, + ) + + def create_data_frame_by_entity( + self, + variables=None, + expressions=None, + filter_by=None, + index=False, + period=None, + use_baseline=False, + merge=False, + ): + """Create dataframe(s) of computed variable for every entity (eventually merged in a unique dataframe). + + Args: + variables(list, optional): Variable to compute, defaults to None + expressions(str, optional): Expressions to compute, defaults to None + filter_by(str, optional): Boolean variable or expression, defaults to None + index(bool, optional): Index by entity id, defaults to False + period(Period, optional): Period, defaults to None + use_baseline(bool, optional): Use baseline tax and benefit system, defaults to False + merge(bool, optional): Merge all the entities in one data frame, defaults to False + + Returns: + dict or pandas.DataFrame: Dictionnary of dataframes by entities or dataframe with all the computed variables + + """ + simulation = self._get_simulation(use_baseline) + return simulation.create_data_frame_by_entity( + variables=variables, + expressions=expressions, + filter_by=filter_by, + index=index, + period=period, + merge=merge, + ) diff --git a/openfisca_survey_manager/processing/__init__.py b/openfisca_survey_manager/processing/__init__.py index b0097a35..fa8106e9 100644 --- a/openfisca_survey_manager/processing/__init__.py +++ b/openfisca_survey_manager/processing/__init__.py @@ -8,7 +8,8 @@ # Lazy import to avoid circular dependency (processing -> policy -> survey_collections -> core) def __getattr__(name: str): if name in ("Calibration", "calmar", "check_calmar"): - from openfisca_survey_manager.processing.weights import Calibration, calmar, check_calmar + from openfisca_survey_manager.policy.calibration import Calibration + from openfisca_survey_manager.policy.calmar import calmar, check_calmar return {"Calibration": Calibration, "calmar": calmar, "check_calmar": check_calmar}[name] raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/openfisca_survey_manager/processing/weights/__init__.py b/openfisca_survey_manager/processing/weights/__init__.py deleted file mode 100644 index cfa0ba8f..00000000 --- a/openfisca_survey_manager/processing/weights/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Calibration and CALMAR weight calibration. Re-exports from policy. - -import warnings - -from openfisca_survey_manager.policy.calibration import Calibration -from openfisca_survey_manager.policy.calmar import calmar, check_calmar - -warnings.warn( - "openfisca_survey_manager.processing.weights is deprecated for Calibration/calmar. " - "Prefer: from openfisca_survey_manager.policy.calibration import Calibration, " - "from openfisca_survey_manager.policy.calmar import calmar, check_calmar", - DeprecationWarning, - stacklevel=3, -) - -__all__ = ["Calibration", "calmar", "check_calmar"] diff --git a/openfisca_survey_manager/scenarios/__init__.py b/openfisca_survey_manager/scenarios/__init__.py index e69de29b..2c41e560 100644 --- a/openfisca_survey_manager/scenarios/__init__.py +++ b/openfisca_survey_manager/scenarios/__init__.py @@ -0,0 +1,17 @@ +"""Re-export for backward compatibility. + +Prefer: from openfisca_survey_manager.policy.scenarios import ... +""" + +import warnings + +from openfisca_survey_manager.policy.scenarios import AbstractSurveyScenario, ReformScenario + +warnings.warn( + "openfisca_survey_manager.scenarios is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.scenarios import ...", + DeprecationWarning, + stacklevel=2, +) + +__all__ = ["AbstractSurveyScenario", "ReformScenario"] diff --git a/openfisca_survey_manager/scenarios/abstract_scenario.py b/openfisca_survey_manager/scenarios/abstract_scenario.py index f2765dd7..1e70c7dd 100644 --- a/openfisca_survey_manager/scenarios/abstract_scenario.py +++ b/openfisca_survey_manager/scenarios/abstract_scenario.py @@ -1,763 +1,17 @@ -"Abstract survey scenario definition." +"""Re-export for backward compatibility. -import logging -from pathlib import Path -from typing import Optional, Union +Prefer: from openfisca_survey_manager.policy.scenarios.abstract_scenario import ... +""" -import numpy as np -import pandas as pd -from openfisca_core import periods -from openfisca_core.periods import MONTH, YEAR -from openfisca_core.tools.simulation_dumper import dump_simulation, restore_simulation -from openfisca_core.types import Array, Period, TaxBenefitSystem +import warnings -from openfisca_survey_manager.exceptions import SurveyManagerError -from openfisca_survey_manager.policy.calibration import Calibration -from openfisca_survey_manager.simulations import Simulation -from openfisca_survey_manager.surveys import Survey +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario -log = logging.getLogger(__name__) +warnings.warn( + "openfisca_survey_manager.scenarios.abstract_scenario is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.scenarios.abstract_scenario import ...", + DeprecationWarning, + stacklevel=2, +) - -class AbstractSurveyScenario: - """Abstract survey scenario.""" - - cache_blacklist = None - collection = None - debug = False - filtering_variable_by_entity = None - id_variable_by_entity_key = None - inflator_by_variable = None # factor used to inflate variable total - input_data_frame = None - input_data_table_by_entity_by_period = None - input_data_table_by_period = None - non_neutralizable_variables = None - period = None - role_variable_by_entity_key = None - simulations = None - target_by_variable = None # variable total target to inflate to - tax_benefit_systems = None - trace = False - used_as_input_variables = None - used_as_input_variables_by_entity = None - variation_factor = 0.03 # factor used to compute variation when estimating marginal tax rate - varying_variable = None - weight_variable_by_entity = None - - def build_input_data(self, **kwargs): - """Build input data.""" - raise NotImplementedError - - def calculate_series(self, variable, period=None, simulation=None): - """Compute variable values for period for a given simulation. - - Args: - variable(str, optional): Variable to compute - period(Period, optional): Period, defaults to None - simulation(str, optional): Simulation to use - - Returns: - pandas.Series: Variable values - - """ - return pd.Series( - data=self.calculate_variable(variable, period, simulation=simulation), - name=variable, - ) - - def calculate_variable(self, variable, period=None, simulation=None): - """Compute variable values for period for a given simulation. - - Args: - variable(str, optional): Variable to compute - period(Period, optional): Period, defaults to None - simulation(str, optional): Simulation to use - - Returns: - numpy.ndarray: Variable values - - """ - if simulation is None: - assert len(self.simulations.keys()) == 1 - simulation = next(iter(self.simulations.values())) - else: - simulation = self.simulations[simulation] - assert simulation is not None - return simulation.adaptative_calculate_variable(variable, period=period) - - def calibrate( - self, - period: Optional[int] = None, - target_margins_by_variable: Optional[dict] = None, - parameters: Optional[dict] = None, - target_entity_count: Optional[float] = None, - other_entity_count: Optional[float] = None, - entity: Optional[str] = None, - ): - """Calibrate the scenario data. - - Args: - period (int, optionnal): Period of calibration. Defaults to scenario.year - target_margins_by_variable (dict, optional): Variable targets margins. Defaults to None. - parameters (dict, optional): Calibration parameters. Defaults to None. - target_entity_count (float, optional): Total population target. Defaults to None. - other_entity_count (float, optional): Total population target of the second entity. Defaults to None. - entity (str): Entity specified when no variable comes with a target margins - but `target_entity_count` is not None. - """ - survey_scenario = self - - if period is None: - assert survey_scenario.period is not None - period = survey_scenario.period - - if parameters is not None: - assert parameters["method"] in ["linear", "raking ratio", "logit", "hyperbolic sinus"], ( - "Incorrect parameter value: method should be 'linear', 'raking ratio', 'logit' or 'hyperbolic sinus'" - ) - if parameters["method"] == "logit": - assert parameters["invlo"] is not None - assert parameters["up"] is not None - elif parameters["method"] == "hyperbolic sinus": - assert parameters["alpha"] is not None - else: - parameters = {"method": "logit", "up": 3, "invlo": 3} - - # TODO: filtering using filtering_variable_by_entity - for simulation in self.simulations.values(): - if simulation is None: - continue - calibration = Calibration( - simulation, - target_margins_by_variable, - period, - target_entity_count=target_entity_count, - other_entity_count=other_entity_count, - entity=entity, - parameters=parameters, - # filter_by = self.filter_by, - ) - calibration.calibrate(inplace=True) - simulation.calibration = calibration - - def compute_aggregate( - self, - variable: Optional[str] = None, - aggfunc: str = "sum", - filter_by: Optional[str] = None, - period: Optional[Union[int, str, Period]] = None, - simulation: Optional[str] = None, - baseline_simulation: Optional[str] = None, - missing_variable_default_value=np.nan, - weighted: bool = True, - alternative_weights: Optional[Union[str, int, float, Array]] = None, - ): - """Compute variable aggregate. - - Args: - variable (str, optional): Variable to aggregate. Defaults to None. - aggfunc (str, optional): Aggregation function. Defaults to 'sum'. - filter_by (str, optional): Filter variable or expression to use. Defaults to None. - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - simulation(str, optional): Simulation to use - baseline_simulation(str, optional): Baseline simulation to use when computing a difference - missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. - weighted (bool, optional): Whether to weight the variable or not. Defaults to True. - alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. - Defaults to None. - filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. - - Returns: - float: Aggregate - """ - assert aggfunc in ["count", "mean", "sum", "count_non_zero"] - assert period is not None - assert variable is not None - if simulation is None: - assert len(self.simulations.keys()) == 1 - simulation = next(iter(self.simulations.values())) - else: - simulation = self.simulations[simulation] - - assert simulation is not None, f"Missing {simulation} simulation" - - if baseline_simulation: - baseline_simulation = self.simulations[baseline_simulation] - return simulation.compute_aggregate( - variable=variable, - aggfunc=aggfunc, - filter_by=filter_by, - period=period, - missing_variable_default_value=missing_variable_default_value, - weighted=weighted, - alternative_weights=alternative_weights, - filtering_variable_by_entity=self.filtering_variable_by_entity, - ) - baseline_simulation.compute_aggregate( - variable=variable, - aggfunc=aggfunc, - filter_by=filter_by, - period=period, - missing_variable_default_value=missing_variable_default_value, - weighted=weighted, - alternative_weights=alternative_weights, - filtering_variable_by_entity=self.filtering_variable_by_entity, - ) - - return simulation.compute_aggregate( - variable=variable, - aggfunc=aggfunc, - filter_by=filter_by, - period=period, - missing_variable_default_value=missing_variable_default_value, - weighted=weighted, - alternative_weights=alternative_weights, - filtering_variable_by_entity=self.filtering_variable_by_entity, - ) - - def compute_quantiles( - self, - simulation: Simulation, - variable: str, - nquantiles: Optional[int] = None, - period: Optional[Union[int, str, Period]] = None, - filter_by=None, - weighted: bool = True, - alternative_weights=None, - filtering_variable_by_entity=None, - ) -> list[float]: - """ - Compute quantiles of a variable. - - Args: - simulation (Simulation, optional): Simulation to be used. Defaults to None. - variable (str, optional): Variable which quantiles are computed. Defaults to None. - nquantiles (int, optional): Number of quantiles. Defaults to None. - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. - weighted (bool, optional): Whether to weight the variable or not. Defaults to True. - alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. - Defaults to None. - filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. - - Returns: - List(float): The quantiles values - """ - assert variable is not None - assert nquantiles is not None - simulation = self.simulations[simulation] - assert simulation is not None, f"Missing {simulation} simulation" - - return simulation.compute_quantiles( - variable=variable, - period=period, - nquantiles=nquantiles, - filter_by=filter_by, - weighted=weighted, - alternative_weights=alternative_weights, - ) - - def compute_marginal_tax_rate( - self, - target_variable: str, - period: Optional[Union[int, str, Period]], - simulation: Optional[str] = None, - value_for_zero_varying_variable: float = 0.0, - ) -> Array: - """ - Compute marginal a rate of a target (MTR) with respect to a varying variable. - - Args: - target_variable (str): the variable which marginal tax rate is computed - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - simulation(str, optional): Simulation to use - value_for_zero_varying_variable (float, optional): value of MTR when the varying variable is zero. - Defaults to 0. - - Returns: - numpy.array: Vector of marginal rates - """ - varying_variable = self.varying_variable - if simulation is None: - assert len(self.simulations.keys()) == 2 - simulation_name = next(name for name in self.simulations if not name.startswith("_modified_")) - simulation = self.simulations[simulation_name] - else: - simulation_name = simulation - simulation = self.simulations[simulation_name] - - modified_simulation = self.simulations[f"_modified_{simulation_name}"] - - variables = simulation.tax_benefit_system.variables - assert target_variable in variables - - variables_belong_to_same_entity = ( - variables[varying_variable].entity.key == variables[target_variable].entity.key - ) - varying_variable_belongs_to_person_entity = variables[varying_variable].entity.is_person - - assert variables_belong_to_same_entity or varying_variable_belongs_to_person_entity - - if variables_belong_to_same_entity: - modified_varying = modified_simulation.calculate_add(varying_variable, period=period) - varying = simulation.calculate_add(varying_variable, period=period) - else: - target_variable_entity_key = variables[target_variable].entity.key - - def cast_to_target_entity(simulation: Simulation): - population = simulation.populations[target_variable_entity_key] - df = ( - pd.DataFrame( - { - "members_entity_id": population._members_entity_id, - varying_variable: simulation.calculate_add(varying_variable, period=period), - } - ) - .groupby("members_entity_id") - .sum() - ) - varying_variable_for_target_entity = df.loc[population.ids, varying_variable].values - return varying_variable_for_target_entity - - modified_varying = cast_to_target_entity(modified_simulation) - varying = cast_to_target_entity(simulation) - - modified_target = modified_simulation.calculate_add(target_variable, period=period) - target = simulation.calculate_add(target_variable, period=period) - - numerator = modified_target - target - denominator = modified_varying - varying - marginal_rate = 1 - np.divide( - numerator, - denominator, - out=np.full_like(numerator, value_for_zero_varying_variable, dtype=float), - where=(denominator != 0), - ) - - return marginal_rate - - def compute_pivot_table( - self, - aggfunc="mean", - columns=None, - baseline_simulation=None, - filter_by=None, - index=None, - period=None, - simulation=None, - difference=False, - use_baseline_for_columns=None, - values=None, - missing_variable_default_value=np.nan, - concat_axis=None, - weighted=True, - alternative_weights=None, - ): - """Compute a pivot table of agregated values casted along specified index and columns. - - Args: - aggfunc(str, optional): Aggregation function, defaults to 'mean' - columns(list, optional): Variable(s) in columns, defaults to None - difference(bool, optional): Compute difference, defaults to False - filter_by(str, optional): Boolean variable to filter by, defaults to None - index(list, optional): Variable(s) in index (lines), defaults to None - period(Period, optional): Period, defaults to None - simulation(str, optional): Simulation to use - baseline_simulation(str, optional): Baseline simulation to use when computing a difference - use_baseline_for_columns(bool, optional): Use columns from baseline columns values, defaults to None - values(list, optional): Aggregated variable(s) within cells, defaults to None - missing_variable_default_value(float, optional): Default value for missing variables, defaults to np.nan - concat_axis(int, optional): Axis to concatenate along (index = 0, columns = 1), defaults to None - weighted(bool, optional): Whether to weight te aggregates (Default value = True) - alternative_weights(str or int or float, optional): Weight variable name or numerical value. - Use Simulation's weight_variable_by_entity if None, and if the later is None uses 1 - ((Default value = None) - - Returns: - pd.DataFrame: Pivot table - - """ - assert (not difference) or (baseline_simulation is not None), ( - "Can't have difference when not baseline simulation" - ) - - simulation = self.simulations[simulation] - if baseline_simulation: - baseline_simulation = self.simulations[baseline_simulation] - - filtering_variable_by_entity = self.filtering_variable_by_entity - - return simulation.compute_pivot_table( - baseline_simulation=baseline_simulation, - aggfunc=aggfunc, - columns=columns, - difference=difference, - filter_by=filter_by, - index=index, - period=period, - use_baseline_for_columns=use_baseline_for_columns, - values=values, - missing_variable_default_value=missing_variable_default_value, - concat_axis=concat_axis, - weighted=weighted, - alternative_weights=alternative_weights, - filtering_variable_by_entity=filtering_variable_by_entity, - ) - - def compute_winners_losers( - self, - variable, - simulation, - baseline_simulation=None, - filter_by=None, - period=None, - absolute_minimal_detected_variation=0, - relative_minimal_detected_variation=0.01, - observations_threshold=None, - weighted=True, - alternative_weights=None, - ): - simulation = self.simulations[simulation] - if baseline_simulation: - baseline_simulation = self.simulations[baseline_simulation] - - return simulation.compute_winners_losers( - baseline_simulation, - variable=variable, - filter_by=filter_by, - period=period, - absolute_minimal_detected_variation=absolute_minimal_detected_variation, - relative_minimal_detected_variation=relative_minimal_detected_variation, - observations_threshold=observations_threshold, - weighted=weighted, - alternative_weights=alternative_weights, - filtering_variable_by_entity=self.filtering_variable_by_entity, - ) - - def create_data_frame_by_entity( - self, variables=None, expressions=None, filter_by=None, index=False, period=None, simulation=None, merge=False - ): - """Create dataframe(s) of computed variable for every entity (eventually merged in a unique dataframe). - - Args: - variables(list, optional): Variable to compute, defaults to None - expressions(str, optional): Expressions to compute, defaults to None - filter_by(str, optional): Boolean variable or expression, defaults to None - index(bool, optional): Index by entity id, defaults to False - period(Period, optional): Period, defaults to None - simulation(str, optional): Simulation to use - merge(bool, optional): Merge all the entities in one data frame, defaults to False - - Returns: - dict or pandas.DataFrame: Dictionnary of dataframes by entities or dataframe with all the computed variables - - """ - if simulation is None: - assert len(self.simulations.keys()) == 1 - simulation = next(iter(self.simulations.values())) - else: - simulation = self.simulations[simulation] - - return simulation.create_data_frame_by_entity( - variables=variables, - expressions=expressions, - filter_by=filter_by, - index=index, - period=period, - merge=merge, - ) - - def custom_input_data_frame(self, input_data_frame, **kwargs): - """Customize input data frame. - - Args: - input_data_frame: original input data frame. - kwargs: keyword arguments. - """ - pass - - def dump_data_frame_by_entity(self, variables=None, survey_collection=None, survey_name=None): - assert survey_collection is not None - assert survey_name is not None - assert variables is not None - openfisca_data_frame_by_entity = self.create_data_frame_by_entity(variables=variables) - for entity_key, data_frame in openfisca_data_frame_by_entity.items(): - survey = Survey(name=survey_name) - survey.insert_table(name=entity_key, data_frame=data_frame) - survey_collection.surveys.append(survey) - survey_collection.dump(collection="openfisca") - - def dump_simulations(self, directory: str): - """ - Dump simulations. - - Args: - directory (str, optional): Dump directory - """ - assert directory is not None - use_sub_directories = len(self.simulations) >= 2 - - if use_sub_directories: - for simulation_name, simulation in self.simulations.items(): - dump_simulation(simulation, directory=str(Path(directory) / simulation_name)) - else: - assert len(self.simulations.keys()) == 1 - simulation = next(iter(self.simulations.values())) - dump_simulation(simulation, directory) - - def generate_performance_data(self, output_dir: str): - if not self.trace: - raise SurveyManagerError("Method generate_performance_data cannot be used if trace hasn't been activated.") - - for simulation_name, simulation in self.simulations.items(): - simulation_dir = Path(output_dir) / f"{simulation_name}_perf_log" - if not Path(output_dir).exists(): - Path(output_dir).mkdir() - if not Path(simulation_dir).exists(): - Path(simulation_dir).mkdir() - simulation.tracer.generate_performance_graph(simulation_dir) - simulation.tracer.generate_performance_tables(simulation_dir) - - def inflate(self, inflator_by_variable=None, period=None, target_by_variable=None): - assert inflator_by_variable or target_by_variable - assert period is not None - inflator_by_variable = {} if inflator_by_variable is None else inflator_by_variable - target_by_variable = {} if target_by_variable is None else target_by_variable - self.inflator_by_variable = inflator_by_variable - self.target_by_variable = target_by_variable - - for _, simulation in self.simulations.items(): - simulation.inflate(inflator_by_variable, period, target_by_variable) - - def init_from_data( - self, - calibration_kwargs=None, - inflation_kwargs=None, - rebuild_input_data=False, - rebuild_kwargs=None, - data=None, - memory_config=None, - use_marginal_tax_rate=False, - ): - """Initialise a survey scenario from data. - - Args: - rebuild_input_data(bool): Whether or not to clean, format and save data. - Take a look at :func:`build_input_data` - data(dict): Contains the data, or metadata needed to know where to find it. - use_marginal_tax_rate(bool): True to go into marginal effective tax rate computation mode. - calibration_kwargs(dict): Calibration options (Default value = None) - inflation_kwargs(dict): Inflations options (Default value = None) - rebuild_input_data(bool): Whether to rebuild the data (Default value = False) - rebuild_kwargs: Rebuild options (Default value = None) - """ - # When not ``None``, it'll try to get the data for *period*. - if data is not None: - data_year = data.get("data_year", self.period) - - # When ``True`` it'll assume it is raw data and do all that described supra. - # When ``False``, it'll assume data is ready for consumption. - if rebuild_input_data: - if rebuild_kwargs is not None: - self.build_input_data(year=data_year, **rebuild_kwargs) - else: - self.build_input_data(year=data_year) - - debug = self.debug - trace = self.trace - - if use_marginal_tax_rate: - for name, tax_benefit_system in self.tax_benefit_systems.items(): - assert self.varying_variable in tax_benefit_system.variables, ( - f"Variable {self.varying_variable} is not present tax benefit system named {name}" - ) - - # Inverting reform and baseline because we are more likely - # to use baseline input in reform than the other way around - self.simulations = {} - for simulation_name, _ in self.tax_benefit_systems.items(): - self.new_simulation(simulation_name, debug=debug, data=data, trace=trace, memory_config=memory_config) - if use_marginal_tax_rate: - self.new_simulation( - simulation_name, - debug=debug, - data=data, - trace=trace, - memory_config=memory_config, - marginal_tax_rate_only=True, - ) - - if calibration_kwargs is not None: - assert set(calibration_kwargs.keys()).issubset( - {"target_margins_by_variable", "parameters", "target_entity_count", "other_entity_count", "entity"} - ) - - if inflation_kwargs is not None: - assert set(inflation_kwargs.keys()).issubset({"inflator_by_variable", "target_by_variable", "period"}) - - if calibration_kwargs: - self.calibrate(**calibration_kwargs) - - if inflation_kwargs: - self.inflate(**inflation_kwargs) - - def new_simulation( - self, simulation_name, debug=False, trace=False, data=None, memory_config=None, marginal_tax_rate_only=False - ): - tax_benefit_system = self.tax_benefit_systems[simulation_name] - assert tax_benefit_system is not None - - period = periods.period(self.period) - - if "custom_initialize" in dir(self): - custom_initialize = None if marginal_tax_rate_only else self.custom_initialize - else: - custom_initialize = None - - data["collection"] = self.collection - data["id_variable_by_entity_key"] = self.id_variable_by_entity_key - data["role_variable_by_entity_key"] = self.role_variable_by_entity_key - data["used_as_input_variables"] = self.used_as_input_variables - - simulation = Simulation.new_from_tax_benefit_system( - tax_benefit_system=tax_benefit_system, - debug=debug, - trace=trace, - data=data, - memory_config=memory_config, - period=period, - custom_initialize=custom_initialize, - ) - - if marginal_tax_rate_only: - self._apply_modification(simulation, period) - if custom_initialize: - custom_initialize(simulation) - self.simulations[f"_modified_{simulation_name}"] = simulation - else: - self.simulations[simulation_name] = simulation - - simulation.weight_variable_by_entity = self.weight_variable_by_entity - - if self.period is not None: - simulation.period = periods.period(self.period) - - return simulation - - def memory_usage(self): - """Log memory usage.""" - for simulation_name, simulation in self.simulations.items(): - log.info("simulation: %s", simulation_name) - simulation.print_memory_usage() - - def neutralize_variables(self, tax_benefit_system): - """Neutralizes input variables not in input dataframe and keep some crucial variables. - - Args: - tax_benefit_system: The TaxBenefitSystem variables belongs to - - """ - for variable_name, variable in tax_benefit_system.variables.items(): - if variable.formulas: - continue - if self.used_as_input_variables and (variable_name in self.used_as_input_variables): - continue - if self.non_neutralizable_variables and (variable_name in self.non_neutralizable_variables): - continue - if self.weight_variable_by_entity and variable_name in list(self.weight_variable_by_entity.values()): - continue - - tax_benefit_system.neutralize_variable(variable_name) - - def restore_simulations(self, directory, **kwargs): - """Restores SurveyScenario's simulations. - - Args: - directory: Directory to restore simulations from - kwargs: Restoration options - - """ - assert Path(directory).exists(), "Cannot restore simulations from non existent directory" - use_sub_directories = len(self.tax_benefit_systems) >= 2 - - self.simulations = {} - if use_sub_directories: - for simulation_name, tax_benefit_system in self.tax_benefit_systems.items(): - simulation = restore_simulation(str(Path(directory) / simulation_name), tax_benefit_system, **kwargs) - simulation.id_variable_by_entity_key = self.id_variable_by_entity_key - self.simulations[simulation_name] = simulation - else: - simulation = restore_simulation(directory, next(iter(self.tax_benefit_systems.values())), **kwargs) - simulation.id_variable_by_entity_key = self.id_variable_by_entity_key - self.simulations["unique_simulation"] = simulation - - def set_input_data_frame(self, input_data_frame): - """Set the input dataframe. - - Args: - input_data_frame (pd.DataFrame): Input data frame - - """ - self.input_data_frame = input_data_frame - - def set_tax_benefit_systems(self, tax_benefit_systems: dict[str, TaxBenefitSystem]): - """ - Set the tax and benefit systems of the scenario. - - Args: - tax_benefit_systems (Dict[str, TaxBenefitSystem]): The tax benefit systems - """ - for tax_benefit_system in tax_benefit_systems.values(): - assert tax_benefit_system is not None - if self.cache_blacklist is not None: - tax_benefit_system.cache_blacklist = self.cache_blacklist - # - self.tax_benefit_systems = tax_benefit_systems - - def set_weight_variable_by_entity(self, weight_variable_by_entity=None): - if weight_variable_by_entity is not None: - self.weight_variable_by_entity = weight_variable_by_entity - - if self.simulations is not None: - for simulation in self.simulations.values(): - simulation.set_weight_variable_by_entity(self.weight_variable_by_entity) - - def summarize_variable(self, variable=None, weighted=False, force_compute=False): - """Log a summary of a variable including its memory usage for all the simulations. - - Args: - variable(string): The variable being summarized - weighted(bool): Whether the produced statistics should be weigthted or not - force_compute(bool): Whether the computation of the variable should be forced - - Example (output is sent to the logging system at INFO level): - >>> from openfisca_survey_manager.tests.test_scenario import create_randomly_initialized_survey_scenario - >>> survey_scenario = create_randomly_initialized_survey_scenario(collection = None) - >>> survey_scenario.summarize_variable(variable = "housing_occupancy_status", force_compute = True) - >>> survey_scenario.summarize_variable(variable = "rent", force_compute = True) - >>> survey_scenario.tax_benefit_systems["baseline"].neutralize_variable('age') - >>> survey_scenario.summarize_variable(variable = "age") - """ - for _simulation_name, simulation in self.simulations.items(): - simulation.summarize_variable(variable, weighted, force_compute) - - def _apply_modification(self, simulation, period): - period = periods.period(period) - varying_variable = self.varying_variable - definition_period = simulation.tax_benefit_system.variables[varying_variable].definition_period - - def set_variable(varying_variable, varying_variable_value, period_): - delta = self.variation_factor * varying_variable_value - new_variable_value = varying_variable_value + delta - simulation.delete_arrays(varying_variable, period_) - simulation.set_input(varying_variable, period_, new_variable_value) - - if period.unit == definition_period: - varying_variable_value = simulation.calculate(varying_variable, period=period) - set_variable(varying_variable, varying_variable_value, period) - - elif (definition_period == MONTH) and (period.unit == YEAR and period.size_in_months == 12): - varying_variable_value = simulation.calculate_add(varying_variable, period=period) - for period_ in [periods.Period(("month", period.start.offset(month, "month"), 1)) for month in range(12)]: - set_variable(varying_variable, varying_variable_value / 12, period_) - else: - ValueError() +__all__ = ["AbstractSurveyScenario"] diff --git a/openfisca_survey_manager/scenarios/reform_scenario.py b/openfisca_survey_manager/scenarios/reform_scenario.py index 5c4673b2..029abb84 100644 --- a/openfisca_survey_manager/scenarios/reform_scenario.py +++ b/openfisca_survey_manager/scenarios/reform_scenario.py @@ -1,293 +1,17 @@ -"""Abstract survey scenario definition.""" +"""Re-export for backward compatibility. -import logging -from typing import Optional, Union +Prefer: from openfisca_survey_manager.policy.scenarios.reform_scenario import ... +""" -import numpy as np -import pandas as pd -from openfisca_core.types import Array, Period +import warnings -from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario -from openfisca_survey_manager.simulations import Simulation +from openfisca_survey_manager.policy.scenarios.reform_scenario import ReformScenario -log = logging.getLogger(__name__) +warnings.warn( + "openfisca_survey_manager.scenarios.reform_scenario is deprecated and will be removed in a future version. " + "Prefer: from openfisca_survey_manager.policy.scenarios.reform_scenario import ...", + DeprecationWarning, + stacklevel=2, +) - -class ReformScenario(AbstractSurveyScenario): - """Reform survey scenario.""" - - def _get_simulation(self, use_baseline: bool = False): - """ - Get relevant simulation - - Args: - use_baseline (bool, optional): Whether to get baseline or reform simulation. Defaults to False. - """ - - if len(self.simulations) == 1: - return next(iter(self.simulations.values())) - - simulation_name = "baseline" if use_baseline else "reform" - simulation = self.simulations[simulation_name] - assert simulation is not None, f"{simulation_name} does not exist" - return simulation - - def build_input_data(self, **kwargs): - """Build input data.""" - raise NotImplementedError - - def calculate_series(self, variable, period=None, use_baseline=False): - """Compute variable values for period and baseline or reform tax benefit and system. - - Args: - variable(str, optional): Variable to compute - period(Period, optional): Period, defaults to None - use_baseline(bool, optional): Use baseline tax and benefit system, defaults to False - - Returns: - pandas.Series: Variable values - - """ - return pd.Series( - data=self.calculate_variable(variable, period, use_baseline), - name=variable, - ) - - def calculate_variable(self, variable, period=None, use_baseline=False): - """Compute variable values for period and baseline or reform tax benefit and system. - - Args: - variable(str, optional): Variable to compute - period(Period, optional): Period, defaults to None - use_baseline(bool, optional): Use baseline tax and benefit system, defaults to False - - Returns: - numpy.ndarray: Variable values - - """ - simulation = self._get_simulation(use_baseline) - return simulation.adaptative_calculate_variable(variable, period=period) - - def compute_aggregate( - self, - variable: Optional[str] = None, - aggfunc: str = "sum", - filter_by: Optional[str] = None, - period: Optional[Union[int, str, Period]] = None, - use_baseline: bool = False, - difference: bool = False, - missing_variable_default_value=np.nan, - weighted: bool = True, - alternative_weights: Optional[Union[str, int, float, Array]] = None, - ): - """Compute variable aggregate. - - Args: - variable (str, optional): Variable to aggregate. Defaults to None. - aggfunc (str, optional): Aggregation function. Defaults to 'sum'. - filter_by (str, optional): Filter variable or expression to use. Defaults to None. - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - use_baseline: Use baseline simulation. Defaults to False. - missing_variable_default_value (optional): Value to use for missing values. Defaults to np.nan. - weighted (bool, optional): Whether to weight the variable or not. Defaults to True. - alternative_weights (Optional[Union[str, int, float, Array]], optional): Alternative weigh to use. - Defaults to None. - filtering_variable_by_entity (Dict, optional): Filtering variable by entity. Defaults to None. - - Returns: - float: Aggregate - """ - assert aggfunc in ["count", "mean", "sum", "count_non_zero"] - assert period is not None - assert not (difference and use_baseline), "Can't have difference and use_baseline both set to True" - - if difference: - return self.compute_aggregate( - variable=variable, - aggfunc=aggfunc, - filter_by=filter_by, - period=period, - use_baseline=False, - missing_variable_default_value=missing_variable_default_value, - weighted=weighted, - alternative_weights=alternative_weights, - ) - self.compute_aggregate( - variable=variable, - aggfunc=aggfunc, - filter_by=filter_by, - period=period, - use_baseline=True, - missing_variable_default_value=missing_variable_default_value, - weighted=weighted, - alternative_weights=alternative_weights, - ) - - assert variable is not None - simulation = self._get_simulation(use_baseline) - return simulation.compute_aggregate( - variable=variable, - aggfunc=aggfunc, - filter_by=filter_by, - period=period, - missing_variable_default_value=missing_variable_default_value, - weighted=weighted, - alternative_weights=alternative_weights, - filtering_variable_by_entity=self.filtering_variable_by_entity, - ) - - def compute_quantiles( - self, - variable: Optional[str] = None, - nquantiles=None, - period=None, - use_baseline=False, - filter_by=None, - weighted=True, - alternative_weights=None, - ): - assert variable is not None - assert nquantiles is not None - simulation = self._get_simulation(use_baseline) - - return simulation.compute_quantiles( - variable=variable, - period=period, - nquantiles=nquantiles, - filter_by=filter_by, - weighted=weighted, - alternative_weights=alternative_weights, - ) - - def compute_marginal_tax_rate( - self, - target_variable: str, - period: Optional[Union[int, str, Period]], - use_baseline: bool = False, - value_for_zero_varying_variable: float = 0.0, - ) -> Array: - """ - Compute marginal a rate of a target (MTR) with respect to a varying variable. - - Args: - target_variable (str): the variable which marginal tax rate is computed - period (Optional[Union[int, str, Period]], optional): Period. Defaults to None. - use_baseline: Use baseline simulation. Defaults to False. - value_for_zero_varying_variable (float, optional): value of MTR when the varying variable is zero. - Defaults to 0. - - Returns: - numpy.array: Vector of marginal rates - """ - if use_baseline: - return super().compute_marginal_tax_rate( - target_variable=target_variable, - period=period, - simulation="baseline", - value_for_zero_varying_variable=value_for_zero_varying_variable, - ) - else: - return super().compute_marginal_tax_rate( - target_variable=target_variable, - period=period, - simulation="reform", - value_for_zero_varying_variable=value_for_zero_varying_variable, - ) - - def compute_pivot_table( - self, - aggfunc="mean", - columns=None, - difference=False, - filter_by=None, - index=None, - period=None, - use_baseline=False, - use_baseline_for_columns=None, - values=None, - missing_variable_default_value=np.nan, - concat_axis=None, - weighted=True, - alternative_weights=None, - ): - filtering_variable_by_entity = self.filtering_variable_by_entity - - return Simulation.compute_pivot_table( - aggfunc=aggfunc, - columns=columns, - baseline_simulation=self._get_simulation(use_baseline=True), - filter_by=filter_by, - index=index, - period=period, - simulation=self._get_simulation(use_baseline), - difference=difference, - use_baseline_for_columns=use_baseline_for_columns, - values=values, - missing_variable_default_value=missing_variable_default_value, - concat_axis=concat_axis, - weighted=weighted, - alternative_weights=alternative_weights, - filtering_variable_by_entity=filtering_variable_by_entity, - ) - - def compute_winners_losers( - self, - variable=None, - filter_by=None, - period=None, - absolute_minimal_detected_variation=0, - relative_minimal_detected_variation=0.01, - observations_threshold=None, - weighted=True, - alternative_weights=None, - ): - return super().compute_winners_losers( - simulation="reform", - baseline_simulation="baseline", - variable=variable, - filter_by=filter_by, - period=period, - absolute_minimal_detected_variation=getattr( - self, "absolute_minimal_detected_variation", absolute_minimal_detected_variation - ), - relative_minimal_detected_variation=getattr( - self, "relative_minimal_detected_variation", relative_minimal_detected_variation - ), - observations_threshold=getattr(self, "observations_threshold", observations_threshold), - weighted=weighted, - alternative_weights=alternative_weights, - ) - - def create_data_frame_by_entity( - self, - variables=None, - expressions=None, - filter_by=None, - index=False, - period=None, - use_baseline=False, - merge=False, - ): - """Create dataframe(s) of computed variable for every entity (eventually merged in a unique dataframe). - - Args: - variables(list, optional): Variable to compute, defaults to None - expressions(str, optional): Expressions to compute, defaults to None - filter_by(str, optional): Boolean variable or expression, defaults to None - index(bool, optional): Index by entity id, defaults to False - period(Period, optional): Period, defaults to None - use_baseline(bool, optional): Use baseline tax and benefit system, defaults to False - merge(bool, optional): Merge all the entities in one data frame, defaults to False - - Returns: - dict or pandas.DataFrame: Dictionnary of dataframes by entities or dataframe with all the computed variables - - """ - simulation = self._get_simulation(use_baseline) - return simulation.create_data_frame_by_entity( - variables=variables, - expressions=expressions, - filter_by=filter_by, - index=index, - period=period, - merge=merge, - ) +__all__ = ["ReformScenario"] diff --git a/openfisca_survey_manager/statshelpers.py b/openfisca_survey_manager/statshelpers.py deleted file mode 100644 index c41f2d0b..00000000 --- a/openfisca_survey_manager/statshelpers.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.statshelpers import ... -""" - -import warnings - -from openfisca_survey_manager.policy.statshelpers import ( - bottom_share, - gini, - kakwani, - lorenz, - mark_weighted_percentiles, - pseudo_lorenz, - top_share, - weighted_quantiles, - weightedcalcs_quantiles, -) - -warnings.warn( - "openfisca_survey_manager.statshelpers is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.statshelpers import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "bottom_share", - "gini", - "kakwani", - "lorenz", - "mark_weighted_percentiles", - "pseudo_lorenz", - "top_share", - "weighted_quantiles", - "weightedcalcs_quantiles", -] diff --git a/openfisca_survey_manager/tests/test_calibration.py b/openfisca_survey_manager/tests/test_calibration.py index 5d8074fd..d9f827a8 100644 --- a/openfisca_survey_manager/tests/test_calibration.py +++ b/openfisca_survey_manager/tests/test_calibration.py @@ -2,7 +2,7 @@ from openfisca_core.tools import assert_near from openfisca_survey_manager.policy.calibration import Calibration -from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.tests import tax_benefit_system from openfisca_survey_manager.tests.test_scenario import ( create_randomly_initialized_survey_scenario, diff --git a/openfisca_survey_manager/tests/test_coverage_boost.py b/openfisca_survey_manager/tests/test_coverage_boost.py index 2838ab5f..845843e1 100644 --- a/openfisca_survey_manager/tests/test_coverage_boost.py +++ b/openfisca_survey_manager/tests/test_coverage_boost.py @@ -8,8 +8,8 @@ from openfisca_survey_manager.core.dataset import SurveyCollection from openfisca_survey_manager.core.survey import Survey from openfisca_survey_manager.policy import AbstractAggregates +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.policy.variables import quantile -from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.scripts.build_collection import ( check_template_config_files, create_data_file_by_format, diff --git a/openfisca_survey_manager/tests/test_enum.py b/openfisca_survey_manager/tests/test_enum.py index 8189acb8..8d6277c4 100644 --- a/openfisca_survey_manager/tests/test_enum.py +++ b/openfisca_survey_manager/tests/test_enum.py @@ -1,7 +1,7 @@ import pandas as pd from openfisca_country_template.variables.housing import HousingOccupancyStatus -from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.tests import tax_benefit_system diff --git a/openfisca_survey_manager/tests/test_parquet.py b/openfisca_survey_manager/tests/test_parquet.py index 6612874f..d1f4b51b 100644 --- a/openfisca_survey_manager/tests/test_parquet.py +++ b/openfisca_survey_manager/tests/test_parquet.py @@ -4,13 +4,13 @@ import pytest from openfisca_core import periods -from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario +from openfisca_survey_manager.core.dataset import SurveyCollection +from openfisca_survey_manager.core.survey import NoMoreDataError +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.scripts.build_collection import ( add_survey_to_collection, build_survey_collection, ) -from openfisca_survey_manager.survey_collections import SurveyCollection -from openfisca_survey_manager.surveys import NoMoreDataError from openfisca_survey_manager.tests import tax_benefit_system logger = logging.getLogger(__name__) diff --git a/openfisca_survey_manager/tests/test_quantile.py b/openfisca_survey_manager/tests/test_quantile.py index 0038b7a7..a2ef4174 100644 --- a/openfisca_survey_manager/tests/test_quantile.py +++ b/openfisca_survey_manager/tests/test_quantile.py @@ -5,9 +5,9 @@ from openfisca_core.taxbenefitsystems import TaxBenefitSystem from openfisca_survey_manager.paths import default_config_files_directory +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.policy.statshelpers import mark_weighted_percentiles from openfisca_survey_manager.policy.variables import quantile -from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.tests.test_scenario import setup_test_config Individu = build_entity( diff --git a/openfisca_survey_manager/tests/test_scenario.py b/openfisca_survey_manager/tests/test_scenario.py index 9657639d..941e5626 100644 --- a/openfisca_survey_manager/tests/test_scenario.py +++ b/openfisca_survey_manager/tests/test_scenario.py @@ -10,8 +10,8 @@ from openfisca_survey_manager.configuration.paths import ( default_config_files_directory, ) -from openfisca_survey_manager.scenarios.abstract_scenario import AbstractSurveyScenario -from openfisca_survey_manager.scenarios.reform_scenario import ReformScenario +from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario +from openfisca_survey_manager.policy.scenarios.reform_scenario import ReformScenario from openfisca_survey_manager.tests import tax_benefit_system from openfisca_survey_manager.tests.input_dataframe_generator import ( make_input_dataframe_by_entity, diff --git a/openfisca_survey_manager/variables.py b/openfisca_survey_manager/variables.py deleted file mode 100644 index afc512ba..00000000 --- a/openfisca_survey_manager/variables.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.variables import ... -""" - -import warnings - -from openfisca_survey_manager.policy.variables import ( - create_quantile, - old_quantile, - quantile, -) - -warnings.warn( - "openfisca_survey_manager.variables is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.variables import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "create_quantile", - "old_quantile", - "quantile", -] From e429ff43b89a978173383ac6bfc41efec2028b36 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Tue, 3 Mar 2026 09:28:06 +0100 Subject: [PATCH 09/18] Add typing to policy folder and related modules - policy: add py.typed marker; type legislation_asof, variables, coicop, matching, statshelpers, calmar, calibration, simulation_builder, aggregates, simulations, scenarios (abstract + reform) - configuration/models: type Config.__init__ and save - processing/__init__: type __getattr__ return Made-with: Cursor --- .../configuration/models.py | 12 +- openfisca_survey_manager/policy/aggregates.py | 97 +++++----- openfisca_survey_manager/policy/calmar.py | 12 +- openfisca_survey_manager/policy/coicop.py | 18 +- .../policy/legislation_asof.py | 42 +++-- openfisca_survey_manager/policy/matching.py | 37 ++-- openfisca_survey_manager/policy/py.typed | 0 .../policy/scenarios/abstract_scenario.py | 165 ++++++++++++------ .../policy/scenarios/reform_scenario.py | 82 +++++---- .../policy/simulation_builder.py | 42 ++++- .../policy/simulations.py | 65 +++---- .../policy/statshelpers.py | 61 ++++++- openfisca_survey_manager/policy/variables.py | 26 ++- .../processing/__init__.py | 2 +- 14 files changed, 438 insertions(+), 223 deletions(-) create mode 100644 openfisca_survey_manager/policy/py.typed diff --git a/openfisca_survey_manager/configuration/models.py b/openfisca_survey_manager/configuration/models.py index 9ab3841a..7a52d399 100644 --- a/openfisca_survey_manager/configuration/models.py +++ b/openfisca_survey_manager/configuration/models.py @@ -1,8 +1,11 @@ """Configuration model (Config class from config.ini).""" +from __future__ import annotations + import configparser import logging from pathlib import Path +from typing import Optional, Union log = logging.getLogger(__name__) @@ -10,9 +13,12 @@ class Config(configparser.ConfigParser): """Parser for config.ini; used by SurveyCollection and build scripts.""" - config_ini = None + config_ini: Optional[Path] = None - def __init__(self, config_files_directory=None): + def __init__( + self, + config_files_directory: Optional[Union[Path, str]] = None, + ) -> None: configparser.ConfigParser.__init__(self) if config_files_directory is not None: config_ini = Path(config_files_directory) / "config.ini" @@ -21,7 +27,7 @@ def __init__(self, config_files_directory=None): self.read([config_ini]) log.debug("Loaded config from %s", config_ini) - def save(self): + def save(self) -> None: assert self.config_ini, "configuration file path is not defined" assert self.config_ini.exists() config_file = self.config_ini.open("w") diff --git a/openfisca_survey_manager/policy/aggregates.py b/openfisca_survey_manager/policy/aggregates.py index 7d8059c4..7ed440fc 100644 --- a/openfisca_survey_manager/policy/aggregates.py +++ b/openfisca_survey_manager/policy/aggregates.py @@ -1,8 +1,12 @@ +"""Aggregates computation for survey scenarios.""" + +from __future__ import annotations + import collections import logging from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Any, Optional, Union import numpy as np import pandas as pd @@ -29,11 +33,11 @@ class AbstractAggregates: def __init__( self, - survey_scenario=None, - absolute_minimal_detected_variation=0, - relative_minimal_detected_variation=0, - observations_threshold=0, - ): + survey_scenario: Any = None, + absolute_minimal_detected_variation: float = 0, + relative_minimal_detected_variation: float = 0, + observations_threshold: float = 0, + ) -> None: assert survey_scenario is not None self.period = survey_scenario.period @@ -270,7 +274,7 @@ def compute_variable_aggregates( return variable_data_frame - def create_description(self): + def create_description(self) -> pd.DataFrame: """Create a description dataframe.""" now = datetime.now() return pd.DataFrame( @@ -284,14 +288,14 @@ def create_description(self): def to_csv( self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): + path: Optional[Union[Path, str]] = None, + absolute: bool = True, + amount: bool = True, + beneficiaries: bool = True, + default: str = "actual", + relative: bool = True, + target: str = "reform", + ) -> None: """Saves the table to csv.""" assert path is not None @@ -313,14 +317,14 @@ def to_csv( def to_excel( self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): + path: Optional[Union[Path, str]] = None, + absolute: bool = True, + amount: bool = True, + beneficiaries: bool = True, + default: str = "actual", + relative: bool = True, + target: str = "reform", + ) -> None: """Save the table to excel.""" assert path is not None @@ -346,14 +350,14 @@ def to_excel( def to_html( self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): + path: Optional[Union[Path, str]] = None, + absolute: bool = True, + amount: bool = True, + beneficiaries: bool = True, + default: str = "actual", + relative: bool = True, + target: str = "reform", + ) -> str: """Get or saves the table to html format.""" df = self.get_data_frame( absolute=absolute, @@ -377,14 +381,14 @@ def to_html( def to_markdown( self, - path=None, - absolute=True, - amount=True, - beneficiaries=True, - default="actual", - relative=True, - target="reform", - ): + path: Optional[Union[Path, str]] = None, + absolute: bool = True, + amount: bool = True, + beneficiaries: bool = True, + default: str = "actual", + relative: bool = True, + target: str = "reform", + ) -> str: """Get or saves the table to markdown format.""" df = self.get_data_frame( absolute=absolute, @@ -425,7 +429,7 @@ def get_data_frame( relative: bool = True, target: str = "reform", ignore_labels: bool = False, - ): + ) -> pd.DataFrame: assert target is None or target in ["reform", "baseline"] columns = self.labels.keys() @@ -509,10 +513,14 @@ def get_data_frame( return df - def load_actual_data(self, period=None): + def load_actual_data(self, period: Any = None) -> None: pass - def compute_winners_losers(self, variable: str, filter_by: Optional[str] = None): + def compute_winners_losers( + self, + variable: str, + filter_by: Optional[str] = None, + ) -> pd.DataFrame: if "reform" not in self.simulations or "baseline" not in self.simulations: log.warning("Cannot compute winners and losers without a reform and a baseline simulation.") return pd.DataFrame() @@ -546,7 +554,10 @@ def compute_winners_losers(self, variable: str, filter_by: Optional[str] = None) ) return winners_losers_df - def compute_all_winners_losers(self, filter_by: Optional[str] = None): + def compute_all_winners_losers( + self, + filter_by: Optional[str] = None, + ) -> pd.DataFrame: all_winners_losers = pd.DataFrame() for variable in self.aggregate_variables: winners_losers = self.compute_winners_losers(variable, filter_by=filter_by) diff --git a/openfisca_survey_manager/policy/calmar.py b/openfisca_survey_manager/policy/calmar.py index 4133727d..c54d0497 100644 --- a/openfisca_survey_manager/policy/calmar.py +++ b/openfisca_survey_manager/policy/calmar.py @@ -190,20 +190,20 @@ def calmar( assert lo is not None, "When method == 'logit', a value < 1 for lo is mandatory" assert lo < 1, "lo should be < 1" - def f(x): + def f(x: np.ndarray) -> np.ndarray: return logit(x, lo, up) - def f_prime(x): + def f_prime(x: np.ndarray) -> np.ndarray: return logit_prime(x, lo, up) elif method == "hyperbolic sinus": assert alpha is not None, "When method == 'hyperbolic sinus', a value > 0 for alpha is mandatory" assert alpha > 0, "alpha should be > 0" - def f(x): + def f(x: np.ndarray) -> np.ndarray: return hyperbolic_sinus(x, alpha) - def f_prime(x): + def f_prime(x: np.ndarray) -> np.ndarray: return hyperbolic_sinus_prime(x, alpha) margins = margins.copy() @@ -300,10 +300,10 @@ def f_prime(x): margins_dict[var] = val # Résolution des équations du premier ordre - def constraint(lambda_): + def constraint(lambda_: np.ndarray) -> np.ndarray: return dot(d * f(dot(x, lambda_)), x) - xmargins - def constraint_prime(lambda_): + def constraint_prime(lambda_: np.ndarray) -> np.ndarray: return dot(d * (x.T * f_prime(dot(x, lambda_))), x) # le jacobien ci-dessus est constraintprime = @(lambda) x*(d.*Fprime(x'*lambda)*x'); diff --git a/openfisca_survey_manager/policy/coicop.py b/openfisca_survey_manager/policy/coicop.py index 22367356..f6785ed5 100644 --- a/openfisca_survey_manager/policy/coicop.py +++ b/openfisca_survey_manager/policy/coicop.py @@ -1,5 +1,10 @@ +"""COICOP nomenclature helpers.""" + +from __future__ import annotations + import logging from pathlib import Path +from typing import Literal import pandas as pd @@ -11,11 +16,18 @@ legislation_directory = Path(openfisca_survey_manager_location) / "openfisca_survey_manager" / "assets" -sub_levels = ["divisions", "groupes", "classes", "sous_classes", "postes"] +sub_levels: tuple[str, ...] = ("divisions", "groupes", "classes", "sous_classes", "postes") divisions = [f"0{i}" for i in range(1, 10)] + ["11", "12"] +CoicopLevel = Literal["divisions", "groupes", "classes", "sous_classes", "postes"] + -def build_coicop_level_nomenclature(level, year=2016, keep_code=False, to_csv=False): +def build_coicop_level_nomenclature( + level: CoicopLevel, + year: int = 2016, + keep_code: bool = False, + to_csv: bool = False, +) -> pd.DataFrame: assert level in sub_levels log.debug(f"Reading nomenclature coicop {year} source data for level {level}") try: @@ -69,7 +81,7 @@ def build_coicop_level_nomenclature(level, year=2016, keep_code=False, to_csv=Fa return data_frame -def build_raw_coicop_nomenclature(year=2016): +def build_raw_coicop_nomenclature(year: int = 2016) -> pd.DataFrame: """Builds raw COICOP nomenclature from ecoicop levels""" coicop_nomenclature = None diff --git a/openfisca_survey_manager/policy/legislation_asof.py b/openfisca_survey_manager/policy/legislation_asof.py index b5f48c9d..362fa6e4 100644 --- a/openfisca_survey_manager/policy/legislation_asof.py +++ b/openfisca_survey_manager/policy/legislation_asof.py @@ -1,6 +1,9 @@ """Shared helpers (no survey collection dependency) to avoid circular imports.""" +from __future__ import annotations + import logging +from typing import Any, Optional from openfisca_core import periods from openfisca_core.parameters import ParameterNode, Scale @@ -8,19 +11,19 @@ log = logging.getLogger(__name__) -def do_nothing(*args, **kwargs): +def do_nothing(*args: Any, **kwargs: Any) -> None: return None def inflate_parameters( - parameters, - inflator, - base_year, - last_year=None, - ignore_missing_units=False, - start_instant=None, - round_ndigits=2, -): + parameters: ParameterNode | Scale | Any, + inflator: float, + base_year: int, + last_year: Optional[int] = None, + ignore_missing_units: bool = False, + start_instant: Optional[str] = None, + round_ndigits: int = 2, +) -> None: """ Inflate a Parameter node or a Parameter leaf for the years between base_year and last_year. @@ -94,7 +97,14 @@ def inflate_parameters( ) -def inflate_parameter_leaf(sub_parameter, base_year, inflator, unit_type="unit", start_instant=None, round_ndigits=2): +def inflate_parameter_leaf( + sub_parameter: Any, + base_year: int, + inflator: float, + unit_type: str = "unit", + start_instant: Optional[str] = None, + round_ndigits: int = 2, +) -> None: """ Inflate a Parameter leaf according to unit type for the year after base_year. @@ -167,13 +177,13 @@ def inflate_parameter_leaf(sub_parameter, base_year, inflator, unit_type="unit", ) -def asof(tax_benefit_system, instant): +def asof(tax_benefit_system: Any, instant: str | periods.Instant) -> None: parameters = tax_benefit_system.parameters parameters_asof(parameters, instant) variables_asof(tax_benefit_system, instant) -def leaf_asof(sub_parameter, instant): +def leaf_asof(sub_parameter: Any, instant: periods.Instant) -> None: kept_instants_str = [ parameter_at_instant.instant_str for parameter_at_instant in sub_parameter.values_list @@ -187,7 +197,7 @@ def leaf_asof(sub_parameter, instant): sub_parameter.update(start=last_admissible_instant_str, value=sub_parameter(last_admissible_instant_str)) -def parameters_asof(parameters, instant): +def parameters_asof(parameters: ParameterNode | Any, instant: str | periods.Instant) -> None: if isinstance(instant, str): instant = periods.instant(instant) assert isinstance(instant, periods.Instant) @@ -210,7 +220,11 @@ def parameters_asof(parameters, instant): leaf_asof(sub_parameter, instant) -def variables_asof(tax_benefit_system, instant, variables_list=None): +def variables_asof( + tax_benefit_system: Any, + instant: str | periods.Instant, + variables_list: Optional[list[str]] = None, +) -> None: if isinstance(instant, str): instant = periods.instant(instant) assert isinstance(instant, periods.Instant) diff --git a/openfisca_survey_manager/policy/matching.py b/openfisca_survey_manager/policy/matching.py index 14ec18ab..06f8e996 100644 --- a/openfisca_survey_manager/policy/matching.py +++ b/openfisca_survey_manager/policy/matching.py @@ -1,7 +1,10 @@ """Nearest-neighbor donor (NND) hot deck matching — pure Python or R (StatMatch).""" +from __future__ import annotations + import logging from pathlib import Path +from typing import List, Optional, Union import numpy as np import pandas as pd @@ -13,7 +16,10 @@ config_files_directory = Path(openfisca_survey_manager_location) -def _normalize_list(x, name="variables"): +def _normalize_list( + x: Optional[Union[str, List[str]]], + name: str = "variables", +) -> Optional[list[str]]: """Return a list of variable names from str or list.""" if x is None: return None @@ -168,7 +174,7 @@ def nnd_hotdeck( dist_fun: str = "Manhattan", use_r: bool = False, random_state: int | None = None, -): +) -> tuple[pd.DataFrame, pd.DataFrame]: """ Nearest-neighbor donor (NND) hot deck: match each receiver row to a donor, then fuse z_variables from donor into receiver. @@ -226,7 +232,13 @@ def nnd_hotdeck( return fused_0, fused_1 -def _nnd_hotdeck_rpy2(receiver, donor, matching_variables, z_variables, donor_classes=None): +def _nnd_hotdeck_rpy2( + receiver: pd.DataFrame, + donor: pd.DataFrame, + matching_variables: list[str], + z_variables: list[str], + donor_classes: str | list[str] | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame]: """R (StatMatch) implementation via rpy2; same return as nnd_hotdeck.""" from rpy2.robjects import pandas2ri from rpy2.robjects.packages import importr @@ -269,12 +281,12 @@ def _nnd_hotdeck_rpy2(receiver, donor, matching_variables, z_variables, donor_cl def nnd_hotdeck_using_rpy2( - receiver=None, - donor=None, - matching_variables=None, - z_variables=None, - donor_classes=None, -): + receiver: pd.DataFrame | None = None, + donor: pd.DataFrame | None = None, + matching_variables: str | list[str] | None = None, + z_variables: str | list[str] | None = None, + donor_classes: str | list[str] | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame]: """ NND hot deck via R (StatMatch). Prefer `nnd_hotdeck(..., use_r=True)`. """ @@ -288,7 +300,12 @@ def nnd_hotdeck_using_rpy2( ) -def nnd_hotdeck_using_feather(receiver=None, donor=None, matching_variables=None, z_variables=None): +def nnd_hotdeck_using_feather( + receiver: pd.DataFrame | None = None, + donor: pd.DataFrame | None = None, + matching_variables: str | list[str] | None = None, + z_variables: str | list[str] | None = None, +) -> None: """ Not working """ diff --git a/openfisca_survey_manager/policy/py.typed b/openfisca_survey_manager/policy/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/openfisca_survey_manager/policy/scenarios/abstract_scenario.py b/openfisca_survey_manager/policy/scenarios/abstract_scenario.py index 62e1a5ba..6c5ef32e 100644 --- a/openfisca_survey_manager/policy/scenarios/abstract_scenario.py +++ b/openfisca_survey_manager/policy/scenarios/abstract_scenario.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union import numpy as np import pandas as pd @@ -44,11 +44,16 @@ class AbstractSurveyScenario: varying_variable = None weight_variable_by_entity = None - def build_input_data(self, **kwargs): + def build_input_data(self, **kwargs: Any) -> None: """Build input data.""" raise NotImplementedError - def calculate_series(self, variable, period=None, simulation=None): + def calculate_series( + self, + variable: str, + period: Optional[Union[int, str, Period]] = None, + simulation: Optional[str] = None, + ) -> pd.Series: """Compute variable values for period for a given simulation. Args: @@ -65,7 +70,12 @@ def calculate_series(self, variable, period=None, simulation=None): name=variable, ) - def calculate_variable(self, variable, period=None, simulation=None): + def calculate_variable( + self, + variable: str, + period: Optional[Union[int, str, Period]] = None, + simulation: Optional[str] = None, + ) -> Array: """Compute variable values for period for a given simulation. Args: @@ -303,7 +313,7 @@ def compute_marginal_tax_rate( else: target_variable_entity_key = variables[target_variable].entity.key - def cast_to_target_entity(simulation: Simulation): + def cast_to_target_entity(simulation: Simulation) -> np.ndarray: population = simulation.populations[target_variable_entity_key] df = ( pd.DataFrame( @@ -337,21 +347,21 @@ def cast_to_target_entity(simulation: Simulation): def compute_pivot_table( self, - aggfunc="mean", - columns=None, - baseline_simulation=None, - filter_by=None, - index=None, - period=None, - simulation=None, - difference=False, - use_baseline_for_columns=None, - values=None, - missing_variable_default_value=np.nan, - concat_axis=None, - weighted=True, - alternative_weights=None, - ): + aggfunc: str = "mean", + columns: Optional[list[str]] = None, + baseline_simulation: Optional[str] = None, + filter_by: Optional[str] = None, + index: Optional[list[str]] = None, + period: Optional[Union[int, str, Period]] = None, + simulation: Optional[str] = None, + difference: bool = False, + use_baseline_for_columns: Optional[bool] = None, + values: Optional[list[str]] = None, + missing_variable_default_value: Any = np.nan, + concat_axis: Optional[int] = None, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + ) -> pd.DataFrame: """Compute a pivot table of agregated values casted along specified index and columns. Args: @@ -405,17 +415,17 @@ def compute_pivot_table( def compute_winners_losers( self, - variable, - simulation, - baseline_simulation=None, - filter_by=None, - period=None, - absolute_minimal_detected_variation=0, - relative_minimal_detected_variation=0.01, - observations_threshold=None, - weighted=True, - alternative_weights=None, - ): + variable: str, + simulation: str, + baseline_simulation: Optional[str] = None, + filter_by: Optional[str] = None, + period: Optional[Union[int, str, Period]] = None, + absolute_minimal_detected_variation: float = 0, + relative_minimal_detected_variation: float = 0.01, + observations_threshold: Optional[int] = None, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + ) -> dict[str, Union[int, float]]: simulation = self.simulations[simulation] if baseline_simulation: baseline_simulation = self.simulations[baseline_simulation] @@ -434,8 +444,15 @@ def compute_winners_losers( ) def create_data_frame_by_entity( - self, variables=None, expressions=None, filter_by=None, index=False, period=None, simulation=None, merge=False - ): + self, + variables: Optional[list[str]] = None, + expressions: Optional[list[str]] = None, + filter_by: Optional[str] = None, + index: bool = False, + period: Optional[Union[int, str, Period]] = None, + simulation: Optional[str] = None, + merge: bool = False, + ) -> Union[pd.DataFrame, dict[str, pd.DataFrame]]: """Create dataframe(s) of computed variable for every entity (eventually merged in a unique dataframe). Args: @@ -466,7 +483,11 @@ def create_data_frame_by_entity( merge=merge, ) - def custom_input_data_frame(self, input_data_frame, **kwargs): + def custom_input_data_frame( + self, + input_data_frame: pd.DataFrame, + **kwargs: Any, + ) -> None: """Customize input data frame. Args: @@ -475,7 +496,12 @@ def custom_input_data_frame(self, input_data_frame, **kwargs): """ pass - def dump_data_frame_by_entity(self, variables=None, survey_collection=None, survey_name=None): + def dump_data_frame_by_entity( + self, + variables: Optional[list[str]] = None, + survey_collection: Optional[Any] = None, + survey_name: Optional[str] = None, + ) -> None: assert survey_collection is not None assert survey_name is not None assert variables is not None @@ -486,7 +512,7 @@ def dump_data_frame_by_entity(self, variables=None, survey_collection=None, surv survey_collection.surveys.append(survey) survey_collection.dump(collection="openfisca") - def dump_simulations(self, directory: str): + def dump_simulations(self, directory: str) -> None: """ Dump simulations. @@ -504,7 +530,7 @@ def dump_simulations(self, directory: str): simulation = next(iter(self.simulations.values())) dump_simulation(simulation, directory) - def generate_performance_data(self, output_dir: str): + def generate_performance_data(self, output_dir: str) -> None: if not self.trace: raise SurveyManagerError("Method generate_performance_data cannot be used if trace hasn't been activated.") @@ -517,7 +543,12 @@ def generate_performance_data(self, output_dir: str): simulation.tracer.generate_performance_graph(simulation_dir) simulation.tracer.generate_performance_tables(simulation_dir) - def inflate(self, inflator_by_variable=None, period=None, target_by_variable=None): + def inflate( + self, + inflator_by_variable: Optional[dict[str, float]] = None, + period: Optional[Union[int, str, Period]] = None, + target_by_variable: Optional[dict[str, float]] = None, + ) -> None: assert inflator_by_variable or target_by_variable assert period is not None inflator_by_variable = {} if inflator_by_variable is None else inflator_by_variable @@ -530,14 +561,14 @@ def inflate(self, inflator_by_variable=None, period=None, target_by_variable=Non def init_from_data( self, - calibration_kwargs=None, - inflation_kwargs=None, - rebuild_input_data=False, - rebuild_kwargs=None, - data=None, - memory_config=None, - use_marginal_tax_rate=False, - ): + calibration_kwargs: Optional[dict[str, Any]] = None, + inflation_kwargs: Optional[dict[str, Any]] = None, + rebuild_input_data: bool = False, + rebuild_kwargs: Optional[dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, + memory_config: Optional[Any] = None, + use_marginal_tax_rate: bool = False, + ) -> None: """Initialise a survey scenario from data. Args: @@ -601,8 +632,14 @@ def init_from_data( self.inflate(**inflation_kwargs) def new_simulation( - self, simulation_name, debug=False, trace=False, data=None, memory_config=None, marginal_tax_rate_only=False - ): + self, + simulation_name: str, + debug: bool = False, + trace: bool = False, + data: Optional[dict[str, Any]] = None, + memory_config: Optional[Any] = None, + marginal_tax_rate_only: bool = False, + ) -> Simulation: tax_benefit_system = self.tax_benefit_systems[simulation_name] assert tax_benefit_system is not None @@ -643,13 +680,13 @@ def new_simulation( return simulation - def memory_usage(self): + def memory_usage(self) -> None: """Log memory usage.""" for simulation_name, simulation in self.simulations.items(): log.info("simulation: %s", simulation_name) simulation.print_memory_usage() - def neutralize_variables(self, tax_benefit_system): + def neutralize_variables(self, tax_benefit_system: TaxBenefitSystem) -> None: """Neutralizes input variables not in input dataframe and keep some crucial variables. Args: @@ -668,7 +705,7 @@ def neutralize_variables(self, tax_benefit_system): tax_benefit_system.neutralize_variable(variable_name) - def restore_simulations(self, directory, **kwargs): + def restore_simulations(self, directory: Union[str, Path], **kwargs: Any) -> None: """Restores SurveyScenario's simulations. Args: @@ -690,7 +727,7 @@ def restore_simulations(self, directory, **kwargs): simulation.id_variable_by_entity_key = self.id_variable_by_entity_key self.simulations["unique_simulation"] = simulation - def set_input_data_frame(self, input_data_frame): + def set_input_data_frame(self, input_data_frame: pd.DataFrame) -> None: """Set the input dataframe. Args: @@ -713,7 +750,10 @@ def set_tax_benefit_systems(self, tax_benefit_systems: dict[str, TaxBenefitSyste # self.tax_benefit_systems = tax_benefit_systems - def set_weight_variable_by_entity(self, weight_variable_by_entity=None): + def set_weight_variable_by_entity( + self, + weight_variable_by_entity: Optional[dict[str, str]] = None, + ) -> None: if weight_variable_by_entity is not None: self.weight_variable_by_entity = weight_variable_by_entity @@ -721,7 +761,12 @@ def set_weight_variable_by_entity(self, weight_variable_by_entity=None): for simulation in self.simulations.values(): simulation.set_weight_variable_by_entity(self.weight_variable_by_entity) - def summarize_variable(self, variable=None, weighted=False, force_compute=False): + def summarize_variable( + self, + variable: Optional[str] = None, + weighted: bool = False, + force_compute: bool = False, + ) -> None: """Log a summary of a variable including its memory usage for all the simulations. Args: @@ -740,12 +785,20 @@ def summarize_variable(self, variable=None, weighted=False, force_compute=False) for _simulation_name, simulation in self.simulations.items(): simulation.summarize_variable(variable, weighted, force_compute) - def _apply_modification(self, simulation, period): + def _apply_modification( + self, + simulation: Simulation, + period: Union[int, str, Period], + ) -> None: period = periods.period(period) varying_variable = self.varying_variable definition_period = simulation.tax_benefit_system.variables[varying_variable].definition_period - def set_variable(varying_variable, varying_variable_value, period_): + def set_variable( + varying_variable: str, + varying_variable_value: np.ndarray, + period_: Period, + ) -> None: delta = self.variation_factor * varying_variable_value new_variable_value = varying_variable_value + delta simulation.delete_arrays(varying_variable, period_) diff --git a/openfisca_survey_manager/policy/scenarios/reform_scenario.py b/openfisca_survey_manager/policy/scenarios/reform_scenario.py index 1859f636..cd056a56 100644 --- a/openfisca_survey_manager/policy/scenarios/reform_scenario.py +++ b/openfisca_survey_manager/policy/scenarios/reform_scenario.py @@ -1,7 +1,7 @@ """Abstract survey scenario definition.""" import logging -from typing import Optional, Union +from typing import Any, Optional, Union import numpy as np import pandas as pd @@ -16,7 +16,7 @@ class ReformScenario(AbstractSurveyScenario): """Reform survey scenario.""" - def _get_simulation(self, use_baseline: bool = False): + def _get_simulation(self, use_baseline: bool = False) -> Simulation: """ Get relevant simulation @@ -32,11 +32,16 @@ def _get_simulation(self, use_baseline: bool = False): assert simulation is not None, f"{simulation_name} does not exist" return simulation - def build_input_data(self, **kwargs): + def build_input_data(self, **kwargs: Any) -> None: """Build input data.""" raise NotImplementedError - def calculate_series(self, variable, period=None, use_baseline=False): + def calculate_series( + self, + variable: str, + period: Optional[Union[int, str, Period]] = None, + use_baseline: bool = False, + ) -> pd.Series: """Compute variable values for period and baseline or reform tax benefit and system. Args: @@ -53,7 +58,12 @@ def calculate_series(self, variable, period=None, use_baseline=False): name=variable, ) - def calculate_variable(self, variable, period=None, use_baseline=False): + def calculate_variable( + self, + variable: str, + period: Optional[Union[int, str, Period]] = None, + use_baseline: bool = False, + ) -> Array: """Compute variable values for period and baseline or reform tax benefit and system. Args: @@ -195,20 +205,20 @@ def compute_marginal_tax_rate( def compute_pivot_table( self, - aggfunc="mean", - columns=None, - difference=False, - filter_by=None, - index=None, - period=None, - use_baseline=False, - use_baseline_for_columns=None, - values=None, - missing_variable_default_value=np.nan, - concat_axis=None, - weighted=True, - alternative_weights=None, - ): + aggfunc: str = "mean", + columns: Optional[list[str]] = None, + difference: bool = False, + filter_by: Optional[str] = None, + index: Optional[list[str]] = None, + period: Optional[Union[int, str, Period]] = None, + use_baseline: bool = False, + use_baseline_for_columns: Optional[bool] = None, + values: Optional[list[str]] = None, + missing_variable_default_value: Any = np.nan, + concat_axis: Optional[int] = None, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + ) -> pd.DataFrame: filtering_variable_by_entity = self.filtering_variable_by_entity return Simulation.compute_pivot_table( @@ -231,15 +241,15 @@ def compute_pivot_table( def compute_winners_losers( self, - variable=None, - filter_by=None, - period=None, - absolute_minimal_detected_variation=0, - relative_minimal_detected_variation=0.01, - observations_threshold=None, - weighted=True, - alternative_weights=None, - ): + variable: Optional[str] = None, + filter_by: Optional[str] = None, + period: Optional[Union[int, str, Period]] = None, + absolute_minimal_detected_variation: float = 0, + relative_minimal_detected_variation: float = 0.01, + observations_threshold: Optional[int] = None, + weighted: bool = True, + alternative_weights: Optional[Union[str, int, float, Array]] = None, + ) -> dict[str, Union[int, float]]: return super().compute_winners_losers( simulation="reform", baseline_simulation="baseline", @@ -259,14 +269,14 @@ def compute_winners_losers( def create_data_frame_by_entity( self, - variables=None, - expressions=None, - filter_by=None, - index=False, - period=None, - use_baseline=False, - merge=False, - ): + variables: Optional[list[str]] = None, + expressions: Optional[list[str]] = None, + filter_by: Optional[str] = None, + index: bool = False, + period: Optional[Union[int, str, Period]] = None, + use_baseline: bool = False, + merge: bool = False, + ) -> Union[pd.DataFrame, dict[str, pd.DataFrame]]: """Create dataframe(s) of computed variable for every entity (eventually merged in a unique dataframe). Args: diff --git a/openfisca_survey_manager/policy/simulation_builder.py b/openfisca_survey_manager/policy/simulation_builder.py index 1feaad1f..9b9594fb 100644 --- a/openfisca_survey_manager/policy/simulation_builder.py +++ b/openfisca_survey_manager/policy/simulation_builder.py @@ -1,7 +1,14 @@ +"""Simulation builder extensions for survey manager.""" + +from __future__ import annotations + import logging +from typing import Any, Optional +import pandas as pd from openfisca_core.model_api import MONTH, YEAR from openfisca_core.simulations.simulation_builder import SimulationBuilder +from openfisca_core.types import TaxBenefitSystem from openfisca_survey_manager.exceptions import SurveyManagerError @@ -17,7 +24,10 @@ # Helpers -def diagnose_variable_mismatch(used_as_input_variables, input_data_frame): +def diagnose_variable_mismatch( + used_as_input_variables: Optional[list[str]], + input_data_frame: pd.DataFrame, +) -> None: """Diagnose variables mismatch. Args: @@ -41,7 +51,7 @@ def diagnose_variable_mismatch(used_as_input_variables, input_data_frame): # SimulationBuilder monkey-patched methods -def _set_id_variable_by_entity_key(builder) -> dict[str, str]: +def _set_id_variable_by_entity_key(builder: SimulationBuilder) -> dict[str, str]: """Identify and sets the correct ids for the different entities.""" if builder.id_variable_by_entity_key is None: log.debug("Use default id_variable names") @@ -52,7 +62,7 @@ def _set_id_variable_by_entity_key(builder) -> dict[str, str]: return builder.id_variable_by_entity_key -def _set_role_variable_by_entity_key(builder) -> dict[str, str]: +def _set_role_variable_by_entity_key(builder: SimulationBuilder) -> dict[str, str]: """Identify and sets the correct roles for the different entities.""" if builder.role_variable_by_entity_key is None: builder.role_variable_by_entity_key = { @@ -62,7 +72,7 @@ def _set_role_variable_by_entity_key(builder) -> dict[str, str]: return builder.role_variable_by_entity_key -def _set_used_as_input_variables_by_entity(builder) -> dict[str, list[str]]: +def _set_used_as_input_variables_by_entity(builder: SimulationBuilder) -> Optional[dict[str, list[str]]]: """Identify and sets the correct input variables for the different entities.""" if builder.used_as_input_variables_by_entity is not None: return @@ -88,7 +98,11 @@ def _set_used_as_input_variables_by_entity(builder) -> dict[str, list[str]]: return builder.used_as_input_variables_by_entity -def filter_input_variables(builder, input_data_frame, tax_benefit_system): +def filter_input_variables( + builder: SimulationBuilder, + input_data_frame: pd.DataFrame, + tax_benefit_system: TaxBenefitSystem, +) -> pd.DataFrame: """Filter the input data frame from variables that won't be used or are set to be computed. Args: @@ -152,7 +166,11 @@ def filter_input_variables(builder, input_data_frame, tax_benefit_system): return input_data_frame -def init_all_entities(builder, input_data_frame, period=None): +def init_all_entities( + builder: SimulationBuilder, + input_data_frame: pd.DataFrame, + period: Any = None, +) -> Any: assert period is not None log.debug(f"Initialasing simulation using input_data_frame for period {period}") builder._set_id_variable_by_entity_key() @@ -182,7 +200,11 @@ def init_all_entities(builder, input_data_frame, period=None): return simulation -def init_entity_structure(builder, entity, input_data_frame): +def init_entity_structure( + builder: SimulationBuilder, + entity: Any, + input_data_frame: pd.DataFrame, +) -> None: """Initialize sthe simulation with tax_benefit_system entities and input_data_frame. Args: @@ -225,7 +247,11 @@ def init_entity_structure(builder, entity, input_data_frame): ) -def init_simulation_with_data_frame(builder, input_data_frame, period): +def init_simulation_with_data_frame( + builder: SimulationBuilder, + input_data_frame: pd.DataFrame, + period: Any, +) -> Any: """Initialize the simulation period with current input_data_frame for an entity if specified.""" used_as_input_variables = builder.used_as_input_variables id_variable_by_entity_key = builder.id_variable_by_entity_key diff --git a/openfisca_survey_manager/policy/simulations.py b/openfisca_survey_manager/policy/simulations.py index 2d55f20f..f699e562 100644 --- a/openfisca_survey_manager/policy/simulations.py +++ b/openfisca_survey_manager/policy/simulations.py @@ -32,7 +32,10 @@ # Helpers -def assert_variables_in_same_entity(tax_benefit_system: TaxBenefitSystem, variables: list): +def assert_variables_in_same_entity( + tax_benefit_system: TaxBenefitSystem, + variables: list[str], +) -> str: """ Assert that variables are in the same entity. @@ -55,7 +58,7 @@ def assert_variables_in_same_entity(tax_benefit_system: TaxBenefitSystem, variab return entity.key -def get_words(text: str): +def get_words(text: str) -> list[str]: return re.compile("[A-Za-z_]+[A-Za-z0-9_]*").findall(text) @@ -116,7 +119,7 @@ def compute_aggregate( missing_variable_default_value: Any = np.nan, weighted: bool = True, alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, + filtering_variable_by_entity: Optional[dict[str, str]] = None, ) -> Optional[Union[float, int]]: """ Compute aggregate of a variable. @@ -247,7 +250,7 @@ def compute_quantiles( filter_by: Optional[str] = None, weighted: bool = True, alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, + filtering_variable_by_entity: Optional[dict[str, str]] = None, ) -> list[float]: """ Compute quantiles of a variable. @@ -299,8 +302,8 @@ def compute_quantiles( def compute_pivot_table( - simulation: Simulation = None, - baseline_simulation: Simulation = None, + simulation: Optional[Simulation] = None, + baseline_simulation: Optional[Simulation] = None, aggfunc: str = "mean", columns: Optional[list[str]] = None, difference: bool = False, @@ -313,8 +316,8 @@ def compute_pivot_table( concat_axis: Optional[int] = None, weighted: bool = True, alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, -): + filtering_variable_by_entity: Optional[dict[str, str]] = None, +) -> pd.DataFrame: """ Compute pivot table. @@ -690,7 +693,7 @@ def compute_winners_losers( observations_threshold: Optional[int] = None, weighted: bool = True, alternative_weights: Optional[Union[str, int, float, Array]] = None, - filtering_variable_by_entity: Optional[dict] = None, + filtering_variable_by_entity: Optional[dict[str, str]] = None, ) -> dict[str, Union[int, float]]: """ Compute the number of winners and losers for a given variable. @@ -806,8 +809,8 @@ def init_entity_data( entity: Entity, filtered_input_data_frame: pd.DataFrame, period: Period, - used_as_input_variables_by_entity: dict, -): + used_as_input_variables_by_entity: dict[str, list[str]], +) -> None: """ Initialize entity in simulation at some period with input provided by a dataframe. @@ -840,9 +843,9 @@ def init_entity_data( def inflate( simulation: Simulation, - inflator_by_variable: Optional[dict] = None, + inflator_by_variable: Optional[dict[str, float]] = None, period: Optional[Union[int, str, Period]] = None, - target_by_variable: Optional[dict] = None, + target_by_variable: Optional[dict[str, float]] = None, ) -> None: tax_benefit_system = simulation.tax_benefit_system for variable_name in set(inflator_by_variable.keys()).union(set(target_by_variable.keys())): @@ -877,12 +880,12 @@ def inflate( def _load_table_for_survey( config_files_directory: str, collection: str, - survey: str, - table: str, + survey: Optional[str] = None, + table: Optional[str] = None, batch_size: Optional[int] = None, batch_index: Optional[int] = None, filter_by: Optional[str] = None, -): +) -> pd.DataFrame: if survey is not None: input_data_frame = load_table( config_files_directory=config_files_directory, @@ -908,15 +911,15 @@ def _load_table_for_survey( def _input_data_table_by_entity_by_period_monolithic( tax_benefit_system: TaxBenefitSystem, - simulation: Simulation, + simulation: Optional[Simulation], period: Period, - input_data_table_by_entity: dict, + input_data_table_by_entity: dict[str, Any], builder: SimulationBuilder, - custom_input_data_frame: Callable, + custom_input_data_frame: Callable[..., Any], config_files_directory: str, collection: str, survey: Optional[str] = None, -): +) -> Simulation: """ Initialize simulation with input data from a table for each entity and period. """ @@ -964,15 +967,15 @@ def _input_data_table_by_entity_by_period_monolithic( def _input_data_table_by_entity_by_period_batch( tax_benefit_system: TaxBenefitSystem, - simulation: Simulation, + simulation: Optional[Simulation], period: Period, - input_data_table_by_entity: dict, + input_data_table_by_entity: dict[str, Any], builder: SimulationBuilder, - custom_input_data_frame: Callable, + custom_input_data_frame: Callable[..., Any], config_files_directory: str, collection: str, survey: Optional[str] = None, -): +) -> Simulation: """ Initialize simulation with input data from a table for each entity and period. """ @@ -1050,8 +1053,8 @@ def _input_data_table_by_entity_by_period_batch( def init_simulation( tax_benefit_system: TaxBenefitSystem, period: Union[str, int, Period], - data: dict, -): + data: dict[str, Any], +) -> Simulation: builder = SimulationBuilder() builder.create_entities(tax_benefit_system) @@ -1189,7 +1192,7 @@ def init_variable_in_entity( variable_name: str, series: pd.Series, period: Period, -): +) -> None: variable = simulation.tax_benefit_system.variables[variable_name] # np.issubdtype cannot handles categorical variables @@ -1266,7 +1269,7 @@ def new_from_tax_benefit_system( debug: bool = False, trace: bool = False, data: Optional[dict] = None, - memory_config: MemoryConfig = None, + memory_config: Optional[MemoryConfig] = None, period: Optional[Union[int, str, Period]] = None, custom_initialize: Optional[Callable] = None, ) -> Simulation: @@ -1298,7 +1301,7 @@ def new_from_tax_benefit_system( return simulation -def print_memory_usage(simulation: Simulation): +def print_memory_usage(simulation: Simulation) -> None: """ Print memory usage. @@ -1340,7 +1343,7 @@ def print_memory_usage(simulation: Simulation): def set_weight_variable_by_entity( simulation: Simulation, - weight_variable_by_entity: dict, + weight_variable_by_entity: Optional[dict[str, str]], ) -> None: """ Set weight variable for each entity. @@ -1357,7 +1360,7 @@ def summarize_variable( variable: Optional[str] = None, weighted: bool = False, force_compute: bool = False, -): +) -> None: """Print a summary of a variable including its memory usage. Args: diff --git a/openfisca_survey_manager/policy/statshelpers.py b/openfisca_survey_manager/policy/statshelpers.py index 3960afcf..a55e2464 100644 --- a/openfisca_survey_manager/policy/statshelpers.py +++ b/openfisca_survey_manager/policy/statshelpers.py @@ -1,4 +1,9 @@ +"""Statistical helpers (Gini, Lorenz, weighted percentiles, etc.).""" + +from __future__ import annotations + import logging +from typing import Optional import numpy as np import pandas as pd @@ -10,7 +15,10 @@ log = logging.getLogger(__name__) -def gini(values, weights=None): +def gini( + values: np.ndarray | pd.Series, + weights: Optional[np.ndarray | pd.Series] = None, +) -> float: """Computes Gini coefficient (normalized to 1). # Using fastgini formula : # i=N j=i @@ -47,7 +55,11 @@ def gini(values, weights=None): return gini -def kakwani(values, ineq_axis, weights=None): +def kakwani( + values: np.ndarray | pd.Series, + ineq_axis: np.ndarray | pd.Series, + weights: Optional[np.ndarray | pd.Series] = None, +) -> float: """Computes the Kakwani index Args: @@ -71,7 +83,10 @@ def kakwani(values, ineq_axis, weights=None): return simps((lcy - plcy), lcx) -def lorenz(values, weights=None): +def lorenz( + values: np.ndarray | pd.Series, + weights: Optional[np.ndarray | pd.Series] = None, +) -> tuple[np.ndarray, np.ndarray]: """Computes Lorenz curve coordinates (x, y) Args: @@ -94,7 +109,13 @@ def lorenz(values, weights=None): return x, y -def mark_weighted_percentiles(a, labels, weights, method, return_quantiles=False): +def mark_weighted_percentiles( + a: np.ndarray | pd.Series, + labels: np.ndarray | list, + weights: np.ndarray | pd.Series, + method: int, + return_quantiles: bool = False, +) -> np.ndarray | tuple[np.ndarray, list[float]]: """ Args: @@ -257,7 +278,11 @@ def mark_weighted_percentiles(a, labels, weights, method, return_quantiles=False return ret -def pseudo_lorenz(values, ineq_axis, weights=None): +def pseudo_lorenz( + values: np.ndarray | pd.Series, + ineq_axis: np.ndarray | pd.Series, + weights: Optional[np.ndarray | pd.Series] = None, +) -> tuple[np.ndarray, np.ndarray]: """Computes The pseudo Lorenz Curve coordinates Args: @@ -280,7 +305,11 @@ def pseudo_lorenz(values, ineq_axis, weights=None): return x, y -def bottom_share(values, rank_from_bottom, weights=None): +def bottom_share( + values: np.ndarray | pd.Series, + rank_from_bottom: float, + weights: Optional[np.ndarray | pd.Series] = None, +) -> float: """ Args: @@ -308,7 +337,11 @@ def bottom_share(values, rank_from_bottom, weights=None): ).sum() -def top_share(values, rank_from_top, weights=None): +def top_share( + values: np.ndarray | pd.Series, + rank_from_top: float, + weights: Optional[np.ndarray | pd.Series] = None, +) -> float: """ Args: @@ -335,7 +368,12 @@ def top_share(values, rank_from_top, weights=None): ).sum() -def weighted_quantiles(data, labels, weights, return_quantiles=False): +def weighted_quantiles( + data: np.ndarray | pd.Series, + labels: np.ndarray | list, + weights: np.ndarray | pd.Series, + return_quantiles: bool = False, +) -> np.ndarray | tuple[np.ndarray, list[float]]: num_categories = len(labels) breaks = linspace(0, 1, num_categories + 1) quantiles = [wquantiles.quantile_1D(data, weights, mybreak) for mybreak in breaks[1:]] @@ -351,7 +389,12 @@ def weighted_quantiles(data, labels, weights, return_quantiles=False): return ret + 1 -def weightedcalcs_quantiles(data, labels, weights, return_quantiles=False): +def weightedcalcs_quantiles( + data: np.ndarray | pd.Series, + labels: np.ndarray | list, + weights: np.ndarray | pd.Series, + return_quantiles: bool = False, +) -> np.ndarray | tuple[np.ndarray, list[float]]: calc = wc.Calculator("weights") num_categories = len(labels) breaks = linspace(0, 1, num_categories + 1) diff --git a/openfisca_survey_manager/policy/variables.py b/openfisca_survey_manager/policy/variables.py index 0e506c14..fc6f8fc1 100644 --- a/openfisca_survey_manager/policy/variables.py +++ b/openfisca_survey_manager/policy/variables.py @@ -1,4 +1,9 @@ +"""Policy variables helpers (quantile formulas).""" + +from __future__ import annotations + import logging +from typing import Any, Callable, Optional from numpy import arange from openfisca_core.model_api import ADD, YEAR, Variable, where @@ -8,7 +13,12 @@ log = logging.getLogger(__name__) -def create_quantile(x, nquantiles, weight_variable, entity_name): +def create_quantile( + x: str, + nquantiles: int, + weight_variable: str, + entity_name: Any, +) -> type[Variable]: class quantile(Variable): value_type = int entity = entity_name @@ -35,7 +45,12 @@ def formula(entity, period): return quantile -def quantile(q, variable, weight_variable=None, filter_variable=None): +def quantile( + q: int, + variable: str, + weight_variable: Optional[str] = None, + filter_variable: Optional[str] = None, +) -> Callable[..., Any]: """ Return quantile of a variable with weight provided by a specific wieght variable potentially filtered """ @@ -63,7 +78,12 @@ def formula(entity, period): return formula -def old_quantile(q, variable, weight_variable=None, filter_variable=None): +def old_quantile( + q: int, + variable: str, + weight_variable: Optional[str] = None, + filter_variable: Optional[str] = None, +) -> Callable[..., Any]: def formula(entity, period): value = entity(variable, period) if weight_variable is not None: diff --git a/openfisca_survey_manager/processing/__init__.py b/openfisca_survey_manager/processing/__init__.py index fa8106e9..b6e95546 100644 --- a/openfisca_survey_manager/processing/__init__.py +++ b/openfisca_survey_manager/processing/__init__.py @@ -6,7 +6,7 @@ # Lazy import to avoid circular dependency (processing -> policy -> survey_collections -> core) -def __getattr__(name: str): +def __getattr__(name: str) -> object: if name in ("Calibration", "calmar", "check_calmar"): from openfisca_survey_manager.policy.calibration import Calibration from openfisca_survey_manager.policy.calmar import calmar, check_calmar From 3a71facde2a062a2af160817cc4a8593e94c9fbe Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 18:30:49 +0100 Subject: [PATCH 10/18] =?UTF-8?q?Release=201.0:=20retrait=20des=20r=C3=A9-?= =?UTF-8?q?exports=20et=20DeprecationWarning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Suppression des modules de compatibilité (config, paths, tables, surveys, survey_collections, read_sas, read_spss, read_dbf, calibration, calmar, utils) - load_table déplacé dans core.dataset, imports migrés selon MIGRATION_IMPORTS.md - Version 1.0.0, CHANGELOG et doc (REFACTORING_PLAN, RFC-001) Made-with: Cursor --- CHANGELOG.md | 7 + docs/REFACTORING_PLAN.md | 8 +- docs/RFC-001-OPENFISCA-DATA-STACK.md | 134 ++++++++++++ docs/TICKET_OPENFISCA_DATA_STACK.md | 4 +- openfisca_survey_manager/calibration.py | 17 -- openfisca_survey_manager/calmar.py | 17 -- openfisca_survey_manager/config.py | 8 - openfisca_survey_manager/core/__init__.py | 7 +- openfisca_survey_manager/core/dataset.py | 30 +++ openfisca_survey_manager/core/survey.py | 173 +++++++-------- openfisca_survey_manager/core/table.py | 205 ++++++++---------- openfisca_survey_manager/google_colab.py | 28 --- openfisca_survey_manager/paths.py | 22 -- openfisca_survey_manager/read_dbf.py | 5 - openfisca_survey_manager/read_sas.py | 5 - openfisca_survey_manager/read_spss.py | 5 - .../scripts/build_collection.py | 9 +- .../survey_collections.py | 8 - openfisca_survey_manager/surveys.py | 8 - openfisca_survey_manager/tables.py | 8 - openfisca_survey_manager/temporary.py | 2 +- .../tests/test_quantile.py | 2 +- .../tests/test_read_sas.py | 4 +- pyproject.toml | 2 +- 24 files changed, 374 insertions(+), 344 deletions(-) create mode 100644 docs/RFC-001-OPENFISCA-DATA-STACK.md delete mode 100644 openfisca_survey_manager/calibration.py delete mode 100644 openfisca_survey_manager/calmar.py delete mode 100644 openfisca_survey_manager/config.py delete mode 100644 openfisca_survey_manager/google_colab.py delete mode 100644 openfisca_survey_manager/paths.py delete mode 100644 openfisca_survey_manager/read_dbf.py delete mode 100644 openfisca_survey_manager/read_sas.py delete mode 100644 openfisca_survey_manager/read_spss.py delete mode 100644 openfisca_survey_manager/survey_collections.py delete mode 100644 openfisca_survey_manager/surveys.py delete mode 100644 openfisca_survey_manager/tables.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 96b02f25..79014119 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +# 1.0.0 + +* **Breaking**: Version 1.0 — retrait des ré-exports et des DeprecationWarning + - **Suppression des modules de compatibilité** : `config`, `paths`, `tables`, `surveys`, `survey_collections`, `read_sas`, `read_spss`, `read_dbf`, `calibration`, `calmar`, `utils` sont supprimés. Utiliser les imports canoniques (voir `docs/MIGRATION_IMPORTS.md`). + - **`load_table`** : déplacé de `utils` vers `openfisca_survey_manager.core.dataset` (et exporté depuis `core`). + - Tous les imports internes ont été migrés vers `configuration.paths`, `configuration.models`, `core.dataset`, `core.survey`, `core.table`, `io.readers`, `processing.weights`, `common.misc`. + # 6.5.0 * Typing (no breaking API changes) diff --git a/docs/REFACTORING_PLAN.md b/docs/REFACTORING_PLAN.md index bde3cc61..bd56bf6d 100644 --- a/docs/REFACTORING_PLAN.md +++ b/docs/REFACTORING_PLAN.md @@ -86,12 +86,14 @@ Aujourd’hui ces couches sont entremêlées (ex. lecture + nettoyage dans `tabl ### 3.1 Fonctions longues (> 100 lignes) -- Découper les grosses fonctions en étapes nommées, par exemple : - - `load_survey()` → `_parse_config()`, `_load_raw_data()`, `_transform()`, `_store()`. -- Cible : lisibilité et testabilité, sans changer le comportement. +- **Entamé** : découpage en étapes nommées sans changer le comportement. + - `core.table.Table.read_source` → `_read_csv_with_inferred_encoding()`, `_apply_stata_categorical_strategy()` ; `read_source()` orchestre. + - `core.survey.Survey.get_values` → `_get_values_from_hdf5()`, `_get_values_from_parquet()` ; `get_values()` orchestre et applique l’harmonisation. +- À poursuivre : autres modules (simulations, scenarios, scripts, processing/weights/calmar, etc.). ### 3.2 Dépendances circulaires +- **Vérifié** (imports à froid) : aucune dépendance circulaire. Chaîne cohérente : `exceptions` → `configuration` → `io`/`processing` → `core.table` → `core.survey` → `core.dataset` ; `utils` → `common.misc`, `survey_collections` ; `core.table` n'importe `Survey` qu'en tardif dans `Table.__init__`. Si des cycles apparaissent : extraire la logique commune dans `common/` ou `configuration/`. - Si des modules s’importent mutuellement, extraire la logique commune dans `utils/` (ou `config/`) et faire dépendre les deux côtés de ce module commun. - Vérifier avec des imports à froid (démarrer l’app et importer les sous-modules). diff --git a/docs/RFC-001-OPENFISCA-DATA-STACK.md b/docs/RFC-001-OPENFISCA-DATA-STACK.md new file mode 100644 index 00000000..168b603f --- /dev/null +++ b/docs/RFC-001-OPENFISCA-DATA-STACK.md @@ -0,0 +1,134 @@ +# RFC-001 : OpenFisca Data Stack + +**Statut** : Draft +**Issue** : [#381](https://github.com/openfisca/openfisca-survey-manager/issues/381) +**Auteur(s)** : Équipe OpenFisca +**Date** : 2025-01 + +--- + +## Résumé + +Cette RFC formalise une **stack data OpenFisca** avec rôles et frontières explicites. Elle définit l’évolution cible de l’actuel `openfisca-survey-manager` vers un cœur data réutilisable (`openfisca-data-manager`), la place des dépôts pays (`openfisca--data`) et celle de la couche analyse (`openfisca-policy-analysis`). Elle sert de référence pour les PR de refactor et les évolutions à venir. + +--- + +## 1. Contexte et motivation + +### 1.1 Problème + +Aujourd’hui, la gestion des données d’enquête et l’analyse de politique sont fortement couplées dans `openfisca-survey-manager`. Il en résulte : + +- une frontière floue entre « accès aux microdata » et « analyse (scénarios, réformes, agrégats) » ; +- une réutilisation limitée du cœur data en dehors des cas d’usage policy ; +- une évolution difficile (backend, schémas, reproductibilité) sans impacter toute la stack. + +### 1.2 Objectif + +Définir une **OpenFisca Data Stack** claire : briques, responsabilités, APIs cibles et règles de dépendance, afin de guider le refactor et les futures évolutions. + +--- + +## 2. Objectifs et non-objectifs + +### 2.1 Objectifs + +- Séparer conceptuellement (et à terme en code) : **données** (accès, stockage, schémas) vs **analyse** (scénarios, réformes, agrégats). +- Proposer une API data minimale stable (v1.0) pour l’accès aux microdata. +- Clarifier le rôle de chaque brique (data-manager, country-data, policy-analysis) et leurs dépendances. +- Aligner les PR de refactor (survey-manager) et les décisions long terme sur cette vision. + +### 2.2 Non-objectifs + +- Cette RFC ne fixe pas de calendrier de mise en œuvre ni d’ordre précis de migration. +- Elle ne détaille pas l’implémentation technique (choix de librairies, formats internes) au-delà des principes et des APIs cibles. + +--- + +## 3. Spécification : OpenFisca Data Stack (cible) + +### 3.1 Vue d’ensemble + +``` +OpenFisca Data Stack +├── openfisca-data-manager (cœur data, pays-agnostique) +├── openfisca--data (adaptation microdata → OpenFisca par pays) +├── openfisca-policy-analysis( scénarios, réformes, agrégats, indicateurs ) +└── OpenFisca Core (moteur de calcul) +``` + +### 3.2 Brique 1 : openfisca-data-manager + +**Évolution cible de l’actuel openfisca-survey-manager (cœur data).** + +- **Rôle** : brique **universelle**, **indépendante des pays**. +- **Responsabilités** : + - abstraction backend (parquet par défaut, HDF en transition) ; + - gestion de datasets versionnés ; + - pipeline data (lecture, nettoyage, écriture) ; + - validation de schéma ; + - métadonnées reproductibles ; + - **API stable d’accès aux microdata**. +- **Ce qu’il ne doit pas faire** : + - dépendre d’un tax benefit system ; + - connaître des variables OpenFisca ; + - contenir de l’analyse policy. + +**API cible minimale stable (v1.0)** : + +```python +dataset = DataManager.load("lfs", year=2019) +df = dataset.to_pandas(columns=["income", "weight"]) +``` + +Et exposition de : + +- `dataset.metadata` +- `dataset.schema` +- `dataset.hash` + +### 3.3 Brique 2 : openfisca--data + +Exemples : `openfisca-france-data`, `openfisca-tunisia-data`. + +- **Rôle** : préparer les **microdata pour ingestion OpenFisca**. +- **Dépendances** : `openfisca-data-manager`, `openfisca-` (Core). +- **Responsabilités** : + - mapping variables enquête → variables OpenFisca ; + - création des entités et périodes ; + - validation de cohérence avec le TBS. + +**API possible** : + +```python +adapter = CountryDataAdapter(dataset) +of_input = adapter.to_openfisca_entities() +``` + +### 3.4 Brique 3 : openfisca-policy-analysis + +Contenu actuel du survey-manager à **migrer ou extraire** dans cette brique (ou un module dédié) : + +- survey scenarios (baseline vs reform) ; +- agrégations pondérées ; +- indicateurs d’inégalités ; +- diagnostics. + +Cette brique s’appuie sur les microdata (via data-manager ou country-data) et sur OpenFisca Core pour les calculs. + +--- + +## 4. Compatibilité et liaison avec le refactor + +- Les PR de **refactor openfisca-survey-manager** (réorganisation, nettoyage, processing/weights, core/io, typage, etc.) restent compatibles avec cette RFC : elles préparent la séparation des couches sans imposer de big-bang. +- Les évolutions ultérieures (découplage data-manager / policy-analysis, exposition de l’API v1.0) pourront référencer cette RFC (et l’issue #381) comme objectif de long terme. +- Aucune rupture d’API publique n’est requise à court terme ; la RFC décrit une cible et un cap. + +--- + +## 5. Références + +- [Issue #381](https://github.com/openfisca/openfisca-survey-manager/issues/381) (vision Data Stack). +- `docs/REFACTORING_PLAN.md` (réorganisation interne du survey-manager). +- `docs/MIGRATION_IMPORTS.md` (migration des imports après retrait des ré-exports). +- `docs/TICKET_OPENFISCA_DATA_STACK.md` (version ticket originale, à considérer comme remplacée par la présente RFC). diff --git a/docs/TICKET_OPENFISCA_DATA_STACK.md b/docs/TICKET_OPENFISCA_DATA_STACK.md index fac3c70b..9939c187 100644 --- a/docs/TICKET_OPENFISCA_DATA_STACK.md +++ b/docs/TICKET_OPENFISCA_DATA_STACK.md @@ -1,4 +1,6 @@ -# 🏗️ Vision : OpenFisca Data Stack officielle +# 🏗️ Vision : OpenFisca Data Stack officielle ([issue #381](https://github.com/openfisca/openfisca-survey-manager/issues/381)) + +> **Voir la version RFC** : [RFC-001 : OpenFisca Data Stack](RFC-001-OPENFISCA-DATA-STACK.md). Le présent document conserve la formulation initiale du ticket ; la RFC en fixe la forme normative. **Objectif** : Formaliser une stack data OpenFisca claire, avec rôles et frontières bien définis. Ce ticket sert de référence pour les PR de refactor (survey-manager → data-manager, découplage, etc.) et les évolutions à venir. diff --git a/openfisca_survey_manager/calibration.py b/openfisca_survey_manager/calibration.py deleted file mode 100644 index 23124894..00000000 --- a/openfisca_survey_manager/calibration.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.calibration import ... -""" - -import warnings - -from openfisca_survey_manager.policy.calibration import Calibration - -warnings.warn( - "openfisca_survey_manager.calibration is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.calibration import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["Calibration"] diff --git a/openfisca_survey_manager/calmar.py b/openfisca_survey_manager/calmar.py deleted file mode 100644 index 6734f8d7..00000000 --- a/openfisca_survey_manager/calmar.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.calmar import ... -""" - -import warnings - -from openfisca_survey_manager.policy.calmar import calmar, check_calmar - -warnings.warn( - "openfisca_survey_manager.calmar is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.calmar import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["calmar", "check_calmar"] diff --git a/openfisca_survey_manager/config.py b/openfisca_survey_manager/config.py deleted file mode 100644 index 2e09a43c..00000000 --- a/openfisca_survey_manager/config.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.configuration.models import Config. -""" - -from openfisca_survey_manager.configuration.models import Config - -__all__ = ["Config"] diff --git a/openfisca_survey_manager/core/__init__.py b/openfisca_survey_manager/core/__init__.py index af07d7cf..4ea97a4d 100644 --- a/openfisca_survey_manager/core/__init__.py +++ b/openfisca_survey_manager/core/__init__.py @@ -1,8 +1,7 @@ -# Survey, Table, SurveyCollection. Legacy modules re-export for compatibility. -# See docs/REFACTORING_PLAN.md. +# Survey, Table, SurveyCollection, load_table. -from openfisca_survey_manager.core.dataset import SurveyCollection +from openfisca_survey_manager.core.dataset import SurveyCollection, load_table from openfisca_survey_manager.core.survey import NoMoreDataError, Survey from openfisca_survey_manager.core.table import Table -__all__ = ["NoMoreDataError", "Survey", "SurveyCollection", "Table"] +__all__ = ["NoMoreDataError", "Survey", "SurveyCollection", "Table", "load_table"] diff --git a/openfisca_survey_manager/core/dataset.py b/openfisca_survey_manager/core/dataset.py index 3477a394..99519d80 100644 --- a/openfisca_survey_manager/core/dataset.py +++ b/openfisca_survey_manager/core/dataset.py @@ -10,6 +10,8 @@ from pathlib import Path from typing import List, Optional, Union +import pandas as pd + from openfisca_survey_manager.configuration.models import Config from openfisca_survey_manager.configuration.paths import default_config_files_directory from openfisca_survey_manager.core.survey import Survey @@ -172,3 +174,31 @@ def to_json(self) -> dict: for survey in self.surveys: self_json["surveys"][survey.name] = survey.to_json() return self_json + + +def load_table( + config_files_directory, + variables: Optional[list] = None, + collection: Optional[str] = None, + survey: Optional[str] = None, + input_data_survey_prefix: Optional[str] = None, + data_year=None, + table: Optional[str] = None, + batch_size=None, + batch_index=0, + filter_by=None, +) -> pd.DataFrame: + """Load table from a survey in a collection.""" + survey_collection = SurveyCollection.load(collection=collection, config_files_directory=config_files_directory) + survey_name = survey if survey is not None else f"{input_data_survey_prefix}_{data_year}" + survey_ = survey_collection.get_survey(survey_name) + log.debug("Loading table %s in survey %s from collection %s", table, survey_name, collection) + if batch_size: + return survey_.get_values( + table=table, + variables=variables, + batch_size=batch_size, + batch_index=batch_index, + filter_by=filter_by, + ) + return survey_.get_values(table=table, variables=variables, filter_by=filter_by) diff --git a/openfisca_survey_manager/core/survey.py b/openfisca_survey_manager/core/survey.py index 59b8bd4e..e4d3b723 100644 --- a/openfisca_survey_manager/core/survey.py +++ b/openfisca_survey_manager/core/survey.py @@ -188,6 +188,85 @@ def get_value( ) -> pandas.DataFrame: return self.get_values([variable], table) + def _get_values_from_hdf5(self, table: str, ignorecase: bool = False) -> tuple[pandas.DataFrame, str]: + """Read table from HDF5 store. Returns (df, resolved_table_name).""" + assert Path(self.hdf5_file_path).exists(), ( + f"{self.hdf5_file_path} is not a valid path. This could happen because " + "your data were not builded yet. Please consider using a rebuild option in your code." + ) + store = pandas.HDFStore(self.hdf5_file_path, "r") + try: + if ignorecase: + keys = store.keys() + eligible_tables = [ + match[0] for string in keys for match in [re.findall(table, string, re.IGNORECASE)] if match + ] + if len(eligible_tables) > 1: + raise SurveyManagerError( + f"{table} is ambiguous since the following tables are available: {eligible_tables}" + ) + if len(eligible_tables) == 0: + raise SurveyIOError(f"No eligible available table in {keys}") + table = eligible_tables[0] + df = store.select(table) + return df, table + except KeyError: + log.error("No table %s in the file %s", table, self.hdf5_file_path) + log.error( + "This could happen because your data were not builded yet. Available tables are: %s", + store.keys(), + ) + raise + finally: + store.close() + + def _get_values_from_parquet( + self, + table: str, + variables: Optional[List[str]], + filter_by: Optional[List[tuple]], + batch_size: Optional[int], + batch_index: int, + ) -> pandas.DataFrame: + """Read table from parquet. Resolves variables from table content if None.""" + if table is None: + raise SurveyIOError("A table name is needed to retrieve data from a parquet file") + for table_name, table_content in self.tables.items(): + if table != table_name: + continue + parquet_file = table_content.get("parquet_file") + if Path(parquet_file).is_dir(): + for file in Path(parquet_file).iterdir(): + if file.suffix == ".parquet": + one_parquet_file = str(Path(parquet_file) / file) + break + else: + raise SurveyIOError(f"No parquet file found in {parquet_file}") + else: + one_parquet_file = parquet_file + parquet_schema = pq.read_schema(one_parquet_file) + assert len(parquet_schema.names) >= 1, f"The parquet file {table_content.get('parquet_file')} is empty" + if variables is None: + variables = table_content.get("variables") + if filter_by: + return pq.ParquetDataset(parquet_file, filters=filter_by).read(columns=variables).to_pandas() + if batch_size is not None: + paths = ( + [str(p) for p in Path(parquet_file).glob("*.parquet")] + if Path(parquet_file).is_dir() + else [parquet_file] + ) + tables_list = [pq.read_table(fp, columns=variables) for fp in paths] + final_table = pa.concat_tables(tables_list) if len(tables_list) > 1 else tables_list[0] + record_batches = final_table.to_batches(max_chunksize=batch_size) + if len(record_batches) <= batch_index: + raise NoMoreDataError( + f"Batch {batch_index} not found in {table_name}. Max index is {len(record_batches)}" + ) + return record_batches[batch_index].to_pandas() + return pq.ParquetDataset(parquet_file).read(columns=variables).to_pandas() + raise SurveyIOError(f"No table {table} found in {self.parquet_file_path}") + def get_values( self, variables: Optional[List[str]] = None, @@ -202,95 +281,17 @@ def get_values( if self.parquet_file_path is None and self.hdf5_file_path is None: raise SurveyIOError(f"No data file found for survey {self.name}") if self.hdf5_file_path is not None: - assert Path(self.hdf5_file_path).exists(), ( - f"{self.hdf5_file_path} is not a valid path. This could happen because " - "your data were not builded yet. Please consider using a rebuild option in your code." - ) - store = pandas.HDFStore(self.hdf5_file_path, "r") - if ignorecase: - keys = store.keys() - eligible_tables = [] - for string in keys: - match = re.findall(table, string, re.IGNORECASE) - if match: - eligible_tables.append(match[0]) - if len(eligible_tables) > 1: - raise SurveyManagerError( - f"{table} is ambiguous since the following tables are available: {eligible_tables}" - ) - elif len(eligible_tables) == 0: - raise SurveyIOError(f"No eligible available table in {keys}") - else: - table = eligible_tables[0] - try: - df = store.select(table) - except KeyError: - log.error(f"No table {table} in the file {self.hdf5_file_path}") - log.error( - f"This could happen because your data were not builded yet. Available tables are: {store.keys()}" - ) - store.close() - raise - - store.close() - - elif self.parquet_file_path is not None: - if table is None: - raise SurveyIOError("A table name is needed to retrieve data from a parquet file") - for table_name, table_content in self.tables.items(): - if table == table_name: - parquet_file = table_content.get("parquet_file") - if Path(parquet_file).is_dir(): - for file in Path(parquet_file).iterdir(): - if file.suffix == ".parquet": - one_parquet_file = str(Path(parquet_file) / file) - break - else: - raise SurveyIOError(f"No parquet file found in {parquet_file}") - else: - one_parquet_file = parquet_file - parquet_schema = pq.read_schema(one_parquet_file) - assert len(parquet_schema.names) >= 1, ( - f"The parquet file {table_content.get('parquet_file')} is empty" - ) - if variables is None: - variables = table_content.get("variables") - if filter_by: - df = pq.ParquetDataset(parquet_file, filters=filter_by).read(columns=variables).to_pandas() - elif batch_size: - if Path(parquet_file).is_dir(): - parquet_file = [str(p) for p in Path(parquet_file).glob("*.parquet")] - else: - parquet_file = [parquet_file] - tables = [] - for file_path in parquet_file: - table = pq.read_table(file_path, columns=variables) - tables.append(table) - - final_table = pa.concat_tables(tables) if len(tables) > 1 else tables[0] - record_batches = final_table.to_batches(max_chunksize=batch_size) - if len(record_batches) <= batch_index: - raise NoMoreDataError( - f"Batch {batch_index} not found in {table_name}. Max index is {len(record_batches)}" - ) - df = record_batches[batch_index].to_pandas() - else: - df = pq.ParquetDataset(parquet_file).read(columns=variables).to_pandas() - break - else: - raise SurveyIOError(f"No table {table} found in {self.parquet_file_path}") - + df, _ = self._get_values_from_hdf5(table or "", ignorecase=ignorecase) + else: + df = self._get_values_from_parquet(table, variables, filter_by, batch_size, batch_index) harmonize_data_frame_columns(df, lowercase=lowercase, rename_ident=rename_ident) - if variables is None: return df - else: - diff = set(variables) - set(df.columns) - if diff: - raise SurveyIOError(f"The following variable(s) {diff} are missing") - variables = list(set(variables).intersection(df.columns)) - df = df[variables] - return df + diff = set(variables) - set(df.columns) + if diff: + raise SurveyIOError(f"The following variable(s) {diff} are missing") + variables = list(set(variables).intersection(df.columns)) + return df[variables] def insert_table( self, diff --git a/openfisca_survey_manager/core/table.py b/openfisca_survey_manager/core/table.py index 71104335..683b14ec 100644 --- a/openfisca_survey_manager/core/table.py +++ b/openfisca_survey_manager/core/table.py @@ -16,15 +16,15 @@ from chardet.universaldetector import UniversalDetector from pyarrow import parquet as pq -from openfisca_survey_manager import read_sas from openfisca_survey_manager.exceptions import SurveyIOError +from openfisca_survey_manager.io.readers import read_sas from openfisca_survey_manager.io.writers import write_table_to_hdf5, write_table_to_parquet from openfisca_survey_manager.processing.cleaning import clean_data_frame try: - from openfisca_survey_manager.read_spss import read_spss + from openfisca_survey_manager.io.readers import read_spss except ImportError: - read_spss = None + read_spss = None # optional dependency (savReaderWriter) if TYPE_CHECKING: from openfisca_survey_manager.core.survey import Survey @@ -33,7 +33,7 @@ reader_by_source_format = { "csv": pandas.read_csv, - "sas": read_sas.read_sas, + "sas": read_sas, "spss": read_spss, "stata": pandas.read_stata, "parquet": pandas.read_parquet, @@ -175,12 +175,83 @@ def read_parquet_columns(self, parquet_file: Optional[str] = None) -> list[str]: self.survey.tables[self.name]["variables"] = self.variables return self.variables + def _read_csv_with_inferred_encoding( + self, data_file: str, reader: Any, kwargs: dict[str, Any] + ) -> pandas.DataFrame: + """Read CSV, inferring encoding and dialect if default read fails.""" + log.debug("Failing to read %s, trying to infer encoding and dialect/separator", data_file) + detector = UniversalDetector() + with Path(data_file).open("rb") as csvfile: + for line in csvfile: + detector.feed(line) + if detector.done: + break + detector.close() + encoding = detector.result["encoding"] + confidence = detector.result["confidence"] + try: + with Path(data_file).open("r", newline="", encoding=encoding) as csvfile: + dialect = csv.Sniffer().sniff(csvfile.read(1024), delimiters=";,") + except Exception: + dialect = None + delimiter = ";" + log.debug( + "dialect.delimiter = %s, encoding = %s, confidence = %s", + dialect.delimiter if dialect is not None else delimiter, + encoding, + confidence, + ) + kwargs = {**kwargs, "engine": "python", "encoding": encoding} + if dialect: + kwargs["dialect"] = dialect + else: + kwargs["delimiter"] = delimiter + return reader(data_file, **kwargs) + + def _apply_stata_categorical_strategy( + self, + data_frame: pandas.DataFrame, + data_file: str, + categorical_strategy: str, + ) -> None: + """Apply categorical_strategy (unique_labels, codes, skip) to Stata value labels in place.""" + from pandas.io.stata import StataReader + + stata_reader = StataReader(data_file) + value_labels = stata_reader.value_labels() + for col_name, labels in value_labels.items(): + if col_name not in data_frame.columns: + continue + if categorical_strategy == "unique_labels": + unique_labels = {} + seen_labels = {} + for code, label in labels.items(): + if pandas.isna(code): + unique_labels[code] = label + elif label in seen_labels: + unique_labels[code] = f"{label} ({code})" + else: + unique_labels[code] = label + seen_labels[label] = code + code_to_label = {code: unique_labels[code] for code in sorted(labels.keys())} + data_frame[col_name] = data_frame[col_name].map(code_to_label) + data_frame[col_name] = pandas.Categorical( + data_frame[col_name], + categories=list(code_to_label.values()), + ordered=False, + ) + elif categorical_strategy == "codes": + codes = sorted([c for c in labels if pandas.notna(c)]) + if codes: + data_frame[col_name] = pandas.Categorical(data_frame[col_name], categories=codes, ordered=False) + elif categorical_strategy != "skip": + log.warning("Unknown categorical_strategy %r, using 'skip'", categorical_strategy) + def read_source(self, data_file: str, **kwargs: Any) -> pandas.DataFrame: source_format = self.source_format store_file_path = ( self.survey.hdf5_file_path if self.survey.store_format == "hdf5" else self.survey.parquet_file_path ) - self._check_and_log(data_file, store_file_path=store_file_path) reader = reader_by_source_format[source_format] categorical_strategy = ( @@ -192,121 +263,33 @@ def read_source(self, data_file: str, **kwargs: Any) -> pandas.DataFrame: if source_format == "csv": try: data_frame = reader(data_file, **kwargs) - if len(data_frame.columns) == 1 and ";" in data_frame.columns[0]: raise SurveyIOError( "A ';' is present in the unique column name. Looks like we got the wrong separator." ) - except Exception: - log.debug(f"Failing to read {data_file}, Trying to infer encoding and dialect/separator") - - detector = UniversalDetector() - with Path(data_file).open("rb") as csvfile: - for line in csvfile: - detector.feed(line) - if detector.done: - break - detector.close() - - encoding = detector.result["encoding"] - confidence = detector.result["confidence"] - - try: - with Path(data_file).open("r", newline="", encoding=encoding) as csvfile: - dialect = csv.Sniffer().sniff(csvfile.read(1024), delimiters=";,") - except Exception: - dialect = None - delimiter = ";" - - log.debug( - f"dialect.delimiter = {dialect.delimiter if dialect is not None else delimiter}, " - f"encoding = {encoding}, confidence = {confidence}" - ) - kwargs["engine"] = "python" - if dialect: - kwargs["dialect"] = dialect - else: - kwargs["delimiter"] = delimiter - kwargs["encoding"] = encoding - data_frame = reader(data_file, **kwargs) - - else: - if "encoding" in kwargs and source_format == "stata": + data_frame = self._read_csv_with_inferred_encoding(data_file, reader, kwargs) + elif source_format == "stata": + if "encoding" in kwargs: kwargs.pop("encoding") - if source_format == "stata": - try: - if "convert_categoricals" not in kwargs: - data_frame = reader(data_file, **kwargs) - else: - data_frame = reader(data_file, **kwargs) - except ValueError as e: - if "not unique" in str(e) or "Categorical categories must be unique" in str(e): - log.info( - f"Non-unique value labels detected in {data_file}, " - f"using strategy '{categorical_strategy}'" - ) - - kwargs_no_cat = kwargs.copy() - kwargs_no_cat["convert_categoricals"] = False - data_frame = reader(data_file, **kwargs_no_cat) - - if categorical_strategy == "unique_labels": - from pandas.io.stata import StataReader - - stata_reader = StataReader(data_file) - value_labels = stata_reader.value_labels() - - for col_name, labels in value_labels.items(): - if col_name in data_frame.columns: - unique_labels = {} - seen_labels = {} - - for code, label in labels.items(): - if pandas.isna(code): - unique_labels[code] = label - elif label in seen_labels: - unique_labels[code] = f"{label} ({code})" - else: - unique_labels[code] = label - seen_labels[label] = code - - code_to_label = {code: unique_labels[code] for code in sorted(labels.keys())} - - data_frame[col_name] = data_frame[col_name].map(code_to_label) - data_frame[col_name] = pandas.Categorical( - data_frame[col_name], - categories=list(code_to_label.values()), - ordered=False, - ) - - elif categorical_strategy == "codes": - from pandas.io.stata import StataReader - - stata_reader = StataReader(data_file) - value_labels = stata_reader.value_labels() - - for col_name, labels in value_labels.items(): - if col_name in data_frame.columns: - codes = sorted([c for c in labels if pandas.notna(c)]) - if codes: - data_frame[col_name] = pandas.Categorical( - data_frame[col_name], categories=codes, ordered=False - ) - - elif categorical_strategy == "skip": - pass - else: - log.warning(f"Unknown categorical_strategy '{categorical_strategy}', using 'skip'") - else: - raise - else: + try: data_frame = reader(data_file, **kwargs) - + except ValueError as e: + if "not unique" not in str(e) and "Categorical categories must be unique" not in str(e): + raise + log.info( + "Non-unique value labels detected in %s, using strategy %r", + data_file, + categorical_strategy, + ) + kwargs_no_cat = {**kwargs, "convert_categoricals": False} + data_frame = reader(data_file, **kwargs_no_cat) + self._apply_stata_categorical_strategy(data_frame, data_file, categorical_strategy) + else: + data_frame = reader(data_file, **kwargs) except Exception as e: - log.info(f"Error while reading {data_file}") + log.info("Error while reading %s", data_file) raise e - gc.collect() return data_frame diff --git a/openfisca_survey_manager/google_colab.py b/openfisca_survey_manager/google_colab.py deleted file mode 100644 index 817f4a72..00000000 --- a/openfisca_survey_manager/google_colab.py +++ /dev/null @@ -1,28 +0,0 @@ -import configparser -import logging -from pathlib import Path - -from openfisca_survey_manager.paths import default_config_files_directory as config_files_directory - -log = logging.getLogger(__name__) - - -def create_raw_data_ini(value_by_option_by_section=None): - """Creates raw_data.ini configureation file - - Args: - value_by_option_by_section(dict(dict)): Options value by section (Default value = None) - - """ - config_parser = configparser.ConfigParser() - - if value_by_option_by_section is not None: - for section, value_by_option in value_by_option_by_section.items(): - config_parser.add_section(section) - for option, value in value_by_option.items(): - config_parser.set(section, option, value) - - path = Path(config_files_directory) / "raw_data.ini" - with path.open("w") as raw_data_config_file: - config_parser.write(raw_data_config_file) - log.info("Wrote raw_data.ini to %s", path) diff --git a/openfisca_survey_manager/paths.py b/openfisca_survey_manager/paths.py deleted file mode 100644 index 288c0a4a..00000000 --- a/openfisca_survey_manager/paths.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.configuration.paths import ... -""" - -from openfisca_survey_manager.configuration.paths import ( - config_ini, - default_config_files_directory, - is_in_ci, - openfisca_survey_manager_location, - private_run_with_data, - test_config_files_directory, -) - -__all__ = [ - "config_ini", - "default_config_files_directory", - "is_in_ci", - "openfisca_survey_manager_location", - "private_run_with_data", - "test_config_files_directory", -] diff --git a/openfisca_survey_manager/read_dbf.py b/openfisca_survey_manager/read_dbf.py deleted file mode 100644 index 3ff17ae2..00000000 --- a/openfisca_survey_manager/read_dbf.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Re-export for backward compatibility. Prefer: from openfisca_survey_manager.io.readers import read_dbf.""" - -from openfisca_survey_manager.io.readers import read_dbf - -__all__ = ["read_dbf"] diff --git a/openfisca_survey_manager/read_sas.py b/openfisca_survey_manager/read_sas.py deleted file mode 100644 index 58168894..00000000 --- a/openfisca_survey_manager/read_sas.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Re-export for backward compatibility. Prefer: from openfisca_survey_manager.io.readers import read_sas.""" - -from openfisca_survey_manager.io.readers import read_sas - -__all__ = ["read_sas"] diff --git a/openfisca_survey_manager/read_spss.py b/openfisca_survey_manager/read_spss.py deleted file mode 100644 index ee4da5cb..00000000 --- a/openfisca_survey_manager/read_spss.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Re-export for backward compatibility. Prefer: from openfisca_survey_manager.io.readers import read_spss.""" - -from openfisca_survey_manager.io.readers import read_spss - -__all__ = ["read_spss"] diff --git a/openfisca_survey_manager/scripts/build_collection.py b/openfisca_survey_manager/scripts/build_collection.py index 4aa6472b..3e94da4c 100755 --- a/openfisca_survey_manager/scripts/build_collection.py +++ b/openfisca_survey_manager/scripts/build_collection.py @@ -13,9 +13,12 @@ import sys from pathlib import Path -from openfisca_survey_manager.paths import default_config_files_directory, openfisca_survey_manager_location -from openfisca_survey_manager.survey_collections import SurveyCollection -from openfisca_survey_manager.surveys import Survey +from openfisca_survey_manager.configuration.paths import ( + default_config_files_directory, + openfisca_survey_manager_location, +) +from openfisca_survey_manager.core.dataset import SurveyCollection +from openfisca_survey_manager.core.survey import Survey app_name = Path(__file__).stem log = logging.getLogger(app_name) diff --git a/openfisca_survey_manager/survey_collections.py b/openfisca_survey_manager/survey_collections.py deleted file mode 100644 index 1a676ee2..00000000 --- a/openfisca_survey_manager/survey_collections.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.core.dataset import SurveyCollection. -""" - -from openfisca_survey_manager.core.dataset import SurveyCollection - -__all__ = ["SurveyCollection"] diff --git a/openfisca_survey_manager/surveys.py b/openfisca_survey_manager/surveys.py deleted file mode 100644 index 4283544f..00000000 --- a/openfisca_survey_manager/surveys.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.core.survey import Survey, NoMoreDataError. -""" - -from openfisca_survey_manager.core.survey import NoMoreDataError, Survey - -__all__ = ["NoMoreDataError", "Survey"] diff --git a/openfisca_survey_manager/tables.py b/openfisca_survey_manager/tables.py deleted file mode 100644 index 57d6ef66..00000000 --- a/openfisca_survey_manager/tables.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.core.table import Table. -""" - -from openfisca_survey_manager.core.table import Table - -__all__ = ["Table"] diff --git a/openfisca_survey_manager/temporary.py b/openfisca_survey_manager/temporary.py index 4b185fb5..59d84fb9 100644 --- a/openfisca_survey_manager/temporary.py +++ b/openfisca_survey_manager/temporary.py @@ -5,7 +5,7 @@ from pandas import HDFStore -from openfisca_survey_manager.paths import default_config_files_directory +from openfisca_survey_manager.configuration.paths import default_config_files_directory log = logging.getLogger(__name__) diff --git a/openfisca_survey_manager/tests/test_quantile.py b/openfisca_survey_manager/tests/test_quantile.py index a2ef4174..87b40a1c 100644 --- a/openfisca_survey_manager/tests/test_quantile.py +++ b/openfisca_survey_manager/tests/test_quantile.py @@ -4,7 +4,7 @@ from openfisca_core.model_api import YEAR, Variable from openfisca_core.taxbenefitsystems import TaxBenefitSystem -from openfisca_survey_manager.paths import default_config_files_directory +from openfisca_survey_manager.configuration.paths import default_config_files_directory from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario from openfisca_survey_manager.policy.statshelpers import mark_weighted_percentiles from openfisca_survey_manager.policy.variables import quantile diff --git a/openfisca_survey_manager/tests/test_read_sas.py b/openfisca_survey_manager/tests/test_read_sas.py index 455c87b6..27a77ef8 100644 --- a/openfisca_survey_manager/tests/test_read_sas.py +++ b/openfisca_survey_manager/tests/test_read_sas.py @@ -6,8 +6,8 @@ from pandas.testing import assert_frame_equal -from openfisca_survey_manager.paths import openfisca_survey_manager_location -from openfisca_survey_manager.read_sas import read_sas +from openfisca_survey_manager.configuration.paths import openfisca_survey_manager_location +from openfisca_survey_manager.io.readers import read_sas def test(): diff --git a/pyproject.toml b/pyproject.toml index 133e5de4..f9888837 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "OpenFisca-Survey-Manager" -version = "6.5.0" +version = "1.0.0" description = "A tool for managing survey/administrative data and import them in OpenFisca" readme = "README.md" keywords = ["microsimulation", "tax", "benefit", "rac", "rules-as-code", "survey", "data"] From b79bc4a715e2c1f0015f6896722c15c2f2f5573c Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 22:14:17 +0100 Subject: [PATCH 11/18] breaking: remove survey-manager placeholders and deprecation shims - Delete root simulations/simulation_builder/aggregates/utils modules in favor of policy.* - Keep common.* as thin aliases to policy.legislation_asof (no DeprecationWarning) - Update scenarios and tests to import from policy and core.dataset/core.survey - Wire policy.simulations to SurveyCollection/load_table from core.dataset; tests still all pass Made-with: Cursor --- openfisca_survey_manager/aggregates.py | 21 ----- openfisca_survey_manager/common/__init__.py | 17 +++++ openfisca_survey_manager/common/misc.py | 24 ++++++ .../policy/simulations.py | 4 +- .../scenarios/__init__.py | 17 ----- .../simulation_builder.py | 25 ------ openfisca_survey_manager/simulations.py | 57 -------------- openfisca_survey_manager/utils.py | 76 ------------------- uv.lock | 2 +- 9 files changed, 44 insertions(+), 199 deletions(-) delete mode 100644 openfisca_survey_manager/aggregates.py create mode 100644 openfisca_survey_manager/common/__init__.py create mode 100644 openfisca_survey_manager/common/misc.py delete mode 100644 openfisca_survey_manager/scenarios/__init__.py delete mode 100644 openfisca_survey_manager/simulation_builder.py delete mode 100644 openfisca_survey_manager/simulations.py delete mode 100644 openfisca_survey_manager/utils.py diff --git a/openfisca_survey_manager/aggregates.py b/openfisca_survey_manager/aggregates.py deleted file mode 100644 index 07629736..00000000 --- a/openfisca_survey_manager/aggregates.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy import ... -or from openfisca_survey_manager.policy.aggregates import ... -""" - -import warnings - -from openfisca_survey_manager.policy.aggregates import AbstractAggregates - -warnings.warn( - "openfisca_survey_manager.aggregates is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy import ... " - "or from openfisca_survey_manager.policy.aggregates import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "AbstractAggregates", -] diff --git a/openfisca_survey_manager/common/__init__.py b/openfisca_survey_manager/common/__init__.py new file mode 100644 index 00000000..cc8c0061 --- /dev/null +++ b/openfisca_survey_manager/common/__init__.py @@ -0,0 +1,17 @@ +from openfisca_survey_manager.policy.legislation_asof import ( # Backward-compat imports + asof, + do_nothing, + inflate_parameter_leaf, + inflate_parameters, + parameters_asof, + variables_asof, +) + +__all__ = [ + "asof", + "do_nothing", + "inflate_parameter_leaf", + "inflate_parameters", + "parameters_asof", + "variables_asof", +] diff --git a/openfisca_survey_manager/common/misc.py b/openfisca_survey_manager/common/misc.py new file mode 100644 index 00000000..fe95bf70 --- /dev/null +++ b/openfisca_survey_manager/common/misc.py @@ -0,0 +1,24 @@ +"""Backward-compatibility wrapper for legislation helpers. + +Use ``openfisca_survey_manager.policy.legislation_asof`` as canonical import path. +""" + +from openfisca_survey_manager.policy.legislation_asof import ( + asof, + do_nothing, + inflate_parameter_leaf, + inflate_parameters, + leaf_asof, + parameters_asof, + variables_asof, +) + +__all__ = [ + "asof", + "do_nothing", + "inflate_parameter_leaf", + "inflate_parameters", + "leaf_asof", + "parameters_asof", + "variables_asof", +] diff --git a/openfisca_survey_manager/policy/simulations.py b/openfisca_survey_manager/policy/simulations.py index f699e562..39c30cdb 100644 --- a/openfisca_survey_manager/policy/simulations.py +++ b/openfisca_survey_manager/policy/simulations.py @@ -17,14 +17,14 @@ from openfisca_core.types import Array, Period, TaxBenefitSystem from openfisca_core.types import CoreEntity as Entity +from openfisca_survey_manager.core.dataset import SurveyCollection, load_table from openfisca_survey_manager.exceptions import SurveyManagerError +from openfisca_survey_manager.policy.legislation_asof import do_nothing from openfisca_survey_manager.policy.simulation_builder import ( SimulationBuilder, diagnose_variable_mismatch, ) from openfisca_survey_manager.policy.statshelpers import mark_weighted_percentiles -from openfisca_survey_manager.survey_collections import SurveyCollection -from openfisca_survey_manager.utils import do_nothing, load_table log = logging.getLogger(__name__) diff --git a/openfisca_survey_manager/scenarios/__init__.py b/openfisca_survey_manager/scenarios/__init__.py deleted file mode 100644 index 2c41e560..00000000 --- a/openfisca_survey_manager/scenarios/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.scenarios import ... -""" - -import warnings - -from openfisca_survey_manager.policy.scenarios import AbstractSurveyScenario, ReformScenario - -warnings.warn( - "openfisca_survey_manager.scenarios is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.scenarios import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["AbstractSurveyScenario", "ReformScenario"] diff --git a/openfisca_survey_manager/simulation_builder.py b/openfisca_survey_manager/simulation_builder.py deleted file mode 100644 index baf45057..00000000 --- a/openfisca_survey_manager/simulation_builder.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy import ... -or from openfisca_survey_manager.policy.simulation_builder import ... -""" - -import warnings - -from openfisca_survey_manager.policy.simulation_builder import ( - SimulationBuilder, - diagnose_variable_mismatch, -) - -warnings.warn( - "openfisca_survey_manager.simulation_builder is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy import ... " - "or from openfisca_survey_manager.policy.simulation_builder import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "SimulationBuilder", - "diagnose_variable_mismatch", -] diff --git a/openfisca_survey_manager/simulations.py b/openfisca_survey_manager/simulations.py deleted file mode 100644 index 5d6479f1..00000000 --- a/openfisca_survey_manager/simulations.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy import ... -or from openfisca_survey_manager.policy.simulations import ... -""" - -import warnings - -from openfisca_survey_manager.policy.simulations import ( - SecretViolationError, - Simulation, - adaptative_calculate_variable, - assert_variables_in_same_entity, - compute_aggregate, - compute_pivot_table, - compute_quantiles, - compute_winners_losers, - create_data_frame_by_entity, - get_words, - inflate, - init_entity_data, - init_simulation, - init_variable_in_entity, - new_from_tax_benefit_system, - print_memory_usage, - set_weight_variable_by_entity, - summarize_variable, -) - -warnings.warn( - "openfisca_survey_manager.simulations is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy import ... " - "or from openfisca_survey_manager.policy.simulations import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "SecretViolationError", - "Simulation", - "adaptative_calculate_variable", - "assert_variables_in_same_entity", - "compute_aggregate", - "compute_pivot_table", - "compute_quantiles", - "compute_winners_losers", - "create_data_frame_by_entity", - "get_words", - "inflate", - "init_entity_data", - "init_simulation", - "init_variable_in_entity", - "new_from_tax_benefit_system", - "print_memory_usage", - "set_weight_variable_by_entity", - "summarize_variable", -] diff --git a/openfisca_survey_manager/utils.py b/openfisca_survey_manager/utils.py deleted file mode 100644 index 3b3d6fa4..00000000 --- a/openfisca_survey_manager/utils.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Utilities: re-exports from policy.legislation_asof + load_table (survey-dependent).""" - -import logging -import warnings -from typing import Optional - -import pandas as pd - -from openfisca_survey_manager.policy.legislation_asof import ( - asof, - do_nothing, - inflate_parameter_leaf, - inflate_parameters, - parameters_asof, - variables_asof, -) -from openfisca_survey_manager.survey_collections import SurveyCollection - -log = logging.getLogger(__name__) - -warnings.warn( - "openfisca_survey_manager.utils is deprecated for legislation helpers and will be removed in a future " - "version. Prefer: from openfisca_survey_manager.policy.legislation_asof import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "asof", - "do_nothing", - "inflate_parameter_leaf", - "inflate_parameters", - "load_table", - "parameters_asof", - "variables_asof", -] - - -def load_table( - config_files_directory, - variables: Optional[list] = None, - collection: Optional[str] = None, - survey: Optional[str] = None, - input_data_survey_prefix: Optional[str] = None, - data_year=None, - table: Optional[str] = None, - batch_size=None, - batch_index=0, - filter_by=None, -) -> pd.DataFrame: - """ - Load values from table from a survey in a collection. - - Args: - config_files_directory : _description_. - variables (List, optional): List of the variables to retrieve in the table. - Defaults to None to get all the variables. - collection (str, optional): Collection. Defaults to None. - survey (str, optional): Survey. Defaults to None. - input_data_survey_prefix (str, optional): Prefix of the survey to be combined with data year. Defaults to None. - data_year (_type_, optional): Year of the survey data. Defaults to None. - table (str, optional): Table. Defaults to None. - - Returns: - pandas.DataFrame: A table with the retrieved variables - """ - survey_collection = SurveyCollection.load(collection=collection, config_files_directory=config_files_directory) - survey = survey if survey is not None else f"{input_data_survey_prefix}_{data_year}" - survey_ = survey_collection.get_survey(survey) - log.debug(f"Loading table {table} in survey {survey} from collection {collection}") - if batch_size: - return survey_.get_values( - table=table, variables=variables, batch_size=batch_size, batch_index=batch_index, filter_by=filter_by - ) - else: - return survey_.get_values(table=table, variables=variables, filter_by=filter_by) diff --git a/uv.lock b/uv.lock index ae2d22e0..5b3118ad 100644 --- a/uv.lock +++ b/uv.lock @@ -1272,7 +1272,7 @@ wheels = [ [[package]] name = "openfisca-survey-manager" -version = "6.5.0" +version = "1.0.0" source = { editable = "." } dependencies = [ { name = "chardet" }, From ef83dff583eec0d7c7ad540667f9bfb79d606626 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 22:30:33 +0100 Subject: [PATCH 12/18] chore: remove scenarios placeholders and DeprecationWarnings - Delete scenarios/abstract_scenario.py, scenarios/reform_scenario.py - Remove DeprecationWarning in core/table.py (HDF5), keep log.warning - build_collection: keep log.warning only (no DeprecationWarning) Made-with: Cursor --- openfisca_survey_manager/core/table.py | 7 ------- .../scenarios/abstract_scenario.py | 17 ----------------- .../scenarios/reform_scenario.py | 17 ----------------- .../scripts/build_collection.py | 9 --------- 4 files changed, 50 deletions(-) delete mode 100644 openfisca_survey_manager/scenarios/abstract_scenario.py delete mode 100644 openfisca_survey_manager/scenarios/reform_scenario.py diff --git a/openfisca_survey_manager/core/table.py b/openfisca_survey_manager/core/table.py index 683b14ec..256b3625 100644 --- a/openfisca_survey_manager/core/table.py +++ b/openfisca_survey_manager/core/table.py @@ -124,13 +124,6 @@ def _save( assert store_format in ["hdf5", "parquet"], f"invalid store_format: {store_format}" if store_format == "hdf5": - import warnings - - warnings.warn( - "HDF5 will no longer be the default format in a future version. Please use parquet format instead.", - DeprecationWarning, - stacklevel=3, - ) log.warning( "HDF5 will no longer be the default format in a future version. Please use parquet format instead." ) diff --git a/openfisca_survey_manager/scenarios/abstract_scenario.py b/openfisca_survey_manager/scenarios/abstract_scenario.py deleted file mode 100644 index 1e70c7dd..00000000 --- a/openfisca_survey_manager/scenarios/abstract_scenario.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.scenarios.abstract_scenario import ... -""" - -import warnings - -from openfisca_survey_manager.policy.scenarios.abstract_scenario import AbstractSurveyScenario - -warnings.warn( - "openfisca_survey_manager.scenarios.abstract_scenario is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.scenarios.abstract_scenario import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["AbstractSurveyScenario"] diff --git a/openfisca_survey_manager/scenarios/reform_scenario.py b/openfisca_survey_manager/scenarios/reform_scenario.py deleted file mode 100644 index 029abb84..00000000 --- a/openfisca_survey_manager/scenarios/reform_scenario.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Re-export for backward compatibility. - -Prefer: from openfisca_survey_manager.policy.scenarios.reform_scenario import ... -""" - -import warnings - -from openfisca_survey_manager.policy.scenarios.reform_scenario import ReformScenario - -warnings.warn( - "openfisca_survey_manager.scenarios.reform_scenario is deprecated and will be removed in a future version. " - "Prefer: from openfisca_survey_manager.policy.scenarios.reform_scenario import ...", - DeprecationWarning, - stacklevel=2, -) - -__all__ = ["ReformScenario"] diff --git a/openfisca_survey_manager/scripts/build_collection.py b/openfisca_survey_manager/scripts/build_collection.py index 3e94da4c..6a2b4ca9 100755 --- a/openfisca_survey_manager/scripts/build_collection.py +++ b/openfisca_survey_manager/scripts/build_collection.py @@ -286,16 +286,7 @@ def main(): # Determine store format based on argument store_format = "parquet" if args.parquet else "hdf5" - # Deprecation warning for HDF5 format if not args.parquet: - import warnings - - warnings.warn( - "HDF5 will no longer be the default format in a future version. " - "Please use --parquet option to save data in parquet format.", - DeprecationWarning, - stacklevel=2, - ) log.warning( "HDF5 will no longer be the default format in a future version. " "Please use --parquet option to save data in parquet format." From 2d94af5c287d5d3498359b9bb686647672e4c3f5 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Mon, 2 Mar 2026 23:00:32 +0100 Subject: [PATCH 13/18] fix: reduce pytest warnings (logging, HDF5, tests) - policy/simulations: log.warn -> log.warning (3), groupby(..., observed=False) - io/hdf: hdf5_safe_key() for PyTables NaturalNameWarning; to_hdf key= keyword - core/survey: use hdf5_safe_key when reading HDF5, backward compat fallback - tests: fix PytestReturnNotNoneWarning (assert instead of return) Made-with: Cursor --- openfisca_survey_manager/core/survey.py | 15 ++++++++++----- openfisca_survey_manager/io/hdf.py | 15 +++++++++++++-- openfisca_survey_manager/policy/simulations.py | 8 ++++---- .../policy/tests/test_compute_pivot_table.py | 3 ++- .../tests/test_calibration.py | 2 +- 5 files changed, 30 insertions(+), 13 deletions(-) diff --git a/openfisca_survey_manager/core/survey.py b/openfisca_survey_manager/core/survey.py index e4d3b723..186ae9ff 100644 --- a/openfisca_survey_manager/core/survey.py +++ b/openfisca_survey_manager/core/survey.py @@ -15,6 +15,7 @@ from openfisca_survey_manager.core.table import Table from openfisca_survey_manager.exceptions import SurveyIOError, SurveyManagerError +from openfisca_survey_manager.io.hdf import hdf5_safe_key from openfisca_survey_manager.processing.harmonization import harmonize_data_frame_columns if TYPE_CHECKING: @@ -196,19 +197,23 @@ def _get_values_from_hdf5(self, table: str, ignorecase: bool = False) -> tuple[p ) store = pandas.HDFStore(self.hdf5_file_path, "r") try: + # Use same key normalization as at write time (PyTables NaturalNameWarning) + hdf5_key = hdf5_safe_key(table) if ignorecase: keys = store.keys() - eligible_tables = [ - match[0] for string in keys for match in [re.findall(table, string, re.IGNORECASE)] if match - ] + eligible_tables = [k for k in keys if hdf5_safe_key(k.lstrip("/")).lower() == hdf5_key.lower()] if len(eligible_tables) > 1: raise SurveyManagerError( f"{table} is ambiguous since the following tables are available: {eligible_tables}" ) if len(eligible_tables) == 0: raise SurveyIOError(f"No eligible available table in {keys}") - table = eligible_tables[0] - df = store.select(table) + hdf5_key = eligible_tables[0].lstrip("/") + try: + df = store.select(hdf5_key) + except KeyError: + # Backward compat: try raw table name (old files may have keys with hyphens) + df = store.select(table) return df, table except KeyError: log.error("No table %s in the file %s", table, self.hdf5_file_path) diff --git a/openfisca_survey_manager/io/hdf.py b/openfisca_survey_manager/io/hdf.py index fcf0eba3..33d74ee7 100644 --- a/openfisca_survey_manager/io/hdf.py +++ b/openfisca_survey_manager/io/hdf.py @@ -3,12 +3,22 @@ from __future__ import annotations import logging +import re from typing import Any import pandas as pd log = logging.getLogger(__name__) +# PyTables / pandas-HDF5 require node names to match ^[a-zA-Z_][a-zA-Z0-9_]*$ +# to avoid NaturalNameWarning. We normalize table names (e.g. person_2017-01 -> person_2017_01). +_HDF5_SAFE_PATTERN = re.compile(r"[^a-zA-Z0-9_]") + + +def hdf5_safe_key(name: str) -> str: + """Return an HDF5 node name safe for PyTables (valid Python identifier).""" + return _HDF5_SAFE_PATTERN.sub("_", name) + def write_table_to_hdf5( data_frame: pd.DataFrame, @@ -22,8 +32,9 @@ def write_table_to_hdf5( Mirrors historical behavior from `tables.Table.save_data_frame_to_hdf5`. May mutate `data_frame` (type conversions) to ensure it can be written. """ + key = hdf5_safe_key(store_path) try: - data_frame.to_hdf(hdf5_file_path, store_path, append=False, **kwargs) + data_frame.to_hdf(hdf5_file_path, key=key, append=False, **kwargs) except (TypeError, NotImplementedError): log.info("Type problem(s) when creating %s in %s", store_path, hdf5_file_path) dtypes = data_frame.dtypes @@ -42,4 +53,4 @@ def write_table_to_hdf5( "The following types are added as category using the table format %s", dtypes[converted_dtypes], ) - data_frame.to_hdf(hdf5_file_path, store_path, append=False, format="table", **kwargs) + data_frame.to_hdf(hdf5_file_path, key=key, append=False, format="table", **kwargs) diff --git a/openfisca_survey_manager/policy/simulations.py b/openfisca_survey_manager/policy/simulations.py index 39c30cdb..fb9f3ffc 100644 --- a/openfisca_survey_manager/policy/simulations.py +++ b/openfisca_survey_manager/policy/simulations.py @@ -405,7 +405,7 @@ def compute_pivot_table( variables.add(weight_variable) else: - log.warn( + log.warning( f"There is no weight variable for entity {entity_key} nor alternative weights. " "Switch to unweighted" ) @@ -750,7 +750,7 @@ def compute_winners_losers( weight_variable = weight_variable_by_entity[entity_key] weight = baseline_simulation.calculate(weight_variable, period=period) else: - log.warn( + log.warning( f"There is no weight variable for entity {entity_key} nor alternative weights. Switch to unweighted" ) @@ -1249,7 +1249,7 @@ def init_variable_in_entity( if variable.definition_period == YEAR and period.unit == MONTH: # Some variables defined for a year are present in month/quarter dataframes # Cleaning the dataframe would probably be better in the long run - log.warn( + log.warning( f"Trying to set a monthly value for variable {variable_name}, which is defined on a year. " "The montly values you provided will be summed." ) @@ -1434,7 +1434,7 @@ def summarize_variable( ) df = pd.DataFrame({variable: array}).replace(categories_by_index).astype(categories_type) df["weights"] = weights if weighted else 1 - groupby = df.groupby(variable)["weights"].sum() + groupby = df.groupby(variable, observed=False)["weights"].sum() total = groupby.sum() expr = [f" {index} = {row:.2e} ({row / total:.1%})" for index, row in groupby.items()] log.info("%s: %s.", period, ",".join(expr)) diff --git a/openfisca_survey_manager/policy/tests/test_compute_pivot_table.py b/openfisca_survey_manager/policy/tests/test_compute_pivot_table.py index 75706e21..4c20ff94 100644 --- a/openfisca_survey_manager/policy/tests/test_compute_pivot_table.py +++ b/openfisca_survey_manager/policy/tests/test_compute_pivot_table.py @@ -7,7 +7,7 @@ def test_compute_pivot_table(): survey_scenario = create_randomly_initialized_survey_scenario(reform=modify_social_security_taxation) period = "2017-01" - return survey_scenario.compute_pivot_table( + pivot_table = survey_scenario.compute_pivot_table( aggfunc="mean", columns=["age"], difference=False, @@ -22,3 +22,4 @@ def test_compute_pivot_table(): weighted=True, alternative_weights=None, ) + assert pivot_table is not None diff --git a/openfisca_survey_manager/tests/test_calibration.py b/openfisca_survey_manager/tests/test_calibration.py index d9f827a8..05f23fcb 100644 --- a/openfisca_survey_manager/tests/test_calibration.py +++ b/openfisca_survey_manager/tests/test_calibration.py @@ -164,4 +164,4 @@ def test_simulation_calibration_input_from_data(tmp_path): f"{simulation_name} weight_variable_by_entity does not match {weight_variable_by_entity}" ) assert (survey_scenario.calculate_series("household_weight", period, simulation=simulation_name) != 0).all() - return survey_scenario + assert survey_scenario is not None From 57442232f9c274865279ff1ffd7340181059568e Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Tue, 3 Mar 2026 11:37:24 +0100 Subject: [PATCH 14/18] Changelog: add typing section for 1.0.0 Made-with: Cursor --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79014119..5520742b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ - **`load_table`** : déplacé de `utils` vers `openfisca_survey_manager.core.dataset` (et exporté depuis `core`). - Tous les imports internes ont été migrés vers `configuration.paths`, `configuration.models`, `core.dataset`, `core.survey`, `core.table`, `io.readers`, `processing.weights`, `common.misc`. +* **Typing** (no breaking API changes) + - **policy**: Add `py.typed` marker; type hints on `legislation_asof`, `variables`, `coicop`, `matching`, `statshelpers`, `calmar`, `calibration`, `simulation_builder`, `aggregates`, `simulations`, and scenarios (`abstract_scenario`, `reform_scenario`). + - **configuration**: Type hints on `Config.__init__` and `save` in `configuration.models`. + - **processing**: Type return of `__getattr__` in `processing/__init__.py`. + # 6.5.0 * Typing (no breaking API changes) From daa597845079aee88ac5367eae8f44a397780a29 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Tue, 3 Mar 2026 12:30:04 +0100 Subject: [PATCH 15/18] RFC-002: config.yaml + manifests, migration script, DeprecationWarning - Add config_loader (get_config_dir, load_config, load_manifest, manifest_survey_to_json) - SurveyCollection.load() tries config.yaml + manifest first, else legacy config.ini + JSON - Add migrate_config_to_rfc002 script (config.ini/raw_data.ini/JSON -> config.yaml + manifests) - Emit DeprecationWarning when loading via legacy config.ini + JSON - Add tests for RFC-002 and migration; legacy load test expects deprecation warning - Add docs/RFC-002-METADATA-AND-CONFIG.md Made-with: Cursor --- docs/RFC-002-METADATA-AND-CONFIG.md | 195 +++++++++++++++ .../configuration/__init__.py | 10 +- .../configuration/config_loader.py | 96 +++++++ openfisca_survey_manager/core/dataset.py | 59 ++++- openfisca_survey_manager/core/survey.py | 8 +- .../scripts/migrate_config_to_rfc002.py | 234 ++++++++++++++++++ .../tests/test_config_manifest_rfc002.py | 226 +++++++++++++++++ 7 files changed, 813 insertions(+), 15 deletions(-) create mode 100644 docs/RFC-002-METADATA-AND-CONFIG.md create mode 100644 openfisca_survey_manager/configuration/config_loader.py create mode 100644 openfisca_survey_manager/scripts/migrate_config_to_rfc002.py create mode 100644 openfisca_survey_manager/tests/test_config_manifest_rfc002.py diff --git a/docs/RFC-002-METADATA-AND-CONFIG.md b/docs/RFC-002-METADATA-AND-CONFIG.md new file mode 100644 index 00000000..12c671d9 --- /dev/null +++ b/docs/RFC-002-METADATA-AND-CONFIG.md @@ -0,0 +1,195 @@ +# RFC-002 : Architecture des métadonnées et de la configuration + +**Statut** : Implémenté (chargement config.yaml + manifest, compat legacy) +**Branche** : feature/backend +**Date** : 2025-01 + +--- + +## 1. Résumé + +Cette RFC propose une architecture **plus simple et plus standard** pour la gestion des métadonnées et des chemins dans openfisca-survey-manager, en s’appuyant sur les conventions XDG, un seul format de configuration par répertoire, et une structure de répertoires prévisible. Elle prévoit une **migration progressive** de l’existant. + +--- + +## 2. État actuel (à migrer) + +### 2.1 Où est la config ? + +Le répertoire de configuration (« config_files_directory ») est résolu dans `configuration/paths.py` par une **cascade de hacks** : + +| Priorité | Condition | Répertoire | +|----------|-----------|------------| +| 1 | Package `taxipp` importé et répertoire existe | `taxipp_install/.config/openfisca-survey-manager` | +| 2 | Package `openfisca_france_data` importé et répertoire existe | `BaseDirectory.save_config_path("openfisca-survey-manager")` → **~/.config/openfisca-survey-manager** | +| 3 | CI ou pytest | `openfisca_survey_manager/tests/data_files` | +| 4 | Fallback | `~/.config/openfisca-survey-manager` (XDG) | + +Problèmes : ordre dépendant des imports, écriture possible de `config.ini` dans les tests à l’import, assertion si le répertoire n’existe pas. + +### 2.2 Fichiers dans le répertoire de config + +Aujourd’hui, **deux INI** + des **JSON** externes : + +- **config.ini** (obligatoire dans le répertoire) + - `[collections]` : `collections_directory` + paires `nom_collection` = chemin vers un fichier JSON. + - `[data]` : `output_directory`, `tmp_directory` (et en tests `input_directory`). +- **raw_data.ini** (utilisé uniquement par le script `build-collection`) + - Une section par collection : `[nom_collection]`. + - Clés = noms d’enquêtes, valeurs = chemins vers répertoire/fichier de données brutes. +- **Fichiers JSON** (un par collection, chemin dans config.ini ou sous `collections_directory`) + - Contenu : `name`, `label`, `surveys` : { `survey_name` → métadonnées du survey (tables, hdf5_file_path, parquet_file_path, **informations** dont `csv_files`, `sas_files`, etc.) }. + +Les métadonnées sont donc réparties entre : config.ini (où trouver les JSON), raw_data.ini (où sont les données brutes, seulement pour build), et les JSON (décriture des surveys, chemins de stockage, listes de fichiers sources). Redondance et deux formats INI différents. + +### 2.3 Utilisation dans le code + +- **SurveyCollection** : lit `Config(config_files_directory)` → `config.ini` ; get/set `collections` (nom → json_path) ; `config.get("data", "output_directory")` pour `fill_store`. +- **Survey** : `informations` (dict) contient p.ex. `csv_files`, `sas_files` ; utilisé dans `fill_store` pour savoir quels fichiers lire. +- **build_collection** : lit `raw_data.ini` pour savoir quels répertoires associer à quelles enquêtes, puis crée/met à jour la collection JSON et les données. + +--- + +## 3. Proposition : architecture cible + +### 3.1 Principes + +1. **Un seul répertoire de configuration** : XDG uniquement par défaut, ou chemin explicite (variable d’environnement ou argument). Plus de résolution selon `taxipp` / `openfisca_france_data`. +2. **Un seul fichier de config par répertoire** : tout ce qui est « config globale » (chemins de base, options) dans un seul fichier (voir 3.2). +3. **Métadonnées des datasets au plus près des données** : un « dataset » (ex-collection) = un répertoire dédié avec un manifeste (metadata) à l’intérieur, plutôt qu’un JSON éclaté référencé par un INI. +4. **Standard et lisible** : YAML ou INI clair pour la config ; YAML ou JSON pour les manifests (alignement possible avec RFC-001 Data Stack). + +### 3.2 Répertoire de configuration (XDG) + +**Emplacement par défaut** : `$XDG_CONFIG_HOME/openfisca-survey-manager/` (sinon `~/.config/openfisca-survey-manager/`). + +Contenu proposé : + +``` +~/.config/openfisca-survey-manager/ +├── config.yaml # unique fichier de config (remplace config.ini + raw_data.ini pour la partie “où sont les choses”) +``` + +**config.yaml** (exemple) : + +```yaml +# Répertoire où sont stockées les collections/datasets (manifests + données dérivées) +collections_dir: ~/.local/share/openfisca-survey-manager/collections + +# Répertoire de sortie par défaut pour build / fill_store (optionnel, peut être overridé par dataset) +default_output_dir: ~/.local/share/openfisca-survey-manager/output + +# Répertoire temporaire (optionnel) +tmp_dir: /tmp/openfisca-survey-manager +``` + +Alternative si on garde l’INI : un seul **config.ini** avec des sections claires, p.ex. : + +```ini +[paths] +collections_dir = ~/.local/share/openfisca-survey-manager/collections +default_output_dir = ~/.local/share/openfisca-survey-manager/output +tmp_dir = /tmp/openfisca-survey-manager +``` + +On supprime : `[collections]` avec une entrée par collection (les manifests seront dans chaque dataset, voir 3.3). On supprime **raw_data.ini** : les sources brutes seront décrites dans le manifest du dataset. + +### 3.3 Structure d’un dataset (ex-collection) + +Un dataset = un répertoire sous `collections_dir` (ou chemin absolu configuré), avec un **manifeste** à l’intérieur : + +``` +collections_dir/ +└── erfs/ + ├── manifest.yaml # métadonnées du dataset + liste des surveys + ├── erfs_2019/ # (optionnel) données dérivées par survey + │ ├── data.parquet + │ └── ... + └── erfs_2020/ + └── ... +``` + +**manifest.yaml** (exemple) : + +```yaml +name: erfs +label: "Enquête Revenus Fiscaux et Sociaux" + +# Par survey : sources brutes (remplace raw_data.ini + informations) +surveys: + erfs_2019: + label: "ERFS 2019" + source: + format: sas # ou csv, stata, parquet + path: /data/erfs/2019 # répertoire ou fichier + # optionnel : chemins de sortie relatifs au dataset + output_subdir: erfs_2019 + + erfs_2020: + label: "ERFS 2020" + source: + format: parquet + path: /data/erfs/2020 + output_subdir: erfs_2020 +``` + +Cela remplace : la section `[erfs]` de raw_data.ini + la partie « informations » (csv_files, sas_files, …) dans le JSON de collection. Un seul endroit pour « où sont les données brutes » et « où écrire les sorties ». + +Pour la rétrocompatibilité, on peut prévoir un **adaptateur** qui lit l’ancien JSON + raw_data.ini et produit (ou expose) un équivalent manifest. + +### 3.4 Résolution du répertoire de config (simplifiée) + +- **Valeur explicite** : toujours possible de passer `config_dir` (ou `config_files_directory`) en argument aux APIs et au CLI. +- **Par défaut** : `os.environ.get("OPENFISCA_SURVEY_CONFIG_DIR")` ou `xdg_config_home() / "openfisca-survey-manager"`. +- **Tests** : répertoire dédié (ex. `tests/data_files`) fourni explicitement par les tests ; plus d’effet de bord à l’import (plus d’écriture de config.ini au chargement de `paths`). + +On **ne** résout plus le répertoire en fonction de la présence de `taxipp` ou `openfisca_france_data`. Les projets (france-data, taxipp) peuvent : +- soit définir `OPENFISCA_SURVEY_CONFIG_DIR` vers leur répertoire, +- soit passer le chemin de config à chaque appel. + +### 3.5 API cible (alignement RFC-001) + +- Charger un dataset par nom : `DataManager.load("erfs", config_dir=...)` → lit `collections_dir/erfs/manifest.yaml` et les données associées. +- Accès aux métadonnées : `dataset.metadata` (provenant du manifest), `dataset.schema` (si on l’expose), chemins dérivés déterministes à partir de `collections_dir` + `name` + `output_subdir`. + +On garde une compatibilité avec l’API actuelle « SurveyCollection.load(collection=...) » pendant la transition, en faisant que cette API s’appuie en interne sur la nouvelle config + manifests (éventuellement via un bridge depuis l’ancien JSON). + +--- + +## 4. Migration de l’existant + +### 4.1 Conserver l’existant en parallèle + +- Garder la lecture de **config.ini** et **raw_data.ini** tant que la nouvelle config n’est pas présente. +- Si `config.yaml` (ou le nouveau config.ini [paths]) existe dans le répertoire de config : utiliser la nouvelle structure (manifests sous `collections_dir`). +- Sinon : comportement actuel (config.ini [collections] + [data], raw_data.ini, JSON externes). + +### 4.2 Script de migration + +Un script permet de migrer l’existant vers la nouvelle structure : + +- **Emplacement** : `openfisca_survey_manager.scripts.migrate_config_to_rfc002` +- **Usage** : + ```bash + python -m openfisca_survey_manager.scripts.migrate_config_to_rfc002 [--config-dir PATH] [--dry-run] [-v] + ``` +- **Comportement** : lit `config.ini` ([collections] + [data]) et, si présent, `raw_data.ini` ; pour chaque collection, charge le JSON, déduit `source.format` et `source.path` à partir de `informations` (csv_files, sas_files, etc.) ou de la section correspondante de raw_data.ini ; crée `config.yaml` et `collections_dir//manifest.yaml` pour chaque collection. Avec `--dry-run`, n’écrit aucun fichier. +- **Répertoire de config par défaut** : celui retourné par `get_config_dir()` (env `OPENFISCA_SURVEY_CONFIG_DIR` ou XDG). On peut imposer un répertoire avec `--config-dir`. + +### 4.3 Dépréciation + +- À terme : annoncer comme dépréciés `config.ini` [collections] (mapping nom → JSON), `raw_data.ini`, et les JSON de collection « à l’ancienne ». Documenter la migration dans MIGRATION_IMPORTS.md ou un nouveau MIGRATION_CONFIG.md. + +--- + +## 5. Résumé des changements proposés + +| Actuel | Cible | +|--------|--------| +| Résolution config par taxipp / france_data / CI / XDG | XDG ou env `OPENFISCA_SURVEY_CONFIG_DIR` ou argument explicite | +| config.ini + raw_data.ini | Un seul fichier (config.yaml ou config.ini [paths]) | +| JSON de collection hors répertoire, référencé par config | Manifest (YAML/JSON) par dataset dans `collections_dir//manifest.yaml` | +| Sources brutes dans raw_data.ini + informations (JSON) | Sources dans le manifest du dataset (`surveys.*.source`) | +| Écriture config.ini au chargement des paths (tests) | Plus d’écriture à l’import ; tests passent un config_dir explicite | + +Cela donne une architecture **plus simple** (un format, un lieu par dataset), **plus standard** (XDG, chemins explicites), et **migrable** en gardant l’ancien comportement tant que la nouvelle config n’est pas en place. diff --git a/openfisca_survey_manager/configuration/__init__.py b/openfisca_survey_manager/configuration/__init__.py index b8cab758..137908e1 100644 --- a/openfisca_survey_manager/configuration/__init__.py +++ b/openfisca_survey_manager/configuration/__init__.py @@ -1,6 +1,11 @@ # Config and paths; config.py and paths.py re-export for compatibility. -# See docs/REFACTORING_PLAN.md. +# See docs/REFACTORING_PLAN.md. RFC-002: config_loader for config.yaml + manifest. +from openfisca_survey_manager.configuration.config_loader import ( + get_config_dir, + load_config, + load_manifest, +) from openfisca_survey_manager.configuration.models import Config from openfisca_survey_manager.configuration.paths import ( config_ini, @@ -15,7 +20,10 @@ "Config", "config_ini", "default_config_files_directory", + "get_config_dir", "is_in_ci", + "load_config", + "load_manifest", "openfisca_survey_manager_location", "private_run_with_data", "test_config_files_directory", diff --git a/openfisca_survey_manager/configuration/config_loader.py b/openfisca_survey_manager/configuration/config_loader.py new file mode 100644 index 00000000..75688214 --- /dev/null +++ b/openfisca_survey_manager/configuration/config_loader.py @@ -0,0 +1,96 @@ +"""RFC-002: New config and manifest loading (YAML-based).""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Optional + +import yaml +from xdg import BaseDirectory + +log = logging.getLogger(__name__) + +CONFIG_FILENAME = "config.yaml" +MANIFEST_FILENAME = "manifest.yaml" +ENV_CONFIG_DIR = "OPENFISCA_SURVEY_CONFIG_DIR" + + +def get_config_dir(explicit: Optional[Path | str] = None) -> Path: + """Return config directory: explicit path, or env OPENFISCA_SURVEY_CONFIG_DIR, or XDG.""" + if explicit is not None: + return Path(explicit).expanduser().resolve() + import os + + env_path = os.environ.get(ENV_CONFIG_DIR) + if env_path: + return Path(env_path).expanduser().resolve() + return Path(BaseDirectory.save_config_path("openfisca-survey-manager")) + + +def load_config(config_dir: Path) -> Optional[dict[str, Any]]: + """ + Load new-style config from config_dir/config.yaml. + Returns dict with collections_dir, default_output_dir, tmp_dir (paths expanded), + or None if config.yaml is missing or invalid. + """ + config_path = config_dir / CONFIG_FILENAME + if not config_path.is_file(): + return None + try: + with config_path.open() as f: + data = yaml.safe_load(f) + except Exception as e: + log.warning("Failed to load %s: %s", config_path, e) + return None + if not data or not isinstance(data, dict): + return None + collections_dir = data.get("collections_dir") + if not collections_dir: + return None + out = { + "collections_dir": Path(collections_dir).expanduser().resolve(), + "default_output_dir": Path(data.get("default_output_dir", ".")).expanduser().resolve(), + "tmp_dir": Path(data.get("tmp_dir", "/tmp")).expanduser().resolve(), + } + return out + + +def load_manifest(collections_dir: Path, name: str) -> Optional[dict[str, Any]]: + """ + Load dataset manifest from collections_dir/name/manifest.yaml. + Returns manifest dict (name, label, surveys) or None if missing. + """ + manifest_path = collections_dir / name / MANIFEST_FILENAME + if not manifest_path.is_file(): + return None + try: + with manifest_path.open() as f: + data = yaml.safe_load(f) + except Exception as e: + log.warning("Failed to load manifest %s: %s", manifest_path, e) + return None + if not data or not isinstance(data, dict) or "surveys" not in data: + return None + return data + + +def manifest_survey_to_json(survey_name: str, entry: dict[str, Any]) -> dict[str, Any]: + """ + Convert a manifest survey entry to the dict shape expected by Survey.create_from_json. + entry: { label?, source: { format, path }, output_subdir? } + """ + source = entry.get("source") or {} + fmt = source.get("format", "csv") + path = source.get("path", "") + # Survey expects e.g. csv_files, sas_files list in informations + files_key = f"{fmt}_files" + informations = {files_key: [path] if path else []} + return { + "name": survey_name, + "label": entry.get("label", survey_name), + "hdf5_file_path": None, + "parquet_file_path": None, + "tables": entry.get("tables"), + "informations": informations, + } diff --git a/openfisca_survey_manager/core/dataset.py b/openfisca_survey_manager/core/dataset.py index 99519d80..322ccdd2 100644 --- a/openfisca_survey_manager/core/dataset.py +++ b/openfisca_survey_manager/core/dataset.py @@ -7,11 +7,17 @@ import configparser import json import logging +import warnings from pathlib import Path from typing import List, Optional, Union import pandas as pd +from openfisca_survey_manager.configuration.config_loader import ( + load_config, + load_manifest, + manifest_survey_to_json, +) from openfisca_survey_manager.configuration.models import Config from openfisca_survey_manager.configuration.paths import default_config_files_directory from openfisca_survey_manager.core.survey import Survey @@ -28,6 +34,7 @@ class SurveyCollection: json_file_path: Optional[str] = None surveys: List[Survey] # set in __init__ config: Optional[Config] = None + output_directory: Optional[str] = None # RFC-002: used when config is None (manifest-based) def __init__( self, @@ -73,19 +80,16 @@ def dump( config_files_directory: Optional[Union[Path, str]] = None, json_file_path: Optional[str] = None, ) -> None: - if self.config is not None: - config = self.config - else: - if config_files_directory is not None: - pass - else: - config_files_directory = default_config_files_directory - self.config = Config(config_files_directory=config_files_directory) + if json_file_path is not None: + self.json_file_path = json_file_path - if json_file_path is None: + if self.config is None: + # RFC-002: manifest-based collection; no config.ini to update + return + + config = self.config + if self.json_file_path is None: assert self.json_file_path is not None, "A json_file_path should be provided" - else: - self.json_file_path = json_file_path config.set("collections", self.name, str(self.json_file_path)) config.save() @@ -132,7 +136,38 @@ def load( collection: Optional[str] = None, config_files_directory: Optional[Union[Path, str]] = default_config_files_directory, ) -> SurveyCollection: - assert Path(config_files_directory).exists() + config_dir = Path(config_files_directory).expanduser().resolve() + assert config_dir.exists(), f"Config directory does not exist: {config_dir}" + + # RFC-002: try new config.yaml + manifest first + new_cfg = load_config(config_dir) + if json_file_path is None and collection is not None and new_cfg is not None: + manifest = load_manifest(new_cfg["collections_dir"], collection) + if manifest is not None: + self = cls.__new__(cls) + self.name = manifest.get("name", collection) + self.label = manifest.get("label", self.name) + self.json_file_path = str(new_cfg["collections_dir"] / collection / "manifest.yaml") + self.config = None + self.output_directory = str(new_cfg["default_output_dir"]) + self.surveys = [] + for survey_name, entry in manifest.get("surveys", {}).items(): + survey_json = manifest_survey_to_json(survey_name, entry) + survey = Survey(name=survey_name) + survey = survey.create_from_json(survey_json) + survey.survey_collection = self + self.surveys.append(survey) + return self + + # Legacy: config.ini + JSON + warnings.warn( + "Loading collections from config.ini and JSON files is deprecated. " + "Migrate to config.yaml and manifest.yaml using: " + "python -m openfisca_survey_manager.scripts.migrate_config_to_rfc002 --config-dir " + "See docs/RFC-002-METADATA-AND-CONFIG.md.", + DeprecationWarning, + stacklevel=2, + ) config = Config(config_files_directory=config_files_directory) if json_file_path is None: assert collection is not None, "A collection is needed" diff --git a/openfisca_survey_manager/core/survey.py b/openfisca_survey_manager/core/survey.py index 186ae9ff..18062309 100644 --- a/openfisca_survey_manager/core/survey.py +++ b/openfisca_survey_manager/core/survey.py @@ -116,8 +116,12 @@ def fill_store( assert self.survey_collection is not None assert isinstance(overwrite, (bool, list)) survey = self - config = survey.survey_collection.config - directory_path = config.get("data", "output_directory") + sc = survey.survey_collection + if sc.config is not None: + directory_path = sc.config.get("data", "output_directory") + else: + directory_path = getattr(sc, "output_directory", None) + assert directory_path is not None, "SurveyCollection has no config and no output_directory; cannot fill_store" if not Path(directory_path).is_dir(): log.warning( f"{directory_path} who should be the store data directory does not exist: we create the directory" diff --git a/openfisca_survey_manager/scripts/migrate_config_to_rfc002.py b/openfisca_survey_manager/scripts/migrate_config_to_rfc002.py new file mode 100644 index 00000000..426e5ae7 --- /dev/null +++ b/openfisca_survey_manager/scripts/migrate_config_to_rfc002.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python +""" +Migrate existing config (config.ini + raw_data.ini + JSON collections) to RFC-002 layout. + +Produces: + - config.yaml (collections_dir, default_output_dir, tmp_dir) + - collections_dir//manifest.yaml per collection + +Usage: + python -m openfisca_survey_manager.scripts.migrate_config_to_rfc002 [--config-dir PATH] [--dry-run] +""" + +from __future__ import annotations + +import argparse +import configparser +import json +import logging +import sys +from pathlib import Path + +import yaml + +# Allow running as __main__ or as script +try: + from openfisca_survey_manager.configuration.config_loader import ( + CONFIG_FILENAME, + MANIFEST_FILENAME, + ) +except ImportError: + CONFIG_FILENAME = "config.yaml" + MANIFEST_FILENAME = "manifest.yaml" + +log = logging.getLogger(__name__) + +SOURCE_FORMAT_KEYS = ("csv_files", "sas_files", "stata_files", "parquet_files") + + +def _informations_to_source(informations: dict) -> tuple[str, str]: + """From Survey.informations (e.g. csv_files, sas_files), return (format, path).""" + if not informations: + return "csv", "" + for key in SOURCE_FORMAT_KEYS: + paths = informations.get(key) + if paths and isinstance(paths, list) and len(paths) > 0: + fmt = key.replace("_files", "") + path = paths[0] if isinstance(paths[0], str) else str(paths[0]) + return fmt, path + return "csv", "" + + +def build_manifest_from_json( + json_path: Path, + raw_data_section: dict[str, str] | None = None, +) -> dict: + """ + Build RFC-002 manifest dict from a legacy collection JSON file. + raw_data_section: optional dict survey_name -> path from raw_data.ini [collection_name]. + """ + with json_path.open(encoding="utf-8") as f: + data = json.load(f) + name = data.get("name", json_path.stem) + label = data.get("label", name) + surveys_data = data.get("surveys", {}) + if not isinstance(surveys_data, dict): + surveys_data = {} + surveys = {} + for survey_name, survey_obj in surveys_data.items(): + if not isinstance(survey_obj, dict): + continue + infos = survey_obj.get("informations", {}) or {} + if raw_data_section and survey_name in raw_data_section: + path = raw_data_section[survey_name] + fmt = "csv" + for k in SOURCE_FORMAT_KEYS: + if infos.get(k): + fmt = k.replace("_files", "") + break + else: + fmt, path = _informations_to_source(infos) + surveys[survey_name] = { + "label": survey_obj.get("label", survey_name), + "source": {"format": fmt, "path": path}, + } + if survey_obj.get("output_subdir"): + surveys[survey_name]["output_subdir"] = survey_obj["output_subdir"] + return {"name": name, "label": label, "surveys": surveys} + + +def load_raw_data_ini(config_dir: Path) -> configparser.ConfigParser | None: + """Load raw_data.ini if present.""" + path = config_dir / "raw_data.ini" + if not path.is_file(): + return None + parser = configparser.ConfigParser() + parser.read(path, encoding="utf-8") + return parser + + +def migrate( + config_dir: Path, + *, + dry_run: bool = False, +) -> bool: + """ + Migrate config_dir from config.ini (+ raw_data.ini + JSON) to config.yaml + manifests. + Returns True if migration was done (or dry_run succeeded). + """ + config_ini = config_dir / "config.ini" + if not config_ini.is_file(): + log.error("No config.ini found in %s", config_dir) + return False + + parser = configparser.ConfigParser() + parser.read(config_ini, encoding="utf-8") + if "collections" not in parser.sections(): + log.error("config.ini has no [collections] section") + return False + + collections_dir_str = parser.get("collections", "collections_directory", fallback=None) + if not collections_dir_str: + collections_dir_str = str(config_dir / "collections") + collections_dir = Path(collections_dir_str).expanduser().resolve() + + output_dir = parser.get("data", "output_directory", fallback=str(config_dir / "output")) + tmp_dir = parser.get("data", "tmp_directory", fallback="/tmp") + if "data" not in parser.sections(): + output_dir = str(config_dir / "output") + tmp_dir = "/tmp" + + raw_data = load_raw_data_ini(config_dir) + collection_names: list[str] = [] + for key in parser.options("collections"): + if key == "collections_directory": + continue + collection_names.append(key) + + if not collection_names: + log.warning("No collection entries in config.ini (only collections_directory)") + # Still write config.yaml so the dir is ready for new-style use + else: + if not dry_run: + collections_dir.mkdir(parents=True, exist_ok=True) + for name in collection_names: + try: + json_path_str = parser.get("collections", name) + except configparser.NoOptionError: + continue + json_path = Path(json_path_str).expanduser().resolve() + if not json_path.is_file(): + log.warning("Collection %s: JSON file not found %s", name, json_path) + continue + raw_section = None + if raw_data and raw_data.has_section(name): + raw_section = dict(raw_data.items(name)) + manifest = build_manifest_from_json(json_path, raw_section) + manifest_path = collections_dir / name / MANIFEST_FILENAME + if dry_run: + log.info("[dry-run] Would write %s", manifest_path) + continue + manifest_path.parent.mkdir(parents=True, exist_ok=True) + with manifest_path.open("w", encoding="utf-8") as f: + yaml.safe_dump( + manifest, + f, + default_flow_style=False, + allow_unicode=True, + sort_keys=False, + ) + log.info("Wrote %s", manifest_path) + + config_yaml_path = config_dir / CONFIG_FILENAME + new_config = { + "collections_dir": str(collections_dir), + "default_output_dir": str(Path(output_dir).expanduser().resolve()), + "tmp_dir": str(Path(tmp_dir).expanduser().resolve()), + } + if dry_run: + log.info("[dry-run] Would write %s with %s", config_yaml_path, new_config) + return True + with config_yaml_path.open("w", encoding="utf-8") as f: + yaml.safe_dump(new_config, f, default_flow_style=False, sort_keys=False) + log.info("Wrote %s", config_yaml_path) + return True + + +def main() -> int: + ap = argparse.ArgumentParser( + description="Migrate config.ini + raw_data.ini + JSON to RFC-002 (config.yaml + manifests)", + ) + ap.add_argument( + "--config-dir", + type=Path, + default=None, + help="Directory containing config.ini (default: XDG or OPENFISCA_SURVEY_CONFIG_DIR)", + ) + ap.add_argument( + "--dry-run", + action="store_true", + help="Only log what would be done", + ) + ap.add_argument( + "-v", + "--verbose", + action="store_true", + help="Verbose logging", + ) + args = ap.parse_args() + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(message)s", + stream=sys.stdout, + ) + if args.config_dir is None: + try: + from openfisca_survey_manager.configuration.config_loader import ( + get_config_dir, + ) + + config_dir = get_config_dir() + except Exception: + log.error("Provide --config-dir or set OPENFISCA_SURVEY_CONFIG_DIR") + return 1 + else: + config_dir = args.config_dir.expanduser().resolve() + if not config_dir.is_dir(): + log.error("Config directory does not exist: %s", config_dir) + return 1 + ok = migrate(config_dir, dry_run=args.dry_run) + return 0 if ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/openfisca_survey_manager/tests/test_config_manifest_rfc002.py b/openfisca_survey_manager/tests/test_config_manifest_rfc002.py new file mode 100644 index 00000000..ae46f0ee --- /dev/null +++ b/openfisca_survey_manager/tests/test_config_manifest_rfc002.py @@ -0,0 +1,226 @@ +"""Tests for RFC-002: config.yaml and manifest.yaml (new metadata architecture).""" + +from pathlib import Path + +import pytest + +from openfisca_survey_manager.configuration.config_loader import ( + get_config_dir, + load_config, + load_manifest, + manifest_survey_to_json, +) +from openfisca_survey_manager.configuration.paths import openfisca_survey_manager_location +from openfisca_survey_manager.core.dataset import SurveyCollection + + +@pytest.fixture +def rfc002_config_dir(tmp_path): + """Create a config dir with config.yaml and a dataset with manifest.yaml.""" + config_dir = tmp_path / "config" + config_dir.mkdir() + (config_dir / "config.yaml").write_text( + """ +collections_dir: {collections} +default_output_dir: {output} +tmp_dir: {tmp} +""".format( + collections=tmp_path / "collections", + output=tmp_path / "output", + tmp=tmp_path / "tmp", + ) + ) + collections_dir = tmp_path / "collections" + collections_dir.mkdir() + dataset_dir = collections_dir / "test_dataset" + dataset_dir.mkdir() + (dataset_dir / "manifest.yaml").write_text( + """ +name: test_dataset +label: "Test dataset (RFC-002)" + +surveys: + survey_a: + label: "Survey A" + source: + format: csv + path: /data/survey_a + survey_b: + label: "Survey B" + source: + format: sas + path: /data/survey_b +""" + ) + return config_dir + + +def test_get_config_dir_explicit(tmp_path): + assert get_config_dir(tmp_path) == tmp_path.resolve() + + +def test_get_config_dir_env(monkeypatch, tmp_path): + monkeypatch.setenv("OPENFISCA_SURVEY_CONFIG_DIR", str(tmp_path)) + assert get_config_dir() == tmp_path.resolve() + + +def test_load_config_missing(tmp_path): + assert load_config(tmp_path) is None + + +def test_load_config_present(rfc002_config_dir): + cfg = load_config(rfc002_config_dir) + assert cfg is not None + assert "collections_dir" in cfg + assert "default_output_dir" in cfg + assert cfg["collections_dir"].is_dir() + assert (cfg["collections_dir"] / "test_dataset").is_dir() + + +def test_load_manifest_missing(tmp_path): + assert load_manifest(tmp_path, "nonexistent") is None + + +def test_load_manifest_present(rfc002_config_dir): + cfg = load_config(rfc002_config_dir) + assert cfg is not None + manifest = load_manifest(cfg["collections_dir"], "test_dataset") + assert manifest is not None + assert manifest["name"] == "test_dataset" + assert manifest["label"] == "Test dataset (RFC-002)" + assert "survey_a" in manifest["surveys"] + assert manifest["surveys"]["survey_a"]["source"]["format"] == "csv" + assert manifest["surveys"]["survey_a"]["source"]["path"] == "/data/survey_a" + + +def test_manifest_survey_to_json(): + entry = { + "label": "My survey", + "source": {"format": "sas", "path": "/path/to/data"}, + } + out = manifest_survey_to_json("my_survey", entry) + assert out["name"] == "my_survey" + assert out["label"] == "My survey" + assert out["informations"]["sas_files"] == ["/path/to/data"] + + +def test_survey_collection_load_from_manifest(rfc002_config_dir): + """SurveyCollection.load(collection=..., config_files_directory=...) uses manifest when config.yaml exists.""" + col = SurveyCollection.load( + collection="test_dataset", + config_files_directory=rfc002_config_dir, + ) + assert col.name == "test_dataset" + assert col.label == "Test dataset (RFC-002)" + assert col.config is None + assert col.output_directory is not None + assert len(col.surveys) == 2 + names = {s.name for s in col.surveys} + assert names == {"survey_a", "survey_b"} + survey_a = col.get_survey("survey_a") + assert survey_a.label == "Survey A" + assert survey_a.informations.get("csv_files") == ["/data/survey_a"] + + +def test_survey_collection_load_legacy_unchanged(tmp_path): + """Legacy config.ini + JSON still works when config.yaml is absent (emits DeprecationWarning).""" + # Use the package test data dir which has config.ini and fake.json + tests_data = Path(openfisca_survey_manager_location) / "openfisca_survey_manager" / "tests" / "data_files" + if not (tests_data / "config.ini").exists(): + pytest.skip("config.ini not present in tests/data_files") + if not (tests_data / "fake.json").exists(): + pytest.skip("fake.json not present in tests/data_files") + with pytest.warns(DeprecationWarning, match="config.ini and JSON files is deprecated"): + col = SurveyCollection.load( + collection="fake", + config_files_directory=tests_data, + ) + assert col.config is not None + assert col.name == "fake" + assert len(col.surveys) >= 0 + + +# --- Migration script tests --- + + +@pytest.fixture +def legacy_config_dir(tmp_path): + """Create a minimal legacy config dir: config.ini + one collection JSON.""" + config_dir = tmp_path / "legacy_config" + config_dir.mkdir() + collections_dir = tmp_path / "legacy_collections" + collections_dir.mkdir() + json_path = collections_dir / "my_collection.json" + json_path.write_text( + """ +{ + "name": "my_collection", + "label": "My collection", + "surveys": { + "survey_1": { + "label": "Survey 1", + "informations": { + "csv_files": ["/data/s1/file1.csv"] + } + }, + "survey_2": { + "label": "Survey 2", + "informations": { + "sas_files": ["/data/s2/file.sas7bdat"] + } + } + } +} +""", + encoding="utf-8", + ) + config_ini = config_dir / "config.ini" + config_ini.write_text( + f"""[collections] +collections_directory = {collections_dir} +my_collection = {json_path} + +[data] +output_directory = {tmp_path / "output"} +tmp_directory = /tmp +""", + encoding="utf-8", + ) + return config_dir + + +def test_migrate_produces_config_yaml_and_manifests(legacy_config_dir): + """Migration script creates config.yaml and collection manifest.""" + from openfisca_survey_manager.scripts.migrate_config_to_rfc002 import ( + CONFIG_FILENAME, + MANIFEST_FILENAME, + migrate, + ) + + ok = migrate(legacy_config_dir, dry_run=False) + assert ok is True + config_yaml = legacy_config_dir / CONFIG_FILENAME + assert config_yaml.is_file() + cfg = load_config(legacy_config_dir) + assert cfg is not None + assert "collections_dir" in cfg + assert (Path(cfg["collections_dir"]) / "my_collection" / MANIFEST_FILENAME).is_file() + manifest = load_manifest(cfg["collections_dir"], "my_collection") + assert manifest is not None + assert manifest["name"] == "my_collection" + assert manifest["surveys"]["survey_1"]["source"]["format"] == "csv" + assert manifest["surveys"]["survey_1"]["source"]["path"] == "/data/s1/file1.csv" + assert manifest["surveys"]["survey_2"]["source"]["format"] == "sas" + assert manifest["surveys"]["survey_2"]["source"]["path"] == "/data/s2/file.sas7bdat" + + +def test_migrate_dry_run_does_not_write(legacy_config_dir): + """Migration with --dry-run does not create files.""" + from openfisca_survey_manager.scripts.migrate_config_to_rfc002 import ( + CONFIG_FILENAME, + migrate, + ) + + ok = migrate(legacy_config_dir, dry_run=True) + assert ok is True + assert not (legacy_config_dir / CONFIG_FILENAME).is_file() From 9d494cdb35325e18d03ea42f118bc17e5fed38a3 Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Tue, 3 Mar 2026 14:31:08 +0100 Subject: [PATCH 16/18] feat: store backends (HDF5/Parquet/Zarr), store_format in manifest, migration (v1.1.0) - io/backends: backend registry (hdf5, parquet, zarr), get_backend, register_backend - Survey: zarr_file_path, fill_store/get_values for zarr; build-collection --zarr - Table: delegate write/read to backends via _get_store_path_and_format - Manifest: store_format (hdf5|parquet|zarr) at dataset level; load applies it and sets store paths - Migration script: infer store_format from legacy JSON and write in manifest - Docs: ZARR-BACKEND.md, RFC-002 store_format example and migration note - Changelog 1.1.0, pyproject 1.1.0 Made-with: Cursor --- CHANGELOG.md | 16 ++ docs/RFC-002-METADATA-AND-CONFIG.md | 22 +- docs/ZARR-BACKEND.md | 126 +++++++++ .../configuration/config_loader.py | 1 + openfisca_survey_manager/core/dataset.py | 9 + openfisca_survey_manager/core/survey.py | 34 ++- openfisca_survey_manager/core/table.py | 69 +++-- openfisca_survey_manager/io/backends.py | 227 +++++++++++++++ .../scripts/build_collection.py | 22 +- .../scripts/migrate_config_to_rfc002.py | 20 +- .../tests/test_config_manifest_rfc002.py | 76 +++++ .../tests/test_store_backends.py | 50 ++++ pyproject.toml | 5 +- uv.lock | 263 +++++++++++++++++- 14 files changed, 905 insertions(+), 35 deletions(-) create mode 100644 docs/ZARR-BACKEND.md create mode 100644 openfisca_survey_manager/io/backends.py create mode 100644 openfisca_survey_manager/tests/test_store_backends.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5520742b..b70d7fd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +# 1.1.0 + +* **Store backends** (choix du format de stockage des tables) + - **io/backends**: Backends HDF5, Parquet et Zarr (abstraction `StoreBackend`) ; `get_backend(name)`, `get_available_backend_names()`, `register_backend()` pour étendre. + - **Zarr** : backend optionnel (`pip install openfisca-survey-manager[zarr]`) ; une table = un groupe zarr dans un répertoire `.zarr` par survey. + - **Survey** : attribut `zarr_file_path` ; `fill_store(store_format="zarr")` et lecture via `get_values` pour zarr. + - **Table** : écriture/lecture et `_is_stored` délégués aux backends ; `_get_store_path_and_format()` unifie les chemins. + - **build-collection** : option `--zarr` en plus de `--parquet` ; défaut HDF5 avec avertissement. + - **Docs** : `docs/ZARR-BACKEND.md` (utilisation Zarr, compression, parallélisation). + +* **Manifest (RFC-002) : store_format** + - **manifest.yaml** : clé optionnelle `store_format` (hdf5, parquet, zarr) au niveau dataset ; par défaut `parquet` au chargement. + - **SurveyCollection.load** : depuis un manifest, applique `store_format` et déduit les chemins de store (`hdf5_file_path`, `parquet_file_path`, `zarr_file_path`) à partir de `default_output_dir`. + - **Script de migration** : infère `store_format` depuis le JSON legacy (`parquet_file_path` / `zarr_file_path` / `hdf5_file_path`) et l’écrit dans le manifest généré. + - **RFC-002** : exemple de manifest avec `store_format` ; section 3.5 et 4.2 mises à jour. + # 1.0.0 * **Breaking**: Version 1.0 — retrait des ré-exports et des DeprecationWarning diff --git a/docs/RFC-002-METADATA-AND-CONFIG.md b/docs/RFC-002-METADATA-AND-CONFIG.md index 12c671d9..e803f9de 100644 --- a/docs/RFC-002-METADATA-AND-CONFIG.md +++ b/docs/RFC-002-METADATA-AND-CONFIG.md @@ -115,6 +115,9 @@ collections_dir/ name: erfs label: "Enquête Revenus Fiscaux et Sociaux" +# Backend de stockage des tables (hdf5, parquet, zarr) ; par défaut parquet +store_format: parquet + # Par survey : sources brutes (remplace raw_data.ini + informations) surveys: erfs_2019: @@ -147,7 +150,22 @@ On **ne** résout plus le répertoire en fonction de la présence de `taxipp` ou - soit définir `OPENFISCA_SURVEY_CONFIG_DIR` vers leur répertoire, - soit passer le chemin de config à chaque appel. -### 3.5 API cible (alignement RFC-001) +### 3.5 Backends de stockage (store) + +Le stockage des tables d’enquête peut s’effectuer via différents **backends** (choix au build / `fill_store`) : + +| Backend | Format | Usage | +|----------|---------------------|--------------------------------------------| +| **hdf5** | Un fichier .h5 | Historique (déprécié à terme) | +| **parquet** | Répertoire, un .parquet par table | Recommandé (interop, colonnes) | +| **zarr** | Répertoire .zarr, un groupe par table | Optionnel (dépendance `[zarr]`) | + +- **API** : `io.backends.get_backend(name)`, `get_available_backend_names()`, `register_backend(name, backend)` pour étendre. +- **CLI** : `build-collection --parquet` ou `build-collection --zarr` ; par défaut HDF5 (avec avertissement). +- **Survey** : `store_format`, `hdf5_file_path` / `parquet_file_path` / `zarr_file_path` selon le backend. +- **Zarr (compression, parallélisation)** : voir [docs/ZARR-BACKEND.md](ZARR-BACKEND.md). + +### 3.6 API cible (alignement RFC-001) - Charger un dataset par nom : `DataManager.load("erfs", config_dir=...)` → lit `collections_dir/erfs/manifest.yaml` et les données associées. - Accès aux métadonnées : `dataset.metadata` (provenant du manifest), `dataset.schema` (si on l’expose), chemins dérivés déterministes à partir de `collections_dir` + `name` + `output_subdir`. @@ -173,7 +191,7 @@ Un script permet de migrer l’existant vers la nouvelle structure : ```bash python -m openfisca_survey_manager.scripts.migrate_config_to_rfc002 [--config-dir PATH] [--dry-run] [-v] ``` -- **Comportement** : lit `config.ini` ([collections] + [data]) et, si présent, `raw_data.ini` ; pour chaque collection, charge le JSON, déduit `source.format` et `source.path` à partir de `informations` (csv_files, sas_files, etc.) ou de la section correspondante de raw_data.ini ; crée `config.yaml` et `collections_dir//manifest.yaml` pour chaque collection. Avec `--dry-run`, n’écrit aucun fichier. +- **Comportement** : lit `config.ini` ([collections] + [data]) et, si présent, `raw_data.ini` ; pour chaque collection, charge le JSON, déduit `source.format` et `source.path` à partir de `informations` (csv_files, sas_files, etc.) ou de la section correspondante de raw_data.ini ; **infère `store_format`** (parquet, hdf5 ou zarr) à partir des champs `parquet_file_path` / `zarr_file_path` / `hdf5_file_path` des surveys du JSON legacy, et l’écrit dans le manifest ; crée `config.yaml` et `collections_dir//manifest.yaml` pour chaque collection. Avec `--dry-run`, n’écrit aucun fichier. - **Répertoire de config par défaut** : celui retourné par `get_config_dir()` (env `OPENFISCA_SURVEY_CONFIG_DIR` ou XDG). On peut imposer un répertoire avec `--config-dir`. ### 4.3 Dépréciation diff --git a/docs/ZARR-BACKEND.md b/docs/ZARR-BACKEND.md new file mode 100644 index 00000000..f375f609 --- /dev/null +++ b/docs/ZARR-BACKEND.md @@ -0,0 +1,126 @@ +# Utiliser Zarr avec OpenFisca Survey Manager + +Ce document explique **si et comment** utiliser le backend Zarr pour stocker les enquêtes, et ce qu’il en est de la **compression** et de la **parallélisation** en lecture/écriture. + +--- + +## 1. Utiliser Zarr avec OpenFisca + +### Oui, c’est possible + +Le backend **zarr** est disponible dans `openfisca-survey-manager` à condition d’installer la dépendance optionnelle : + +```bash +pip install openfisca-survey-manager[zarr] +# ou +pip install openfisca-survey-manager zarr numcodecs +``` + +(pandas 2.x utilise `to_zarr` / `read_zarr` ; le package **zarr** est requis.) + +### En ligne de commande (build-collection) + +Pour construire une collection en stockant les tables au format Zarr : + +```bash +build-collection -c ma_collection --zarr +``` + +Sans `--zarr`, le format par défaut reste HDF5 (avec avertissement) ou vous pouvez utiliser `--parquet`. + +### En Python (fill_store) + +```python +from openfisca_survey_manager.core.dataset import SurveyCollection + +col = SurveyCollection.load(collection="ma_collection", config_files_directory="...") +col.fill_store( + source_format="sas", # ou csv, parquet, etc. + store_format="zarr", +) +``` + +Après cela, chaque survey a un répertoire `{output}/{survey.name}.zarr`, et chaque table est un **groupe zarr** (sous-répertoire) dans ce store. La lecture se fait comme d’habitude avec `survey.get_values(table=..., variables=...)` ; le code utilise automatiquement le backend zarr si `store_format == "zarr"`. + +### Vérifier que Zarr est disponible + +```python +from openfisca_survey_manager.io.backends import get_available_backend_names, get_backend + +print(get_available_backend_names()) # doit contenir "zarr" si le package est installé +backend = get_backend("zarr") # lève ValueError si zarr absent +``` + +--- + +## 2. Compression + +### Comportement actuel + +Dans l’implémentation actuelle, l’écriture Zarr passe par `pandas.DataFrame.to_zarr(path, mode="w")` **sans options de compression explicites**. Zarr/pandas peuvent donc utiliser un comportement par défaut (par ex. compression légère ou aucune selon les versions). + +### Ce que Zarr permet en général + +Zarr gère la compression **par blocs (chunks)** via **numcodecs**. On peut utiliser par exemple : + +- **Blosc** (LZ4, Zstd, Zlib) : bon compromis vitesse / ratio, très utilisé +- **Zstd** : bon ratio, décompression rapide +- **LZ4** : très rapide, ratio moindre +- **Gzip** : standard, plus lent + +Ces options se configurent au moment de la **création** du tableau zarr (compressor, chunks). Avec **pandas** : + +- `df.to_zarr(path, ...)` peut accepter des arguments supplémentaires passés au store zarr sous-jacent (selon la version de pandas). +- Pour un contrôle fin (compression, chunking), on peut créer soi‑même un store zarr avec le bon `compressor` puis y écrire les colonnes, ou étendre le backend (voir ci‑dessous). + +### Évolution possible dans le survey-manager + +On peut faire évoluer le backend Zarr pour accepter des options (compression, chunks) soit : + +- via des **kwargs** dans `fill_store(..., store_format="zarr", **zarr_options)` transmis à `to_zarr`, +- soit via la **config** (manifest ou config.yaml) pour définir un compressor par défaut pour le format zarr. + +Aujourd’hui, si vous avez besoin d’une compression précise, vous pouvez : + +1. **Enregistrer un backend personnalisé** (`register_backend`) qui appelle `to_zarr` avec le `compressor` (et éventuellement les chunks) de votre choix. +2. Ou **post‑traiter** les répertoires `.zarr` générés (ré‑écriture avec d’autres options zarr) en dehors du survey-manager. + +--- + +## 3. Parallélisation lecture / écriture + +### Zarr en général + +- **Parallélisme par blocs** : Zarr est conçu pour que des **chunks différents** puissent être lus ou écrits en parallèle sans verrou global (chaque chunk est indépendant). +- **En Python** : le **GIL** limite le gain avec des threads pour la partie compression/décompression ; le parallélisme efficace passe souvent par **multi‑processus** ou des runtimes qui libèrent le GIL (Cython, C extensions utilisées par numcodecs/blosc). +- **Goulot d’étranglement** : en pratique, la **compression/décompression** peut saturer le CPU (~1 GB/s) alors que le disque ou le réseau peuvent aller plus vite ; des évolutions (batch encode/decode, GPU) sont en cours dans l’écosystème zarr. + +### Dans le survey-manager aujourd’hui + +- **Écriture** : `fill_store(store_format="zarr")` appelle `to_zarr` pour chaque table, de façon **séquentielle** (une table après l’autre, pas de parallélisation interne exposée). +- **Lecture** : `get_values()` utilise `read_zarr` pour une table donnée, également de façon **séquentielle** par appel. + +Donc **par défaut** : pas de parallélisation multi‑tables ni multi‑chunks exposée dans l’API actuelle. + +### Comment paralléliser quand même + +1. **Plusieurs tables / plusieurs surveys** + Vous pouvez paralléliser vous‑même au niveau applicatif : lancer plusieurs processus ou threads qui appellent `fill_store` (ou `get_values`) sur des collections/surveys/tables différents ; chaque processus écrira/lira ses propres fichiers ou groupes zarr sans conflit. + +2. **Dask** + Pour des tableaux zarr, **Dask** (dask.array, ou chargement des zarr en Dask) gère le chargement parallèle par chunks. Cela ne passe pas directement par l’API Survey/SurveyCollection actuelle : il faudrait soit exporter les chemins `.zarr` puis les ouvrir avec Dask, soit ajouter une couche d’intégration (p.ex. une fonction qui retourne un Dask DataFrame à partir d’un survey zarr). + +3. **Évolution du backend** + On pourrait ajouter plus tard un mode « écriture parallèle par table » (threads/processes) ou une option de lecture qui retourne un objet Dask pour exploiter le parallélisme par chunks côté zarr. + +--- + +## 4. Résumé pratique + +| Question | Réponse | +|----------|--------| +| **Utiliser Zarr avec OpenFisca ?** | Oui : `pip install openfisca-survey-manager[zarr]`, puis `build-collection --zarr` ou `fill_store(store_format="zarr")`. | +| **Compression ?** | Par défaut : comportement zarr/pandas (souvent léger). Pour plus de contrôle : backend personnalisé avec `to_zarr(..., compressor=...)` ou post‑traitement des stores zarr. | +| **Parallélisation lecture/écriture ?** | Pas exposée dans l’API actuelle (une table à la fois). Parallélisme possible : vous-même sur plusieurs tables/surveys, ou en utilisant Dask sur les chemins zarr générés. | + +Si vous voulez, on peut détailler une proposition d’API pour passer des options de compression (et éventuellement de chunking) au backend Zarr dans `fill_store` ou dans la config. diff --git a/openfisca_survey_manager/configuration/config_loader.py b/openfisca_survey_manager/configuration/config_loader.py index 75688214..44c982cf 100644 --- a/openfisca_survey_manager/configuration/config_loader.py +++ b/openfisca_survey_manager/configuration/config_loader.py @@ -91,6 +91,7 @@ def manifest_survey_to_json(survey_name: str, entry: dict[str, Any]) -> dict[str "label": entry.get("label", survey_name), "hdf5_file_path": None, "parquet_file_path": None, + "zarr_file_path": None, "tables": entry.get("tables"), "informations": informations, } diff --git a/openfisca_survey_manager/core/dataset.py b/openfisca_survey_manager/core/dataset.py index 322ccdd2..4eb008fe 100644 --- a/openfisca_survey_manager/core/dataset.py +++ b/openfisca_survey_manager/core/dataset.py @@ -151,11 +151,20 @@ def load( self.config = None self.output_directory = str(new_cfg["default_output_dir"]) self.surveys = [] + store_format = manifest.get("store_format", "parquet") + output_dir = Path(self.output_directory) for survey_name, entry in manifest.get("surveys", {}).items(): survey_json = manifest_survey_to_json(survey_name, entry) survey = Survey(name=survey_name) survey = survey.create_from_json(survey_json) survey.survey_collection = self + survey.store_format = store_format + if store_format == "hdf5": + survey.hdf5_file_path = str(output_dir / (survey.name + ".h5")) + elif store_format == "parquet": + survey.parquet_file_path = str(output_dir / survey.name) + elif store_format == "zarr": + survey.zarr_file_path = str(output_dir / (survey.name + ".zarr")) self.surveys.append(survey) return self diff --git a/openfisca_survey_manager/core/survey.py b/openfisca_survey_manager/core/survey.py index 18062309..05375ad9 100644 --- a/openfisca_survey_manager/core/survey.py +++ b/openfisca_survey_manager/core/survey.py @@ -15,6 +15,7 @@ from openfisca_survey_manager.core.table import Table from openfisca_survey_manager.exceptions import SurveyIOError, SurveyManagerError +from openfisca_survey_manager.io.backends import get_backend from openfisca_survey_manager.io.hdf import hdf5_safe_key from openfisca_survey_manager.processing.harmonization import harmonize_data_frame_columns @@ -46,6 +47,7 @@ class Survey: hdf5_file_path: Optional[str] = None parquet_file_path: Optional[str] = None + zarr_file_path: Optional[str] = None label: Optional[str] = None name: Optional[str] = None survey_collection: Optional[SurveyCollection] = None @@ -89,12 +91,16 @@ def __repr__(self) -> str: @classmethod def create_from_json(cls, survey_json: dict) -> Survey: + # Top-level store paths; exclude from informations to avoid duplicate kwargs + store_path_keys = {"hdf5_file_path", "parquet_file_path", "zarr_file_path"} + infos = {k: v for k, v in survey_json.get("informations", {}).items() if k not in store_path_keys} self = cls( name=survey_json.get("name"), label=survey_json.get("label"), hdf5_file_path=survey_json.get("hdf5_file_path"), parquet_file_path=survey_json.get("parquet_file_path"), - **survey_json.get("informations", {}), + zarr_file_path=survey_json.get("zarr_file_path"), + **infos, ) self.tables = survey_json.get("tables") return self @@ -137,6 +143,9 @@ def fill_store( if store_format == "parquet" and survey.parquet_file_path is None: survey.parquet_file_path = str(Path(directory_path) / survey.name) + if store_format == "zarr" and survey.zarr_file_path is None: + survey.zarr_file_path = str(Path(directory_path) / (survey.name + ".zarr")) + self.store_format = store_format if source_format is not None: @@ -276,6 +285,23 @@ def _get_values_from_parquet( return pq.ParquetDataset(parquet_file).read(columns=variables).to_pandas() raise SurveyIOError(f"No table {table} found in {self.parquet_file_path}") + def _get_values_from_zarr( + self, + table: str, + variables: Optional[List[str]] = None, + **kwargs: Any, + ) -> pandas.DataFrame: + """Read table from zarr store.""" + if self.zarr_file_path is None: + raise SurveyIOError("No zarr store path for survey") + backend = get_backend("zarr") + return backend.read_table( + self.zarr_file_path, + table, + variables=variables, + **kwargs, + ) + def get_values( self, variables: Optional[List[str]] = None, @@ -287,9 +313,11 @@ def get_values( batch_index: int = 0, filter_by: Optional[List[tuple]] = None, ) -> pandas.DataFrame: - if self.parquet_file_path is None and self.hdf5_file_path is None: + if self.parquet_file_path is None and self.hdf5_file_path is None and self.zarr_file_path is None: raise SurveyIOError(f"No data file found for survey {self.name}") - if self.hdf5_file_path is not None: + if self.store_format == "zarr" and self.zarr_file_path is not None: + df = self._get_values_from_zarr(table or "", variables=variables) + elif self.hdf5_file_path is not None: df, _ = self._get_values_from_hdf5(table or "", ignorecase=ignorecase) else: df = self._get_values_from_parquet(table, variables, filter_by, batch_size, batch_index) diff --git a/openfisca_survey_manager/core/table.py b/openfisca_survey_manager/core/table.py index 256b3625..4fe3fa76 100644 --- a/openfisca_survey_manager/core/table.py +++ b/openfisca_survey_manager/core/table.py @@ -17,6 +17,7 @@ from pyarrow import parquet as pq from openfisca_survey_manager.exceptions import SurveyIOError +from openfisca_survey_manager.io.backends import get_backend from openfisca_survey_manager.io.readers import read_sas from openfisca_survey_manager.io.writers import write_table_to_hdf5, write_table_to_parquet from openfisca_survey_manager.processing.cleaning import clean_data_frame @@ -91,18 +92,36 @@ def _check_and_log(self, data_file_path: str, store_file_path: Optional[str]) -> f"at point {self.name}" ) - def _is_stored(self) -> bool: + def _get_store_path_and_format(self) -> Optional[tuple[str, str]]: + """Return (store_path, store_format) for the survey's current backend, or None.""" + fmt = getattr(self.survey, "store_format", None) or "hdf5" + if fmt == "hdf5" and self.survey.hdf5_file_path is not None: + return (self.survey.hdf5_file_path, "hdf5") + if fmt == "parquet" and self.survey.parquet_file_path is not None: + return (self.survey.parquet_file_path, "parquet") + if fmt == "zarr" and getattr(self.survey, "zarr_file_path", None) is not None: + return (self.survey.zarr_file_path, "zarr") if self.survey.hdf5_file_path is not None: - store = pandas.HDFStore(self.survey.hdf5_file_path) - if self.name in store: - log.info(f"Exiting without overwriting {self.name} in {self.survey.hdf5_file_path}") - store.close() - return True + return (self.survey.hdf5_file_path, "hdf5") + if self.survey.parquet_file_path is not None: + return (self.survey.parquet_file_path, "parquet") + return None - store.close() - return False - else: + def _is_stored(self) -> bool: + path_fmt = self._get_store_path_and_format() + if path_fmt is None: return False + store_path, store_format = path_fmt + backend = get_backend(store_format) + if backend.table_exists(store_path, self.name): + log.info( + "Exiting without overwriting %s in %s (%s)", + self.name, + store_path, + store_format, + ) + return True + return False def _save( self, @@ -122,16 +141,28 @@ def _save( ) data_frame = data_frame[stored_variables].copy() - assert store_format in ["hdf5", "parquet"], f"invalid store_format: {store_format}" + path_fmt = self._get_store_path_and_format() + if path_fmt is None: + raise SurveyIOError( + f"No store path set for survey (store_format={store_format}). " + "Set hdf5_file_path, parquet_file_path, or zarr_file_path." + ) + store_path, resolved_format = path_fmt + if store_format != resolved_format: + store_format = resolved_format + backend = get_backend(store_format) if store_format == "hdf5": log.warning( - "HDF5 will no longer be the default format in a future version. Please use parquet format instead." + "HDF5 will no longer be the default format in a future version. " + "Please use parquet or zarr format instead." ) - self.save_data_frame_to_hdf5(data_frame) - else: - parquet_file_path = self.survey.parquet_file_path - log.info(f"Inserting table {self.name} in Parquet file {parquet_file_path}") - self.save_data_frame_to_parquet(data_frame) + log.info("Inserting table %s in %s store %s", self.name, store_format, store_path) + backend.write_table(store_path, self.name, data_frame) + self.variables = list(data_frame.columns) + self.survey.tables[self.name]["variables"] = self.variables + if store_format == "parquet": + self.parquet_file = str(Path(store_path) / f"{self.name}.parquet") + self.survey.tables[self.name]["parquet_file"] = self.parquet_file gc.collect() def fill_store( @@ -142,7 +173,6 @@ def fill_store( **kwargs: Any, ) -> None: if not overwrite and self._is_stored(): - log.info(f"Exiting without overwriting {self.name} in {self.survey.hdf5_file_path}") return start_table_time = datetime.datetime.now() @@ -242,9 +272,8 @@ def _apply_stata_categorical_strategy( def read_source(self, data_file: str, **kwargs: Any) -> pandas.DataFrame: source_format = self.source_format - store_file_path = ( - self.survey.hdf5_file_path if self.survey.store_format == "hdf5" else self.survey.parquet_file_path - ) + path_fmt = self._get_store_path_and_format() + store_file_path = path_fmt[0] if path_fmt else None self._check_and_log(data_file, store_file_path=store_file_path) reader = reader_by_source_format[source_format] categorical_strategy = ( diff --git a/openfisca_survey_manager/io/backends.py b/openfisca_survey_manager/io/backends.py new file mode 100644 index 00000000..a6f8273e --- /dev/null +++ b/openfisca_survey_manager/io/backends.py @@ -0,0 +1,227 @@ +"""Store backends for survey tables: HDF5, Parquet, Zarr. + +Allows choosing the storage format (backend) when building or filling the store. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Optional, Protocol + +import pandas as pd + +from openfisca_survey_manager.io.hdf import hdf5_safe_key, write_table_to_hdf5 +from openfisca_survey_manager.io.writers import write_table_to_parquet + +log = logging.getLogger(__name__) + + +# Supported store format names (zarr only if zarr package is installed) +def get_available_backend_names() -> tuple[str, ...]: + return tuple(_backends.keys()) + + +class StoreBackend(Protocol): + """Protocol for a store backend (write/read tables).""" + + def write_table( + self, + store_path: str, + table_name: str, + data_frame: pd.DataFrame, + **kwargs: Any, + ) -> Optional[str]: + """Write a table. Returns path used for the table (e.g. file path) or None.""" + ... + + def read_table( + self, + store_path: str, + table_name: str, + variables: Optional[list[str]] = None, + **kwargs: Any, + ) -> pd.DataFrame: + """Read a table as DataFrame.""" + ... + + def table_exists(self, store_path: str, table_name: str) -> bool: + """Return True if the table exists in the store.""" + ... + + +class HDF5Backend: + """Store tables in a single HDF5 file.""" + + def write_table( + self, + store_path: str, + table_name: str, + data_frame: pd.DataFrame, + **kwargs: Any, + ) -> Optional[str]: + write_table_to_hdf5( + data_frame, + hdf5_file_path=store_path, + store_path=table_name, + **kwargs, + ) + return None + + def read_table( + self, + store_path: str, + table_name: str, + variables: Optional[list[str]] = None, + **kwargs: Any, + ) -> pd.DataFrame: + key = hdf5_safe_key(table_name) + store = pd.HDFStore(store_path, "r") + try: + df = store.select(key) + finally: + store.close() + if variables is not None: + df = df[[c for c in variables if c in df.columns]] + return df + + def table_exists(self, store_path: str, table_name: str) -> bool: + if not Path(store_path).is_file(): + return False + key = hdf5_safe_key(table_name) + store = pd.HDFStore(store_path, "r") + try: + keys = store.keys() + return key in keys or any(k.lstrip("/") == key for k in keys) + finally: + store.close() + + +class ParquetBackend: + """Store each table as a parquet file in a directory (store_path/table_name.parquet).""" + + def write_table( + self, + store_path: str, + table_name: str, + data_frame: pd.DataFrame, + **kwargs: Any, + ) -> Optional[str]: + return write_table_to_parquet( + data_frame, + parquet_dir_path=store_path, + table_name=table_name, + ) + + def read_table( + self, + store_path: str, + table_name: str, + variables: Optional[list[str]] = None, + **kwargs: Any, + ) -> pd.DataFrame: + path = Path(store_path) / f"{table_name}.parquet" + if not path.is_file(): + raise FileNotFoundError(f"No table {table_name} at {path}") + return pd.read_parquet(path, columns=variables) + + def table_exists(self, store_path: str, table_name: str) -> bool: + return (Path(store_path) / f"{table_name}.parquet").is_file() + + +def _write_table_to_zarr( + data_frame: pd.DataFrame, + zarr_dir_path: str, + table_name: str, +) -> str: + """Write a DataFrame to a zarr group (store_path/table_name).""" + import pandas as pd + + zarr_path = str(Path(zarr_dir_path) / table_name) + Path(zarr_path).parent.mkdir(parents=True, exist_ok=True) + # Object columns can cause issues; coerce to string like parquet backend + for col in data_frame.columns: + if data_frame[col].dtype == "object": + try: + data_frame[col] = data_frame[col].astype(str) + except Exception: + data_frame[col] = data_frame[col].apply(lambda x: str(x) if pd.notna(x) else None) + data_frame.to_zarr(zarr_path, mode="w") + return zarr_path + + +def _read_table_from_zarr( + zarr_dir_path: str, + table_name: str, + variables: Optional[list[str]] = None, +) -> pd.DataFrame: + """Read a table from a zarr group.""" + zarr_path = str(Path(zarr_dir_path) / table_name) + df = pd.read_zarr(zarr_path) + if variables is not None: + df = df[[c for c in variables if c in df.columns]] + return df + + +class ZarrBackend: + """Store each table as a zarr group in a directory (store_path/table_name).""" + + def write_table( + self, + store_path: str, + table_name: str, + data_frame: pd.DataFrame, + **kwargs: Any, + ) -> Optional[str]: + return _write_table_to_zarr(data_frame, store_path, table_name) + + def read_table( + self, + store_path: str, + table_name: str, + variables: Optional[list[str]] = None, + **kwargs: Any, + ) -> pd.DataFrame: + return _read_table_from_zarr(store_path, table_name, variables) + + def table_exists(self, store_path: str, table_name: str) -> bool: + return (Path(store_path) / table_name).is_dir() + + +def _build_backends() -> dict[str, StoreBackend]: + backends: dict[str, StoreBackend] = { + "hdf5": HDF5Backend(), + "parquet": ParquetBackend(), + } + try: + import zarr # noqa: F401 + + backends["zarr"] = ZarrBackend() + except ImportError: + log.debug("zarr not installed; zarr store backend unavailable") + return backends + + +_backends = _build_backends() + +STORE_BACKEND_NAMES = get_available_backend_names() + +__all__ = [ + "get_backend", + "get_available_backend_names", + "register_backend", + "StoreBackend", + "STORE_BACKEND_NAMES", +] + + +def get_backend(name: str) -> StoreBackend: + """Return the store backend for the given format name.""" + if name not in _backends: + raise ValueError(f"Unknown store backend: {name}. Choose from {list(_backends.keys())}") + return _backends[name] + + +def register_backend(name: str, backend: StoreBackend) -> None: + """Register a custom store backend (e.g. for testing or extensions).""" + _backends[name] = backend diff --git a/openfisca_survey_manager/scripts/build_collection.py b/openfisca_survey_manager/scripts/build_collection.py index 6a2b4ca9..51b368e0 100755 --- a/openfisca_survey_manager/scripts/build_collection.py +++ b/openfisca_survey_manager/scripts/build_collection.py @@ -243,9 +243,14 @@ def main(): "--parquet", action="store_true", default=False, + help="save data in parquet format (directory with one .parquet file per table)", + ) + parser.add_argument( + "--zarr", + action="store_true", + default=False, help=( - "save data in parquet format instead of HDF5 " - "(HDF5 will no longer be the default format in a future version)" + "save data in zarr format (one zarr group per table); requires: pip install openfisca-survey-manager[zarr]" ), ) parser.add_argument( @@ -283,13 +288,16 @@ def main(): start_time = datetime.datetime.now() - # Determine store format based on argument - store_format = "parquet" if args.parquet else "hdf5" - - if not args.parquet: + # Determine store format based on argument (--zarr > --parquet > default hdf5) + if args.zarr: + store_format = "zarr" + elif args.parquet: + store_format = "parquet" + else: + store_format = "hdf5" log.warning( "HDF5 will no longer be the default format in a future version. " - "Please use --parquet option to save data in parquet format." + "Please use --parquet or --zarr to save data in parquet or zarr format." ) try: diff --git a/openfisca_survey_manager/scripts/migrate_config_to_rfc002.py b/openfisca_survey_manager/scripts/migrate_config_to_rfc002.py index 426e5ae7..acd42d1f 100644 --- a/openfisca_survey_manager/scripts/migrate_config_to_rfc002.py +++ b/openfisca_survey_manager/scripts/migrate_config_to_rfc002.py @@ -84,7 +84,25 @@ def build_manifest_from_json( } if survey_obj.get("output_subdir"): surveys[survey_name]["output_subdir"] = survey_obj["output_subdir"] - return {"name": name, "label": label, "surveys": surveys} + + store_format = _infer_store_format_from_legacy(surveys_data) + return {"name": name, "label": label, "store_format": store_format, "surveys": surveys} + + +def _infer_store_format_from_legacy(surveys_data: dict) -> str: + """Infer store_format from legacy JSON surveys (parquet_file_path, zarr_file_path, hdf5_file_path).""" + if not isinstance(surveys_data, dict): + return "parquet" + for survey_obj in surveys_data.values(): + if not isinstance(survey_obj, dict): + continue + if survey_obj.get("zarr_file_path"): + return "zarr" + if survey_obj.get("parquet_file_path"): + return "parquet" + if survey_obj.get("hdf5_file_path"): + return "hdf5" + return "parquet" def load_raw_data_ini(config_dir: Path) -> configparser.ConfigParser | None: diff --git a/openfisca_survey_manager/tests/test_config_manifest_rfc002.py b/openfisca_survey_manager/tests/test_config_manifest_rfc002.py index ae46f0ee..9f5da00c 100644 --- a/openfisca_survey_manager/tests/test_config_manifest_rfc002.py +++ b/openfisca_survey_manager/tests/test_config_manifest_rfc002.py @@ -120,6 +120,51 @@ def test_survey_collection_load_from_manifest(rfc002_config_dir): survey_a = col.get_survey("survey_a") assert survey_a.label == "Survey A" assert survey_a.informations.get("csv_files") == ["/data/survey_a"] + # Default store_format when missing in manifest is parquet + assert survey_a.store_format == "parquet" + assert survey_a.parquet_file_path is not None + assert "survey_a" in survey_a.parquet_file_path + + +def test_survey_collection_load_from_manifest_store_format_zarr(tmp_path): + """When manifest has store_format: zarr, surveys get zarr_file_path set.""" + config_dir = tmp_path / "config" + config_dir.mkdir() + (config_dir / "config.yaml").write_text( + f""" +collections_dir: {tmp_path / "collections"} +default_output_dir: {tmp_path / "output"} +tmp_dir: {tmp_path / "tmp"} +""" + ) + collections_dir = tmp_path / "collections" + collections_dir.mkdir() + dataset_dir = collections_dir / "zarr_dataset" + dataset_dir.mkdir() + (dataset_dir / "manifest.yaml").write_text( + """ +name: zarr_dataset +label: "Zarr dataset" +store_format: zarr +surveys: + s1: + label: "Survey 1" + source: + format: csv + path: /data/s1 +""" + ) + col = SurveyCollection.load( + collection="zarr_dataset", + config_files_directory=config_dir, + ) + assert col.output_directory is not None + survey_s1 = col.get_survey("s1") + assert survey_s1.store_format == "zarr" + assert survey_s1.zarr_file_path is not None + assert survey_s1.zarr_file_path.endswith(".zarr") + assert survey_s1.hdf5_file_path is None + assert survey_s1.parquet_file_path is None def test_survey_collection_load_legacy_unchanged(tmp_path): @@ -208,12 +253,43 @@ def test_migrate_produces_config_yaml_and_manifests(legacy_config_dir): manifest = load_manifest(cfg["collections_dir"], "my_collection") assert manifest is not None assert manifest["name"] == "my_collection" + assert manifest.get("store_format") == "parquet" assert manifest["surveys"]["survey_1"]["source"]["format"] == "csv" assert manifest["surveys"]["survey_1"]["source"]["path"] == "/data/s1/file1.csv" assert manifest["surveys"]["survey_2"]["source"]["format"] == "sas" assert manifest["surveys"]["survey_2"]["source"]["path"] == "/data/s2/file.sas7bdat" +def test_migrate_infers_store_format_from_legacy(tmp_path): + """Migration infers store_format from legacy JSON (hdf5_file_path -> hdf5, etc.).""" + from openfisca_survey_manager.scripts.migrate_config_to_rfc002 import ( + _infer_store_format_from_legacy, + build_manifest_from_json, + ) + + # Legacy with parquet_file_path + json_parquet = tmp_path / "p.json" + json_parquet.write_text( + '{"name":"p","label":"P","surveys":{"s1":{"label":"S1","parquet_file_path":"/out/s1"}}}', + encoding="utf-8", + ) + manifest_parquet = build_manifest_from_json(json_parquet, None) + assert manifest_parquet["store_format"] == "parquet" + + # Legacy with hdf5_file_path only + json_hdf5 = tmp_path / "h.json" + json_hdf5.write_text( + '{"name":"h","label":"H","surveys":{"s1":{"label":"S1","hdf5_file_path":"/out/s1.h5"}}}', + encoding="utf-8", + ) + manifest_hdf5 = build_manifest_from_json(json_hdf5, None) + assert manifest_hdf5["store_format"] == "hdf5" + + # Infer function directly + assert _infer_store_format_from_legacy({}) == "parquet" + assert _infer_store_format_from_legacy({"s": {"zarr_file_path": "/z"}}) == "zarr" + + def test_migrate_dry_run_does_not_write(legacy_config_dir): """Migration with --dry-run does not create files.""" from openfisca_survey_manager.scripts.migrate_config_to_rfc002 import ( diff --git a/openfisca_survey_manager/tests/test_store_backends.py b/openfisca_survey_manager/tests/test_store_backends.py new file mode 100644 index 00000000..6e08e56b --- /dev/null +++ b/openfisca_survey_manager/tests/test_store_backends.py @@ -0,0 +1,50 @@ +"""Tests for store backends (HDF5, Parquet, Zarr).""" + +from pathlib import Path + +import pandas as pd +import pytest + +from openfisca_survey_manager.io.backends import ( + get_available_backend_names, + get_backend, +) + + +def test_get_backend_hdf5(): + backend = get_backend("hdf5") + assert backend is not None + assert hasattr(backend, "write_table") and hasattr(backend, "read_table") + assert hasattr(backend, "table_exists") + + +def test_get_backend_parquet(): + backend = get_backend("parquet") + assert backend is not None + + +def test_get_backend_invalid_raises(): + with pytest.raises(ValueError, match="Unknown store backend"): + get_backend("invalid_format") + + +def test_parquet_backend_roundtrip(tmp_path): + backend = get_backend("parquet") + store_path = str(tmp_path / "survey") + store_path_path = Path(store_path) + store_path_path.mkdir(parents=True) + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + backend.write_table(store_path, "mytable", df) + assert backend.table_exists(store_path, "mytable") + df2 = backend.read_table(store_path, "mytable") + pd.testing.assert_frame_equal(df, df2) + df3 = backend.read_table(store_path, "mytable", variables=["a"]) + assert list(df3.columns) == ["a"] + + +def test_available_backends_include_hdf5_parquet(): + names = get_available_backend_names() + assert "hdf5" in names + assert "parquet" in names + # zarr only if zarr package installed + assert len(names) >= 2 diff --git a/pyproject.toml b/pyproject.toml index f9888837..55480ebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "OpenFisca-Survey-Manager" -version = "1.0.0" +version = "1.1.0" description = "A tool for managing survey/administrative data and import them in OpenFisca" readme = "README.md" keywords = ["microsimulation", "tax", "benefit", "rac", "rules-as-code", "survey", "data"] @@ -52,6 +52,9 @@ build-backend = "setuptools.build_meta" build-collection = "openfisca_survey_manager.scripts.build_collection:main" [project.optional-dependencies] +zarr = [ + 'zarr >=2.18.0, < 3.0', +] matching = [ # 'feather', 'rpy2 >=3.5.10, < 4.0' diff --git a/uv.lock b/uv.lock index 5b3118ad..a5c605b4 100644 --- a/uv.lock +++ b/uv.lock @@ -9,6 +9,12 @@ resolution-markers = [ "python_full_version < '3.10'", ] +[[package]] +name = "asciitree" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/6a/885bc91484e1aa8f618f6f0228d76d0e67000b0fdd6090673b777e311913/asciitree-0.3.3.tar.gz", hash = "sha256:4aa4b9b649f85e3fcb343363d97564aa1fb62e249677f2e18a96765145cc0f6e", size = 3951 } + [[package]] name = "asttokens" version = "3.0.0" @@ -439,6 +445,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190 }, ] +[[package]] +name = "deprecated" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wrapt", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/85/12f0a49a7c4ffb70572b6c2ef13c90c88fd190debda93b23f026b25f9634/deprecated-1.3.1.tar.gz", hash = "sha256:b1b50e0ff0c1fddaa5708a2c6b0a6588bb09b892825ab2b214ac9ea9d92a5223", size = 2932523 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298 }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -497,6 +515,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/c3/6f0e3896f193528bbd2b4d2122d4be8108a37efab0b8475855556a8c4afa/fancycompleter-0.11.1-py3-none-any.whl", hash = "sha256:44243d7fab37087208ca5acacf8f74c0aa4d733d04d593857873af7513cdf8a6", size = 11207 }, ] +[[package]] +name = "fasteners" +version = "0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/18/7881a99ba5244bfc82f06017316ffe93217dbbbcfa52b887caa1d4f2a6d3/fasteners-0.20.tar.gz", hash = "sha256:55dce8792a41b56f727ba6e123fcaee77fd87e638a6863cec00007bfea84c8d8", size = 25087 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/ac/e5d886f892666d2d1e5cb8c1a41146e1d79ae8896477b1153a21711d3b44/fasteners-0.20-py3-none-any.whl", hash = "sha256:9422c40d1e350e4259f509fb2e608d6bc43c0136f79a00db1b49046029d0b3b7", size = 18702 }, +] + [[package]] name = "filelock" version = "3.19.1" @@ -999,6 +1026,96 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, ] +[[package]] +name = "numcodecs" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/1b/1f1d880e29e719c7c6205065d1afbc91114c0d91935ac419faa43e5e08b0/numcodecs-0.12.1.tar.gz", hash = "sha256:05d91a433733e7eef268d7e80ec226a0232da244289614a8f3826901aec1098e", size = 4091415 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/66/08744c9007f1d02476dd97f3c23032f3555dbb8e9a32b0f0ea4724e6b2a2/numcodecs-0.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d37f628fe92b3699e65831d5733feca74d2e33b50ef29118ffd41c13c677210e", size = 1696843 }, + { url = "https://files.pythonhosted.org/packages/b8/6f/a04a33c5edb8fa9ba63783d34ff5768ba6b562ebe11078c07848e283f4ad/numcodecs-0.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:941b7446b68cf79f089bcfe92edaa3b154533dcbcd82474f994b28f2eedb1c60", size = 1422578 }, + { url = "https://files.pythonhosted.org/packages/1e/b8/1040f299803eacc9c522fdc69a4dafc42ad0e8722bb48aa43d2310cf195b/numcodecs-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e79bf9d1d37199ac00a60ff3adb64757523291d19d03116832e600cac391c51", size = 7709402 }, + { url = "https://files.pythonhosted.org/packages/8c/fa/da0637e1a6db74361a2875425021957859749166c0174ddedbb629518970/numcodecs-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:82d7107f80f9307235cb7e74719292d101c7ea1e393fe628817f0d635b7384f5", size = 790204 }, + { url = "https://files.pythonhosted.org/packages/10/63/a50f4113a2bb1decfaedeffc448c5f8b26ded1c583247c893120fcd25e3e/numcodecs-0.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eeaf42768910f1c6eebf6c1bb00160728e62c9343df9e2e315dc9fe12e3f6071", size = 1696786 }, + { url = "https://files.pythonhosted.org/packages/92/77/0fde34bf3a8402d696218a565230097d904c9eebb62cd952923b1155b7f7/numcodecs-0.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:135b2d47563f7b9dc5ee6ce3d1b81b0f1397f69309e909f1a35bb0f7c553d45e", size = 1422330 }, + { url = "https://files.pythonhosted.org/packages/14/e6/8f9d4a498a06f11a06297f0b02af9968844d2e40ee79d372ccee33595285/numcodecs-0.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a191a8e347ecd016e5c357f2bf41fbcb026f6ffe78fff50c77ab12e96701d155", size = 7949787 }, + { url = "https://files.pythonhosted.org/packages/08/f3/44597198c2cfb0d808d68583445b60b0d0ae057f20f0caf2a1200405655e/numcodecs-0.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:21d8267bd4313f4d16f5b6287731d4c8ebdab236038f29ad1b0e93c9b2ca64ee", size = 790313 }, + { url = "https://files.pythonhosted.org/packages/d7/b2/7842675a798e79686d14a20baa554b165aab86feac28f32695266ab42b7e/numcodecs-0.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2f84df6b8693206365a5b37c005bfa9d1be486122bde683a7b6446af4b75d862", size = 1697725 }, + { url = "https://files.pythonhosted.org/packages/fc/1f/e3b033181a28ce153fd0c9acd3ed978ee9c424de7cc3d8e97fc60647eddf/numcodecs-0.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:760627780a8b6afdb7f942f2a0ddaf4e31d3d7eea1d8498cf0fd3204a33c4618", size = 1423927 }, + { url = "https://files.pythonhosted.org/packages/3b/88/fb3186f944b9586e9c4c54bd1d1899947b88465ad3ab1ff1111066871644/numcodecs-0.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c258bd1d3dfa75a9b708540d23b2da43d63607f9df76dfa0309a7597d1de3b73", size = 7944856 }, + { url = "https://files.pythonhosted.org/packages/f4/03/54e22e273d584e83100ffa60c47c29cae905015ecb1f693918072c3595b9/numcodecs-0.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e04649ea504aff858dbe294631f098fbfd671baf58bfc04fc48d746554c05d67", size = 787000 }, + { url = "https://files.pythonhosted.org/packages/dd/3c/950f816b837fc7714102b45491e2612b10757106f9a8e3785d7b3806acd4/numcodecs-0.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2fbb12a6a1abe95926f25c65e283762d63a9bf9e43c0de2c6a1a798347dfcb40", size = 1700073 }, + { url = "https://files.pythonhosted.org/packages/76/2f/19f4f012f253ff33948a024e0a814c758ea137e3ba86118daac83a8d9123/numcodecs-0.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f2207871868b2464dc11c513965fd99b958a9d7cde2629be7b2dc84fdaab013b", size = 1425835 }, + { url = "https://files.pythonhosted.org/packages/6d/0f/0442e80d707b5dd2e177a9490c25b89aa6a6c44579de8ec223e78a8884da/numcodecs-0.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abff3554a6892a89aacf7b642a044e4535499edf07aeae2f2e6e8fc08c9ba07f", size = 7722207 }, + { url = "https://files.pythonhosted.org/packages/77/b6/345f8648874a81232bc1a87e55a771430488a832c68f873aa6ed23a1dedf/numcodecs-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef964d4860d3e6b38df0633caf3e51dc850a6293fd8e93240473642681d95136", size = 792870 }, +] + +[[package]] +name = "numcodecs" +version = "0.13.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.10.*'", +] +dependencies = [ + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/56/8895a76abe4ec94ebd01eeb6d74f587bc4cddd46569670e1402852a5da13/numcodecs-0.13.1.tar.gz", hash = "sha256:a3cf37881df0898f3a9c0d4477df88133fe85185bffe57ba31bcc2fa207709bc", size = 5955215 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/c0/6d72cde772bcec196b7188731d41282993b2958440f77fdf0db216f722da/numcodecs-0.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:96add4f783c5ce57cc7e650b6cac79dd101daf887c479a00a29bc1487ced180b", size = 1580012 }, + { url = "https://files.pythonhosted.org/packages/94/1d/f81fc1fa9210bbea97258242393a1f9feab4f6d8fb201f81f76003005e4b/numcodecs-0.13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:237b7171609e868a20fd313748494444458ccd696062f67e198f7f8f52000c15", size = 1176919 }, + { url = "https://files.pythonhosted.org/packages/16/e4/b9ec2f4dfc34ecf724bc1beb96a9f6fa9b91801645688ffadacd485089da/numcodecs-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96e42f73c31b8c24259c5fac6adba0c3ebf95536e37749dc6c62ade2989dca28", size = 8625842 }, + { url = "https://files.pythonhosted.org/packages/fe/90/299952e1477954ec4f92813fa03e743945e3ff711bb4f6c9aace431cb3da/numcodecs-0.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:eda7d7823c9282e65234731fd6bd3986b1f9e035755f7fed248d7d366bb291ab", size = 828638 }, + { url = "https://files.pythonhosted.org/packages/f0/78/34b8e869ef143e88d62e8231f4dbfcad85e5c41302a11fc5bd2228a13df5/numcodecs-0.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2eda97dd2f90add98df6d295f2c6ae846043396e3d51a739ca5db6c03b5eb666", size = 1580199 }, + { url = "https://files.pythonhosted.org/packages/3b/cf/f70797d86bb585d258d1e6993dced30396f2044725b96ce8bcf87a02be9c/numcodecs-0.13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2a86f5367af9168e30f99727ff03b27d849c31ad4522060dde0bce2923b3a8bc", size = 1177203 }, + { url = "https://files.pythonhosted.org/packages/a8/b5/d14ad69b63fde041153dfd05d7181a49c0d4864de31a7a1093c8370da957/numcodecs-0.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233bc7f26abce24d57e44ea8ebeb5cd17084690b4e7409dd470fdb75528d615f", size = 8868743 }, + { url = "https://files.pythonhosted.org/packages/13/d4/27a7b5af0b33f6d61e198faf177fbbf3cb83ff10d9d1a6857b7efc525ad5/numcodecs-0.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:796b3e6740107e4fa624cc636248a1580138b3f1c579160f260f76ff13a4261b", size = 829603 }, + { url = "https://files.pythonhosted.org/packages/37/3a/bc09808425e7d3df41e5fc73fc7a802c429ba8c6b05e55f133654ade019d/numcodecs-0.13.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5195bea384a6428f8afcece793860b1ab0ae28143c853f0b2b20d55a8947c917", size = 1575806 }, + { url = "https://files.pythonhosted.org/packages/3a/cc/dc74d0bfdf9ec192332a089d199f1e543e747c556b5659118db7a437dcca/numcodecs-0.13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3501a848adaddce98a71a262fee15cd3618312692aa419da77acd18af4a6a3f6", size = 1178233 }, + { url = "https://files.pythonhosted.org/packages/d4/ce/434e8e3970b8e92ae9ab6d9db16cb9bc7aa1cd02e17c11de6848224100a1/numcodecs-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da2230484e6102e5fa3cc1a5dd37ca1f92dfbd183d91662074d6f7574e3e8f53", size = 8857827 }, + { url = "https://files.pythonhosted.org/packages/83/e7/1d8b1b266a92f9013c755b1c146c5ad71a2bff147ecbc67f86546a2e4d6a/numcodecs-0.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5db4824ebd5389ea30e54bc8aeccb82d514d28b6b68da6c536b8fa4596f4bca", size = 826539 }, + { url = "https://files.pythonhosted.org/packages/83/8b/06771dead2cc4a8ae1ea9907737cf1c8d37a323392fa28f938a586373468/numcodecs-0.13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7a60d75179fd6692e301ddfb3b266d51eb598606dcae7b9fc57f986e8d65cb43", size = 1571660 }, + { url = "https://files.pythonhosted.org/packages/f9/ea/d925bf85f92dfe4635356018da9fe4bfecb07b1c72f62b01c1bc47f936b1/numcodecs-0.13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3f593c7506b0ab248961a3b13cb148cc6e8355662ff124ac591822310bc55ecf", size = 1169925 }, + { url = "https://files.pythonhosted.org/packages/0f/d6/643a3839d571d8e439a2c77dc4b0b8cab18d96ac808e4a81dbe88e959ab6/numcodecs-0.13.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80d3071465f03522e776a31045ddf2cfee7f52df468b977ed3afdd7fe5869701", size = 8814257 }, + { url = "https://files.pythonhosted.org/packages/a6/c5/f3e56bc9b4e438a287fff738993d6d11abef368c0328a612ac2842ba9fca/numcodecs-0.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:90d3065ae74c9342048ae0046006f99dcb1388b7288da5a19b3bddf9c30c3176", size = 821887 }, +] + +[[package]] +name = "numcodecs" +version = "0.15.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13'", + "python_full_version == '3.12.*'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "deprecated", marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, + { name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/63/fc/bb532969eb8236984ba65e4f0079a7da885b8ac0ce1f0835decbb3938a62/numcodecs-0.15.1.tar.gz", hash = "sha256:eeed77e4d6636641a2cc605fbc6078c7a8f2cc40f3dfa2b3f61e52e6091b04ff", size = 6267275 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/fc/410f1cacaef0931f5daf06813b1b8a2442f7418ee284ec73fe5e830dca48/numcodecs-0.15.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:698f1d59511488b8fe215fadc1e679a4c70d894de2cca6d8bf2ab770eed34dfd", size = 1649501 }, + { url = "https://files.pythonhosted.org/packages/85/29/dff62fae04323035912c419a82dc9624fad7d08541dbfcd9ab78a3a40074/numcodecs-0.15.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bef8c8e64fab76677324a07672b10c31861775d03fc63ed5012ca384144e4bb9", size = 1187306 }, + { url = "https://files.pythonhosted.org/packages/a6/a8/908a226632ffabf19caf8c99f1b2898f2f22aac02795a6fe9d018fd6d9dd/numcodecs-0.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdfaef9f5f2ed8f65858db801f1953f1007c9613ee490a1c56233cd78b505ed5", size = 8891971 }, + { url = "https://files.pythonhosted.org/packages/2b/e8/058aac43e1300d588e99b2d0d5b771c8a43fa92ce9c9517da596869fc146/numcodecs-0.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:e2547fa3a7ffc9399cfd2936aecb620a3db285f2630c86c8a678e477741a4b3c", size = 840035 }, + { url = "https://files.pythonhosted.org/packages/e7/7e/f12fc32d3beedc6a8f1ec69ea0ba72e93cb99c0350feed2cff5d04679bc3/numcodecs-0.15.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b0a9d9cd29a0088220682dda4a9898321f7813ff7802be2bbb545f6e3d2f10ff", size = 1691889 }, + { url = "https://files.pythonhosted.org/packages/81/38/88e40d40288b73c3b3a390ed5614a34b0661d00255bdd4cfb91c32101364/numcodecs-0.15.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a34f0fe5e5f3b837bbedbeb98794a6d4a12eeeef8d4697b523905837900b5e1c", size = 1189149 }, + { url = "https://files.pythonhosted.org/packages/28/7d/7527d9180bc76011d6163c848c9cf02cd28a623c2c66cf543e1e86de7c5e/numcodecs-0.15.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3a09e22140f2c691f7df26303ff8fa2dadcf26d7d0828398c0bc09b69e5efa3", size = 8879163 }, + { url = "https://files.pythonhosted.org/packages/ab/bc/b6c3cde91c754860a3467a8c058dcf0b1a5ca14d82b1c5397c700cf8b1eb/numcodecs-0.15.1-cp312-cp312-win_amd64.whl", hash = "sha256:daed6066ffcf40082da847d318b5ab6123d69ceb433ba603cb87c323a541a8bc", size = 836785 }, + { url = "https://files.pythonhosted.org/packages/78/57/acbc54b3419e5be65015e47177c76c0a73e037fd3ae2cde5808169194d4d/numcodecs-0.15.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e3d82b70500cf61e8d115faa0d0a76be6ecdc24a16477ee3279d711699ad85f3", size = 1688220 }, + { url = "https://files.pythonhosted.org/packages/b6/56/9863fa6dc679f40a31bea5e9713ee5507a31dcd3ee82ea4b1a9268ce52e8/numcodecs-0.15.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1d471a1829ce52d3f365053a2bd1379e32e369517557c4027ddf5ac0d99c591e", size = 1180294 }, + { url = "https://files.pythonhosted.org/packages/fa/91/d96999b41e3146b6c0ce6bddc5ad85803cb4d743c95394562c2a4bb8cded/numcodecs-0.15.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1dfdea4a67108205edfce99c1cb6cd621343bc7abb7e16a041c966776920e7de", size = 8834323 }, + { url = "https://files.pythonhosted.org/packages/c3/32/233e5ede6568bdb044e6f99aaa9fa39827ff3109c6487fc137315f733586/numcodecs-0.15.1-cp313-cp313-win_amd64.whl", hash = "sha256:a4f7bdb26f1b34423cb56d48e75821223be38040907c9b5954eeb7463e7eb03c", size = 831955 }, +] + [[package]] name = "numexpr" version = "2.10.2" @@ -1327,6 +1444,11 @@ sas = [ { name = "pyreadstat" }, { name = "sas7bdat" }, ] +zarr = [ + { name = "zarr", version = "2.18.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, + { name = "zarr", version = "2.18.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] [package.metadata] requires-dist = [ @@ -1364,8 +1486,9 @@ requires-dist = [ { name = "tabulate", specifier = ">=0.9.0,<0.10.0" }, { name = "weightedcalcs", specifier = ">=0.1.2,<0.2.0" }, { name = "wquantiles", specifier = ">=0.6,<0.7" }, + { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18.0,<3.0" }, ] -provides-extras = ["matching", "dev", "casd", "sas"] +provides-extras = ["zarr", "matching", "dev", "casd", "sas"] [[package]] name = "packaging" @@ -2430,6 +2553,144 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/75/3cce30508bf46121b7cabce57b9cacbf8d935fa555cb3c5fca43f8dd0414/wquantiles-0.6-py3-none-any.whl", hash = "sha256:1b90d68fa05251bb96f8806a346e8d7dec9a9bb99f381ad5094707b46ab85218", size = 3291 }, ] +[[package]] +name = "wrapt" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f7/37/ae31f40bec90de2f88d9597d0b5281e23ffe85b893a47ca5d9c05c63a4f6/wrapt-2.1.1.tar.gz", hash = "sha256:5fdcb09bf6db023d88f312bd0767594b414655d58090fc1c46b3414415f67fac", size = 81329 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/21/293b657a27accfbbbb6007ebd78af0efa2083dac83e8f523272ea09b4638/wrapt-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e927375e43fd5a985b27a8992327c22541b6dede1362fc79df337d26e23604f", size = 60554 }, + { url = "https://files.pythonhosted.org/packages/25/e9/96dd77728b54a899d4ce2798d7b1296989ce687ed3c0cb917d6b3154bf5d/wrapt-2.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c99544b6a7d40ca22195563b6d8bc3986ee8bb82f272f31f0670fe9440c869", size = 61496 }, + { url = "https://files.pythonhosted.org/packages/44/79/4c755b45df6ef30c0dd628ecfaa0c808854be147ca438429da70a162833c/wrapt-2.1.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b2be3fa5f4efaf16ee7c77d0556abca35f5a18ad4ac06f0ef3904c3399010ce9", size = 113528 }, + { url = "https://files.pythonhosted.org/packages/9f/63/23ce28f7b841217d9a6337a340fbb8d4a7fbd67a89d47f377c8550fa34aa/wrapt-2.1.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:67c90c1ae6489a6cb1a82058902caa8006706f7b4e8ff766f943e9d2c8e608d0", size = 115536 }, + { url = "https://files.pythonhosted.org/packages/23/7b/5ca8d3b12768670d16c8329e29960eedd56212770365a02a8de8bf73dc01/wrapt-2.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:05c0db35ccffd7480143e62df1e829d101c7b86944ae3be7e4869a7efa621f53", size = 114716 }, + { url = "https://files.pythonhosted.org/packages/c7/3a/9789ccb14a096d30bb847bf3ee137bf682cc9750c2ce155f4c5ae1962abf/wrapt-2.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0c2ec9f616755b2e1e0bf4d0961f59bb5c2e7a77407e7e2c38ef4f7d2fdde12c", size = 113200 }, + { url = "https://files.pythonhosted.org/packages/cf/e5/4ec3526ce6ce920b267c8d35d2c2f0874d3fad2744c8b7259353f1132baa/wrapt-2.1.1-cp310-cp310-win32.whl", hash = "sha256:203ba6b3f89e410e27dbd30ff7dccaf54dcf30fda0b22aa1b82d560c7f9fe9a1", size = 57876 }, + { url = "https://files.pythonhosted.org/packages/d1/4e/661c7c76ecd85375b2bc03488941a3a1078642af481db24949e2b9de01f4/wrapt-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:6f9426d9cfc2f8732922fc96198052e55c09bb9db3ddaa4323a18e055807410e", size = 60224 }, + { url = "https://files.pythonhosted.org/packages/5f/b7/53c7252d371efada4cb119e72e774fa2c6b3011fc33e3e552cdf48fb9488/wrapt-2.1.1-cp310-cp310-win_arm64.whl", hash = "sha256:69c26f51b67076b40714cff81bdd5826c0b10c077fb6b0678393a6a2f952a5fc", size = 58645 }, + { url = "https://files.pythonhosted.org/packages/b8/a8/9254e4da74b30a105935197015b18b31b7a298bf046e67d8952ef74967bd/wrapt-2.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6c366434a7fb914c7a5de508ed735ef9c133367114e1a7cb91dfb5cd806a1549", size = 60554 }, + { url = "https://files.pythonhosted.org/packages/9e/a1/378579880cc7af226354054a2c255f69615b379d8adad482bfe2f22a0dc2/wrapt-2.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d6a2068bd2e1e19e5a317c8c0b288267eec4e7347c36bc68a6e378a39f19ee7", size = 61491 }, + { url = "https://files.pythonhosted.org/packages/dc/72/957b51c56acca35701665878ad31626182199fc4afecfe67dea072210f95/wrapt-2.1.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:891ab4713419217b2aed7dd106c9200f64e6a82226775a0d2ebd6bef2ebd1747", size = 113949 }, + { url = "https://files.pythonhosted.org/packages/cd/74/36bbebb4a3d2ae9c3e6929639721f8606cd0710a82a777c371aa69e36504/wrapt-2.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c8ef36a0df38d2dc9d907f6617f89e113c5892e0a35f58f45f75901af0ce7d81", size = 115989 }, + { url = "https://files.pythonhosted.org/packages/ae/0d/f1177245a083c7be284bc90bddfe5aece32cdd5b858049cb69ce001a0e8d/wrapt-2.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:76e9af3ebd86f19973143d4d592cbf3e970cf3f66ddee30b16278c26ae34b8ab", size = 115242 }, + { url = "https://files.pythonhosted.org/packages/62/3e/3b7cf5da27e59df61b1eae2d07dd03ff5d6f75b5408d694873cca7a8e33c/wrapt-2.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ff562067485ebdeaef2fa3fe9b1876bc4e7b73762e0a01406ad81e2076edcebf", size = 113676 }, + { url = "https://files.pythonhosted.org/packages/f7/65/8248d3912c705f2c66f81cb97c77436f37abcbedb16d633b5ab0d795d8cd/wrapt-2.1.1-cp311-cp311-win32.whl", hash = "sha256:9e60a30aa0909435ec4ea2a3c53e8e1b50ac9f640c0e9fe3f21fd248a22f06c5", size = 57863 }, + { url = "https://files.pythonhosted.org/packages/6b/31/d29310ab335f71f00c50466153b3dc985aaf4a9fc03263e543e136859541/wrapt-2.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:7d79954f51fcf84e5ec4878ab4aea32610d70145c5bbc84b3370eabfb1e096c2", size = 60224 }, + { url = "https://files.pythonhosted.org/packages/0c/90/a6ec319affa6e2894962a0cb9d73c67f88af1a726d15314bfb5c88b8a08d/wrapt-2.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:d3ffc6b0efe79e08fd947605fd598515aebefe45e50432dc3b5cd437df8b1ada", size = 58643 }, + { url = "https://files.pythonhosted.org/packages/df/cb/4d5255d19bbd12be7f8ee2c1fb4269dddec9cef777ef17174d357468efaa/wrapt-2.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab8e3793b239db021a18782a5823fcdea63b9fe75d0e340957f5828ef55fcc02", size = 61143 }, + { url = "https://files.pythonhosted.org/packages/6f/07/7ed02daa35542023464e3c8b7cb937fa61f6c61c0361ecf8f5fecf8ad8da/wrapt-2.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c0300007836373d1c2df105b40777986accb738053a92fe09b615a7a4547e9f", size = 61740 }, + { url = "https://files.pythonhosted.org/packages/c4/60/a237a4e4a36f6d966061ccc9b017627d448161b19e0a3ab80a7c7c97f859/wrapt-2.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2b27c070fd1132ab23957bcd4ee3ba707a91e653a9268dc1afbd39b77b2799f7", size = 121327 }, + { url = "https://files.pythonhosted.org/packages/ae/fe/9139058a3daa8818fc67e6460a2340e8bbcf3aef8b15d0301338bbe181ca/wrapt-2.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b0e36d845e8b6f50949b6b65fc6cd279f47a1944582ed4ec8258cd136d89a64", size = 122903 }, + { url = "https://files.pythonhosted.org/packages/91/10/b8479202b4164649675846a531763531f0a6608339558b5a0a718fc49a8d/wrapt-2.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4aeea04a9889370fcfb1ef828c4cc583f36a875061505cd6cd9ba24d8b43cc36", size = 121333 }, + { url = "https://files.pythonhosted.org/packages/5f/75/75fc793b791d79444aca2c03ccde64e8b99eda321b003f267d570b7b0985/wrapt-2.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d88b46bb0dce9f74b6817bc1758ff2125e1ca9e1377d62ea35b6896142ab6825", size = 120458 }, + { url = "https://files.pythonhosted.org/packages/d7/8f/c3f30d511082ca6d947c405f9d8f6c8eaf83cfde527c439ec2c9a30eb5ea/wrapt-2.1.1-cp312-cp312-win32.whl", hash = "sha256:63decff76ca685b5c557082dfbea865f3f5f6d45766a89bff8dc61d336348833", size = 58086 }, + { url = "https://files.pythonhosted.org/packages/0a/c8/37625b643eea2849f10c3b90f69c7462faa4134448d4443234adaf122ae5/wrapt-2.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:b828235d26c1e35aca4107039802ae4b1411be0fe0367dd5b7e4d90e562fcbcd", size = 60328 }, + { url = "https://files.pythonhosted.org/packages/ce/79/56242f07572d5682ba8065a9d4d9c2218313f576e3c3471873c2a5355ffd/wrapt-2.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:75128507413a9f1bcbe2db88fd18fbdbf80f264b82fa33a6996cdeaf01c52352", size = 58722 }, + { url = "https://files.pythonhosted.org/packages/f7/ca/3cf290212855b19af9fcc41b725b5620b32f470d6aad970c2593500817eb/wrapt-2.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ce9646e17fa7c3e2e7a87e696c7de66512c2b4f789a8db95c613588985a2e139", size = 61150 }, + { url = "https://files.pythonhosted.org/packages/9d/33/5b8f89a82a9859ce82da4870c799ad11ce15648b6e1c820fec3e23f4a19f/wrapt-2.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:428cfc801925454395aa468ba7ddb3ed63dc0d881df7b81626cdd433b4e2b11b", size = 61743 }, + { url = "https://files.pythonhosted.org/packages/1e/2f/60c51304fbdf47ce992d9eefa61fbd2c0e64feee60aaa439baf42ea6f40b/wrapt-2.1.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5797f65e4d58065a49088c3b32af5410751cd485e83ba89e5a45e2aa8905af98", size = 121341 }, + { url = "https://files.pythonhosted.org/packages/ad/03/ce5256e66dd94e521ad5e753c78185c01b6eddbed3147be541f4d38c0cb7/wrapt-2.1.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5a2db44a71202c5ae4bb5f27c6d3afbc5b23053f2e7e78aa29704541b5dad789", size = 122947 }, + { url = "https://files.pythonhosted.org/packages/eb/ae/50ca8854b81b946a11a36fcd6ead32336e6db2c14b6e4a8b092b80741178/wrapt-2.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:8d5350c3590af09c1703dd60ec78a7370c0186e11eaafb9dda025a30eee6492d", size = 121370 }, + { url = "https://files.pythonhosted.org/packages/fb/d9/d6a7c654e0043319b4cc137a4caaf7aa16b46b51ee8df98d1060254705b7/wrapt-2.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d9b076411bed964e752c01b49fd224cc385f3a96f520c797d38412d70d08359", size = 120465 }, + { url = "https://files.pythonhosted.org/packages/55/90/65be41e40845d951f714b5a77e84f377a3787b1e8eee6555a680da6d0db5/wrapt-2.1.1-cp313-cp313-win32.whl", hash = "sha256:0bb7207130ce6486727baa85373503bf3334cc28016f6928a0fa7e19d7ecdc06", size = 58090 }, + { url = "https://files.pythonhosted.org/packages/5f/66/6a09e0294c4fc8c26028a03a15191721c9271672467cc33e6617ee0d91d2/wrapt-2.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:cbfee35c711046b15147b0ae7db9b976f01c9520e6636d992cd9e69e5e2b03b1", size = 60341 }, + { url = "https://files.pythonhosted.org/packages/7a/f0/20ceb8b701e9a71555c87a5ddecbed76ec16742cf1e4b87bbaf26735f998/wrapt-2.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:7d2756061022aebbf57ba14af9c16e8044e055c22d38de7bf40d92b565ecd2b0", size = 58731 }, + { url = "https://files.pythonhosted.org/packages/80/b4/fe95beb8946700b3db371f6ce25115217e7075ca063663b8cca2888ba55c/wrapt-2.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4814a3e58bc6971e46baa910ecee69699110a2bf06c201e24277c65115a20c20", size = 62969 }, + { url = "https://files.pythonhosted.org/packages/b8/89/477b0bdc784e3299edf69c279697372b8bd4c31d9c6966eae405442899df/wrapt-2.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:106c5123232ab9b9f4903692e1fa0bdc231510098f04c13c3081f8ad71c3d612", size = 63606 }, + { url = "https://files.pythonhosted.org/packages/ed/55/9d0c1269ab76de87715b3b905df54dd25d55bbffd0b98696893eb613469f/wrapt-2.1.1-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1a40b83ff2535e6e56f190aff123821eea89a24c589f7af33413b9c19eb2c738", size = 152536 }, + { url = "https://files.pythonhosted.org/packages/44/18/2004766030462f79ad86efaa62000b5e39b1ff001dcce86650e1625f40ae/wrapt-2.1.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:789cea26e740d71cf1882e3a42bb29052bc4ada15770c90072cb47bf73fb3dbf", size = 158697 }, + { url = "https://files.pythonhosted.org/packages/e1/bb/0a880fa0f35e94ee843df4ee4dd52a699c9263f36881311cfb412c09c3e5/wrapt-2.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ba49c14222d5e5c0ee394495a8655e991dc06cbca5398153aefa5ac08cd6ccd7", size = 155563 }, + { url = "https://files.pythonhosted.org/packages/42/ff/cd1b7c4846c8678fac359a6eb975dc7ab5bd606030adb22acc8b4a9f53f1/wrapt-2.1.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ac8cda531fe55be838a17c62c806824472bb962b3afa47ecbd59b27b78496f4e", size = 150161 }, + { url = "https://files.pythonhosted.org/packages/38/ec/67c90a7082f452964b4621e4890e9a490f1add23cdeb7483cc1706743291/wrapt-2.1.1-cp313-cp313t-win32.whl", hash = "sha256:b8af75fe20d381dd5bcc9db2e86a86d7fcfbf615383a7147b85da97c1182225b", size = 59783 }, + { url = "https://files.pythonhosted.org/packages/ec/08/466afe4855847d8febdfa2c57c87e991fc5820afbdef01a273683dfd15a0/wrapt-2.1.1-cp313-cp313t-win_amd64.whl", hash = "sha256:45c5631c9b6c792b78be2d7352129f776dd72c605be2c3a4e9be346be8376d83", size = 63082 }, + { url = "https://files.pythonhosted.org/packages/9a/62/60b629463c28b15b1eeadb3a0691e17568622b12aa5bfa7ebe9b514bfbeb/wrapt-2.1.1-cp313-cp313t-win_arm64.whl", hash = "sha256:da815b9263947ac98d088b6414ac83507809a1d385e4632d9489867228d6d81c", size = 60251 }, + { url = "https://files.pythonhosted.org/packages/95/a0/1c2396e272f91efe6b16a6a8bce7ad53856c8f9ae4f34ceaa711d63ec9e1/wrapt-2.1.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9aa1765054245bb01a37f615503290d4e207e3fd59226e78341afb587e9c1236", size = 61311 }, + { url = "https://files.pythonhosted.org/packages/b0/9a/d2faba7e61072a7507b5722db63562fdb22f5a24e237d460d18755627f15/wrapt-2.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:feff14b63a6d86c1eee33a57f77573649f2550935981625be7ff3cb7342efe05", size = 61805 }, + { url = "https://files.pythonhosted.org/packages/db/56/073989deb4b5d7d6e7ea424476a4ae4bda02140f2dbeaafb14ba4864dd60/wrapt-2.1.1-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:81fc5f22d5fcfdbabde96bb3f5379b9f4476d05c6d524d7259dc5dfb501d3281", size = 120308 }, + { url = "https://files.pythonhosted.org/packages/d1/b6/84f37261295e38167a29eb82affaf1dc15948dc416925fe2091beee8e4ac/wrapt-2.1.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:951b228ecf66def855d22e006ab9a1fc12535111ae7db2ec576c728f8ddb39e8", size = 122688 }, + { url = "https://files.pythonhosted.org/packages/ea/80/32db2eec6671f80c65b7ff175be61bc73d7f5223f6910b0c921bbc4bd11c/wrapt-2.1.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:0ddf582a95641b9a8c8bd643e83f34ecbbfe1b68bc3850093605e469ab680ae3", size = 121115 }, + { url = "https://files.pythonhosted.org/packages/49/ef/dcd00383df0cd696614127902153bf067971a5aabcd3c9dcb2d8ef354b2a/wrapt-2.1.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:fc5c500966bf48913f795f1984704e6d452ba2414207b15e1f8c339a059d5b16", size = 119484 }, + { url = "https://files.pythonhosted.org/packages/76/29/0630280cdd2bd8f86f35cb6854abee1c9d6d1a28a0c6b6417cd15d378325/wrapt-2.1.1-cp314-cp314-win32.whl", hash = "sha256:4aa4baadb1f94b71151b8e44a0c044f6af37396c3b8bcd474b78b49e2130a23b", size = 58514 }, + { url = "https://files.pythonhosted.org/packages/db/19/5bed84f9089ed2065f6aeda5dfc4f043743f642bc871454b261c3d7d322b/wrapt-2.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:860e9d3fd81816a9f4e40812f28be4439ab01f260603c749d14be3c0a1170d19", size = 60763 }, + { url = "https://files.pythonhosted.org/packages/e4/cb/b967f2f9669e4249b4fe82e630d2a01bc6b9e362b9b12ed91bbe23ae8df4/wrapt-2.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:3c59e103017a2c1ea0ddf589cbefd63f91081d7ce9d491d69ff2512bb1157e23", size = 59051 }, + { url = "https://files.pythonhosted.org/packages/eb/19/6fed62be29f97eb8a56aff236c3f960a4b4a86e8379dc7046a8005901a97/wrapt-2.1.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9fa7c7e1bee9278fc4f5dd8275bc8d25493281a8ec6c61959e37cc46acf02007", size = 63059 }, + { url = "https://files.pythonhosted.org/packages/0a/1c/b757fd0adb53d91547ed8fad76ba14a5932d83dde4c994846a2804596378/wrapt-2.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:39c35e12e8215628984248bd9c8897ce0a474be2a773db207eb93414219d8469", size = 63618 }, + { url = "https://files.pythonhosted.org/packages/10/fe/e5ae17b1480957c7988d991b93df9f2425fc51f128cf88144d6a18d0eb12/wrapt-2.1.1-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:94ded4540cac9125eaa8ddf5f651a7ec0da6f5b9f248fe0347b597098f8ec14c", size = 152544 }, + { url = "https://files.pythonhosted.org/packages/3e/cc/99aed210c6b547b8a6e4cb9d1425e4466727158a6aeb833aa7997e9e08dd/wrapt-2.1.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da0af328373f97ed9bdfea24549ac1b944096a5a71b30e41c9b8b53ab3eec04a", size = 158700 }, + { url = "https://files.pythonhosted.org/packages/81/0e/d442f745f4957944d5f8ad38bc3a96620bfff3562533b87e486e979f3d99/wrapt-2.1.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:4ad839b55f0bf235f8e337ce060572d7a06592592f600f3a3029168e838469d3", size = 155561 }, + { url = "https://files.pythonhosted.org/packages/51/ac/9891816280e0018c48f8dfd61b136af7b0dcb4a088895db2531acde5631b/wrapt-2.1.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0d89c49356e5e2a50fa86b40e0510082abcd0530f926cbd71cf25bee6b9d82d7", size = 150188 }, + { url = "https://files.pythonhosted.org/packages/24/98/e2f273b6d70d41f98d0739aa9a269d0b633684a5fb17b9229709375748d4/wrapt-2.1.1-cp314-cp314t-win32.whl", hash = "sha256:f4c7dd22cf7f36aafe772f3d88656559205c3af1b7900adfccb70edeb0d2abc4", size = 60425 }, + { url = "https://files.pythonhosted.org/packages/1e/06/b500bfc38a4f82d89f34a13069e748c82c5430d365d9e6b75afb3ab74457/wrapt-2.1.1-cp314-cp314t-win_amd64.whl", hash = "sha256:f76bc12c583ab01e73ba0ea585465a41e48d968f6d1311b4daec4f8654e356e3", size = 63855 }, + { url = "https://files.pythonhosted.org/packages/d9/cc/5f6193c32166faee1d2a613f278608e6f3b95b96589d020f0088459c46c9/wrapt-2.1.1-cp314-cp314t-win_arm64.whl", hash = "sha256:7ea74fc0bec172f1ae5f3505b6655c541786a5cabe4bbc0d9723a56ac32eb9b9", size = 60443 }, + { url = "https://files.pythonhosted.org/packages/08/3e/144e085a4a237b60a1b41f56e8a173e5e4f21f42a201e43f8d38272b4772/wrapt-2.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9e03b3d486eb39f5d3f562839f59094dcee30c4039359ea15768dc2214d9e07c", size = 60552 }, + { url = "https://files.pythonhosted.org/packages/69/25/576fa5d1e8c0b2657ed411b947bb50c7cc56a0a882fbd1b04574803e668a/wrapt-2.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0fdf3073f488ce4d929929b7799e3b8c52b220c9eb3f4a5a51e2dc0e8ff07881", size = 61498 }, + { url = "https://files.pythonhosted.org/packages/48/01/37def21f806dee9db8c12f99b872b3cdf15215bafe3919c982968134b804/wrapt-2.1.1-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0cb4f59238c6625fae2eeb72278da31c9cfba0ff4d9cbe37446b73caa0e9bcf7", size = 113232 }, + { url = "https://files.pythonhosted.org/packages/bf/ee/31dfda37ae75db11cc46634aa950c3497f7a8f987d811388bf1b11fe2f80/wrapt-2.1.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f794a1c148871b714cb566f5466ec8288e0148a1c417550983864b3981737cd", size = 115198 }, + { url = "https://files.pythonhosted.org/packages/93/d5/43cb27a2d7142bdbe9700099e7261fdc264f63c6b60a8025dd5f8af157da/wrapt-2.1.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:95ef3866631c6da9ce1fc0f1e17b90c4c0aa6d041fc70a11bc90733aee122e1a", size = 114400 }, + { url = "https://files.pythonhosted.org/packages/61/91/8429803605df5540b918fe6fc9ffc4f167270f4b7ca1f82eaf7d7b1204b6/wrapt-2.1.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:66bc1b2446f01cbbd3c56b79a3a8435bcd4178ac4e06b091913f7751a7f528b8", size = 112998 }, + { url = "https://files.pythonhosted.org/packages/7e/6a/25cb316f3e8262a1626da71b2c299ae2be02fb0547028eac9aa21daeedda/wrapt-2.1.1-cp39-cp39-win32.whl", hash = "sha256:1b9e08e57cabc32972f7c956d10e85093c5da9019faa24faf411e7dd258e528c", size = 57871 }, + { url = "https://files.pythonhosted.org/packages/09/69/ffd41e6149ac4bd9700552659842383f44eb96f00e03c2db433bc856bf2f/wrapt-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e75ad48c3cca739f580b5e14c052993eb644c7fa5b4c90aa51193280b30875ae", size = 60222 }, + { url = "https://files.pythonhosted.org/packages/59/f0/1889e68a0d389d2552b9e014ed6471addcfab98f09611bac61a8d8fab223/wrapt-2.1.1-cp39-cp39-win_arm64.whl", hash = "sha256:9ccd657873b7f964711447d004563a2bc08d1476d7a1afcad310f3713e6f50f4", size = 58647 }, + { url = "https://files.pythonhosted.org/packages/c4/da/5a086bf4c22a41995312db104ec2ffeee2cf6accca9faaee5315c790377d/wrapt-2.1.1-py3-none-any.whl", hash = "sha256:3b0f4629eb954394a3d7c7a1c8cca25f0b07cefe6aa8545e862e9778152de5b7", size = 43886 }, +] + +[[package]] +name = "zarr" +version = "2.18.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "asciitree", marker = "python_full_version < '3.10'" }, + { name = "fasteners", marker = "python_full_version < '3.10' and sys_platform != 'emscripten'" }, + { name = "numcodecs", version = "0.12.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/00/ac5c518ff1c1b1cc87a62f86ad9d19c647c19d969a91faa40d3b6342ccaa/zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47", size = 3603055 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/bd/8d881d8ca6d80fcb8da2b2f94f8855384daf649499ddfba78ffd1ee2caa3/zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38", size = 210228 }, +] + +[[package]] +name = "zarr" +version = "2.18.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.10.*'", +] +dependencies = [ + { name = "asciitree", marker = "python_full_version == '3.10.*'" }, + { name = "fasteners", marker = "python_full_version == '3.10.*' and sys_platform != 'emscripten'" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723 }, +] + +[[package]] +name = "zarr" +version = "2.18.7" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13'", + "python_full_version == '3.12.*'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "asciitree", marker = "python_full_version >= '3.11'" }, + { name = "fasteners", marker = "python_full_version >= '3.11' and sys_platform != 'emscripten'" }, + { name = "numcodecs", version = "0.15.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, + { name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/1d/01cf9e3ab2d85190278efc3fca9f68563de35ae30ee59e7640e3af98abe3/zarr-2.18.7.tar.gz", hash = "sha256:b2b8f66f14dac4af66b180d2338819981b981f70e196c9a66e6bfaa9e59572f5", size = 3604558 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/d8/9ffd8c237b3559945bb52103cf0eed64ea098f7b7f573f8d2962ef27b4b2/zarr-2.18.7-py3-none-any.whl", hash = "sha256:ac3dc4033e9ae4e9d7b5e27c97ea3eaf1003cc0a07f010bd83d5134bf8c4b223", size = 211273 }, +] + [[package]] name = "zipp" version = "3.23.0" From 37bb0d9ed7f1b1767f2086362e2f9dd76d97a4cc Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Tue, 10 Mar 2026 21:27:48 +0100 Subject: [PATCH 17/18] Bump version --- .gitignore | 1 + CHANGELOG.md | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index b188e6f0..b619b74b 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,4 @@ tags # Tests files *.parquet test_*.json +.claude/settings.json diff --git a/CHANGELOG.md b/CHANGELOG.md index b70d7fd4..4cbe1dd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ -# Changelog +# Changelog -# 1.1.0 +# 8.0.0 * **Store backends** (choix du format de stockage des tables) - **io/backends**: Backends HDF5, Parquet et Zarr (abstraction `StoreBackend`) ; `get_backend(name)`, `get_available_backend_names()`, `register_backend()` pour étendre. @@ -16,9 +16,9 @@ - **Script de migration** : infère `store_format` depuis le JSON legacy (`parquet_file_path` / `zarr_file_path` / `hdf5_file_path`) et l’écrit dans le manifest généré. - **RFC-002** : exemple de manifest avec `store_format` ; section 3.5 et 4.2 mises à jour. -# 1.0.0 +# 7.0.0 -* **Breaking**: Version 1.0 — retrait des ré-exports et des DeprecationWarning +* **Breaking**: Version 7.0 — retrait des ré-exports et des DeprecationWarning - **Suppression des modules de compatibilité** : `config`, `paths`, `tables`, `surveys`, `survey_collections`, `read_sas`, `read_spss`, `read_dbf`, `calibration`, `calmar`, `utils` sont supprimés. Utiliser les imports canoniques (voir `docs/MIGRATION_IMPORTS.md`). - **`load_table`** : déplacé de `utils` vers `openfisca_survey_manager.core.dataset` (et exporté depuis `core`). - Tous les imports internes ont été migrés vers `configuration.paths`, `configuration.models`, `core.dataset`, `core.survey`, `core.table`, `io.readers`, `processing.weights`, `common.misc`. From 3430a1204cba4b704230a6d23c0c3da53cbd68ea Mon Sep 17 00:00:00 2001 From: Mahdi Ben Jelloul Date: Tue, 10 Mar 2026 21:27:56 +0100 Subject: [PATCH 18/18] Again --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 55480ebb..0e6c7805 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "OpenFisca-Survey-Manager" -version = "1.1.0" +version = "8.0.0" description = "A tool for managing survey/administrative data and import them in OpenFisca" readme = "README.md" keywords = ["microsimulation", "tax", "benefit", "rac", "rules-as-code", "survey", "data"]