Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@

# The name of the deployed pipeline in Databricks. Must match directly.
PDP_INFERENCE_JOB_NAME = "edvise_github_sourced_pdp_inference_pipeline"
LEGACY_INFERENCE_JOB_NAME = "edvise_github_sourced_legacy_inference_pipeline"


class DatabricksInferenceRunRequest(BaseModel):
"""Databricks parameters for an inference run."""
class DatabricksPDPInferenceRunRequest(BaseModel):
"""Databricks parameters for a PDP inference run."""

inst_name: str
# Note that the following should be the filepath.
Expand All @@ -50,6 +51,18 @@ class DatabricksInferenceRunRequest(BaseModel):
gcp_external_bucket_name: str


class DatabricksLegacyInferenceRunRequest(BaseModel):
"""Databricks parameters for a legacy schools inference run."""

inst_name: str
model_name: str
config_file_name: str
features_table_name: str
# The email where notifications will get sent.
email: str
gcp_external_bucket_name: str


class DatabricksInferenceRunResponse(BaseModel):
"""Databricks parameters for an inference run."""

Expand Down Expand Up @@ -186,7 +199,7 @@ def setup_new_inst(self, inst_name: str) -> None:
# E.g. there is one PDP inference pipeline, so one PDP inference function here.

def run_pdp_inference(
self, req: DatabricksInferenceRunRequest
self, req: DatabricksPDPInferenceRunRequest
) -> DatabricksInferenceRunResponse:
"""Triggers PDP inference Databricks run."""
LOGGER.info(f"Running PDP inference for institution: {req.inst_name}")
Expand Down Expand Up @@ -264,6 +277,73 @@ def run_pdp_inference(

return DatabricksInferenceRunResponse(job_run_id=run_id)

def run_legacy_inference(
self, req: DatabricksLegacyInferenceRunRequest
) -> DatabricksInferenceRunResponse:
"""Triggers legacy schools inference Databricks run."""
LOGGER.info(f"Running legacy inference for institution: {req.inst_name}")
try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
LOGGER.info("Successfully created Databricks WorkspaceClient.")
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars["DATABRICKS_HOST_URL"],
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
raise ValueError(
f"run_legacy_inference(): Workspace client initialization failed: {e}"
)

db_inst_name = databricksify_inst_name(req.inst_name)
pipeline_type = LEGACY_INFERENCE_JOB_NAME

try:
job = next(w.jobs.list(name=pipeline_type), None)
if not job or job.job_id is None:
raise ValueError(
f"run_legacy_inference(): Job '{pipeline_type}' was not found or has no job_id for '{gcs_vars['GCP_SERVICE_ACCOUNT_EMAIL']}' and '{databricks_vars['DATABRICKS_HOST_URL']}'."
)
job_id = job.job_id
LOGGER.info(f"Resolved job ID for '{pipeline_type}': {job_id}")
except Exception as e:
LOGGER.exception(f"Job lookup failed for '{pipeline_type}'.")
raise ValueError(f"run_legacy_inference(): Failed to find job: {e}")

try:
run_job: Any = w.jobs.run_now(
job_id,
job_parameters={
"databricks_institution_name": db_inst_name,
"DB_workspace": databricks_vars[
"DATABRICKS_WORKSPACE"
],
"model_name": req.model_name,
"config_file_name": req.config_file_name,
"features_table_name": req.features_table_name,
"gcp_bucket_name": req.gcp_external_bucket_name,
"datakind_notification_email": req.email,
"DK_CC_EMAIL": req.email,
},
)
LOGGER.info(
f"Successfully triggered job run. Run ID: {run_job.response.run_id}"
)
except Exception as e:
LOGGER.exception("Failed to run the legacy inference job.")
raise ValueError(f"run_legacy_inference(): Job could not be run: {e}")

if not run_job.response or run_job.response.run_id is None:
raise ValueError("run_legacy_inference(): Job did not return a valid run_id.")

run_id = run_job.response.run_id
LOGGER.info(f"Successfully triggered job run. Run ID: {run_id}")

return DatabricksInferenceRunResponse(job_run_id=run_id)

def delete_inst(self, inst_name: str) -> None:
"""Cleanup tasks required on the Databricks side to delete an institution."""
db_inst_name = databricksify_inst_name(inst_name)
Expand Down
105 changes: 100 additions & 5 deletions src/webapp/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from sqlalchemy import and_, update, or_
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from ..databricks import DatabricksControl, DatabricksInferenceRunRequest
from ..databricks import (
DatabricksControl,
DatabricksPDPInferenceRunRequest,
DatabricksLegacyInferenceRunRequest,
)
from ..utilities import (
has_access_to_inst_or_err,
has_full_data_access_or_err,
Expand Down Expand Up @@ -138,6 +142,9 @@ class InferenceRunRequest(BaseModel):
# Note: is_pdp is kept for backward compatibility but is ignored.
# PDP status is derived from the institution's pdp_id field.
is_pdp: bool = False
# Legacy schools inference parameters (required for legacy schools, ignored for PDP)
config_file_name: str | None = None
features_table_name: str | None = None


# Model related operations. Or model specific data.
Expand Down Expand Up @@ -524,11 +531,99 @@ def trigger_inference_run(
+ str(len(inst_result)),
)
inst = inst_result[0][0]
# Check PDP status from institution's pdp_id (ignore req.is_pdp for backward compat)
if not inst.pdp_id:
# Determine institution type: PDP, Edvise, or Legacy
# There are only three options: PDP (pdp_id), Edvise (edvise_id), or Legacy (legacy_id or none)
# Follows the same pattern as validation_helper in data.py
pdp_id = getattr(inst, "pdp_id", None)
edvise_id = getattr(inst, "edvise_id", None)
legacy_id = getattr(inst, "legacy_id", None)
# Defensive check: ensure mutual exclusivity (should not happen if validation works correctly)
if sum(bool(x) for x in (pdp_id, edvise_id, legacy_id)) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution configuration error: cannot have more than one of pdp_id, edvise_id, or legacy_id set",
)
is_pdp = bool(pdp_id)
is_edvise = bool(edvise_id)
is_legacy = not is_pdp and not is_edvise

# Legacy schools inference
if is_legacy:
if not req.config_file_name or not req.features_table_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Legacy schools inference requires config_file_name and features_table_name.",
)
legacy_model_result = (
local_session.get()
.execute(
select(ModelTable).where(
and_(
ModelTable.name == model_name,
ModelTable.inst_id == str_to_uuid(inst_id),
)
)
)
.all()
)
if len(legacy_model_result) != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Unexpected number of models found: Expected 1, got "
+ str(len(legacy_model_result)),
)
# For legacy schools, we don't need batch validation (config and features table are used instead)
db_req = DatabricksLegacyInferenceRunRequest(
inst_name=inst_result[0][0].name,
model_name=model_name,
config_file_name=req.config_file_name,
features_table_name=req.features_table_name,
gcp_external_bucket_name=get_external_bucket_name(inst_id),
email=cast(str, current_user.email),
)
try:
res = databricks_control.run_legacy_inference(db_req)
except Exception as e:
tb = traceback.format_exc()
logging.error(f"Databricks run failure:\n{tb}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Databricks run_legacy_inference error. Error = {str(e)}",
) from e
triggered_timestamp = datetime.now()
latest_model_version = databricks_control.fetch_model_version(
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=inst_result[0][0].name,
model_name=model_name,
)
job = JobTable(
id=res.job_run_id,
triggered_at=triggered_timestamp,
created_by=str_to_uuid(current_user.user_id),
batch_name=f"{model_name}_{triggered_timestamp}", # Legacy schools don't use batches
model_id=legacy_model_result[0][0].id,
output_valid=False,
model_version=latest_model_version.version,
model_run_id=latest_model_version.run_id,
)
local_session.get().add(job)
return {
"inst_id": inst_id,
"m_name": model_name,
"run_id": res.job_run_id,
"created_by": current_user.user_id,
"triggered_at": triggered_timestamp,
"batch_name": f"{model_name}_{triggered_timestamp}",
"output_valid": False,
"model_version": latest_model_version.version,
"model_run_id": latest_model_version.run_id,
}

# PDP inference (existing logic)
if not is_pdp:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Currently, only PDP inference is supported.",
detail="Currently, only PDP and Legacy schools inference are supported.",
)
query_result = (
local_session.get()
Expand Down Expand Up @@ -589,7 +684,7 @@ def trigger_inference_run(
detail=f"The files in this batch don't conform to the schema configs allowed by this model. For debugging reference - file_schema={inst_file_schemas} and model_schema={schema_configs}",
)
# Note to Datakind: In the long-term, this is where you would have a case block or something that would call different types of pipelines.
db_req = DatabricksInferenceRunRequest(
db_req = DatabricksPDPInferenceRunRequest(
inst_name=inst_result[0][0].name,
filepath_to_type=convert_files_to_dict(batch_result[0][0].files),
model_name=model_name,
Expand Down
Loading