diff --git a/.gitignore b/.gitignore index 737887f7f..848ed6ceb 100644 --- a/.gitignore +++ b/.gitignore @@ -212,3 +212,5 @@ __marimo__/ # Claude .claude/ +*notebooks/nsc_sftp_automated_data_ingestion/tmp/ +*notebooks/nsc_sftp_automated_data_ingestion/gcp_config.yaml \ No newline at end of file diff --git a/notebooks/nsc_sftp_automated_data_ingestion/01_sftp_receive_scan.ipynb b/notebooks/nsc_sftp_automated_data_ingestion/01_sftp_receive_scan.ipynb new file mode 100644 index 000000000..6a9e361ec --- /dev/null +++ b/notebooks/nsc_sftp_automated_data_ingestion/01_sftp_receive_scan.ipynb @@ -0,0 +1,377 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7dc0a9a7-1db8-42b9-b0c4-07946f392d5e", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# 1. Connect to SFTP and scan the receive folder for files.\n", + "# 2. Upsert unseen files into `ingestion_manifest` with status=NEW.\n", + "# 3. Download and stage NEW + unqueued files locally and upsert them into `pending_ingest_queue`.\n", + "\n", + "# Recent refactor:\n", + "# - SFTP helpers moved to `helper.py` (`connect_sftp`, `list_receive_files`, `download_sftp_atomic`).\n", + "# - `list_receive_files` now takes `source_system` explicitly (no hidden notebook globals).\n", + "\n", + "# Constraints:\n", + "# - SFTP connection required\n", + "# - NO API calls\n", + "# - Stages files to UC volume (CATALOG.default.tmp) + writes to Delta tables only\n", + "\n", + "# Inputs:\n", + "# - SFTP folder: `./receive`\n", + "# - Required workflow parameters (exact SFTP file names):\n", + "# - `cohort_file_name`\n", + "# - `course_file_name`\n", + "# - Both file names must end with the same 14-digit file stamp (e.g. `..._YYYYMMDDHHMMSS.csv`).\n", + "\n", + "# Outputs:\n", + "# - `CATALOG.default.ingestion_manifest`\n", + "# - `CATALOG.default.pending_ingest_queue`\n", + "# - Staged files written to UC Volume: `CATALOG.default.tmp` (path `/Volumes//default/tmp`)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "cbd7694b-4b30-41bf-9371-259479726010", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install paramiko python-box pyyaml\n", + "%pip install git+https://github.com/datakind/edvise.git@Automated_Ingestion_Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "b9ae88af-ade1-4df0-86a0-34d6d492383a", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%restart_python" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "5888f9b8-bda7-4586-9f9f-ed1243d878de", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "import re\n", + "from databricks.connect import DatabricksSession\n", + "from pyspark.sql import functions as F\n", + "\n", + "from edvise.utils.sftp import connect_sftp, list_receive_files\n", + "from edvise.ingestion.constants import (\n", + " MANIFEST_TABLE_PATH,\n", + " QUEUE_TABLE_PATH,\n", + " SFTP_REMOTE_FOLDER,\n", + " SFTP_SOURCE_SYSTEM,\n", + " SFTP_TMP_DIR,\n", + ")\n", + "from edvise.ingestion.nsc_sftp_helpers import (\n", + " build_listing_df,\n", + " download_new_files_and_queue,\n", + " ensure_manifest_and_queue_tables,\n", + " get_files_to_queue,\n", + " upsert_new_to_manifest,\n", + ")\n", + "from edvise import utils\n", + "\n", + "try:\n", + " dbutils # noqa: F821\n", + "except NameError:\n", + " from unittest.mock import MagicMock\n", + "\n", + " dbutils = MagicMock()\n", + "spark = DatabricksSession.builder.getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "61b348b8-aa62-4b5a-9442-d48d52e1a862", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", + ")\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "asset_scope = \"nsc-sftp-asset\"\n", + "\n", + "host = dbutils.secrets.get(scope=asset_scope, key=\"nsc-sftp-host\")\n", + "user = dbutils.secrets.get(scope=asset_scope, key=\"nsc-sftp-user\")\n", + "password = dbutils.secrets.get(scope=asset_scope, key=\"nsc-sftp-password\")\n", + "\n", + "cohort_file_name = utils.databricks.get_db_widget_param(\"cohort_file_name\", default=\"\")\n", + "course_file_name = utils.databricks.get_db_widget_param(\"course_file_name\", default=\"\")\n", + "cohort_file_name = str(cohort_file_name).strip()\n", + "course_file_name = str(course_file_name).strip()\n", + "if not cohort_file_name or not course_file_name:\n", + " raise ValueError(\n", + " \"Missing required workflow parameters: cohort_file_name and course_file_name. \"\n", + " \"Pass them as Databricks job base parameters.\"\n", + " )\n", + "\n", + "\n", + "def _extract_file_stamp(file_name: str) -> str:\n", + " base = os.path.basename(file_name)\n", + " m = re.search(r\"_(\\d{14})(?:\\.[^.]+)?$\", base)\n", + " if not m:\n", + " raise ValueError(\n", + " \"Expected file name to end with a 14-digit file stamp, e.g. \"\n", + " \"'..._YYYYMMDDHHMMSS.csv'. Got: \"\n", + " f\"{file_name}\"\n", + " )\n", + " return m.group(1)\n", + "\n", + "\n", + "cohort_stamp = _extract_file_stamp(cohort_file_name)\n", + "course_stamp = _extract_file_stamp(course_file_name)\n", + "if cohort_stamp != course_stamp:\n", + " raise ValueError(\n", + " \"cohort_file_name and course_file_name must end with the same file stamp. \"\n", + " f\"Got cohort stamp={cohort_stamp}, course stamp={course_stamp}.\"\n", + " )\n", + "logger.info(f\"Validated file stamp: {cohort_stamp}\")\n", + "logger.info(f\"Staging to UC volume path: {SFTP_TMP_DIR}\")\n", + "logger.info(\n", + " \"Manual file selection enabled: \"\n", + " f\"cohort_file_name={cohort_file_name}, course_file_name={course_file_name}\"\n", + ")\n", + "\n", + "logger.info(\"SFTP secured assets loaded successfully.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "80968f66-5082-49ca-b03f-b3a1ef0bb908", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "transport = None\n", + "sftp = None\n", + "\n", + "try:\n", + " ensure_manifest_and_queue_tables(spark)\n", + "\n", + " transport, sftp = connect_sftp(host, user, password)\n", + " logger.info(\n", + " f\"Connected to SFTP host={host} and scanning folder={SFTP_REMOTE_FOLDER}\"\n", + " )\n", + "\n", + " file_rows_all = list_receive_files(sftp, SFTP_REMOTE_FOLDER, SFTP_SOURCE_SYSTEM)\n", + " if not file_rows_all:\n", + " logger.info(\n", + " f\"No files found in SFTP folder: {SFTP_REMOTE_FOLDER}. Exiting (no-op).\"\n", + " )\n", + " dbutils.notebook.exit(\"NO_FILES\")\n", + "\n", + " requested_names = {cohort_file_name, course_file_name}\n", + " logger.info(\n", + " f\"Found {len(file_rows_all)} file(s) on SFTP in folder={SFTP_REMOTE_FOLDER}; \"\n", + " f\"requested={sorted(requested_names)}\"\n", + " )\n", + " file_rows = [r for r in file_rows_all if r.get(\"file_name\") in requested_names]\n", + "\n", + " found_names = {r.get(\"file_name\") for r in file_rows}\n", + " missing_names = sorted(requested_names - found_names)\n", + " if missing_names:\n", + " available = sorted({r.get(\"file_name\") for r in file_rows_all})\n", + " preview = available[:25]\n", + " raise FileNotFoundError(\n", + " f\"Requested file(s) not found on SFTP in folder '{SFTP_REMOTE_FOLDER}': {missing_names}. \"\n", + " f\"Available file count={len(available)}; first 25={preview}\"\n", + " )\n", + "\n", + " for r in file_rows:\n", + " logger.info(\n", + " f\"Selected SFTP file: name={r.get('file_name')} size={r.get('file_size')} \"\n", + " f\"modified={r.get('file_modified_time')}\"\n", + " )\n", + "\n", + " df_listing = build_listing_df(spark, file_rows)\n", + " fingerprints = [\n", + " r[\"file_fingerprint\"] for r in df_listing.select(\"file_fingerprint\").collect()\n", + " ]\n", + "\n", + " logger.info(\"SFTP listing (selected files):\")\n", + " df_listing.select(\n", + " \"file_name\", \"file_size\", \"file_modified_time\", \"file_fingerprint\"\n", + " ).show(truncate=False)\n", + "\n", + " # 1) Ensure everything on SFTP is at least represented in manifest as NEW\n", + " upsert_new_to_manifest(spark, df_listing)\n", + "\n", + " logger.info(\"Manifest rows (selected files):\")\n", + " spark.table(MANIFEST_TABLE_PATH).where(\n", + " F.col(\"file_fingerprint\").isin(fingerprints)\n", + " ).select(\n", + " \"file_name\",\n", + " \"file_fingerprint\",\n", + " \"status\",\n", + " \"processed_at\",\n", + " \"error_message\",\n", + " ).show(truncate=False)\n", + "\n", + " # 2) Queue anything that is still NEW and not already queued\n", + " df_to_queue = get_files_to_queue(spark, df_listing)\n", + "\n", + " to_queue_count = df_to_queue.count()\n", + " if to_queue_count == 0:\n", + " logger.info(\n", + " \"No files to queue: either nothing is NEW, or NEW files are already queued. Exiting (no-op).\"\n", + " )\n", + " dbutils.notebook.exit(\"QUEUED_FILES=0\")\n", + "\n", + " logger.info(\"Files eligible to queue:\")\n", + " df_to_queue.select(\n", + " \"file_name\", \"file_size\", \"file_modified_time\", \"file_fingerprint\"\n", + " ).show(truncate=False)\n", + "\n", + " logger.info(\n", + " f\"Queuing {to_queue_count} NEW-unqueued file(s) to {QUEUE_TABLE_PATH} and staging to UC volume.\"\n", + " )\n", + " queued_count = download_new_files_and_queue(spark, sftp, df_to_queue, logger)\n", + "\n", + " logger.info(\"Queue rows (selected files):\")\n", + " spark.table(QUEUE_TABLE_PATH).where(\n", + " F.col(\"file_fingerprint\").isin(fingerprints)\n", + " ).select(\"file_name\", \"file_fingerprint\", \"local_tmp_path\", \"queued_at\").show(\n", + " truncate=False\n", + " )\n", + "\n", + " logger.info(\n", + " f\"Queued {queued_count} file(s) for downstream processing in {QUEUE_TABLE_PATH}.\"\n", + " )\n", + " dbutils.notebook.exit(f\"QUEUED_FILES={queued_count}\")\n", + "\n", + "finally:\n", + " try:\n", + " if sftp is not None:\n", + " sftp.close()\n", + " except Exception:\n", + " pass\n", + " try:\n", + " if transport is not None:\n", + " transport.close()\n", + " except Exception:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "edff98e1-0862-4e41-8c35-bd5fb6647136", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "computePreferences": null, + "dashboards": [], + "environmentMetadata": { + "base_environment": "", + "environment_version": "4" + }, + "inputWidgetPreferences": null, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 4 + }, + "notebookName": "01_sftp_receive_scan", + "widgets": {} + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/nsc_sftp_automated_data_ingestion/02_file_institution_expand.ipynb b/notebooks/nsc_sftp_automated_data_ingestion/02_file_institution_expand.ipynb new file mode 100644 index 000000000..9e3c409c0 --- /dev/null +++ b/notebooks/nsc_sftp_automated_data_ingestion/02_file_institution_expand.ipynb @@ -0,0 +1,386 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. Read each *staged* local file (from `pending_ingest_queue`), detect the institution id column,\n", + "# 2. extract unique institution IDs, and emit per-institution work items.\n", + "\n", + "# Constraints:\n", + "# - NO SFTP connection\n", + "# - NO API calls\n", + "# - NO volume writes\n", + "\n", + "# Input table:\n", + "# - `staging_sst_01.default.pending_ingest_queue`\n", + "\n", + "# Output table:\n", + "# - `staging_sst_01.default.institution_ingest_plan`\n", + "# - Columns: `file_fingerprint`, `file_name`, `local_path`, `institution_id`, `inst_col`, `file_size`, `file_modified_time`, `planned_at`\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "679b2064-2a15-4d89-abda-5e9c0148ff61", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install pandas python-box pyyaml paramiko\n", + "%pip install git+https://github.com/datakind/edvise.git@Automated_Ingestion_Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%restart_python" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "62608829-5027-4075-a4fc-1e4afc36ef3a", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "import re\n", + "from datetime import datetime, timezone\n", + "\n", + "from pyspark.sql import functions as F\n", + "from pyspark.sql import types as T\n", + "from databricks.connect import DatabricksSession\n", + "\n", + "from edvise.ingestion.nsc_sftp_helpers import ensure_plan_table, extract_institution_ids\n", + "from edvise.ingestion.constants import (\n", + " QUEUE_TABLE_PATH,\n", + " PLAN_TABLE_PATH,\n", + " COLUMN_RENAMES,\n", + " INSTITUTION_COLUMN_PATTERN,\n", + ")\n", + "\n", + "try:\n", + " dbutils # noqa: F821\n", + "except NameError:\n", + " from unittest.mock import MagicMock\n", + "\n", + " dbutils = MagicMock()\n", + "spark = DatabricksSession.builder.getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "64156fce-07a6-4eb6-8612-6b29bc06edfe", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", + ")\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "INST_COL_PATTERN = re.compile(INSTITUTION_COLUMN_PATTERN, re.IGNORECASE)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "87047914-fec0-4f35-b33f-d1b927605d11", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "ensure_plan_table(spark, PLAN_TABLE_PATH)\n", + "\n", + "# Pull queued staged files (Script 1 output)\n", + "if not spark.catalog.tableExists(QUEUE_TABLE_PATH):\n", + " logger.info(f\"Queue table {QUEUE_TABLE_PATH} not found. Exiting (no-op).\")\n", + " dbutils.notebook.exit(\"NO_QUEUE_TABLE\")\n", + "\n", + "queue_df = spark.read.table(QUEUE_TABLE_PATH)\n", + "\n", + "if queue_df.limit(1).count() == 0:\n", + " logger.info(\"pending_ingest_queue is empty. Exiting (no-op).\")\n", + " dbutils.notebook.exit(\"NO_QUEUED_FILES\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "21683394-0bec-42b8-82dd-1a4590519de5", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# Avoid regenerating plans for files already expanded\n", + "existing_fp = (\n", + " spark.table(PLAN_TABLE_PATH).select(\"file_fingerprint\").distinct()\n", + " if spark.catalog.tableExists(PLAN_TABLE_PATH)\n", + " else None\n", + ")\n", + "if existing_fp is not None:\n", + " queue_df = queue_df.join(existing_fp, on=\"file_fingerprint\", how=\"left_anti\")\n", + "\n", + "if queue_df.limit(1).count() == 0:\n", + " logger.info(\n", + " \"All queued files have already been expanded into institution work items. Exiting (no-op).\"\n", + " )\n", + " dbutils.notebook.exit(\"NO_NEW_EXPANSION_WORK\")\n", + "\n", + "logger.info(\"Queued files to expand preview (after excluding already-expanded):\")\n", + "queue_df.select(\"file_fingerprint\", \"file_name\", \"local_tmp_path\", \"queued_at\").show(\n", + " 25, truncate=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "540c7880-f14a-4607-979a-856f17066c50", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "queued_files = queue_df.select(\n", + " \"file_fingerprint\",\n", + " \"file_name\",\n", + " F.col(\"local_tmp_path\").alias(\"local_path\"),\n", + " \"file_size\",\n", + " \"file_modified_time\",\n", + ").collect()\n", + "\n", + "logger.info(\n", + " f\"Expanding {len(queued_files)} staged file(s) into per-institution work items...\"\n", + ")\n", + "\n", + "work_items = []\n", + "missing_files = []\n", + "\n", + "for r in queued_files:\n", + " fp = r[\"file_fingerprint\"]\n", + " file_name = r[\"file_name\"]\n", + " local_path = r[\"local_path\"]\n", + "\n", + " if not local_path or not os.path.exists(local_path):\n", + " missing_files.append((fp, file_name, local_path))\n", + " continue\n", + "\n", + " try:\n", + " inst_col, inst_ids = extract_institution_ids(\n", + " local_path, renames=COLUMN_RENAMES, inst_col_pattern=INST_COL_PATTERN\n", + " )\n", + " if inst_col is None:\n", + " logger.warning(\n", + " f\"No institution id column found for file={file_name} fp={fp}. Skipping this file.\"\n", + " )\n", + " continue\n", + "\n", + " if not inst_ids:\n", + " logger.warning(\n", + " f\"Institution column found but no IDs present for file={file_name} fp={fp}. Skipping.\"\n", + " )\n", + " continue\n", + "\n", + " now_ts = datetime.now(timezone.utc)\n", + " for inst_id in inst_ids:\n", + " work_items.append(\n", + " {\n", + " \"file_fingerprint\": fp,\n", + " \"file_name\": file_name,\n", + " \"local_path\": local_path,\n", + " \"institution_id\": inst_id,\n", + " \"inst_col\": inst_col,\n", + " \"file_size\": r[\"file_size\"],\n", + " \"file_modified_time\": r[\"file_modified_time\"],\n", + " \"planned_at\": now_ts,\n", + " }\n", + " )\n", + "\n", + " preview_ids = inst_ids[:10]\n", + " logger.info(\n", + " f\"file={file_name} fp={fp}: found {len(inst_ids)} institution id(s) using column '{inst_col}'. \"\n", + " f\"Preview first 10 IDs={preview_ids}\"\n", + " )\n", + "\n", + " except Exception as e:\n", + " logger.exception(f\"Failed expanding file={file_name} fp={fp}: {e}\")\n", + " # We don't write manifests here per your division; fail fast so workflow can surface issue.\n", + " raise" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "32d5bc9c-16a1-42b4-adef-f1a442e5d447", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "if missing_files:\n", + " # This usually indicates the staged files were cleaned up or the staging path\n", + " # is not accessible from this cluster.\n", + " # Fail fast so the workflow stops (downstream cannot proceed without the staged files).\n", + " msg = (\n", + " \"Some staged files are missing on disk (staging path missing/inaccessible). \"\n", + " + \"; \".join([f\"fp={fp} file={fn} path={lp}\" for fp, fn, lp in missing_files])\n", + " )\n", + " logger.error(msg)\n", + " raise FileNotFoundError(msg)\n", + "\n", + "if not work_items:\n", + " logger.info(\"No work items generated from staged files. Exiting (no-op).\")\n", + " dbutils.notebook.exit(\"NO_WORK_ITEMS\")\n", + "\n", + "schema = T.StructType(\n", + " [\n", + " T.StructField(\"file_fingerprint\", T.StringType(), False),\n", + " T.StructField(\"file_name\", T.StringType(), False),\n", + " T.StructField(\"local_path\", T.StringType(), False),\n", + " T.StructField(\"institution_id\", T.StringType(), False),\n", + " T.StructField(\"inst_col\", T.StringType(), False),\n", + " T.StructField(\"file_size\", T.LongType(), True),\n", + " T.StructField(\"file_modified_time\", T.TimestampType(), True),\n", + " T.StructField(\"planned_at\", T.TimestampType(), False),\n", + " ]\n", + ")\n", + "\n", + "df_plan = spark.createDataFrame(work_items, schema=schema)\n", + "\n", + "logger.info(\"Work items summary by file (distinct institutions):\")\n", + "df_plan.groupBy(\"file_name\").agg(\n", + " F.countDistinct(\"institution_id\").alias(\"institution_count\")\n", + ").orderBy(\"file_name\").show(truncate=False)\n", + "\n", + "df_plan.createOrReplaceTempView(\"incoming_plan_rows\")\n", + "\n", + "# Idempotent upsert: unique per (file_fingerprint, institution_id)\n", + "spark.sql(\n", + " f\"\"\"\n", + " MERGE INTO {PLAN_TABLE_PATH} AS t\n", + " USING incoming_plan_rows AS s\n", + " ON t.file_fingerprint = s.file_fingerprint\n", + " AND t.institution_id = s.institution_id\n", + " WHEN MATCHED THEN UPDATE SET\n", + " t.file_name = s.file_name,\n", + " t.local_path = s.local_path,\n", + " t.inst_col = s.inst_col,\n", + " t.file_size = s.file_size,\n", + " t.file_modified_time = s.file_modified_time,\n", + " t.planned_at = s.planned_at\n", + " WHEN NOT MATCHED THEN INSERT *\n", + " \"\"\"\n", + ")\n", + "\n", + "count_out = df_plan.count()\n", + "logger.info(\n", + " f\"Wrote/updated {count_out} institution work item(s) into {PLAN_TABLE_PATH}.\"\n", + ")\n", + "dbutils.notebook.exit(f\"WORK_ITEMS={count_out}\")" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "computePreferences": null, + "dashboards": [], + "environmentMetadata": { + "base_environment": "", + "environment_version": "4" + }, + "inputWidgetPreferences": null, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 4 + }, + "notebookName": "02_file_institution_expand", + "widgets": {} + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/nsc_sftp_automated_data_ingestion/03_per_institution_bronze_ingest.ipynb b/notebooks/nsc_sftp_automated_data_ingestion/03_per_institution_bronze_ingest.ipynb new file mode 100644 index 000000000..58c25716d --- /dev/null +++ b/notebooks/nsc_sftp_automated_data_ingestion/03_per_institution_bronze_ingest.ipynb @@ -0,0 +1,459 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "0ed056e5-420d-4b47-8812-cf63f1f895c3", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# Databricks notebook source\n", + "# Script 4 — 04_per_institution_bronze_ingest\n", + "#\n", + "# Purpose:\n", + "# Consume institution_ingest_plan (created by Script 3), and for each (file × institution):\n", + "# - get bearer token from SST staging using X-API-KEY (from Databricks secrets)\n", + "# - call /api/v1/institutions/pdp-id/{pdp_id} to resolve institution name\n", + "# - map name -> schema prefix via databricksify_inst_name()\n", + "# - locate _bronze schema in staging_sst_01\n", + "# - choose a volume in that schema containing \"bronze\"\n", + "# - filter rows by institution id (exactly like current script)\n", + "# - write to bronze volume using helper.process_and_save_file (exact same ingestion method)\n", + "# After all institutions for a file are processed, update ingestion_manifest:\n", + "# - BRONZE_WRITTEN if all institution ingests succeeded (or were already present)\n", + "# - FAILED if any error occurred for that file (store error_message)\n", + "#\n", + "# Constraints:\n", + "# - NO SFTP connection (uses staged local files from Script 1/3)\n", + "# - Uses existing ingestion function + behavior from current script\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "de7936c9-a18c-4a87-858a-2c15045481d0", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "%pip install pandas python-box pyyaml requests paramiko\n", + "%pip install git+https://github.com/datakind/edvise.git@Automated_Ingestion_Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%restart_python" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "83538ecc-3986-46a8-a755-fb037fee8039", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "\n", + "import pandas as pd\n", + "from databricks.connect import DatabricksSession\n", + "\n", + "from pyspark.sql import functions as F\n", + "\n", + "from edvise.utils.api_requests import (\n", + " EdviseAPIClient,\n", + " fetch_institution_by_pdp_id,\n", + ")\n", + "from edvise.utils.data_cleaning import convert_to_snake_case\n", + "from edvise.utils.databricks import (\n", + " find_bronze_schema,\n", + " find_bronze_volume_name,\n", + " databricksify_inst_name,\n", + ")\n", + "from edvise.utils.sftp import output_file_name_from_sftp\n", + "from edvise.ingestion.nsc_sftp_helpers import (\n", + " process_and_save_file,\n", + " update_manifest,\n", + ")\n", + "from edvise.ingestion.constants import (\n", + " CATALOG,\n", + " PLAN_TABLE_PATH,\n", + " MANIFEST_TABLE_PATH,\n", + " SST_BASE_URL,\n", + " SST_TOKEN_ENDPOINT,\n", + " INSTITUTION_LOOKUP_PATH,\n", + " SST_API_KEY_SECRET_KEY,\n", + " COLUMN_RENAMES,\n", + ")\n", + "\n", + "try:\n", + " dbutils # noqa: F821\n", + "except NameError:\n", + " from unittest.mock import MagicMock\n", + "\n", + " dbutils = MagicMock()\n", + "\n", + "try:\n", + " display # noqa: F821\n", + "except NameError:\n", + "\n", + " def display(x):\n", + " return x\n", + "\n", + "\n", + "spark = DatabricksSession.builder.getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "7aea7d3e-2734-40ed-ae5c-a32e67ce3541", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "logging.basicConfig(\n", + " level=logging.INFO,\n", + " format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", + ")\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "asset_scope = \"nsc-sftp-asset\"\n", + "SST_API_KEY = dbutils.secrets.get(scope=asset_scope, key=SST_API_KEY_SECRET_KEY).strip()\n", + "if not SST_API_KEY:\n", + " raise RuntimeError(\n", + " f\"Empty SST API key from secrets: scope={asset_scope} key={SST_API_KEY_SECRET_KEY}\"\n", + " )\n", + "\n", + "api_client = EdviseAPIClient(\n", + " api_key=SST_API_KEY,\n", + " base_url=SST_BASE_URL,\n", + " token_endpoint=SST_TOKEN_ENDPOINT,\n", + " institution_lookup_path=INSTITUTION_LOOKUP_PATH,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "1a0c7f38-ab8f-4a54-a778-6c2e79b5044d", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "if not spark.catalog.tableExists(PLAN_TABLE_PATH):\n", + " logger.info(f\"Plan table not found: {PLAN_TABLE_PATH}. Exiting (no-op).\")\n", + " dbutils.notebook.exit(\"NO_PLAN_TABLE\")\n", + "\n", + "if not spark.catalog.tableExists(MANIFEST_TABLE_PATH):\n", + " raise RuntimeError(f\"Manifest table missing: {MANIFEST_TABLE_PATH}\")\n", + "\n", + "plan_df = spark.table(PLAN_TABLE_PATH)\n", + "if plan_df.limit(1).count() == 0:\n", + " logger.info(\"institution_ingest_plan is empty. Exiting (no-op).\")\n", + " dbutils.notebook.exit(\"NO_WORK_ITEMS\")\n", + "\n", + "manifest_df = spark.table(MANIFEST_TABLE_PATH).select(\"file_fingerprint\", \"status\")\n", + "plan_new_df = plan_df.join(manifest_df, on=\"file_fingerprint\", how=\"inner\").where(\n", + " F.col(\"status\") == F.lit(\"NEW\")\n", + ")\n", + "if plan_new_df.limit(1).count() == 0:\n", + " logger.info(\"No planned work items where manifest status=NEW. Exiting (no-op).\")\n", + " dbutils.notebook.exit(\"NO_NEW_TO_INGEST\")\n", + "\n", + "plan_summary_df = (\n", + " plan_new_df.groupBy(\"file_name\", \"inst_col\", \"local_path\")\n", + " .agg(F.countDistinct(\"institution_id\").alias(\"institution_count\"))\n", + " .orderBy(\"file_name\")\n", + ")\n", + "logger.info(\"Planned work summary (manifest status=NEW):\")\n", + "display(plan_summary_df)\n", + "\n", + "# Collect file groups\n", + "file_groups = (\n", + " plan_new_df.select(\n", + " \"file_fingerprint\",\n", + " \"file_name\",\n", + " \"local_path\",\n", + " \"inst_col\",\n", + " \"file_size\",\n", + " \"file_modified_time\",\n", + " )\n", + " .distinct()\n", + " .collect()\n", + ")\n", + "\n", + "logger.info(f\"Preparing to ingest {len(file_groups)} NEW file(s).\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "cf0729e1-7a4f-402a-85b6-1bca3696f878", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, + "outputs": [], + "source": [ + "# ---------------------------\n", + "# Main per-file ingest loop\n", + "# ---------------------------\n", + "processed_files = 0\n", + "failed_files = 0\n", + "skipped_files = 0\n", + "\n", + "for fg in file_groups:\n", + " fp = fg[\"file_fingerprint\"]\n", + " sftp_file_name = fg[\"file_name\"]\n", + " local_path = fg[\"local_path\"]\n", + " inst_col = fg[\"inst_col\"]\n", + "\n", + " if not local_path or not os.path.exists(local_path):\n", + " err = f\"Staged local file missing for fp={fp}: {local_path}\"\n", + " logger.error(err)\n", + " update_manifest(\n", + " spark, MANIFEST_TABLE_PATH, fp, status=\"FAILED\", error_message=err[:8000]\n", + " )\n", + " failed_files += 1\n", + " continue\n", + "\n", + " try:\n", + " # Read only the institution-id column as string at load time to avoid float promotion\n", + " header_cols = pd.read_csv(local_path, nrows=0).columns.tolist()\n", + " raw_inst_col = next(\n", + " (\n", + " c\n", + " for c in header_cols\n", + " if COLUMN_RENAMES.get(\n", + " convert_to_snake_case(c), convert_to_snake_case(c)\n", + " )\n", + " == inst_col\n", + " ),\n", + " None,\n", + " )\n", + " dtype = {raw_inst_col: str} if raw_inst_col else None\n", + " df_full = pd.read_csv(local_path, on_bad_lines=\"warn\", dtype=dtype)\n", + " df_full = df_full.rename(\n", + " columns={c: convert_to_snake_case(c) for c in df_full.columns}\n", + " )\n", + " df_full = df_full.rename(columns=COLUMN_RENAMES)\n", + "\n", + " if inst_col not in df_full.columns:\n", + " err = f\"Expected institution column '{inst_col}' not found after normalization/renames for file={sftp_file_name} fp={fp}\"\n", + " logger.error(err)\n", + " update_manifest(\n", + " spark,\n", + " MANIFEST_TABLE_PATH,\n", + " fp,\n", + " status=\"FAILED\",\n", + " error_message=err[:8000],\n", + " )\n", + " failed_files += 1\n", + " continue\n", + "\n", + " inst_ids = (\n", + " plan_new_df.where(F.col(\"file_fingerprint\") == fp)\n", + " .select(\"institution_id\")\n", + " .distinct()\n", + " .collect()\n", + " )\n", + " inst_ids = [r[\"institution_id\"] for r in inst_ids]\n", + "\n", + " if not inst_ids:\n", + " logger.info(\n", + " f\"No institution_ids in plan for file={sftp_file_name} fp={fp}. Marking BRONZE_WRITTEN (no-op).\"\n", + " )\n", + " update_manifest(\n", + " spark,\n", + " MANIFEST_TABLE_PATH,\n", + " fp,\n", + " status=\"BRONZE_WRITTEN\",\n", + " error_message=None,\n", + " )\n", + " skipped_files += 1\n", + " continue\n", + "\n", + " preview_inst_ids = inst_ids[:10]\n", + " logger.info(\n", + " f\"file={sftp_file_name} fp={fp}: ingesting {len(inst_ids)} institution(s) \"\n", + " f\"using inst_col='{inst_col}'. Preview first 10 IDs={preview_inst_ids}\"\n", + " )\n", + "\n", + " # Aggregate errors at file-level\n", + " file_errors = []\n", + "\n", + " for inst_id in inst_ids:\n", + " try:\n", + " target_inst_id = str(inst_id)\n", + " filtered_df = df_full[df_full[inst_col] == target_inst_id].reset_index(\n", + " drop=True\n", + " )\n", + "\n", + " if filtered_df.empty:\n", + " logger.info(\n", + " f\"file={sftp_file_name} fp={fp}: institution {inst_id} has 0 rows; skipping.\"\n", + " )\n", + " continue\n", + "\n", + " # Resolve institution -> name\n", + " inst_info = fetch_institution_by_pdp_id(api_client, inst_id)\n", + " inst_name = inst_info.get(\"name\")\n", + " if not inst_name:\n", + " raise ValueError(\n", + " f\"SST API returned no 'name' for pdp_id={inst_id}. Response={inst_info}\"\n", + " )\n", + "\n", + " inst_prefix = databricksify_inst_name(inst_name)\n", + "\n", + " # Find bronze schema + volume\n", + " bronze_schema = find_bronze_schema(spark, CATALOG, inst_prefix)\n", + " bronze_volume_name = find_bronze_volume_name(\n", + " spark, CATALOG, bronze_schema\n", + " )\n", + " volume_dir = f\"/Volumes/{CATALOG}/{bronze_schema}/{bronze_volume_name}\"\n", + "\n", + " # Output naming rule (same as current script)\n", + " out_file_name = output_file_name_from_sftp(sftp_file_name)\n", + " full_path = os.path.join(volume_dir, out_file_name)\n", + "\n", + " # Idempotency check\n", + " if os.path.exists(full_path):\n", + " logger.info(\n", + " f\"file={sftp_file_name} inst={inst_id}: already exists in {volume_dir}; skipping write.\"\n", + " )\n", + " continue\n", + "\n", + " logger.info(\n", + " f\"file={sftp_file_name} inst={inst_id}: writing to {volume_dir} as {out_file_name}\"\n", + " )\n", + " process_and_save_file(\n", + " volume_dir=volume_dir, file_name=out_file_name, df=filtered_df\n", + " )\n", + " logger.info(f\"file={sftp_file_name} inst={inst_id}: write complete.\")\n", + "\n", + " except Exception as e:\n", + " msg = f\"inst_ingest_failed file={sftp_file_name} fp={fp} inst={inst_id}: {e}\"\n", + " logger.exception(msg)\n", + " file_errors.append(msg)\n", + "\n", + " if file_errors:\n", + " err = \" | \".join(file_errors)[:8000]\n", + " update_manifest(\n", + " spark, MANIFEST_TABLE_PATH, fp, status=\"FAILED\", error_message=err\n", + " )\n", + " failed_files += 1\n", + " else:\n", + " update_manifest(\n", + " spark,\n", + " MANIFEST_TABLE_PATH,\n", + " fp,\n", + " status=\"BRONZE_WRITTEN\",\n", + " error_message=None,\n", + " )\n", + " processed_files += 1\n", + "\n", + " except Exception as e:\n", + " msg = f\"fatal_file_error file={sftp_file_name} fp={fp}: {e}\"\n", + " logger.exception(msg)\n", + " update_manifest(\n", + " spark, MANIFEST_TABLE_PATH, fp, status=\"FAILED\", error_message=msg[:8000]\n", + " )\n", + " failed_files += 1\n", + "\n", + "logger.info(\n", + " f\"Done. processed_files={processed_files}, failed_files={failed_files}, skipped_files={skipped_files}\"\n", + ")\n", + "dbutils.notebook.exit(\n", + " f\"PROCESSED={processed_files};FAILED={failed_files};SKIPPED={skipped_files}\"\n", + ")" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "computePreferences": null, + "dashboards": [], + "environmentMetadata": { + "base_environment": "", + "environment_version": "4" + }, + "inputWidgetPreferences": null, + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 4 + }, + "notebookName": "03_per_institution_bronze_ingest", + "widgets": {} + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pyproject.toml b/pyproject.toml index cf7e01088..4cbbef48e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,3 +101,7 @@ ignore_missing_imports = true follow_imports = "silent" # in case of irreconcilable differences, consider telling mypy to ignore all errors # ignore_errors = true + +[[tool.mypy.overrides]] +module = "paramiko" +ignore_missing_imports = true diff --git a/src/edvise/ingestion/__init__.py b/src/edvise/ingestion/__init__.py new file mode 100644 index 000000000..8df7508bf --- /dev/null +++ b/src/edvise/ingestion/__init__.py @@ -0,0 +1 @@ +"""Data ingestion utilities for various data sources.""" diff --git a/src/edvise/ingestion/constants.py b/src/edvise/ingestion/constants.py new file mode 100644 index 000000000..8eef55f54 --- /dev/null +++ b/src/edvise/ingestion/constants.py @@ -0,0 +1,91 @@ +""" +Constants for NSC SFTP ingestion pipeline. + +These values are fixed and don't vary between runs or environments. +For environment-specific values (like secret scope names), see gcp_config.yaml. +""" + +from typing import Any +from unittest.mock import MagicMock + +dbutils: Any + +# Databricks catalog and schema +try: + from databricks.sdk.runtime import dbutils as _dbutils +except Exception: + # Local/offline context: allow imports/tests to run without Databricks. + dbutils = MagicMock() + CATALOG = "dev_sst_02" +else: + dbutils = _dbutils + try: + workspace_id = str( + dbutils.notebook.entry_point.getDbutils() + .notebook() + .getContext() + .workspaceId() + .get() + ) + except Exception: + # Databricks SDK is importable, but we're not running in a notebook/runtime + # context where workspace ID is available. + dbutils = MagicMock() + CATALOG = "dev_sst_02" + else: + if workspace_id == "4437281602191762": + CATALOG = "dev_sst_02" + elif workspace_id == "2052166062819251": + CATALOG = "staging_sst_01" + else: + raise RuntimeError( + f"Unsupported Databricks workspace_id={workspace_id!r} for NSC ingestion. " + "Add a mapping in src/edvise/ingestion/constants.py." + ) +DEFAULT_SCHEMA = "default" + +# Table names (without catalog.schema prefix) +MANIFEST_TABLE = "ingestion_manifest" +QUEUE_TABLE = "pending_ingest_queue" +PLAN_TABLE = "institution_ingest_plan" + +# Full table paths +MANIFEST_TABLE_PATH = f"{CATALOG}.{DEFAULT_SCHEMA}.{MANIFEST_TABLE}" +QUEUE_TABLE_PATH = f"{CATALOG}.{DEFAULT_SCHEMA}.{QUEUE_TABLE}" +PLAN_TABLE_PATH = f"{CATALOG}.{DEFAULT_SCHEMA}.{PLAN_TABLE}" + +# SFTP settings +SFTP_REMOTE_FOLDER = "./receive" +SFTP_SOURCE_SYSTEM = "NSC" +SFTP_PORT = 22 +SFTP_TMP_VOLUME_NAME = "tmp" +SFTP_TMP_VOLUME_FQN = f"{CATALOG}.{DEFAULT_SCHEMA}.{SFTP_TMP_VOLUME_NAME}" +SFTP_TMP_DIR = f"/Volumes/{CATALOG}/{DEFAULT_SCHEMA}/{SFTP_TMP_VOLUME_NAME}" +SFTP_DOWNLOAD_CHUNK_MB = 150 +SFTP_VERIFY_DOWNLOAD = "size" # Options: "size", "sha256", "md5", "none" + +# Edvise API settings +SST_BASE_URL = "https://staging-sst.datakind.org" +SST_TOKEN_ENDPOINT = f"{SST_BASE_URL}/api/v1/token-from-api-key" +INSTITUTION_LOOKUP_PATH = "/api/v1/institutions/pdp-id/{pdp_id}" +SST_API_KEY_SECRET_KEY = "sst_staging_api_key" # Key name in Databricks secrets + +# File processing settings +INSTITUTION_COLUMN_PATTERN = r"(?=.*institution)(?=.*id)" + +# Column name mappings (mangled -> normalized) +# Applied after snake_case conversion +COLUMN_RENAMES = { + # NOTE: convert_to_snake_case splits trailing digit groups with an underscore, + # e.g. "attemptedgatewaymathyear1" -> "attemptedgatewaymathyear_1". + "attemptedgatewaymathyear_1": "attempted_gateway_math_year_1", + "attemptedgatewayenglishyear_1": "attempted_gateway_english_year_1", + "completedgatewaymathyear_1": "completed_gateway_math_year_1", + "completedgatewayenglishyear_1": "completed_gateway_english_year_1", + "gatewaymathgradey_1": "gateway_math_grade_y_1", + "gatewayenglishgradey_1": "gateway_english_grade_y_1", + "attempteddevmathy_1": "attempted_dev_math_y_1", + "attempteddevenglishy_1": "attempted_dev_english_y_1", + "completeddevmathy_1": "completed_dev_math_y_1", + "completeddevenglishy_1": "completed_dev_english_y_1", +} diff --git a/src/edvise/ingestion/nsc_sftp_helpers.py b/src/edvise/ingestion/nsc_sftp_helpers.py new file mode 100644 index 000000000..5fff15b61 --- /dev/null +++ b/src/edvise/ingestion/nsc_sftp_helpers.py @@ -0,0 +1,541 @@ +""" +NSC SFTP ingestion helpers. + +NSC-specific utilities for processing SFTP files, extracting institution IDs, +managing ingestion manifests, and working with Databricks schemas/volumes. +""" + +from __future__ import annotations + +import logging +import math +import os +import re +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + import paramiko + +import pandas as pd +import pyspark.sql +from pyspark.sql import functions as F +from pyspark.sql import types as T + +from edvise.ingestion.constants import ( + CATALOG, + DEFAULT_SCHEMA, + MANIFEST_TABLE_PATH, + QUEUE_TABLE_PATH, + SFTP_DOWNLOAD_CHUNK_MB, + SFTP_TMP_DIR, + SFTP_TMP_VOLUME_FQN, + SFTP_TMP_VOLUME_NAME, + SFTP_VERIFY_DOWNLOAD, +) +from edvise.utils.data_cleaning import convert_to_snake_case, detect_institution_column +from edvise.utils.sftp import download_sftp_atomic + +LOGGER = logging.getLogger(__name__) + + +def _ensure_sftp_staging_volume_exists(spark: pyspark.sql.SparkSession) -> None: + """ + Ensure the configured UC volume used for SFTP staging exists and is accessible. + + We stage files to a Unity Catalog volume (CATALOG.default.tmp) so paths remain + valid across workflow tasks/clusters. + """ + try: + rows = spark.sql(f"SHOW VOLUMES IN {CATALOG}.{DEFAULT_SCHEMA}").collect() + except Exception as e: + raise RuntimeError( + f"Failed to verify staging volume exists. Expected UC volume: {SFTP_TMP_VOLUME_FQN}. " + f"Could not list volumes in {CATALOG}.{DEFAULT_SCHEMA}: {e}" + ) from e + + def _volume_name(row: pyspark.sql.Row) -> str: + d = row.asDict() + for k in ["volume_name", "volumeName", "name"]: + v = d.get(k) + if v: + return str(v) + return str(list(d.values())[0]) + + volume_names = {_volume_name(r) for r in rows} + if SFTP_TMP_VOLUME_NAME not in volume_names: + raise RuntimeError( + f"Required staging UC volume not found: {SFTP_TMP_VOLUME_FQN}. " + "Create it before running NSC ingestion." + ) + + if not os.path.isdir(SFTP_TMP_DIR): + raise RuntimeError( + f"UC volume exists but filesystem path is not accessible: {SFTP_TMP_DIR}. " + f"Expected UC volume: {SFTP_TMP_VOLUME_FQN}." + ) + + +def ensure_manifest_and_queue_tables(spark: pyspark.sql.SparkSession) -> None: + """ + Create required delta tables if missing. + - ingestion_manifest: includes file_fingerprint for idempotency + - pending_ingest_queue: holds local tmp path so downstream doesn't connect to SFTP again + + Args: + spark: Spark session + """ + spark.sql( + f""" + CREATE TABLE IF NOT EXISTS {MANIFEST_TABLE_PATH} ( + file_fingerprint STRING, + source_system STRING, + sftp_path STRING, + file_name STRING, + file_size BIGINT, + file_modified_time TIMESTAMP, + ingested_at TIMESTAMP, + processed_at TIMESTAMP, + status STRING, + error_message STRING + ) + USING DELTA + """ + ) + + spark.sql( + f""" + CREATE TABLE IF NOT EXISTS {QUEUE_TABLE_PATH} ( + file_fingerprint STRING, + source_system STRING, + sftp_path STRING, + file_name STRING, + file_size BIGINT, + file_modified_time TIMESTAMP, + local_tmp_path STRING, + queued_at TIMESTAMP + ) + USING DELTA + """ + ) + + +def build_listing_df( + spark: pyspark.sql.SparkSession, file_rows: list[dict] +) -> pyspark.sql.DataFrame: + """ + Build DataFrame from file listing rows with file fingerprints. + + Creates a DataFrame with file metadata and computes a stable fingerprint + from metadata (file version identity). + + Args: + spark: Spark session + file_rows: List of dicts with keys: source_system, sftp_path, file_name, + file_size, file_modified_time + + Returns: + DataFrame with file_fingerprint column added + """ + schema = T.StructType( + [ + T.StructField("source_system", T.StringType(), False), + T.StructField("sftp_path", T.StringType(), False), + T.StructField("file_name", T.StringType(), False), + T.StructField("file_size", T.LongType(), True), + T.StructField("file_modified_time", T.TimestampType(), True), + ] + ) + + df = spark.createDataFrame(file_rows, schema=schema) + + # Stable fingerprint from metadata (file version identity) + # Note: cast mtime to string in a consistent format to avoid subtle timestamp formatting diffs. + df = df.withColumn( + "file_fingerprint", + F.sha2( + F.concat_ws( + "||", + F.col("source_system"), + F.col("sftp_path"), + F.col("file_name"), + F.coalesce(F.col("file_size").cast("string"), F.lit("")), + F.coalesce( + F.date_format( + F.col("file_modified_time"), "yyyy-MM-dd'T'HH:mm:ss.SSSXXX" + ), + F.lit(""), + ), + ), + 256, + ), + ) + + return df + + +def upsert_new_to_manifest( + spark: pyspark.sql.SparkSession, df_listing: pyspark.sql.DataFrame +) -> None: + """ + Insert NEW rows for unseen fingerprints only. + + Args: + spark: Spark session + df_listing: DataFrame with file listing (must have file_fingerprint column) + """ + df_manifest_insert = ( + df_listing.select( + "file_fingerprint", + "source_system", + "sftp_path", + "file_name", + "file_size", + "file_modified_time", + ) + .withColumn("ingested_at", F.lit(None).cast("timestamp")) + .withColumn("processed_at", F.lit(None).cast("timestamp")) + .withColumn("status", F.lit("NEW")) + .withColumn("error_message", F.lit(None).cast("string")) + ) + + df_manifest_insert.createOrReplaceTempView("incoming_manifest_rows") + + spark.sql( + f""" + MERGE INTO {MANIFEST_TABLE_PATH} AS t + USING incoming_manifest_rows AS s + ON t.file_fingerprint = s.file_fingerprint + WHEN NOT MATCHED THEN INSERT * + """ + ) + + +def get_files_to_queue( + spark: pyspark.sql.SparkSession, df_listing: pyspark.sql.DataFrame +) -> pyspark.sql.DataFrame: + """ + Return files that should be queued for downstream processing. + + Criteria: + - present in current SFTP listing (df_listing) + - exist in manifest with status = 'NEW' + - NOT already present in pending_ingest_queue + + Args: + spark: Spark session + df_listing: DataFrame with file listing (must have file_fingerprint column) + + Returns: + DataFrame of files to queue + """ + manifest_new = ( + spark.table(MANIFEST_TABLE_PATH) + .select("file_fingerprint", "status") + .where(F.col("status") == F.lit("NEW")) + .select("file_fingerprint") + ) + + already_queued = spark.table(QUEUE_TABLE_PATH).select("file_fingerprint").distinct() + + # Only queue files that are: + # in current listing AND in manifest NEW AND not in queue + to_queue = df_listing.join(manifest_new, on="file_fingerprint", how="inner").join( + already_queued, on="file_fingerprint", how="left_anti" + ) + return to_queue + + +def download_new_files_and_queue( + spark: pyspark.sql.SparkSession, + sftp: paramiko.SFTPClient, + df_new: pyspark.sql.DataFrame, + logger: Optional[logging.Logger] = None, +) -> int: + """ + Download each new file to /tmp and upsert into pending_ingest_queue. + + Args: + spark: Spark session + sftp: SFTP client connection + df_new: DataFrame of files to download and queue + logger: Optional logger instance (defaults to module logger) + + Returns: + Number of files queued + """ + if logger is None: + logger = LOGGER + _ensure_sftp_staging_volume_exists(spark) + + rows = df_new.select( + "file_fingerprint", + "source_system", + "sftp_path", + "file_name", + "file_size", + "file_modified_time", + ).collect() + + queued = [] + for r in rows: + fp = r["file_fingerprint"] + sftp_path = r["sftp_path"] + file_name = r["file_name"] + + remote_path = f"{sftp_path.rstrip('/')}/{file_name}" + local_path = os.path.abspath(os.path.join(SFTP_TMP_DIR, f"{fp}__{file_name}")) + + # If local already exists (e.g., rerun), skip re-download + if not os.path.exists(local_path): + logger.info( + f"Downloading new file from SFTP: {remote_path} -> {local_path}" + ) + download_sftp_atomic( + sftp, + remote_path, + local_path, + chunk=SFTP_DOWNLOAD_CHUNK_MB, + verify=SFTP_VERIFY_DOWNLOAD, + ) + else: + logger.info(f"Local file already staged, skipping download: {local_path}") + + queued.append( + { + "file_fingerprint": fp, + "source_system": r["source_system"], + "sftp_path": sftp_path, + "file_name": file_name, + "file_size": r["file_size"], + "file_modified_time": r["file_modified_time"], + "local_tmp_path": local_path, + "queued_at": datetime.now(timezone.utc), + } + ) + + if not queued: + return 0 + + qschema = T.StructType( + [ + T.StructField("file_fingerprint", T.StringType(), False), + T.StructField("source_system", T.StringType(), False), + T.StructField("sftp_path", T.StringType(), False), + T.StructField("file_name", T.StringType(), False), + T.StructField("file_size", T.LongType(), True), + T.StructField("file_modified_time", T.TimestampType(), True), + T.StructField("local_tmp_path", T.StringType(), False), + T.StructField("queued_at", T.TimestampType(), False), + ] + ) + + df_queue = spark.createDataFrame(queued, schema=qschema) + df_queue.createOrReplaceTempView("incoming_queue_rows") + + # Upsert into queue (idempotent by fingerprint) + spark.sql( + f""" + MERGE INTO {QUEUE_TABLE_PATH} AS t + USING incoming_queue_rows AS s + ON t.file_fingerprint = s.file_fingerprint + WHEN MATCHED THEN UPDATE SET + t.local_tmp_path = s.local_tmp_path, + t.queued_at = s.queued_at + WHEN NOT MATCHED THEN INSERT * + """ + ) + + return len(queued) + + +def ensure_plan_table(spark: pyspark.sql.SparkSession, plan_table: str) -> None: + """ + Create institution_ingest_plan table if it doesn't exist. + + Args: + spark: Spark session + plan_table: Full table path (e.g., "catalog.schema.table") + """ + spark.sql( + f""" + CREATE TABLE IF NOT EXISTS {plan_table} ( + file_fingerprint STRING, + file_name STRING, + local_path STRING, + institution_id STRING, + inst_col STRING, + file_size BIGINT, + file_modified_time TIMESTAMP, + planned_at TIMESTAMP + ) + USING DELTA + """ + ) + + +def extract_institution_ids( + local_path: str, + *, + renames: dict[str, str], + inst_col_pattern: re.Pattern, +) -> tuple[Optional[str], list[str]]: + """ + Extract unique institution IDs from a staged CSV file. + + Reads file, normalizes/renames columns, detects institution column, + and returns unique institution IDs. + + Args: + local_path: Path to local CSV file + renames: Dictionary mapping old column names to new names + inst_col_pattern: Compiled regex pattern to match institution column + + Returns: + Tuple of (institution_column_name, sorted_list_of_unique_ids). + Returns (None, []) if no institution column found. + + Example: + >>> pattern = re.compile(r"(?=.*institution)(?=.*id)", re.IGNORECASE) + >>> renames = {"inst_id": "institution_id"} + >>> col, ids = extract_institution_ids( + ... "/tmp/file.csv", renames=renames, inst_col_pattern=pattern + ... ) + >>> print(col, ids) + 'institution_id' ['12345', '67890'] + """ + df = pd.read_csv(local_path, on_bad_lines="warn") + # Use convert_to_snake_case from utils instead of normalize_col + df = df.rename(columns={c: convert_to_snake_case(c) for c in df.columns}) + df = df.rename(columns=renames) + + inst_col = detect_institution_column(df.columns.tolist(), inst_col_pattern) + if inst_col is None: + return None, [] + + # Make IDs robust: drop nulls, strip whitespace, keep as string + series = df[inst_col].dropna() + + # Some files store as numeric; normalize to integer-like strings when possible + ids = set() + for v in series.tolist(): + # Handle pandas/numpy numeric types + try: + if isinstance(v, int): + ids.add(str(v)) + continue + if isinstance(v, float): + # Treat +/-inf as invalid IDs + if not math.isfinite(v): + continue + # If 323100.0 -> "323100" + if v.is_integer(): + ids.add(str(int(v))) + else: + ids.add(str(v).strip()) + continue + except Exception: + pass + + s = str(v).strip() + if s == "" or s.lower() in { + "nan", + "inf", + "+inf", + "-inf", + "infinity", + "+infinity", + "-infinity", + }: + continue + # If it's "323100.0" as string, coerce safely + if re.fullmatch(r"\d+\.0+", s): + s = s.split(".")[0] + ids.add(s) + + return inst_col, sorted(ids) + + +def update_manifest( + spark: pyspark.sql.SparkSession, + manifest_table: str, + file_fingerprint: str, + *, + status: str, + error_message: Optional[str], +) -> None: + """ + Update ingestion_manifest for a file_fingerprint. + + Assumes upstream inserted status=NEW already. Updates status, error_message, + and timestamps. + + Args: + spark: Spark session + manifest_table: Full table path (e.g., "catalog.schema.table") + file_fingerprint: File fingerprint identifier + status: New status (e.g., "BRONZE_WRITTEN", "FAILED") + error_message: Error message if status is FAILED, None otherwise + """ + from pyspark.sql import types as T + + now_ts = datetime.now(timezone.utc) + + # ingested_at only set when we finish BRONZE_WRITTEN + row = { + "file_fingerprint": file_fingerprint, + "status": status, + "error_message": error_message, + "ingested_at": now_ts if status == "BRONZE_WRITTEN" else None, + "processed_at": now_ts, + } + + schema = T.StructType( + [ + T.StructField("file_fingerprint", T.StringType(), False), + T.StructField("status", T.StringType(), False), + T.StructField("error_message", T.StringType(), True), + T.StructField("ingested_at", T.TimestampType(), True), + T.StructField("processed_at", T.TimestampType(), False), + ] + ) + df = spark.createDataFrame([row], schema=schema) + df.createOrReplaceTempView("manifest_updates") + + spark.sql( + f""" + MERGE INTO {manifest_table} AS t + USING manifest_updates AS s + ON t.file_fingerprint = s.file_fingerprint + WHEN MATCHED THEN UPDATE SET + t.status = s.status, + t.error_message = s.error_message, + t.ingested_at = COALESCE(s.ingested_at, t.ingested_at), + t.processed_at = s.processed_at + """ + ) + + +def process_and_save_file(volume_dir: str, file_name: str, df: pd.DataFrame) -> str: + """ + Process DataFrame and save to Databricks volume. + + Normalizes column names and saves as CSV. + + Args: + volume_dir: Volume directory path + file_name: Output filename + df: DataFrame to save + + Returns: + Full path to saved file + """ + local_file_path = os.path.join(volume_dir, file_name) + + LOGGER.info(f"Saving to Volumes {local_file_path}") + # Normalize column names for Databricks compatibility + df.columns = [re.sub(r"[^a-zA-Z0-9_]", "_", col) for col in df.columns] + df.to_csv(local_file_path, index=False) + LOGGER.info(f"Saved {file_name} to {local_file_path}") + + return local_file_path diff --git a/src/edvise/utils/api_requests.py b/src/edvise/utils/api_requests.py index 5b2654f7d..88891488e 100644 --- a/src/edvise/utils/api_requests.py +++ b/src/edvise/utils/api_requests.py @@ -1,9 +1,9 @@ # Standard library imports import logging -import re import typing as t -from typing import cast -from urllib.parse import quote +from dataclasses import dataclass, field +from typing import Any, cast +from urllib.parse import quote, urljoin # Third-party imports import requests @@ -184,120 +184,6 @@ def validate_custom_model_exist(inst_id: str, model_name: str, api_key: str) -> return resp.text -# Compiled regex patterns for reverse transformation (performance optimization) -_REVERSE_REPLACEMENTS = { - "ctc": "community technical college", - "cc": "community college", - "st": "of science and technology", - "uni": "university", - "col": "college", -} - -# Pre-compile regex patterns for word boundary matching -_COMPILED_REVERSE_PATTERNS = { - abbrev: re.compile(r"\b" + re.escape(abbrev) + r"\b") - for abbrev in _REVERSE_REPLACEMENTS.keys() -} - - -def _validate_databricks_name_format(databricks_name: str) -> None: - """ - Validate that databricks name matches expected format. - - Args: - databricks_name: Name to validate - - Raises: - ValueError: If name is empty or contains invalid characters - """ - if not isinstance(databricks_name, str) or not databricks_name.strip(): - raise ValueError("databricks_name must be a non-empty string") - - pattern = "^[a-z0-9_]*$" - if not re.match(pattern, databricks_name): - raise ValueError( - f"Invalid databricks name format '{databricks_name}'. " - "Must contain only lowercase letters, numbers, and underscores." - ) - - -def _reverse_abbreviation_replacements(name: str) -> str: - """ - Reverse abbreviation replacements in the name. - - Handles the ambiguous "st" abbreviation: - - If "st" appears as the first word, it's kept as "st" (abbreviation for Saint) - and will be capitalized to "St" by title() case - - Otherwise, "st" is treated as "of science and technology" - - Args: - name: Name with underscores replaced by spaces - - Returns: - Name with abbreviations expanded to full forms - """ - # Split into words to handle "st" at the beginning specially - words = name.split() - - # Keep "st" at the beginning as-is (will be capitalized to "St" by title() case) - # Don't expand it to "saint" - preserve the abbreviation - - # Replace "st" in remaining positions with "of science and technology" - for i in range(len(words)): - if words[i] == "st" and i > 0: # Only replace if not the first word - words[i] = "of science and technology" - - # Rejoin and apply other abbreviation replacements - name = " ".join(words) - - # Apply other abbreviation replacements (excluding "st" which we handled above) - for abbrev, full_form in _REVERSE_REPLACEMENTS.items(): - if abbrev != "st": # Skip "st" as we handled it above - pattern = _COMPILED_REVERSE_PATTERNS[abbrev] - name = pattern.sub(full_form, name) - - return name - - -def reverse_databricksify_inst_name(databricks_name: str) -> str: - """ - Reverse the databricksify transformation to get back the original institution name. - - This function attempts to reverse the transformation done by databricksify_inst_name. - Since the transformation is lossy (multiple original names can map to the same - databricks name), this function produces the most likely original name. - - Args: - databricks_name: The databricks-transformed institution name (e.g., "motlow_state_cc") - Case inconsistencies are normalized (input is lowercased before processing). - - Returns: - The reversed institution name with proper capitalization (e.g., "Motlow State Community College") - - Raises: - ValueError: If the databricks name contains invalid characters - """ - # Normalize to lowercase to handle case inconsistencies - # (databricksify_inst_name always produces lowercase output) - databricks_name = databricks_name.lower() - _validate_databricks_name_format(databricks_name) - - # Step 1: Replace underscores with spaces - name = databricks_name.replace("_", " ") - - # Step 2: Reverse the abbreviation replacements - # The original replacements were done in this order (most specific first): - # 1. "community technical college" → "ctc" - # 2. "community college" → "cc" - # 3. "of science and technology" → "st" - # 4. "university" → "uni" - # 5. "college" → "col" - name = _reverse_abbreviation_replacements(name) - - # Step 3: Capitalize appropriately (title case) - return name.title() - - def _fetch_institution_by_name(normalized_name: str, access_token: str) -> t.Any: """ Fetch institution data from API by normalized name. @@ -373,6 +259,8 @@ def _validate_and_transform_institution_name( # Validate and transform databricks name if needed if is_databricks_name: try: + from edvise.utils.databricks import reverse_databricksify_inst_name + institution_name = reverse_databricksify_inst_name(institution_name.strip()) except ValueError as e: LOGGER.error( @@ -515,3 +403,152 @@ def log_custom_job( return resp.json() except ValueError: return resp.text + + +# --------------------------- +# Edvise API Client (with caching and auto-refresh) +# --------------------------- + + +@dataclass +class EdviseAPIClient: + """ + API client for Edvise API with bearer token management. + + Features: + - Automatic bearer token fetching and refresh + - Token caching within a session + - Institution lookup caching + - Automatic retry on 401 (unauthorized) errors + + Example: + >>> client = EdviseAPIClient( + ... api_key="your-api-key", + ... base_url="https://staging-sst.datakind.org", + ... token_endpoint="/api/v1/token-from-api-key", + ... institution_lookup_path="/api/v1/institutions/pdp-id/{pdp_id}" + ... ) + >>> institution = fetch_institution_by_pdp_id(client, "12345") + """ + + api_key: str + base_url: str + token_endpoint: str + institution_lookup_path: str + session: requests.Session = field(default_factory=requests.Session) + bearer_token: str | None = None + institution_cache: dict[str, dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Validate and normalize API client configuration.""" + self.api_key = self.api_key.strip() + if not self.api_key: + raise ValueError("Empty Edvise API key.") + + self.base_url = self.base_url.rstrip("/") + self.token_endpoint = self.token_endpoint.strip() + self.institution_lookup_path = self.institution_lookup_path.strip() + + self.session.headers.update({"accept": "application/json"}) + + +def _fetch_bearer_token_for_client(client: EdviseAPIClient) -> str: + """ + Fetch bearer token from API key using X-API-KEY header. + + Assumes token endpoint returns JSON containing one of: access_token, token, bearer_token, jwt. + + Args: + client: EdviseAPIClient instance + + Returns: + Bearer token string + + Raises: + PermissionError: If API key is invalid (401 response) + ValueError: If token response is missing expected token field + requests.HTTPError: For other HTTP errors + """ + token_url = ( + client.token_endpoint + if client.token_endpoint.startswith(("http://", "https://")) + else urljoin(f"{client.base_url}/", client.token_endpoint) + ) + resp = client.session.post( + token_url, + headers={"accept": "application/json", "X-API-KEY": client.api_key}, + timeout=30, + ) + if resp.status_code == 401: + raise PermissionError( + "Unauthorized calling token endpoint (check X-API-KEY secret)." + ) + resp.raise_for_status() + + data = resp.json() + for k in ["access_token", "token", "bearer_token", "jwt"]: + v = data.get(k) + if isinstance(v, str) and v.strip(): + return v.strip() + + raise ValueError( + "Token endpoint response missing expected token field. " + f"Keys={list(data.keys())}" + ) + + +def _ensure_auth(client: EdviseAPIClient) -> None: + """Ensure client has a valid bearer token, fetching if needed.""" + if client.bearer_token is None: + _refresh_auth(client) + + +def _refresh_auth(client: EdviseAPIClient) -> None: + """Refresh bearer token and update session headers.""" + client.bearer_token = _fetch_bearer_token_for_client(client) + client.session.headers.update({"Authorization": f"Bearer {client.bearer_token}"}) + + +def fetch_institution_by_pdp_id(client: EdviseAPIClient, pdp_id: str) -> dict[str, Any]: + """ + Resolve institution for PDP id using Edvise API. + + Cached within run. Automatically refreshes token on 401 errors. + + Args: + client: EdviseAPIClient instance + pdp_id: Institution PDP ID to look up + + Returns: + Institution data dictionary from API + + Raises: + ValueError: If institution PDP ID not found (404) or other API errors + requests.HTTPError: For HTTP errors other than 401/404 + + Example: + >>> client = EdviseAPIClient(...) + >>> inst = fetch_institution_by_pdp_id(client, "12345") + >>> print(inst["name"]) + 'Example University' + """ + pid = str(pdp_id).strip() + if pid in client.institution_cache: + return client.institution_cache[pid] + + _ensure_auth(client) + + url = client.base_url + client.institution_lookup_path.format(pdp_id=pid) + resp = client.session.get(url, timeout=30) + + if resp.status_code == 401: + _refresh_auth(client) + resp = client.session.get(url, timeout=30) + + if resp.status_code == 404: + raise ValueError(f"Institution PDP ID not found in SST staging: {pid}") + + resp.raise_for_status() + data = cast(dict[str, Any], resp.json()) + client.institution_cache[pid] = data + return data diff --git a/src/edvise/utils/data_cleaning.py b/src/edvise/utils/data_cleaning.py index d834985a0..d15201cff 100644 --- a/src/edvise/utils/data_cleaning.py +++ b/src/edvise/utils/data_cleaning.py @@ -36,6 +36,27 @@ def convert_to_snake_case(col: str) -> str: return "_".join(words).lower() +def detect_institution_column( + cols: list[str], inst_col_pattern: re.Pattern +) -> t.Optional[str]: + """ + Detect institution ID column using regex pattern. + + Args: + cols: List of column names + inst_col_pattern: Compiled regex pattern to match institution column + + Returns: + Matched column name or None if not found + + Example: + >>> pattern = re.compile(r"(?=.*institution)(?=.*id)", re.IGNORECASE) + >>> detect_institution_column(["student_id", "institution_id"], pattern) + 'institution_id' + """ + return next((c for c in cols if inst_col_pattern.search(c)), None) + + def convert_intensity_time_limits( unit: t.Literal["term", "year"], intensity_time_limits: types.IntensityTimeLimitsType, @@ -137,9 +158,10 @@ def drop_course_rows_missing_identifiers(df_course: pd.DataFrame) -> pd.DataFram # Log dropped rows if num_dropped_rows > 0: LOGGER.warning( - " ⚠️ Dropped %s rows (%.1f%%) from course dataset due to missing course_prefix or course_number.", + " ⚠️ Dropped %s rows (%.1f%%) from course dataset due to missing course_prefix or course_number (%s students affected).", num_dropped_rows, pct_dropped_rows, + dropped_students, ) # Warn if any full academic term was completely removed @@ -418,10 +440,11 @@ def log_pre_cohort_courses(df_course: pd.DataFrame, student_id_col: str) -> None LOGGER.info( "log_pre_cohort_courses: %d pre-cohort course records found (%.1f%% of data) and will be kept " - "across %d students.", + "across %d/%d students.", n_pre, pct_pre, students_pre, + students_total, ) # Students with only pre-cohort records diff --git a/src/edvise/utils/databricks.py b/src/edvise/utils/databricks.py index a50f7c78d..b0c094274 100644 --- a/src/edvise/utils/databricks.py +++ b/src/edvise/utils/databricks.py @@ -1,7 +1,9 @@ import logging import mlflow import typing as t +from typing import Any import pydantic as pyd +import re LOGGER = logging.getLogger(__name__) @@ -117,3 +119,275 @@ class Series(t.Generic[GenericDtype]): ... sys.modules[m1.__name__] = m1 sys.modules[m2.__name__] = m2 + + +# Schema and volume caches for Databricks catalog operations +_schema_cache: dict[str, set[str]] = {} +_bronze_volume_cache: dict[str, str] = {} # key: f"{catalog}.{schema}" -> volume_name + + +def list_schemas_in_catalog(spark: SparkSession, catalog: str) -> set[str]: + """ + List all schemas in a catalog (with caching). + + Args: + spark: Spark session + catalog: Catalog name + + Returns: + Set of schema names + """ + if catalog in _schema_cache: + return _schema_cache[catalog] + + rows = spark.sql(f"SHOW SCHEMAS IN {catalog}").collect() + + schema_names: set[str] = set() + for row in rows: + d = row.asDict() + for k in ["databaseName", "database_name", "schemaName", "schema_name", "name"]: + v = d.get(k) + if v: + schema_names.add(v) + break + else: + schema_names.add(list(d.values())[0]) + + _schema_cache[catalog] = schema_names + return schema_names + + +def find_bronze_schema(spark: SparkSession, catalog: str, inst_prefix: str) -> str: + """ + Find bronze schema for institution prefix. + + Args: + spark: Spark session + catalog: Catalog name + inst_prefix: Institution prefix (e.g., "motlow_state_cc") + + Returns: + Bronze schema name (e.g., "motlow_state_cc_bronze") + + Raises: + ValueError: If bronze schema not found + """ + target = f"{inst_prefix}_bronze" + schemas = list_schemas_in_catalog(spark, catalog) + if target not in schemas: + raise ValueError(f"Bronze schema not found: {catalog}.{target}") + return target + + +def find_bronze_volume_name(spark: SparkSession, catalog: str, schema: str) -> str: + """ + Find bronze volume name in schema (with caching). + + Args: + spark: Spark session + catalog: Catalog name + schema: Schema name + + Returns: + Volume name containing "bronze" + + Raises: + ValueError: If no bronze volume found + """ + key = f"{catalog}.{schema}" + if key in _bronze_volume_cache: + return _bronze_volume_cache[key] + + vols = spark.sql(f"SHOW VOLUMES IN {catalog}.{schema}").collect() + if not vols: + raise ValueError(f"No volumes found in {catalog}.{schema}") + + # Usually "volume_name", but be defensive + def _get_vol_name(row: Any) -> str: + d = row.asDict() + for k in ["volume_name", "volumeName", "name"]: + if k in d: + return str(d[k]) + return str(list(d.values())[0]) + + vol_names = [_get_vol_name(v) for v in vols] + bronze_like = [v for v in vol_names if "bronze" in str(v).lower()] + if bronze_like: + result = bronze_like[0] + _bronze_volume_cache[key] = result + return result + + raise ValueError( + f"No volume containing 'bronze' found in {catalog}.{schema}. Volumes={vol_names}" + ) + + +# Compiled regex patterns for reverse transformation (performance optimization) +_REVERSE_REPLACEMENTS = { + "ctc": "community technical college", + "cc": "community college", + "st": "of science and technology", + "uni": "university", + "col": "college", +} + +# Pre-compile regex patterns for word boundary matching +_COMPILED_REVERSE_PATTERNS = { + abbrev: re.compile(r"\b" + re.escape(abbrev) + r"\b") + for abbrev in _REVERSE_REPLACEMENTS.keys() +} + + +def _validate_databricks_name_format(databricks_name: str) -> None: + """ + Validate that databricks name matches expected format. + + Args: + databricks_name: Name to validate + + Raises: + ValueError: If name is empty or contains invalid characters + """ + if not isinstance(databricks_name, str) or not databricks_name.strip(): + raise ValueError("databricks_name must be a non-empty string") + + pattern = "^[a-z0-9_]*$" + if not re.match(pattern, databricks_name): + raise ValueError( + f"Invalid databricks name format '{databricks_name}'. " + "Must contain only lowercase letters, numbers, and underscores." + ) + + +def _reverse_abbreviation_replacements(name: str) -> str: + """ + Reverse abbreviation replacements in the name. + + Handles the ambiguous "st" abbreviation: + - If "st" appears as the first word, it's kept as "st" (abbreviation for Saint) + and will be capitalized to "St" by title() case + - Otherwise, "st" is treated as "of science and technology" + + Args: + name: Name with underscores replaced by spaces + + Returns: + Name with abbreviations expanded to full forms + """ + # Split into words to handle "st" at the beginning specially + words = name.split() + + # Keep "st" at the beginning as-is (will be capitalized to "St" by title() case) + # Don't expand it to "saint" - preserve the abbreviation + + # Replace "st" in remaining positions with "of science and technology" + for i in range(len(words)): + if words[i] == "st" and i > 0: # Only replace if not the first word + words[i] = "of science and technology" + + # Rejoin and apply other abbreviation replacements + name = " ".join(words) + + # Apply other abbreviation replacements (excluding "st" which we handled above) + for abbrev, full_form in _REVERSE_REPLACEMENTS.items(): + if abbrev != "st": # Skip "st" as we handled it above + pattern = _COMPILED_REVERSE_PATTERNS[abbrev] + name = pattern.sub(full_form, name) + + return name + + +def databricksify_inst_name(inst_name: str) -> str: + """ + Transform institution name to Databricks-compatible format. + + Follows DK standardized rules for naming conventions used in Databricks: + - Lowercases the name + - Replaces common phrases with abbreviations (e.g., "community college" → "cc") + - Replaces special characters and spaces with underscores + - Validates final format contains only lowercase letters, numbers, and underscores + + Args: + inst_name: Original institution name (e.g., "Motlow State Community College") + + Returns: + Databricks-compatible name (e.g., "motlow_state_cc") + + Raises: + ValueError: If the resulting name contains invalid characters + + Example: + >>> databricksify_inst_name("Motlow State Community College") + 'motlow_state_cc' + >>> databricksify_inst_name("University of Science & Technology") + 'uni_of_st_technology' + """ + name = inst_name.lower() + + # Apply abbreviation replacements (most specific first) + dk_replacements = { + "community technical college": "ctc", + "community college": "cc", + "of science and technology": "st", + "university": "uni", + "college": "col", + } + + for old, new in dk_replacements.items(): + name = name.replace(old, new) + + # Replace special characters + special_char_replacements = {" & ": " ", "&": " ", "-": " "} + for old, new in special_char_replacements.items(): + name = name.replace(old, new) + + # Replace spaces with underscores + final_name = name.replace(" ", "_") + + # Validate format + pattern = "^[a-z0-9_]*$" + if not re.match(pattern, final_name): + raise ValueError( + f"Unexpected character found in Databricks compatible name: '{final_name}'" + ) + + return final_name + + +def reverse_databricksify_inst_name(databricks_name: str) -> str: + """ + Reverse the databricksify transformation to get back the original institution name. + + This function attempts to reverse the transformation done by databricksify_inst_name. + Since the transformation is lossy (multiple original names can map to the same + databricks name), this function produces the most likely original name. + + Args: + databricks_name: The databricks-transformed institution name (e.g., "motlow_state_cc") + Case inconsistencies are normalized (input is lowercased before processing). + + Returns: + The reversed institution name with proper capitalization (e.g., "Motlow State Community College") + + Raises: + ValueError: If the databricks name contains invalid characters + """ + # Normalize to lowercase to handle case inconsistencies + # (databricksify_inst_name always produces lowercase output) + databricks_name = databricks_name.lower() + _validate_databricks_name_format(databricks_name) + + # Step 1: Replace underscores with spaces + name = databricks_name.replace("_", " ") + + # Step 2: Reverse the abbreviation replacements + # The original replacements were done in this order (most specific first): + # 1. "community technical college" → "ctc" + # 2. "community college" → "cc" + # 3. "of science and technology" → "st" + # 4. "university" → "uni" + # 5. "college" → "col" + name = _reverse_abbreviation_replacements(name) + + # Step 3: Capitalize appropriately (title case) + return name.title() diff --git a/src/edvise/utils/sftp.py b/src/edvise/utils/sftp.py new file mode 100644 index 000000000..c321ee416 --- /dev/null +++ b/src/edvise/utils/sftp.py @@ -0,0 +1,300 @@ +""" +SFTP utilities for file transfer operations. + +Provides functions for connecting to SFTP servers, listing files, and downloading +files with atomic operations and verification. +""" + +from __future__ import annotations + +import hashlib +import logging +import os +import shlex +import stat +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + import paramiko + +LOGGER = logging.getLogger(__name__) + + +def connect_sftp( + host: str, username: str, password: str, port: int = 22 +) -> tuple[paramiko.Transport, paramiko.SFTPClient]: + """ + Connect to an SFTP server. + + Args: + host: SFTP server hostname + username: SFTP username + password: SFTP password + port: SFTP port (default: 22) + + Returns: + Tuple of (transport, sftp_client). Caller must close both. + + Example: + >>> transport, sftp = connect_sftp("example.com", "user", "pass") + >>> try: + ... files = list_receive_files(sftp, "/remote/path", "NSC") + ... finally: + ... sftp.close() + ... transport.close() + """ + import paramiko + + transport = paramiko.Transport((host, port)) + transport.connect(username=username, password=password) + sftp = paramiko.SFTPClient.from_transport(transport) + LOGGER.info(f"Connected successfully to {host}:{port}") + return transport, sftp + + +def list_receive_files( + sftp: paramiko.SFTPClient, remote_dir: str, source_system: str +) -> list[dict[str, Any]]: + """ + List non-directory files in remote directory with metadata. + + Args: + sftp: Paramiko SFTPClient instance + remote_dir: Remote directory path to list + source_system: Source system identifier (e.g., "NSC") + + Returns: + List of dictionaries with keys: source_system, sftp_path, file_name, + file_size, file_modified_time + + Example: + >>> files = list_receive_files(sftp, "/receive", "NSC") + >>> for f in files: + ... print(f["file_name"], f["file_size"]) + """ + results = [] + for attr in sftp.listdir_attr(remote_dir): + if stat.S_ISDIR(attr.st_mode): + continue + + file_name = attr.filename + file_size = int(attr.st_size) if attr.st_size is not None else None + mtime = ( + datetime.fromtimestamp(int(attr.st_mtime), tz=timezone.utc) + if attr.st_mtime + else None + ) + + results.append( + { + "source_system": source_system, + "sftp_path": remote_dir, + "file_name": file_name, + "file_size": file_size, + "file_modified_time": mtime, + } + ) + return results + + +def _hash_file( + path: str, algo: str = "sha256", chunk_size: int = 8 * 1024 * 1024 +) -> str: + """ + Compute hash of a file. + + Args: + path: File path + algo: Hash algorithm ("sha256" or "md5") + chunk_size: Chunk size for reading file + + Returns: + Hexadecimal hash string + """ + h = hashlib.new(algo) + with open(path, "rb") as f: + while True: + b = f.read(chunk_size) + if not b: + break + h.update(b) + return h.hexdigest() + + +def _remote_hash( + ssh: paramiko.SSHClient, remote_path: str, algo: str = "sha256" +) -> Optional[str]: + """ + Compute hash of a remote file using SSH command. + + Args: + ssh: Paramiko SSHClient instance + remote_path: Remote file path + algo: Hash algorithm ("sha256" or "md5") + + Returns: + Hexadecimal hash string, or None if computation fails + """ + cmd = None + if algo.lower() == "sha256": + cmd = f"sha256sum -- {shlex.quote(remote_path)}" + elif algo.lower() == "md5": + cmd = f"md5sum -- {shlex.quote(remote_path)}" + else: + return None + + try: + _, stdout, stderr = ssh.exec_command(cmd, timeout=300) + out = stdout.read().decode("utf-8", "replace").strip() + err = stderr.read().decode("utf-8", "replace").strip() + if err: + return None + # Format: " " + return str(out.split()[0]) + except Exception: + return None + + +def download_sftp_atomic( + sftp: paramiko.SFTPClient, + remote_path: str, + local_path: str, + *, + chunk: int = 150, + verify: str = "size", # "size" | "sha256" | "md5" | None + ssh_for_remote_hash: Optional[paramiko.SSHClient] = None, + progress: bool = True, +) -> None: + """ + Atomic and resumable SFTP download with verification. + + Writes to local_path + '.part' and moves into place after verification. + Supports resuming interrupted downloads. + + Args: + sftp: Paramiko SFTPClient instance + remote_path: Remote file path + local_path: Local destination path + chunk: Chunk size in MB (default: 150) + verify: Verification method: "size", "sha256", "md5", or None + ssh_for_remote_hash: SSHClient for remote hash verification (optional) + progress: Whether to print progress (default: True) + + Raises: + IOError: If download fails, size mismatch, or hash mismatch + + Example: + >>> download_sftp_atomic(sftp, "/remote/file.csv", "/local/file.csv") + >>> # With hash verification: + >>> download_sftp_atomic( + ... sftp, "/remote/file.csv", "/local/file.csv", + ... verify="sha256", ssh_for_remote_hash=ssh + ... ) + """ + remote_size = sftp.stat(remote_path).st_size + tmp_path = f"{local_path}.part" + chunk_size = chunk * 1024 * 1024 + offset = 0 + + # Check for existing partial download + if os.path.exists(tmp_path): + part_size = os.path.getsize(tmp_path) + # If local .part is larger than remote, start fresh + if part_size <= remote_size: + offset = part_size + if progress: + LOGGER.info(f"Resuming download from {offset:,} bytes") + else: + os.remove(tmp_path) + if progress: + LOGGER.warning("Partial file larger than remote, starting fresh") + + # Open remote and local + with sftp.file(remote_path, "rb") as rf: + try: + try: + rf.set_pipelined(True) + except Exception: + pass + + if offset: + rf.seek(offset) + + # Append if resuming, write if fresh + with open(tmp_path, "ab" if offset else "wb") as lf: + transferred = offset + + while transferred < remote_size: + to_read = min(chunk_size, remote_size - transferred) + data = rf.read(to_read) + if not data: + # don't accept short-read silently + raise IOError( + f"Short read at {transferred:,} of {remote_size:,} bytes" + ) + lf.write(data) + transferred += len(data) + if progress and remote_size: + pct = transferred / remote_size + if ( + pct % 0.1 < 0.01 or transferred == remote_size + ): # Print every 10% + LOGGER.info( + f"{pct:.1%} transferred ({transferred:,}/{remote_size:,} bytes)" + ) + lf.flush() + os.fsync(lf.fileno()) + + finally: + # SFTPFile closed by context manager + pass + + # Mandatory size verification + local_size = os.path.getsize(tmp_path) + if local_size != remote_size: + raise IOError( + f"Post-download size mismatch (local {local_size:,}, remote {remote_size:,})" + ) + + # Optional hash verification + if verify in {"sha256", "md5"}: + algo = verify + local_hash = _hash_file(tmp_path, algo=algo) + remote_hash = None + if ssh_for_remote_hash is not None: + remote_hash = _remote_hash(ssh_for_remote_hash, remote_path, algo=algo) + + if remote_hash and (remote_hash != local_hash): + # Clean up .part so next run starts fresh + try: + os.remove(tmp_path) + except Exception: + pass + raise IOError( + f"{algo.upper()} mismatch: local={local_hash} remote={remote_hash}" + ) + + # Move atomically into place + os.replace(tmp_path, local_path) + if progress: + LOGGER.info(f"Download complete (atomic & verified): {local_path}") + + +def output_file_name_from_sftp(file_name: str) -> str: + """ + Generate output filename from SFTP filename. + + Removes extension and adds .csv extension. + + Args: + file_name: Original SFTP filename + + Returns: + Output filename with .csv extension + + Example: + >>> output_file_name_from_sftp("data_2024.xlsx") + 'data_2024.csv' + """ + return f"{os.path.basename(file_name).split('.')[0]}.csv" diff --git a/tests/ingestion/test_nsc_sftp_helper.py b/tests/ingestion/test_nsc_sftp_helper.py new file mode 100644 index 000000000..461eb173c --- /dev/null +++ b/tests/ingestion/test_nsc_sftp_helper.py @@ -0,0 +1,171 @@ +import re + +from edvise.ingestion.nsc_sftp_helpers import ( + detect_institution_column, + extract_institution_ids, +) +from edvise.utils.databricks import databricksify_inst_name +from edvise.utils.data_cleaning import convert_to_snake_case +from edvise.utils.sftp import download_sftp_atomic + + +def test_normalize_col(): + """Test column normalization (now using convert_to_snake_case).""" + assert convert_to_snake_case(" Institution ID ") == "institution_id" + assert convert_to_snake_case("Student-ID#") == "student_id_#" + assert convert_to_snake_case("__Already__Ok__") == "already_ok" + + +def test_detect_institution_column(): + pattern = re.compile(r"(?=.*institution)(?=.*id)", re.IGNORECASE) + assert ( + detect_institution_column(["foo", "institutionid", "bar"], pattern) + == "institutionid" + ) + assert detect_institution_column(["foo", "bar"], pattern) is None + + +def test_extract_institution_ids_handles_numeric(tmp_path): + csv_path = tmp_path / "staged.csv" + csv_path.write_text( + "InstitutionID,other\n323100,1\n323101.0,2\n,3\n323102.0,4\n 323103 ,5\ninf,6\n-inf,7\n" + ) + + inst_col_pattern = re.compile(r"(?=.*institution)(?=.*id)", re.IGNORECASE) + inst_col, inst_ids = extract_institution_ids( + str(csv_path), renames={}, inst_col_pattern=inst_col_pattern + ) + + assert inst_col == "institution_id" + assert inst_ids == ["323100", "323101", "323102", "323103"] + + +def test_databricksify_inst_name(): + assert databricksify_inst_name("Big State University") == "big_state_uni" + + +def test_hash_file_sha256(tmp_path): + """Test file hashing (internal function, tested via download_sftp_atomic).""" + # The _hash_file function is internal to sftp.py, so we test it indirectly + # through download_sftp_atomic which uses it for verification + pass + + +def test_download_sftp_atomic_downloads_and_cleans_part(tmp_path): + class _Stat: + def __init__(self, size: int): + self.st_size = size + + class _RemoteFile: + def __init__(self, data: bytes): + self._data = data + self._pos = 0 + + def set_pipelined(self, _): + return None + + def seek(self, offset: int): + self._pos = offset + + def read(self, n: int) -> bytes: + if self._pos >= len(self._data): + return b"" + b = self._data[self._pos : self._pos + n] + self._pos += len(b) + return b + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Sftp: + def __init__(self, by_path: dict[str, bytes]): + self._by_path = by_path + + def stat(self, path: str): + return _Stat(len(self._by_path[path])) + + def file(self, path: str, mode: str): + assert mode == "rb" + return _RemoteFile(self._by_path[path]) + + remote_path = "/receive/file1.csv" + remote_bytes = b"hello world\n" * 100 + sftp = _Sftp({remote_path: remote_bytes}) + + local_path = tmp_path / "file1.csv" + download_sftp_atomic( + sftp, + remote_path, + str(local_path), + chunk=1, + verify="size", + progress=False, + ) + + assert local_path.read_bytes() == remote_bytes + assert not (tmp_path / "file1.csv.part").exists() + + +def test_download_sftp_atomic_resumes_existing_part(tmp_path): + class _Stat: + def __init__(self, size: int): + self.st_size = size + + class _RemoteFile: + def __init__(self, data: bytes): + self._data = data + self._pos = 0 + + def set_pipelined(self, _): + return None + + def seek(self, offset: int): + self._pos = offset + + def read(self, n: int) -> bytes: + if self._pos >= len(self._data): + return b"" + b = self._data[self._pos : self._pos + n] + self._pos += len(b) + return b + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Sftp: + def __init__(self, by_path: dict[str, bytes]): + self._by_path = by_path + + def stat(self, path: str): + return _Stat(len(self._by_path[path])) + + def file(self, path: str, mode: str): + assert mode == "rb" + return _RemoteFile(self._by_path[path]) + + remote_path = "/receive/file2.csv" + remote_bytes = b"0123456789" * 200 + sftp = _Sftp({remote_path: remote_bytes}) + + local_path = tmp_path / "file2.csv" + part_path = tmp_path / "file2.csv.part" + + part_path.write_bytes(remote_bytes[:123]) + + download_sftp_atomic( + sftp, + remote_path, + str(local_path), + chunk=1, + verify="size", + progress=False, + ) + + assert local_path.read_bytes() == remote_bytes + assert not part_path.exists() diff --git a/tests/utils/test_api_requests.py b/tests/utils/test_api_requests.py index d123c6f6c..d074b517e 100644 --- a/tests/utils/test_api_requests.py +++ b/tests/utils/test_api_requests.py @@ -521,115 +521,3 @@ def test_error_message_includes_institution_name_for_missing_inst_id( # Name is normalized to lowercase in error messages assert "my test university" in error_msg.lower() assert "inst_id" in error_msg - - -class TestReverseDatabricksifyInstName: - """Test cases for reverse_databricksify_inst_name function.""" - - def test_reverse_community_college(self): - """Test reversing community college abbreviation.""" - result = api_requests.reverse_databricksify_inst_name("motlow_state_cc") - assert result == "Motlow State Community College" - - def test_reverse_university(self): - """Test reversing university abbreviation.""" - result = api_requests.reverse_databricksify_inst_name("kentucky_state_uni") - assert result == "Kentucky State University" - - def test_reverse_college(self): - """Test reversing college abbreviation.""" - result = api_requests.reverse_databricksify_inst_name("central_arizona_col") - assert result == "Central Arizona College" - - def test_reverse_community_technical_college(self): - """Test reversing community technical college abbreviation.""" - result = api_requests.reverse_databricksify_inst_name("southeast_kentucky_ctc") - assert result == "Southeast Kentucky Community Technical College" - - def test_reverse_science_and_technology(self): - """Test reversing 'of science and technology' abbreviation.""" - result = api_requests.reverse_databricksify_inst_name("harrisburg_uni_st") - assert result == "Harrisburg University Of Science And Technology" - - def test_reverse_saint_at_beginning(self): - """Test that 'st' at the beginning is kept as abbreviation 'St'.""" - result = api_requests.reverse_databricksify_inst_name("st_johns_uni") - assert result == "St Johns University" - - def test_reverse_saint_vs_science_technology(self): - """Test that 'st' at beginning is St (abbreviation), but in middle is 'of science and technology'.""" - # "st" at beginning should be "St" (abbreviation) - result1 = api_requests.reverse_databricksify_inst_name("st_marys_col") - assert result1 == "St Marys College" - - # "st" in middle should be "of science and technology" - result2 = api_requests.reverse_databricksify_inst_name("harrisburg_uni_st") - assert result2 == "Harrisburg University Of Science And Technology" - - # Both in same name (edge case) - result3 = api_requests.reverse_databricksify_inst_name("st_paul_uni_st") - assert result3 == "St Paul University Of Science And Technology" - - def test_reverse_multiple_words(self): - """Test reversing name with multiple words.""" - result = api_requests.reverse_databricksify_inst_name("metro_state_uni_denver") - assert result == "Metro State University Denver" - - def test_reverse_simple_name(self): - """Test reversing name without abbreviations.""" - result = api_requests.reverse_databricksify_inst_name("test_institution") - assert result == "Test Institution" - - def test_reverse_with_numbers(self): - """Test reversing name with numbers.""" - result = api_requests.reverse_databricksify_inst_name("college_123") - assert result == "College 123" - - def test_reverse_empty_string(self): - """Test that empty string raises ValueError.""" - with pytest.raises(ValueError) as exc_info: - api_requests.reverse_databricksify_inst_name("") - assert "non-empty string" in str(exc_info.value).lower() - - def test_reverse_invalid_characters(self): - """Test that invalid characters raise ValueError.""" - with pytest.raises(ValueError) as exc_info: - api_requests.reverse_databricksify_inst_name("invalid-name!") - assert "invalid" in str(exc_info.value).lower() - - def test_reverse_uppercase_normalized(self): - """Test that uppercase characters are normalized to lowercase.""" - # Uppercase input should be normalized to lowercase and processed - result = api_requests.reverse_databricksify_inst_name("MOTLOW_STATE_CC") - assert result == "Motlow State Community College" - - # Mixed case should also be normalized - result2 = api_requests.reverse_databricksify_inst_name("St_Paul_Uni") - assert result2 == "St Paul University" - - # Invalid characters (even after normalization) should still raise error - with pytest.raises(ValueError) as exc_info: - api_requests.reverse_databricksify_inst_name("Invalid-Name!") - assert "invalid" in str(exc_info.value).lower() - # Verify error message includes the problematic value (normalized) - assert "invalid-name!" in str(exc_info.value).lower() - - def test_reverse_whitespace_stripping(self): - """Test that whitespace is handled correctly in databricks names.""" - # Databricks names shouldn't have spaces, but test edge case - with pytest.raises(ValueError): - api_requests.reverse_databricksify_inst_name(" test_name ") - - def test_reverse_multiple_abbreviations(self): - """Test reversing name with multiple abbreviations.""" - # Test case: name with both "uni" and "col" - result = api_requests.reverse_databricksify_inst_name("test_uni_col") - assert result == "Test University College" - - def test_reverse_error_message_includes_value(self): - """Test that error messages include the problematic value.""" - with pytest.raises(ValueError) as exc_info: - api_requests.reverse_databricksify_inst_name("bad-name!") - error_msg = str(exc_info.value) - assert "bad-name!" in error_msg - assert "Invalid databricks name format" in error_msg diff --git a/tests/utils/test_databricks.py b/tests/utils/test_databricks.py new file mode 100644 index 000000000..e097c605d --- /dev/null +++ b/tests/utils/test_databricks.py @@ -0,0 +1,183 @@ +"""Tests for edvise.utils.databricks module.""" + +import pytest + +from edvise.utils.databricks import ( + databricksify_inst_name, + reverse_databricksify_inst_name, +) + + +class TestDatabricksifyInstName: + """Test cases for databricksify_inst_name function.""" + + def test_community_college(self): + """Test community college abbreviation.""" + assert ( + databricksify_inst_name("Motlow State Community College") + == "motlow_state_cc" + ) + assert ( + databricksify_inst_name("Northwest State Community College") + == "northwest_state_cc" + ) + + def test_university(self): + """Test university abbreviation.""" + assert ( + databricksify_inst_name("Kentucky State University") == "kentucky_state_uni" + ) + assert ( + databricksify_inst_name("Metro State University Denver") + == "metro_state_uni_denver" + ) + + def test_college(self): + """Test college abbreviation.""" + assert ( + databricksify_inst_name("Central Arizona College") == "central_arizona_col" + ) + + def test_community_technical_college(self): + """Test community technical college abbreviation.""" + assert ( + databricksify_inst_name("Southeast Kentucky community technical college") + == "southeast_kentucky_ctc" + ) + + def test_science_and_technology(self): + """Test 'of science and technology' abbreviation.""" + assert ( + databricksify_inst_name("Harrisburg University of Science and Technology") + == "harrisburg_uni_st" + ) + + def test_special_characters(self): + """Test handling of special characters like & and -.""" + assert databricksify_inst_name("State-Community College") == "state_cc" + + def test_invalid_characters(self): + """Test that invalid characters raise ValueError.""" + with pytest.raises(ValueError) as exc_info: + databricksify_inst_name("Northwest (invalid)") + error_msg = str(exc_info.value) + assert "Unexpected character found in Databricks compatible name" in error_msg + assert ( + "northwest" in error_msg.lower() + ) # Error message includes the problematic name + + def test_simple_name(self): + """Test simple name without abbreviations.""" + assert databricksify_inst_name("Big State University") == "big_state_uni" + + +class TestReverseDatabricksifyInstName: + """Test cases for reverse_databricksify_inst_name function.""" + + def test_reverse_community_college(self): + """Test reversing community college abbreviation.""" + result = reverse_databricksify_inst_name("motlow_state_cc") + assert result == "Motlow State Community College" + + def test_reverse_university(self): + """Test reversing university abbreviation.""" + result = reverse_databricksify_inst_name("kentucky_state_uni") + assert result == "Kentucky State University" + + def test_reverse_college(self): + """Test reversing college abbreviation.""" + result = reverse_databricksify_inst_name("central_arizona_col") + assert result == "Central Arizona College" + + def test_reverse_community_technical_college(self): + """Test reversing community technical college abbreviation.""" + result = reverse_databricksify_inst_name("southeast_kentucky_ctc") + assert result == "Southeast Kentucky Community Technical College" + + def test_reverse_science_and_technology(self): + """Test reversing 'of science and technology' abbreviation.""" + result = reverse_databricksify_inst_name("harrisburg_uni_st") + assert result == "Harrisburg University Of Science And Technology" + + def test_reverse_saint_at_beginning(self): + """Test that 'st' at the beginning is kept as abbreviation 'St'.""" + result = reverse_databricksify_inst_name("st_johns_uni") + assert result == "St Johns University" + + def test_reverse_saint_vs_science_technology(self): + """Test that 'st' at beginning is St (abbreviation), but in middle is 'of science and technology'.""" + # "st" at beginning should be "St" (abbreviation) + result1 = reverse_databricksify_inst_name("st_marys_col") + assert result1 == "St Marys College" + + # "st" in middle should be "of science and technology" + result2 = reverse_databricksify_inst_name("harrisburg_uni_st") + assert result2 == "Harrisburg University Of Science And Technology" + + # Both in same name (edge case) + result3 = reverse_databricksify_inst_name("st_paul_uni_st") + assert result3 == "St Paul University Of Science And Technology" + + def test_reverse_multiple_words(self): + """Test reversing name with multiple words.""" + result = reverse_databricksify_inst_name("metro_state_uni_denver") + assert result == "Metro State University Denver" + + def test_reverse_simple_name(self): + """Test reversing name without abbreviations.""" + result = reverse_databricksify_inst_name("test_institution") + assert result == "Test Institution" + + def test_reverse_with_numbers(self): + """Test reversing name with numbers.""" + result = reverse_databricksify_inst_name("college_123") + assert result == "College 123" + + def test_reverse_empty_string(self): + """Test that empty string raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + reverse_databricksify_inst_name("") + assert "non-empty string" in str(exc_info.value).lower() + + def test_reverse_invalid_characters(self): + """Test that invalid characters raise ValueError.""" + with pytest.raises(ValueError) as exc_info: + reverse_databricksify_inst_name("invalid-name!") + assert "invalid" in str(exc_info.value).lower() + + def test_reverse_uppercase_normalized(self): + """Test that uppercase characters are normalized to lowercase.""" + # Uppercase input should be normalized to lowercase and processed + result = reverse_databricksify_inst_name("MOTLOW_STATE_CC") + assert result == "Motlow State Community College" + + # Mixed case should also be normalized + result2 = reverse_databricksify_inst_name("St_Paul_Uni") + assert result2 == "St Paul University" + + # Invalid characters (even after normalization) should still raise error + with pytest.raises(ValueError) as exc_info: + reverse_databricksify_inst_name("Invalid-Name!") + assert "invalid" in str(exc_info.value).lower() + # Verify error message includes the problematic value (normalized) + assert "invalid-name!" in str(exc_info.value).lower() + + def test_reverse_whitespace_stripping(self): + """Test that whitespace is handled correctly in databricks names.""" + # Databricks names shouldn't have spaces, but test edge case + with pytest.raises(ValueError): + reverse_databricksify_inst_name(" test_name ") + + def test_reverse_multiple_abbreviations(self): + """Test reversing name with multiple abbreviations.""" + # Test case: name with both "uni" and "col" + result = reverse_databricksify_inst_name("test_uni_col") + assert result == "Test University College" + + def test_reverse_error_message_includes_value(self): + """Test that error messages include the problematic value.""" + with pytest.raises(ValueError) as exc_info: + reverse_databricksify_inst_name("bad-name!") + error_msg = str(exc_info.value) + assert "bad-name!" in error_msg + assert "Invalid databricks name format" in error_msg